Crates.io | mpnet-rs |
lib.rs | mpnet-rs |
version | 0.1.3 |
source | src |
created_at | 2024-02-09 11:06:43.320614 |
updated_at | 2024-04-02 05:50:58.469943 |
description | This is a translation of MPNet from PyTorch into Rust Candle. |
homepage | |
repository | https://github.com/NewBornRustacean/mpnet-rs |
max_upload_size | |
id | 1133831 |
size | 44,237 |
This is a translation of MPNet from PyTorch into Rust Candle.
use mpnet_rs::mpnet::load_model;
let (model, tokenizer, pooler) = load_model("/path/to/model/and/tokenizer").unwrap();
this is about how to get embeddings and consine similarity
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{VarBuilder, Module};
use mpnet_rs::mpnet::{MPNetEmbeddings, MPNetConfig, create_position_ids_from_input_ids, cumsum, load_model, get_embeddings, normalize_l2, PoolingConfig, MPNetPooler};
fn test_get_embeddings() ->Result<()>{
let path_to_checkpoints_folder = "D:/RustWorkspace/checkpoints/AI-Growth-Lab_PatentSBERTa".to_string();
let (model, mut tokenizer, pooler) = load_model(path_to_checkpoints_folder).unwrap();
let sentences = vec![
"an invention that targets GLP-1",
"new chemical that targets glucagon like peptide-1 ",
"de novo chemical that targets GLP-1",
"invention about GLP-1 receptor",
"new chemical synthesis for glp-1 inhibitors",
"It feels like I'm in America",
"It's rainy. all day long.",
];
let n_sentences = sentences.len();
let embeddings = get_embeddings(&model, &tokenizer, Some(&pooler), &sentences).unwrap();
let l2norm_embeds = normalize_l2(&embeddings).unwrap();
println!("pooled embeddings {:?}", l2norm_embeds.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = l2norm_embeds.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = l2norm_embeds.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
Ok(())
}