| Crates.io | rlkit |
| lib.rs | rlkit |
| version | 0.0.3 |
| created_at | 2025-10-11 09:18:05.48277+00 |
| updated_at | 2025-10-13 05:26:36.735742+00 |
| description | A deep reinforcement learning library based on Rust and Candle, providing complete implementations of Q-Learning and DQN algorithms, supporting custom environments, various policy choices, and flexible training configurations. Future support will include more reinforcement learning algorithms, such as DDPG, PPO, A2C, etc. |
| homepage | |
| repository | https://github.com/Hifive55555/rlkit |
| max_upload_size | |
| id | 1878021 |
| size | 229,473 |
A deep reinforcement learning library based on Rust and Candle, providing complete implementations of Q-Learning and DQN algorithms, supporting custom environments, various policy choices, and flexible training configurations. Future support will include more reinforcement learning algorithms, such as DDPG, PPO, A2C, etc.
cuda: Enables CUDA GPU computation supportsrc/
├── algs/ # Reinforcement learning algorithm implementations
│ ├── dqn.rs # DQN algorithm
│ ├── q_learning.rs # Q-Learning algorithm
│ └── mod.rs # Algorithm interfaces and common components
├── network.rs # Neural network implementation (for DQN)
├── policies.rs # Action selection policies
├── replay_buffer.rs # Experience replay buffer
├── types.rs # Core type definitions (state, action, environment interface, etc.)
└── utils.rs # Utility functions
Add this library to your Rust project:
[dependencies]
rlkit = { version = "0.0.3", features = ["cuda"] }
All reinforcement learning environments must implement the EnvTrait interface:
use rlkit::types::{EnvTrait, Status, Reward, Action};
struct MyEnv {
// Environment state
}
impl EnvTrait<u16, u16> for MyEnv {
fn step(&mut self, state: &Status<u16>, action: &Action<u16>) -> (Status<u16>, Reward, bool) {
// Execute action and return next state, reward, and whether it is done
}
fn reset(&mut self) -> Status<u16> {
// Reset the environment and return the initial state
}
fn action_space(&self) -> &[u16] {
// Return the dimension information of the action space
}
fn state_space(&self) -> &[u16] {
// Return the dimension information of the state space
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
use rlkit::algs::{QLearning, TrainArgs};
use rlkit::policies::EpsilonGreedy;
use rlkit::types::{EnvTrait};
// Create the environment
let mut env = MyEnv::new();
// Create a Q-Learning algorithm instance
let mut q_learning = QLearning::new(&env, 10000).unwrap();
// Create an ε-greedy policy
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// Configure training parameters
let train_args = TrainArgs {
epochs: 1000,
max_steps: 200,
batch_size: 64,
learning_rate: 0.1,
gamma: 0.99,
..Default::default()
};
// Train the model
q_learning.train(&mut env, &mut policy, train_args).unwrap();
// Get an action using the trained model
let state = env.reset();
let action = q_learning.get_action(&state, &mut policy).unwrap();
use candle_core::Device;
use rlkit::algs::{DQN, TrainArgs};
use rlkit::algs::dqn::DNQStateMode;
use rlkit::policies::EpsilonGreedy;
// Create the environment
let mut env = MyEnv::new();
// Select the computation device (CPU or GPU)
let device = Device::Cpu;
// Or use GPU: let device = Device::new_cuda(0).unwrap();
// Create a DQN algorithm instance
let mut dqn = DQN::new(
&env,
10000, // Replay buffer capacity
&[128, 64, 16], // Hidden layer structure
DNQStateMode::OneHot, // State encoding mode
&device
).unwrap();
// Create a policy
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// Configure training parameters
let train_args = TrainArgs {
epochs: 1000,
max_steps: 200,
batch_size: 32,
learning_rate: 1e-3,
gamma: 0.99,
update_freq: 5,
update_interval: 100,
};
// Train the model
dqn.train(&mut env, &mut policy, train_args).unwrap();
The library provides multiple action selection policies:
use rlkit::policies::{PolicyConfig, EpsilonGreedy, Boltzmann};
// ε-greedy policy
let mut epsilon_greedy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// Boltzmann policy
let mut boltzmann = Boltzmann::new(1.0, 0.1, 0.99);
The library includes two complete examples:
Grid World - Located in examples/grid_world-example/
Catch Rabbit - Located in examples/catch_rabbit.rs
cd examples/grid_world-example
cargo run
Then follow the prompts to select the algorithm to use (1: Q-Learning, 2: DQN).
To implement a custom environment, you need to:
EnvTrait interfaceDetailed examples can be found in examples/grid_world-example/env.rs and examples/catch_rabbit.rs.
The TrainArgs struct includes the following parameters:
epochs: Number of training epochsmax_steps: Maximum number of steps per epochbatch_size: Number of samples for batch updateslearning_rate: Learning rategamma: Discount factorupdate_freq: Frequency of network updates (every how many steps)update_interval: Interval for target network updates (only used in DQN)This project is licensed under the MIT License. For more details, see the LICENSE file.