【深度学习】【入门】神经网络的基本骨架nn.Module
PyTorch中构建神经网络的核心是继承nn.Module基类,它封装了网络层定义、参数管理和前向传播逻辑。示例展示了如何创建简单模型:定义类继承nn.Module,在__init__中初始化网络层,在forward方法中实现前向传播。调用父类构造函数和执行前向计算是必要操作。最后通过实例演示了一个简单的加法网络,输入1.0经过网络处理后输出2.0。nn.Module的主要优势在于自动处理梯度计算
·
1.nn.Module的核心概念
nn.Module是PyTorch中构造神经网络的基类,所有自定义网络结构必须继承该类
nn.Module封装了网络层的定义、参数管理及前向传播逻辑
与普通Python类的区别在于自动梯度计算和参数优化机制
2.简单构建神经网络模型
官网:

代码:
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Model():是自己定义的类名,继承nn.Module,自定义的神经网络模型都需要继承nn.Module类
super()._init_():调用父类的构造函数,确保父类的初始化逻辑被执行,是必要的操作
forward():定义了神经网络的向前传播逻辑
3. 实例和结果
代码:
import torch
from torch import nn
class Module(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
output = input + 1
return output
module = Module()
x = torch.tensor(1.0)
output = module(x)
print(output)
结果:
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐

所有评论(0)