pytorch中使用torchviz可视化某网络或loss函数计算图后,计算图节点的理解
torchviz输出网络结构计算图,计算图节点的理解
一、安装graphviz之后,添加环境变量,可以用torchviz输出网络结构计算图
from torchviz import make_dot
make_dot(loss).view()
二、.backward() 方法
在 PyTorch 中,当在计算图上定义一系列操作后调用 .backward() 方法时,PyTorch 会为每个操作生成一个名为 Backward 的节点。这些名称是 PyTorch 为了跟踪反向传播的操作而生成的,是 PyTorch 内部生成的。
具体而言,名称的格式为 [OperatorName]Backward[Number],其中 OperatorName 表示操作的名称,Number 表示该操作在图中的位置。例如,SubBackward0 表示在图中第一个计算的减法操作;MmBackward0 表示第一个计算的乘法操作;UnsqueezeBackward0 表示第一个计算的 unsqueeze 操作。
除了上面的操作名称外,还有其他的节点类型。这取决于在图中定义的操作类型。对于每种操作,PyTorch 都会生成对应的反向传播节点,以支持自动求导。
三、计算图常见的节点类型名称:
-
AddBackward:表示加法操作。
-
SubBackward:表示减法操作。
-
MulBackward:表示乘法操作。
-
DivBackward:表示除法操作。
-
ExpBackward:表示指数运算。
-
LogBackward:表示对数运算。
-
MatmulBackward:表示矩阵乘法。
-
UnsqueezeBackward:表示添加一维的操作。
-
SqueezeBackward:表示删除一维的操作。
-
Conv2dBackward:表示卷积操作。
-
MaxPool2dBackward:表示最大池化操作。
-
StackBackward 表示堆叠操作,即将一组张量堆叠在一起,形成一个新的张量。
-
TBackward 表示转置操作,即对一个张量进行转置。
-
SelectBackward 表示选择操作,即选择张量的一部分作为新的张量。在 PyTorch 中,这可以通过索引(indexing)实现。
下图是自己定义简单MLP的torchviz可视化计算图(截取)


DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐
所有评论(0)