一、手动实现

from d2l import torch as d2l
import torch


def get_params(vocab_size, num_hiddens, device):
    
    num_inputs = num_outputs = vocab_size
    
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    def three():
        return (
            normal((num_inputs, num_hiddens)),
            normal((num_hiddens, num_hiddens)),
            torch.zeros(num_hiddens, device=device)
        )
    
    W_xz, W_hz, b_z = three() # 更新门参数
    W_xr, W_hr, b_r = three() # 重置门参数
    W_xh, W_hh, b_h = three() # 候选隐状态H~参数
    
    # 将隐藏状态转化成输出的参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    
    for param in params:
        param.requires_grad_(True)
    return params

def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    
    outputs = []
    
    for X in inputs:
        Z = torch.sigmoid(X @ W_xz + H @ W_hz + b_z)
        R = torch.sigmoid(X @ W_xr + H @ W_hr + b_r)
        
        H_ = torch.tanh(X @ W_xh + (R * H) @ W_hh + b_h)        
        H = Z * H + (1 - Z) * H_
        Y = H @ W_hq + b_q
        outputs.append(Y)
    
    return torch.cat(outputs, dim=0), (H,)

batch_size , num_steps = 32, 35
num_hiddens = 512
device = torch.device('cuda')
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
lr, num_epochs = 1, 500
net = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)

d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

二、简洁实现

# gru 简洁实现
from torch import nn
vocab_size = len(vocab)

gru_layer = nn.GRU(vocab_size, num_hiddens)

net = d2l.RNNModel(gru_layer, vocab_size)
net.to(device)

d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐