参考:candle-book

定义线性层,没有梯度相关的东西,只能做前向计算

use candle_core::{Device, Result, Tensor};

#[derive(Debug)]
struct LinearOwn {
    weight: Tensor, // 权重
    bias: Tensor, // 偏置
}

impl LinearOwn {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let x = x.matmul(&self.weight)?; // 前向传播, 矩阵乘法 x: (batch_size, in_features), weight: (in_features, out_features)
        x.broadcast_add(&self.bias) 
    }
}

struct Model {
    first: LinearOwn,
    second: LinearOwn,
}

impl Model {
    fn forward(&self, image: &Tensor) -> Result<Tensor> {
        let x = self.first.forward(image)?;
        let x = x.relu()?;
        self.second.forward(&x)
    }
}

fn main() -> Result<()> {
    let device = Device::cuda_if_available(0)?;
    let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
    let bias = Tensor::rand(0f32, 1.0, (100, ), &device)?;
    let first = LinearOwn {weight, bias};
    let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
    let bias = Tensor::rand(0f32, 1.0, (10, ), &device)?;
    let second = LinearOwn {weight, bias};
    let model = Model {first, second};
    let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
    let digit = model.forward(&dummy_image)?;
    println!("Digit {digit:?} digit");
    Ok(())
}

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐