| Crates.io | llm_router |
| lib.rs | llm_router |
| version | 0.1.0 |
| created_at | 2025-04-06 16:28:51.150512+00 |
| updated_at | 2025-04-06 16:28:51.150512+00 |
| description | A high-performance router and load balancer for LLM APIs like ChatGPT |
| homepage | |
| repository | https://github.com/yourusername/llm_router |
| max_upload_size | |
| id | 1623056 |
| size | 175,898 |
A high-performance, Rust-based load balancer and router specifically designed for Large Language Model (LLM) APIs. It intelligently distributes requests across multiple backend LLM API instances based on configurable strategies, health checks, and model capabilities.
https://api.openai.com/v1). Each instance has an ID, base URL, health status, and associated models.LoadBased strategy. The RequestTracker utility simplifies this.Add the dependency to your Cargo.toml:
[dependencies]
llm_router_core = "0.1.0"
# Add other dependencies for your application (e.g., tokio, reqwest, axum)
tokio = { version = "1", features = ["full"] }
reqwest = { version = "0.11", features = ["json"] }
axum = "0.7"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
use llm_router_core::{
Router, RequestTracker,
types::{ModelCapability, ModelInstanceConfig, RoutingStrategy} // Updated import
};
use std::time::Duration;
use std::sync::Arc; // Required for sharing Router
#[tokio::main]
async fn main() {
// Create the router configuration
let router_config = Router::builder()
.strategy(RoutingStrategy::RoundRobin)
.instance_with_models(
"openai_instance1",
"https://api.openai.com/v1", // Replace with your actual endpoint if different
vec![
ModelInstanceConfig {
model_name: "gpt-4".to_string(),
capabilities: vec![ModelCapability::Chat],
},
ModelInstanceConfig {
model_name: "text-embedding-ada-002".to_string(),
capabilities: vec![ModelCapability::Embedding],
},
]
)
.instance_with_models(
"openai_instance2",
"https://api.openai.com/v1", // Replace with your actual endpoint if different
vec![
ModelInstanceConfig {
model_name: "gpt-3.5-turbo".to_string(),
capabilities: vec![ModelCapability::Chat],
}
]
)
.health_check_path("/health") // Optional: Define if your API has a health endpoint
.health_check_interval(Duration::from_secs(30))
.instance_timeout_duration(Duration::from_secs(60)) // Timeout unhealthy instances for 60s
.build();
// Wrap the router in an Arc for sharing across threads/tasks
let router = Arc::new(router_config);
// --- Example: Selecting an instance ---
let model_name = "gpt-4";
let capability = ModelCapability::Chat;
match router.select_instance_for_model(model_name, capability).await {
Ok(instance) => {
println!(
"Selected instance for '{}' ({}): {} ({})",
model_name, capability, instance.id, instance.base_url
);
// Use RequestTracker for automatic request counting (especially with LoadBased strategy)
let _tracker = RequestTracker::new(Arc::clone(&router), instance.id.clone());
// ---> Place your API call logic here <---
// Example: Construct the URL
let api_url = format!("{}/chat/completions", instance.base_url);
println!("Constructed API URL: {}", api_url);
// Use your HTTP client (e.g., reqwest) to send the request to api_url
// Remember to handle potential errors from the API call itself
// If an error occurs, consider calling router.timeout_instance(&instance.id).await
}
Err(e) => eprintln!(
"Error selecting instance for '{}' ({}): {}",
model_name, capability, e
),
}
}
Most LLM APIs require authentication (e.g., API keys). The llm-router itself doesn't handle authentication headers directly during routing or health checks by default. You need to manage authentication in your application's HTTP client when making the actual API calls after selecting an instance.
If your health checks require authentication, you can provide a pre-configured reqwest::Client to the Router::builder.
use llm_router_core::{Router, types::{ModelInstanceConfig, ModelCapability, RoutingStrategy}};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use std::time::Duration;
use std::sync::Arc;
async fn setup_router_with_auth() -> Result<Arc<Router>, Box<dyn std::error::Error>> {
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
// Configure a reqwest client with default auth headers (useful for health checks)
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", api_key))?,
);
let client = reqwest::Client::builder()
.default_headers(headers)
.timeout(Duration::from_secs(10)) // Set a timeout for HTTP requests
.build()?;
let router_config = Router::builder()
.strategy(RoutingStrategy::LoadBased)
.instance_with_models(
"authed_instance",
"https://api.openai.com/v1",
vec![ModelInstanceConfig {
model_name: "gpt-4".to_string(),
capabilities: vec![ModelCapability::Chat],
}],
)
.http_client(client) // Provide the pre-configured client
.health_check_interval(Duration::from_secs(60))
.build();
Ok(Arc::new(router_config))
}
// Remember: When making the actual API call *after* selecting an instance,
// you still need to ensure your request includes the necessary authentication.
// The client passed to the builder is primarily for health checks.
The router supports two strategies set via Router::builder().strategy(...):
Round Robin (RoutingStrategy::RoundRobin):
Load Based (RoutingStrategy::LoadBased):
RequestTracker (recommended) or manually calling increment_request_count and decrement_request_count to be effective.You can modify the router's instance pool after it has been built. This is useful for scaling or maintenance.
# use llm_router_core::{Router, types::{ModelCapability, ModelInstanceConfig, RoutingStrategy}};
# use std::sync::Arc;
# use std::time::Duration;
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
// Add a new instance dynamically
router.add_instance_with_models(
"new_instance",
"https://api.another-provider.com/v1",
vec![
ModelInstanceConfig {
model_name: "claude-3".to_string(),
capabilities: vec![ModelCapability::Chat],
}
]
).await?;
println!("Added new_instance");
// Add a new model/capability to an existing instance
router.add_model_to_instance(
"new_instance",
"claude-3-opus".to_string(),
vec![ModelCapability::Chat]
).await?;
println!("Added claude-3-opus to new_instance");
// Remove an instance
match router.remove_instance("openai_instance1").await {
Ok(_) => println!("Removed openai_instance1"),
Err(e) => eprintln!("Failed to remove instance: {}", e),
}
// Get status of all instances
let instances_status = router.get_instances().await;
println!("
Current Instance Status:");
for instance_info in instances_status {
println!(
"- ID: {}, URL: {}, Status: {:?}, Active Requests: {}, Models: {:?}",
instance_info.id,
instance_info.base_url,
instance_info.status,
instance_info.active_requests, // Only relevant for LoadBased
instance_info.models.keys().collect::<Vec<_>>()
);
}
# Ok(())
# }
The primary way to get a suitable backend URL is by requesting an instance for a specific model and capability.
# use llm_router_core::{Router, RequestTracker, types::{ModelCapability, RoutingStrategy}};
# use std::sync::Arc;
# use std::time::Duration;
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
let model_name = "gpt-3.5-turbo";
let capability = ModelCapability::Chat;
match router.select_instance_for_model(model_name, capability).await {
Ok(instance) => {
println!("Selected instance for {} ({}): {} ({})", model_name, capability, instance.id, instance.base_url);
// Use RequestTracker to ensure load balancing works correctly if using LoadBased strategy
let _tracker = RequestTracker::new(router.clone(), instance.id.clone());
// Now, make the API call to instance.base_url using your HTTP client...
}
Err(e) => {
eprintln!("Could not find a healthy instance for {} ({}): {}", model_name, capability, e);
// Handle the error (e.g., return an error response to the user)
}
}
# Ok(())
# }
Alternatively, if you don't need a specific model and just want the next instance according to the strategy:
# use llm_router_core::{Router, RequestTracker, types::RoutingStrategy};
# use std::sync::Arc;
# use std::time::Duration;
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
match router.select_next_instance().await {
Ok(instance) => {
println!("Selected next instance (any model/capability): {} ({})", instance.id, instance.base_url);
let _tracker = RequestTracker::new(router.clone(), instance.id.clone());
// Make API call... (Be aware this instance might not support the specific model you need)
}
Err(e) => {
eprintln!("Could not select next instance: {}", e);
}
}
# Ok(())
# }
RequestTracker (Important for Load Balancing)When using the LoadBased strategy, the router needs to know how many requests are currently in flight to each instance. The RequestTracker utility handles this automatically using RAII (Resource Acquisition Is Initialization).
# use llm_router_core::{Router, RequestTracker, types::{ModelCapability, RoutingStrategy}};
# use std::sync::Arc;
# use std::time::Duration;
# async fn make_api_call(url: &str) -> Result<(), &'static str> { /* ... */ Ok(()) }
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
let instance = router.select_instance_for_model("gpt-4", ModelCapability::Chat).await?;
// Create the tracker immediately after selecting the instance
let tracker = RequestTracker::new(router.clone(), instance.id.clone());
// Perform the API call or other work associated with this instance
println!("Making API call to {}", instance.base_url);
match make_api_call(&instance.base_url).await {
Ok(_) => println!("API call successful"),
Err(e) => {
eprintln!("API call failed: {}", e);
// If the call fails, consider putting the instance in timeout
router.timeout_instance(&instance.id).await?;
println!("Instance {} put into timeout due to error.", instance.id);
}
}
// When `tracker` goes out of scope here (end of function, or earlier block),
// it automatically decrements the request count for the instance.
println!("Request finished, tracker dropped.");
# Ok(())
# }
If you don't use RequestTracker, you must manually call router.increment_request_count(&instance.id) before the request and router.decrement_request_count(&instance.id) after the request (including in error cases) for LoadBased routing to function correctly. RequestTracker is strongly recommended.
If an API call to a selected instance fails, you might want to temporarily mark that instance as unhealthy to prevent routing further requests to it for a while.
# use llm_router_core::{Router, RequestTracker, types::{ModelCapability, RoutingStrategy}};
# use std::sync::Arc;
# use std::time::Duration;
# async fn make_api_call(url: &str) -> Result<(), &'static str> { Err("Simulated API Error") }
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
let instance = router.select_instance_for_model("gpt-4", ModelCapability::Chat).await?;
let tracker = RequestTracker::new(router.clone(), instance.id.clone());
match make_api_call(&instance.base_url).await {
Ok(_) => { /* Process success */ },
Err(api_error) => {
eprintln!("API Error from instance {}: {}", instance.id, api_error);
// Put the instance in timeout
match router.timeout_instance(&instance.id).await {
Ok(_) => println!("Instance {} placed in timeout.", instance.id),
Err(e) => eprintln!("Error placing instance {} in timeout: {}", instance.id, e),
}
// Return an appropriate error to the caller
return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "API call failed")));
}
}
# Ok(())
# }
The instance will remain in the Timeout state for the duration specified by instance_timeout_duration in the builder, after which the health checker will attempt to bring it back online.
Configure health checks using the builder:
# use llm_router_core::{Router, types::RoutingStrategy};
# use std::time::Duration;
#
let router = Router::builder()
// ... other configurations ...
.health_check_path("/health") // The endpoint path for the health check (e.g., GET <base_url>/health)
.health_check_interval(Duration::from_secs(15)) // Check health every 15 seconds
.health_check_timeout(Duration::from_secs(5)) // Timeout for the health check request itself (5 seconds)
.instance_timeout_duration(Duration::from_secs(60)) // How long an instance stays in Timeout state (60 seconds)
.build();
health_check_path is not set, instances are initially considered Healthy and only move to Timeout if timeout_instance is called.GET request to <instance.base_url><health_check_path>. A 2xx status code marks the instance as Healthy. Any other status or a timeout marks it as Unhealthy.Timeout state are not checked until the timeout duration expires.Here's how to integrate llm-router into an Axum web server to act as a proxy/gateway to your LLM backends.
use axum::{
extract::{Json, State},
http::{StatusCode, Uri},
response::{IntoResponse, Response},
routing::post,
Router as AxumRouter,
};
use llm_router_core::{
Router, RequestTracker,
types::{ModelCapability, ModelInstanceConfig, RoutingStrategy, RouterError},
};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
// Define your request/response structures (matching the target LLM API)
#[derive(Serialize, Deserialize, Debug)]
struct ChatRequest {
model: String,
messages: Vec<serde_json::Value>, // Example structure
// ... other fields like temperature, max_tokens, etc.
}
#[derive(Serialize, Deserialize, Debug)]
struct ChatResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<serde_json::Value>,
usage: serde_json::Value,
}
// Application state
struct AppState {
router: Arc<Router>,
http_client: Client, // Use a shared reqwest client
}
#[tokio::main]
async fn main() {
// --- Router Configuration ---
let router_config = Router::builder()
.strategy(RoutingStrategy::LoadBased) // Example: Use LoadBased
.instance_with_models(
"openai_1",
"https://api.openai.com/v1", // Replace with actual URL
vec![
ModelInstanceConfig::new("gpt-4", vec![ModelCapability::Chat]),
ModelInstanceConfig::new("gpt-3.5-turbo", vec![ModelCapability::Chat]),
ModelInstanceConfig::new("text-embedding-ada-002", vec![ModelCapability::Embedding]),
],
)
.instance_with_models(
"openai_2", // Perhaps using a different key or region
"https://api.openai.com/v1", // Replace with actual URL
vec![
ModelInstanceConfig::new("gpt-4", vec![ModelCapability::Chat]),
],
)
.health_check_path("/v1/models") // OpenAI's model list endpoint can serve as a basic health check
.health_check_interval(Duration::from_secs(60))
.instance_timeout_duration(Duration::from_secs(120))
.build();
let shared_router = Arc::new(router_config);
let shared_http_client = Client::new(); // Create a single reqwest client
let app_state = Arc::new(AppState {
router: shared_router,
http_client: shared_http_client,
});
// --- Axum Setup ---
let app = AxumRouter::new()
.route("/v1/chat/completions", post(chat_completions_handler))
// Add other routes for embeddings, completions etc.
.with_state(app_state);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("🚀 LLM Router Gateway listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn chat_completions_handler(
State(state): State<Arc<AppState>>,
Json(payload): Json<ChatRequest>,
) -> Result<impl IntoResponse, AppError> {
println!("Received chat request for model: {}", payload.model);
// 1. Select an instance capable of handling the request
let instance = state
.router
.select_instance_for_model(&payload.model, ModelCapability::Chat)
.await?; // Use ? to convert RouterError into AppError
println!("Selected instance: {} ({})", instance.id, instance.base_url);
// 2. Use RequestTracker for load balancing (if using LoadBased) and timeout handling
let _tracker = RequestTracker::new(state.router.clone(), instance.id.clone());
// 3. Construct the target URL
let target_url_str = format!("{}/v1/chat/completions", instance.base_url); // Assuming OpenAI path
let target_url = target_url_str.parse::<Uri>().map_err(|_| {
AppError::Internal(format!("Failed to parse target URL: {}", target_url_str))
})?;
// --- Authentication ---
// IMPORTANT: Add authentication headers here. Get the key from secure storage/config.
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
AppError::Internal("OPENAI_API_KEY environment variable not set".to_string())
})?;
let auth_header_value = format!("Bearer {}", api_key);
// 4. Proxy the request using the shared http_client
let response = state.http_client
.post(target_url.to_string())
.header("Authorization", auth_header_value) // Add the auth header
.json(&payload) // Forward the original payload
.send()
.await
.map_err(|e| AppError::BackendError(instance.id.clone(), e.to_string()))?;
// 5. Handle the response from the backend
let backend_status = response.status();
let response_bytes = response.bytes().await.map_err(|e| AppError::BackendError(instance.id.clone(), e.to_string()))?;
if !backend_status.is_success() {
eprintln!(
"Backend error from instance {}: Status: {}, Body: {:?}",
instance.id,
backend_status,
String::from_utf8_lossy(&response_bytes)
);
// If the backend failed, put the instance into timeout
let _ = state.router.timeout_instance(&instance.id).await; // Ignore error during timeout
return Err(AppError::BackendError(
instance.id.clone(),
format!("Status: {}, Body: {:?}", backend_status, String::from_utf8_lossy(&response_bytes)),
));
}
// 6. Forward the successful response (potentially deserialize/re-serialize if needed)
// Here, we forward the raw bytes and original status code/headers
let mut response_builder = Response::builder().status(backend_status);
// Copy relevant headers if necessary (e.g., Content-Type)
if let Some(content_type) = response.headers().get(reqwest::header::CONTENT_TYPE) {
response_builder = response_builder.header(reqwest::header::CONTENT_TYPE, content_type);
}
let response = response_builder
.body(axum::body::Body::from(response_bytes))
.map_err(|e| AppError::Internal(format!("Failed to build response: {}", e)))?;
Ok(response)
}
// Custom Error type for Axum handler
enum AppError {
RouterError(RouterError),
BackendError(String, String), // instance_id, error message
Internal(String),
}
impl From<RouterError> for AppError {
fn from(err: RouterError) -> Self {
AppError::RouterError(err)
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AppError::RouterError(e) => {
eprintln!("Router error: {}", e);
// Handle specific RouterErrors differently if needed
match e {
RouterError::NoHealthyInstances(_) => (
StatusCode::SERVICE_UNAVAILABLE,
format!("No healthy backend instances available: {}", e),
),
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal routing error: {}", e),
),
}
},
AppError::BackendError(instance_id, msg) => {
eprintln!("Backend error from instance {}: {}", instance_id, msg);
(
StatusCode::BAD_GATEWAY, // 502 suggests an issue with the upstream server
format!("Error from backend instance {}: {}", instance_id, msg),
)
},
AppError::Internal(msg) => {
eprintln!("Internal server error: {}", msg);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal server error: {}", msg),
)
}
};
(status, Json(serde_json::json!({ "error": error_message }))).into_response()
}
}
To run this Axum example:
Cargo.toml (axum, tokio, reqwest, serde, serde_json, llm_router_core).src/main.rs).OPENAI_API_KEY).cargo runcurl or Postman) to http://127.0.0.1:3000/v1/chat/completions with a JSON body matching the ChatRequest structure.The crate includes benchmarks to measure the performance of the instance selection logic.
1. Run the Benchmarks:
Execute the standard Cargo benchmark command. This will run the selection logic repeatedly for different numbers of instances and routing strategies.
cargo bench
This command compiles the code in release mode with benchmarking enabled and runs the functions annotated with #[bench]. The output will show time-per-iteration results, but it's easier to analyze with the reporter.
2. Generate the Report:
After running cargo bench, the raw results are typically stored in target/criterion/. A helper binary is provided to parse these results and generate a user-friendly report and a plot.
cargo run --bin bench_reporter
This command runs the bench_reporter binary located in src/bin/bench_reporter.rs. It will:
cargo bench.benchmark_scaling.png in the project's root directory.Example Output:
--- LLM Router Benchmark Report ---
Found result for RoundRobin/10: 1745.87 ns
Found result for RoundRobin/25: 3960.23 ns
... (more results) ...
Found result for LoadBased/100: 14648.41 ns
--- Summary Table ---
+------------+---------------+---------------------------+
| Strategy | Instances (N) | Median Time per Selection |
+------------+---------------+---------------------------+
| RoundRobin | 10 | 1.75 µs |
| LoadBased | 10 | 1.75 µs |
... (more rows) ...
| RoundRobin | 100 | 15.15 µs |
| LoadBased | 100 | 14.65 µs |
+------------+---------------+---------------------------+
--- Performance Scaling Plot (RoundRobin) ---
Time (Median)
N=10 | 1.75 µs |
...
N=100 | 15.15 µs |
+------------------------------------------+
Instances (N) -->
... (LoadBased Plot) ...
--- Plot saved to benchmark_scaling.png ---
The benchmark_scaling.png file provides a visual comparison of how the selection time increases as the number of backend instances grows for both routing strategies. This helps understand the minimal overhead added by the router.
RoundRobin vs. LoadBased:
RoundRobin has slightly lower overhead as it doesn't need to check active request counts.LoadBased provides better load distribution if backend performance varies, potentially leading to more consistent end-to-end latency, at the cost of slightly higher selection overhead (though still in microseconds).MIT