| Crates.io | prefrontal |
| lib.rs | prefrontal |
| version | 0.1.0 |
| created_at | 2025-01-28 01:04:50.439849+00 |
| updated_at | 2025-01-28 01:04:50.439849+00 |
| description | A blazing fast text classifier for real-time agent routing, built in Rust |
| homepage | https://axar.ai |
| repository | https://github.com/axar-ai/prefrontal |
| max_upload_size | |
| id | 1533138 |
| size | 202,799 |
A blazing fast text classifier for real-time agent routing, built in Rust. Prefrontal provides a simple yet powerful interface for routing conversations to the right agent or department based on text content, powered by transformer-based models.
To use Prefrontal in your project, you need:
Build Dependencies
sudo apt-get install cmake pkg-config libssl-devbrew install cmake pkg-config opensslRuntime Requirements
Note: ONNX Runtime v2.0.0-rc.9 is automatically downloaded and managed by the crate.
On first use, Prefrontal will automatically download the required model files from HuggingFace:
https://huggingface.co/axar-ai/minilm/resolve/main/model.onnxhttps://huggingface.co/axar-ai/minilm/resolve/main/tokenizer.jsonThe files are cached locally (see Model Cache Location section). You can control the download behavior using the ModelManager:
let manager = ModelManager::new_default()?;
let model = BuiltinModel::MiniLM;
// Check if model is already downloaded
if !manager.is_model_downloaded(model) {
println!("Downloading model...");
manager.download_model(model).await?;
}
Add this to your Cargo.toml:
[dependencies]
prefrontal = "0.1.0"
use prefrontal::{Classifier, BuiltinModel, ClassDefinition, ModelManager};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Download the model if not already present
let manager = ModelManager::new_default()?;
let model = BuiltinModel::MiniLM;
if !manager.is_model_downloaded(model) {
println!("Downloading model...");
manager.download_model(model).await?;
}
// Initialize the classifier with built-in MiniLM model
let classifier = Classifier::builder()
.with_model(model)?
.add_class(
ClassDefinition::new(
"technical_support",
"Technical issues and product troubleshooting"
).with_examples(vec![
"my app keeps crashing",
"can't connect to the server",
"integration setup help"
])
)?
.add_class(
ClassDefinition::new(
"billing",
"Billing and subscription inquiries"
).with_examples(vec![
"update credit card",
"cancel subscription",
"billing cycle question"
])
)?
.build()?;
// Route incoming message
let (department, scores) = classifier.predict("I need help setting up the API integration")?;
println!("Route to: {}", department);
println!("Confidence scores: {:?}", scores);
Ok(())
}
The MiniLM model is a small and efficient model optimized for text classification:
Embedding Size: 384 dimensions
Max Sequence Length: 256 tokens
Model Size: ~85MB
The ONNX runtime configuration is optional. If not specified, the classifier uses a default configuration optimized for most use cases:
// Using default configuration (recommended for most cases)
let classifier = Classifier::builder()
.with_model(BuiltinModel::MiniLM)?
.build()?;
// Or with custom runtime configuration
use prefrontal::{Classifier, RuntimeConfig};
use ort::session::builder::GraphOptimizationLevel;
let config = RuntimeConfig {
inter_threads: 4, // Number of threads for parallel model execution (0 = auto)
intra_threads: 2, // Number of threads for parallel computation within nodes (0 = auto)
optimization_level: GraphOptimizationLevel::Level3, // Maximum optimization
};
let classifier = Classifier::builder()
.with_runtime_config(config) // Optional: customize ONNX runtime behavior
.with_model(BuiltinModel::MiniLM)?
.build()?;
The default configuration (RuntimeConfig::default()) uses:
You can customize these settings if needed:
inter_threads: Number of threads for parallel model execution
intra_threads: Number of threads for parallel computation within nodes
optimization_level: The level of graph optimization to apply
You can use your own ONNX models:
let classifier = Classifier::builder()
.with_custom_model(
"path/to/model.onnx",
"path/to/tokenizer.json",
Some(512) // Optional: custom sequence length
)?
.add_class(
ClassDefinition::new(
"custom_class",
"Description of the custom class"
).with_examples(vec!["example text"])
)?
.build()?;
The library includes a model management system that handles downloading and verifying models:
use prefrontal::{ModelManager, BuiltinModel};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let manager = ModelManager::new_default()?;
let model = BuiltinModel::MiniLM;
// Check if model is downloaded
if !manager.is_model_downloaded(model) {
println!("Downloading model...");
manager.download_model(model).await?;
}
// Verify model integrity
if !manager.verify_model(model)? {
println!("Model verification failed, redownloading...");
manager.download_model(model).await?;
}
Ok(())
}
Models are stored in one of the following locations, in order of precedence:
PREFRONTAL_CACHE environment variable, if set~/.cache/prefrontal/~/Library/Caches/prefrontal/%LOCALAPPDATA%\prefrontal\Cache\.prefrontal directory in the current working directoryYou can override the default location by:
# Set a custom cache directory
export PREFRONTAL_CACHE=/path/to/your/cache
# Or when running your application
PREFRONTAL_CACHE=/path/to/your/cache cargo run
Classes (departments or routing destinations) are defined with labels, descriptions, and optional examples:
// With examples for few-shot classification
let class = ClassDefinition::new(
"technical_support",
"Technical issues and product troubleshooting"
).with_examples(vec![
"integration problems",
"api errors"
]);
// Without examples for zero-shot classification
let class = ClassDefinition::new(
"billing",
"Payment and subscription related inquiries"
);
// Get routing configuration
let info = classifier.info();
println!("Number of departments: {}", info.num_classes);
Each class requires:
The library provides detailed error types:
pub enum ClassifierError {
TokenizerError(String), // Tokenizer-related errors
ModelError(String), // ONNX model errors
BuildError(String), // Construction errors
PredictionError(String), // Prediction-time errors
ValidationError(String), // Input validation errors
DownloadError(String), // Model download errors
IoError(String), // File system errors
}
The library uses the log crate for logging. Enable debug logging for detailed information:
use prefrontal::init_logger;
fn main() {
init_logger(); // Initialize with default configuration
// Or configure env_logger directly for more control
}
The classifier is thread-safe and can be shared across threads using Arc:
use std::sync::Arc;
use std::thread;
let classifier = Arc::new(classifier);
let mut handles = vec![];
for _ in 0..3 {
let classifier = Arc::clone(&classifier);
handles.push(thread::spawn(move || {
classifier.predict("test text").unwrap();
}));
}
for handle in handles {
handle.join().unwrap();
}
Benchmarks run on MacBook Pro M1, 16GB RAM:
| Operation | Text Length | Time |
|---|---|---|
| Tokenization | Short (< 50 chars) | ~23.2 µs |
| Tokenization | Medium (~100 chars) | ~51.5 µs |
| Tokenization | Long (~200 chars) | ~107.0 µs |
| Operation | Scenario | Time |
|---|---|---|
| Embedding | Single text | ~11.1 ms |
| Classification | Single class | ~11.1 ms |
| Classification | 10 classes | ~11.1 ms |
| Build | 10 classes | ~450 ms |
Key Performance Characteristics:
Run the benchmarks yourself:
cargo bench
If you want to contribute to Prefrontal or build it from source, you'll need:
git clone https://github.com/yourusername/prefrontal.git
cd prefrontal
Install dependencies as described in System Requirements
Build and test:
cargo build
cargo test
# Run all tests
cargo test
# Run specific test
cargo test test_name
# Run benchmarks
cargo bench
This project is licensed under the Apache License, Version 2.0 - see the LICENSE file for details.