| Crates.io | axonml-serialize |
| lib.rs | axonml-serialize |
| version | 0.2.4 |
| created_at | 2026-01-24 14:02:06.859857+00 |
| updated_at | 2026-01-25 22:33:41.072772+00 |
| description | Model serialization for Axonml ML framework |
| homepage | |
| repository | https://github.com/automatanexus/ferrite |
| max_upload_size | |
| id | 2066727 |
| size | 77,410 |
axonml-serialize provides model serialization functionality for the AxonML machine learning framework. It supports saving and loading trained models, including state dictionaries, training checkpoints, and format conversion utilities for interoperability with PyTorch and ONNX.
| Module | Description |
|---|---|
state_dict |
StateDict and TensorData for storing model parameters by name |
checkpoint |
Checkpoint and TrainingState for saving/resuming training sessions |
format |
Format enum and detection utilities for different serialization formats |
convert |
Conversion utilities for PyTorch and ONNX interoperability |
Add the dependency to your Cargo.toml:
[dependencies]
axonml-serialize = "0.1.0"
use axonml_serialize::{save_model, load_state_dict};
use axonml_nn::Linear;
// Save a model (format detected from extension)
let model = Linear::new(10, 5);
save_model(&model, "model.axonml")?; // Binary format
save_model(&model, "model.json")?; // JSON format
// Load state dictionary
let state_dict = load_state_dict("model.axonml")?;
println!("Parameters: {}", state_dict.total_params());
println!("Size: {} bytes", state_dict.size_bytes());
use axonml_serialize::{StateDict, TensorData};
// Create a state dictionary
let mut state_dict = StateDict::new();
let weights = TensorData {
shape: vec![10, 5],
values: vec![0.0; 50],
};
state_dict.insert("linear.weight".to_string(), weights);
let bias = TensorData {
shape: vec![5],
values: vec![0.0; 5],
};
state_dict.insert("linear.bias".to_string(), bias);
// Query the state dictionary
assert!(state_dict.contains("linear.weight"));
println!("{}", state_dict.summary());
// Filter by prefix
let linear_params = state_dict.filter_prefix("linear.");
// Strip prefix from keys
let stripped = state_dict.strip_prefix("linear.");
assert!(stripped.contains("weight"));
use axonml_serialize::{Checkpoint, TrainingState, save_checkpoint, load_checkpoint};
// Track training state
let mut training_state = TrainingState::new();
training_state.record_loss(0.5);
training_state.record_loss(0.3);
training_state.update_best("loss", 0.3, false); // lower is better
training_state.next_epoch();
training_state.next_step();
// Create checkpoint with builder pattern
let checkpoint = Checkpoint::builder()
.model_state(model_state_dict)
.optimizer_state(optimizer_state_dict)
.training_state(training_state)
.epoch(10)
.global_step(5000)
.config("learning_rate", "0.001")
.config("batch_size", "32")
.build();
// Save and load checkpoints
save_checkpoint(&checkpoint, "checkpoint.ckpt")?;
let loaded = load_checkpoint("checkpoint.ckpt")?;
println!("Resuming from epoch {}", loaded.epoch());
println!("Best metric: {:?}", loaded.best_metric());
use axonml_serialize::{detect_format, detect_format_from_bytes, Format};
// Detect from file extension
let format = detect_format("model.json");
assert_eq!(format, Format::Json);
let format = detect_format("model.safetensors");
assert_eq!(format, Format::SafeTensors);
// Detect from file contents
let bytes = b"{\"key\": \"value\"}";
let format = detect_format_from_bytes(bytes);
assert_eq!(format, Some(Format::Json));
// Format properties
assert!(Format::Axonml.is_binary());
assert!(!Format::Json.is_binary());
use axonml_serialize::{from_pytorch_key, convert_from_pytorch, transpose_linear_weights};
// Convert PyTorch key naming to AxonML
let key = from_pytorch_key("module.layer1.weight");
assert_eq!(key, "layer1.weight");
// Convert entire state dictionary
let axonml_dict = convert_from_pytorch(&pytorch_dict);
// Transpose linear weights if needed (PyTorch uses [out, in])
let transposed = transpose_linear_weights(&weight_data);
use axonml_serialize::{to_onnx_shape, from_onnx_shape, OnnxOpType};
// Convert to ONNX shape (with dynamic batch)
let onnx_shape = to_onnx_shape(&[3, 224, 224], true);
assert_eq!(onnx_shape, vec![-1, 3, 224, 224]);
// Convert from ONNX shape (replace -1 with default)
let shape = from_onnx_shape(&[-1, 3, 224, 224], 1);
assert_eq!(shape, vec![1, 3, 224, 224]);
// ONNX operator type mapping
let op = OnnxOpType::from_str("Relu");
assert_eq!(op.as_str(), "Relu");
use axonml_serialize::StateDict;
let mut state_dict = StateDict::new();
state_dict.set_metadata("framework_version", "0.1.0");
state_dict.set_metadata("model_architecture", "ResNet50");
if let Some(version) = state_dict.get_metadata("framework_version") {
println!("Saved with version: {}", version);
}
Run the test suite:
cargo test -p axonml-serialize
Licensed under either of:
at your option.