一行代码计算模型的参数量和FLOPs【亲测有效】
安装对应的库文件pip install ptflops代码(get_model_complexity_info)import torchvision.models as modelsfrom ptflops import get_model_complexity_infonet = models.vgg16() #可以为自己搭建的模型flops, params = get_model_comple
·
安装对应的库文件
pip install ptflops
代码(get_model_complexity_info)
import torchvision.models as models
from ptflops import get_model_complexity_info
net = models.vgg16() #可以为自己搭建的模型
flops, params = get_model_complexity_info(model, (3,512,512), as_strings=True, print_per_layer_stat=True) #(3,512,512)输入图片的尺寸
print("Flops: {}".format(flops))
print("Params: " + params)
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐

所有评论(0)