安装对应的库文件

pip install ptflops

代码(get_model_complexity_info)

import torchvision.models as models
from ptflops import get_model_complexity_info
net = models.vgg16() #可以为自己搭建的模型
flops, params = get_model_complexity_info(model, (3,512,512), as_strings=True, print_per_layer_stat=True)  #(3,512,512)输入图片的尺寸
print("Flops: {}".format(flops))
print("Params: " + params)
Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐