三行代码计算模型参数量:
total_params = sum(p.numel() for p in model.parameters())
trained_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total Parameters: {}. Trained Parameters: {}".format(total_params, trained_params))
一般加在模型优化器之前或之后的位置
所有评论(0)