1.nn.Module的核心概念

  nn.Module是PyTorch中构造神经网络的基类,所有自定义网络结构必须继承该类

  nn.Module封装了网络层的定义、参数管理及前向传播逻辑

  与普通Python类的区别在于自动梯度计算和参数优化机制

2.简单构建神经网络模型

官网:

代码:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Model():是自己定义的类名,继承nn.Module,自定义的神经网络模型都需要继承nn.Module类

super()._init_():调用父类的构造函数,确保父类的初始化逻辑被执行,是必要的操作

forward():定义了神经网络的向前传播逻辑

3. 实例和结果

代码:

import torch
from torch import nn


class Module(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        output = input + 1
        return output

module = Module()
x = torch.tensor(1.0)
output = module(x)
print(output)

结果:

 

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐