| Crates.io | decision_transformer_dfdx |
| lib.rs | decision_transformer_dfdx |
| version | 0.2.0 |
| created_at | 2025-01-01 02:08:37.191753+00 |
| updated_at | 2025-01-01 02:08:37.191753+00 |
| description | A fast, extensible implementation of Decision Transformers in Rust using dfdx. Based on the paper Decision Transformer: Reinforcement Learning via Sequence Modeling. |
| homepage | |
| repository | https://github.com/JYudelson1/decision_transformer_dfdx/tree/main |
| max_upload_size | |
| id | 1500530 |
| size | 70,343 |
A fast, extensible implementation of Decision Transformers in Rust using dfdx. Based on the paper Decision Transformer: Reinforcement Learning via Sequence Modeling.
This crate provides a framework for implementing Decision Transformers in Rust. Decision Transformers frame reinforcement learning as a sequence prediction problem, allowing for more efficient learning from demonstration data. This implementation is built on top of dfdx for efficient tensor operations and automatic differentiation.
cargo add decision-transformer-dfdx
To implement a Decision Transformer for your own environment:
DTModelConfigDTState trait for your environment (required)GetOfflineData if you want to train from demonstrationsHumanEvaluatable if you want to visualize/evaluate the environmentFor a complete working example, check out the Snake Game Implementation.
The framework is built around three main traits:
pub trait DTState<E: Dtype, D: Device<E>, Config: DTModelConfig> {
type Action: Clone;
const STATE_SIZE: usize; // Total number of floats needed to represent the state
const ACTION_SIZE: usize; // Total number of possible actions
// Required methods
fn new_random<R: rand::Rng + ?Sized>(rng: &mut R) -> Self;
fn apply_action(&mut self, action: Self::Action);
fn get_reward(&self, action: Self::Action) -> f32;
fn to_tensor(&self) -> Tensor<(Const<{ Self::STATE_SIZE }>,), E, D>;
fn action_to_index(action: &Self::Action) -> usize;
fn index_to_action(action: usize) -> Self::Action;
// Provided method
fn action_to_tensor(action: &Self::Action) -> Tensor<(Const<{ Self::ACTION_SIZE }>,), E, D>;
fn build_model() -> DTModelWrapper<E, D, Config, Self>;
}
pub trait GetOfflineData<E: Dtype, D: Device<E>, Config: DTModelConfig>: DTState<E, D, Config> {
// Required method
fn play_one_game<R: rand::Rng + ?Sized>(rng: &mut R) -> (Vec<Self>, Vec<Self::Action>);
// Provided method
fn get_batch<const B: usize, R: rand::Rng + ?Sized>(
rng: &mut R,
cap_from_game: Option<usize>
) -> (BatchedInput<B, { Self::STATE_SIZE }, { Self::ACTION_SIZE }, E, D, Config>, [Self::Action; B]);
}
pub trait HumanEvaluatable<E: Dtype, D: Device<E>, Config: DTModelConfig>: DTState<E, D, Config> {
// All methods required
fn print(&self); // Print the current state
fn print_action(action: &Self::Action); // Print a given action
fn is_still_playing(&self) -> bool; // Check if episode is ongoing
}
The DTModelConfig trait allows you to configure the transformer architecture:
pub trait DTModelConfig {
const NUM_ATTENTION_HEADS: usize; // Number of attention heads
const HIDDEN_SIZE: usize; // Size of hidden layers
const MLP_INNER: usize; // Size of inner MLP layer (typically 4*HIDDEN_SIZE)
const SEQ_LEN: usize; // Length of sequence to consider
const MAX_EPISODES_IN_GAME: usize; // Maximum episodes in a game
const NUM_LAYERS: usize; // Number of transformer layers
}
Train your model using pre-collected demonstration data:
let mut model = MyEnvironment::build_model();
let mut optimizer = Adam::new(&model.0, config);
// Get a batch of demonstration data
let (batch, actions) = MyEnvironment::get_batch::<1024, _>(&mut rng, Some(256));
// Train on the batch
let loss = model.train_on_batch(batch.clone(), actions, &mut optimizer);
Train your model through self-play:
let temp = 0.5; // Temperature for exploration
let desired_reward = 5.0; // Target reward
// Train through self-play
let loss = model.online_learn::<100, _>(
temp,
desired_reward,
&mut optimizer,
&mut rng,
Some(256) // Optional cap on episodes per game
);
The repository includes a complete example implementing a snake game environment using the Decision Transformer framework. This serves as a reference implementation showing how to:
Check out the snake game implementation for a complete working example.
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.