三行代码计算模型参数量:

total_params = sum(p.numel() for p in model.parameters())
trained_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total Parameters: {}. Trained Parameters: {}".format(total_params, trained_params))

一般加在模型优化器之前或之后的位置

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐