use anyhow::Result; use clap::{Parser, ValueEnum}; use itertools::Itertools; use ltp::perceptron::SerdeCWSModel; use ltp::{Algorithm, CWSDefinition as Definition, Codec, Format, ModelSerde, PaMode, Trainer}; use std::collections::HashMap; use std::fs::File; use std::io::{BufRead, BufReader, Write}; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] enum Args { Train(Train), Eval(Eval), Predict(Predict), } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] enum AlgorithmArg { Ap, Pa, PaI, PaII, } #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Train { #[clap(long, value_parser, default_value_t = 10)] epoch: usize, #[clap(short, long, value_parser, default_value_t = true)] shuffle: bool, #[clap(value_enum, value_parser, default_value_t = AlgorithmArg::Ap)] algorithm: AlgorithmArg, #[clap(long, value_parser, default_value_t = 8)] ap_threads: usize, #[clap(long, value_parser, default_value_t = 0.5)] pa_margin: f64, #[clap(long, value_parser, default_value_t = 8)] eval_threads: usize, // 模型压缩参数 #[clap(short, long, value_parser, default_value_t = true)] compress: bool, #[clap(long, value_parser, default_value_t = 0.3)] ratio: f64, #[clap(long, value_parser, default_value_t = 1e-3)] threshold: f64, // 数据集 #[clap(short, long)] train: String, #[clap(short, long)] eval: String, // 模型保存 #[clap(short, long, default_value = "model.bin")] model: String, } #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Eval { #[clap(short, long, value_parser, default_value_t = true)] verbose: bool, #[clap(short, long, value_parser, default_value_t = 8)] threads: usize, #[clap(short, long)] eval: String, // 模型保存 #[clap(short, long, default_value = "model.bin")] model: String, } #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Predict { #[clap(short, long, value_parser, default_value_t = true)] verbose: bool, #[clap(short, long, value_parser, default_value_t = 8)] threads: usize, #[clap(short, long)] input: String, #[clap(short, long)] output: String, // 模型保存 #[clap(short, long, default_value = "model.bin")] model: String, } fn main() -> Result<()> { let args = Args::parse(); match args { Args::Train(mode) => { let algorithm = match mode.algorithm { AlgorithmArg::Ap => Algorithm::AP(mode.ap_threads), AlgorithmArg::Pa => Algorithm::PA(PaMode::Pa), AlgorithmArg::PaI => Algorithm::PA(PaMode::PaI(mode.pa_margin)), AlgorithmArg::PaII => Algorithm::PA(PaMode::PaII(mode.pa_margin)), }; let trainer = Trainer::new() .definition(Definition::default()) .epoch(mode.epoch) .algorithm(algorithm) .shuffle(mode.shuffle) .compress(mode.compress) .ratio(mode.ratio) .threshold(mode.threshold) .train_file(mode.train)? .eval_file(mode.eval)? .display(); let model = trainer.build::, Vec<_>>()?; let file = File::create(&mode.model)?; let format = if mode.model.ends_with(".json") { Format::JSON } else { Format::AVRO(Codec::Deflate) }; model.save(file, format)?; } Args::Eval(mode) => { let file = File::open(&mode.model)?; let format = if mode.model.ends_with(".json") { Format::JSON } else { Format::AVRO(Codec::Deflate) }; let model: SerdeCWSModel = ModelSerde::load(file, format)?; let start = std::time::Instant::now(); let trainer = Trainer::new() .definition(model.definition.clone()) .verbose(mode.verbose) .eval_threads(mode.threads) .eval_file(mode.eval)?; let (p, r, f1) = trainer.evaluate(&model)?; let duration = start.elapsed().as_millis(); println!("[{duration}ms] precision: {p}, recall: {r}, f1: {f1}",); } Args::Predict(mode) => { let file = File::open(&mode.model)?; let format = if mode.model.ends_with(".json") { Format::JSON } else { Format::AVRO(Codec::Deflate) }; let model: SerdeCWSModel = ModelSerde::load(file, format)?; let file = File::open(mode.input)?; let lines = BufReader::new(file).lines(); let datasets = lines.flatten().filter(|s| !s.is_empty()).collect_vec(); let start = std::time::Instant::now(); let result: Result>> = datasets .iter() .map(|sentence| model.predict(sentence)) .collect(); let duration = start.elapsed(); println!("{}ms", duration.as_millis()); let mut file = File::create(mode.output)?; result?.iter().for_each(|sentence| { writeln!(file, "{}", sentence.join(" ")).expect("Write Failed!"); }); } } Ok(()) }