rust-candle学习笔记2-使用Tensor定义线性层,只能前向计算
rust-candle学习笔记2-使用Tensor定义线性层,只能前向计算
·
参考:candle-book
定义线性层,没有梯度相关的东西,只能做前向计算
use candle_core::{Device, Result, Tensor};
#[derive(Debug)]
struct LinearOwn {
weight: Tensor, // 权重
bias: Tensor, // 偏置
}
impl LinearOwn {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight)?; // 前向传播, 矩阵乘法 x: (batch_size, in_features), weight: (in_features, out_features)
x.broadcast_add(&self.bias)
}
}
struct Model {
first: LinearOwn,
second: LinearOwn,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let bias = Tensor::rand(0f32, 1.0, (100, ), &device)?;
let first = LinearOwn {weight, bias};
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let bias = Tensor::rand(0f32, 1.0, (10, ), &device)?;
let second = LinearOwn {weight, bias};
let model = Model {first, second};
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐


所有评论(0)