Crates.io | learnwell |
lib.rs | learnwell |
version | 0.2.1 |
source | src |
created_at | 2023-01-16 10:26:22.923764 |
updated_at | 2023-04-02 16:18:20.864407 |
description | Framework for reinforcement learning |
homepage | |
repository | https://github.com/griccardos/learnwell/ |
max_upload_size | |
id | 760076 |
size | 120,354 |
Easy reinforcement learning framework, allowing you to quickly create Environments and test them.
Aims to be simple
Minimal external dependencies
Framework to create your own implementations
Implementation examples
The state of this project is in alpha. Use at your own risk.
See the taxi example and walk through the comments
cargo run --release --example taxi
you can also run the following examples:
hike
- runs with displaytaxi
mouse
mouseimage
- DQNtaxiimage
- DQN, runs with displayImports:
use learnwell::{
runner::Runner,
agent::qlearning::QLearning,
environment::{Environment, EnvironmentDisplay}
strategy::decliningrandom::DecliningRandom,
};
We then ask the Runner
to run the agent for x
number of epochs
Allows 2 modes:
Runner::run
for normal operationRunner::run_with_display
to create a window and display image which gets updated as it runsFor example:
Runner::run(
QLearning::new(0.1, 0.98, DecliningRandom::new(epochs, 0.01)), //Agent
TaxiEnvironment::default(), //Environment
400, //epochs
);
or
Runner::run_with_display(
QLearning::new(0.2, 0.99,DecliningRandom::new(epochs, 0.005) ), //Agent
Hike::new(), //Environment
700_000, //epochs
10 //frames per second to refresh image
);
Environment - this is the game/scenario we want to learn
Agent - this is what interacts with the environment
State
Struct - this is what we base our actions onAction
(normally enum) - these are the actions we performEnvironment<S,A>
trait and depends on the State
and Action
. The Environment struct should hold the state, because we will refer to it laterthe Agent algorithm (e.g. QLearning),
Note we derive Hash, Eq, PartialEq and Clone for both State
and Action
#[derive(Hash, Eq, PartialEq, Clone)]
pub struct TaxiState {
taxi: Point,
dropoff: Point,
passenger: Point,
in_taxi: bool,
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum TaxiAction {
Up,
Down,
Left,
Right,
Dropoff,
Pickup,
}
pub struct TaxiEnvironment {
state: TaxiState, //this is the actual state that gets saved in the qtable
found: usize, //just a helper. there could be a few other items you want to track in the environment
}