《动手学深度学习》之自定义层
参考自定义层import torchfrom torch import nn#含模型参数的自定义层# ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,# 使用的时候可以用索引来访问某个参数,# 另外也可以使用append和extend在列表后面新增参数。class MyListDense(nn.Module):def __ini...
·
参考自定义层
import torch
from torch import nn
#含模型参数的自定义层
# ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,
# 使用的时候可以用索引来访问某个参数,
# 另外也可以使用append和extend在列表后面新增参数。
class MyListDense(nn.Module):
def __init__(self):
super(MyListDense, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
self.params.append(nn.Parameter(torch.randn(4, 1)))
def forward(self, x):
for i in range(len(self.params)):
x = torch.mm(x, self.params[i])
return x
net = MyListDense()
print(net)
# ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典,
# 然后可以按照字典的规则使用了。
# 例如使用update()新增参数,使用keys()返回所有键值,使用items()返回所有键值对等等
class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增
def forward(self, x, choice='linear1'):
return torch.mm(x, self.params[choice])
net = MyDictDense()
print(net)
# 根据传入的键值来进行不同的前向传播
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))
print('--------------------------------')
net = nn.Sequential(
MyDictDense(),
MyListDense(),
)
print(net)
print(net(x))
可以通过Module类自定义神经网络中的层,从而可以被重复调用
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐
所有评论(0)