-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGPT в Rust
32 lines (24 loc) · 1.01 KB
/
GPT в Rust
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
use tch::nn::{Embedding, TransformerConfig, TransformerEncoderLayer, TransformerEncoder};
use tch::{Tensor, Device};
fn main() {
// Configuration
let d_model = 512;
let nhead = 8;
let num_layers = 6;
let dim_feedforward = 2048;
let dropout = 0.1;
// Example input
let input = Tensor::zeros(&[1, 10, d_model], (tch::Kind::Float, Device::cuda_if_available()));
// Embedding layer
let embedding = Embedding::new(10000, d_model, Default::default());
// Transformer configuration
let config = TransformerConfig::new(d_model, nhead, dim_feedforward, dropout, num_layers);
// Transformer encoder layers
let layers: Vec<_> = (0..num_layers).map(|_| TransformerEncoderLayer::new(&config)).collect();
// Transformer encoder
let encoder = TransformerEncoder::new(&config, &layers, Default::default());
// Forward pass
let output = input.apply(&embedding).apply_t(&encoder, true);
// Print output shape
println!("Output shape: {:?}", output.size());
}