use backoff::{retry, ExponentialBackoff}; use pretty_assertions::assert_eq; use reinfer_client::{ resources::dataset::DatasetFlag, Dataset, EntityDef, EntityName, LabelDef, LabelDefPretrained, LabelDefPretrainedId, LabelGroup, LabelGroupName, LabelName, MoonFormFieldDef, Source, }; use serde_json::json; use uuid::Uuid; use crate::{TestCli, TestSource}; pub struct TestDataset { full_name: String, sep_index: usize, } impl TestDataset { pub fn new() -> Self { let cli = TestCli::get(); let user = TestCli::project(); let full_name = format!("{}/test-dataset-{}", user, Uuid::new_v4()); let sep_index = user.len(); let output = cli.run(["create", "dataset", &full_name]); assert!(output.contains(&full_name)); Self { full_name, sep_index, } } pub fn new_args(args: &[&str]) -> Self { let cli = TestCli::get(); let user = TestCli::project(); let full_name = format!("{}/test-dataset-{}", user, Uuid::new_v4()); let sep_index = user.len(); let output = cli.run(["create", "dataset", &full_name].iter().chain(args)); assert!(output.contains(&full_name)); Self { full_name, sep_index, } } pub fn identifier(&self) -> &str { &self.full_name } pub fn owner(&self) -> &str { &self.full_name[..self.sep_index] } pub fn name(&self) -> &str { &self.full_name[self.sep_index + 1..] } } impl Drop for TestDataset { fn drop(&mut self) { let delete_dataset_command = || { TestCli::get() .run_and_result(["delete", "dataset", self.identifier()]) .map_err(backoff::Error::transient) }; retry(ExponentialBackoff::default(), delete_dataset_command).unwrap(); } } #[test] fn test_test_dataset() { let cli = TestCli::get(); let dataset = TestDataset::new(); let identifier = dataset.identifier().to_owned(); let output = cli.run(["get", "datasets"]); assert!(output.contains(&identifier)); drop(dataset); // RAII TestDataset; should automatically clean up the temporary dataset on drop. let output = cli.run(["get", "datasets"]); assert!(!output.contains(&identifier)); } #[test] fn test_list_multiple_datasets() { let cli = TestCli::get(); let dataset1 = TestDataset::new(); let dataset2 = TestDataset::new(); let output = cli.run(["get", "datasets"]); assert!(output.contains(dataset1.identifier())); assert!(output.contains(dataset2.identifier())); let output = cli.run(["get", "datasets", dataset1.identifier()]); assert!(output.contains(dataset1.identifier())); assert!(!output.contains(dataset2.identifier())); let output = cli.run(["get", "datasets", dataset2.identifier()]); assert!(!output.contains(dataset1.identifier())); assert!(output.contains(dataset2.identifier())); } #[test] fn test_create_update_dataset_custom() { let cli = TestCli::get(); let dataset = TestDataset::new_args(&[ "--title=some title", "--description=some description", "--has-sentiment=true", "--entity-defs", &json!( [ { "name": "ent", "title": "A magic tree", "inherits_from": [], "trainable": false, } ] ) .to_string(), "--label-defs", &json!( [ { "name": "bar", }, { "name": "foo", "instructions": "Long label description", "external_id": "ext id", "title": "Title Me", "pretrained": { "id": "0000000000000001", "name": "Autogenerated", }, "moon_form": [{"name": "luna", "kind": "ent", "type": "string", "required": false}], } ] ) .to_string(), ]); /// A subset of source fields that we can easily check for equality accross #[derive(PartialEq, Eq, Debug)] struct DatasetInfo { owner: String, name: String, title: String, description: String, has_sentiment: bool, source_ids: Vec, entity_defs: Vec, label_defs: Vec, label_groups: Vec, } impl From for DatasetInfo { fn from(dataset: Dataset) -> DatasetInfo { DatasetInfo { owner: dataset.owner.0, name: dataset.name.0, title: dataset.title, description: dataset.description, has_sentiment: dataset.has_sentiment, source_ids: dataset.source_ids.into_iter().map(|id| id.0).collect(), entity_defs: dataset.entity_defs.into_iter().map(Into::into).collect(), label_defs: dataset.label_defs, label_groups: dataset.label_groups, } } } /// A subset of fields that we can easily check for equality accross #[derive(PartialEq, Eq, Debug)] struct EntityDefInfo { pub color: u32, pub name: EntityName, pub title: String, pub trainable: bool, } impl From for EntityDefInfo { fn from(value: EntityDef) -> Self { let EntityDef { color, name, title, trainable, .. } = value; Self { color, name, title, trainable, } } } let get_dataset_info = || -> DatasetInfo { let output = cli.run(["--output=json", "get", "datasets", dataset.identifier()]); serde_json::from_str::(&output).unwrap().into() }; let mut expected_dataset_info = DatasetInfo { owner: dataset.owner().to_owned(), name: dataset.name().to_owned(), title: "some title".to_owned(), description: "some description".to_owned(), has_sentiment: true, source_ids: vec![], entity_defs: vec![EntityDefInfo { color: 0, name: EntityName("ent".to_owned()), title: "A magic tree".to_owned(), trainable: false, }], label_defs: vec![ LabelDef { name: LabelName("bar".to_owned()), instructions: "".to_owned(), external_id: None, pretrained: None, title: "".to_owned(), moon_form: None, }, LabelDef { name: LabelName("foo".to_owned()), instructions: "Long label description".to_owned(), external_id: Some("ext id".to_owned()), pretrained: Some(LabelDefPretrained { id: LabelDefPretrainedId("0000000000000001".to_owned()), name: LabelName("Autogenerated".to_owned()), }), title: "Title Me".to_owned(), moon_form: Some(vec![MoonFormFieldDef { name: "luna".to_owned(), kind: "ent".to_owned(), }]), }, ], label_groups: vec![LabelGroup { name: LabelGroupName("default".to_owned()), label_defs: vec![ LabelDef { name: LabelName("bar".to_owned()), instructions: "".to_owned(), external_id: None, pretrained: None, title: "".to_owned(), moon_form: None, }, LabelDef { name: LabelName("foo".to_owned()), instructions: "Long label description".to_owned(), external_id: Some("ext id".to_owned()), pretrained: Some(LabelDefPretrained { id: LabelDefPretrainedId("0000000000000001".to_owned()), name: LabelName("Autogenerated".to_owned()), }), title: "Title Me".to_owned(), moon_form: Some(vec![MoonFormFieldDef { name: "luna".to_owned(), kind: "ent".to_owned(), }]), }, ], }], }; assert_eq!(get_dataset_info(), expected_dataset_info); // Partial update cli.run([ "update", "dataset", "--title=updated title", dataset.identifier(), ]); "updated title".clone_into(&mut expected_dataset_info.title); assert_eq!(get_dataset_info(), expected_dataset_info); // Should be able to update all fields let test_source = TestSource::new(); let source = test_source.get(); cli.run([ "update", "dataset", "--title=updated title", "--description=updated description", &format!("--source={}", source.id.0), dataset.identifier(), ]); "updated title".clone_into(&mut expected_dataset_info.title); "updated description".clone_into(&mut expected_dataset_info.description); expected_dataset_info.source_ids = vec![source.id.0]; assert_eq!(get_dataset_info(), expected_dataset_info); // An empty update should be fine, including leaving source ids untouched cli.run(["update", "dataset", dataset.identifier()]); assert_eq!(get_dataset_info(), expected_dataset_info); // Setting the sources flag with no ids should clear sources cli.run(["update", "dataset", dataset.identifier(), "--source"]); expected_dataset_info.source_ids = vec![]; assert_eq!(get_dataset_info(), expected_dataset_info); } #[test] fn test_create_dataset_with_source() { let cli = TestCli::get(); let source = TestSource::new(); let dataset = TestDataset::new_args(&[&format!("--source={}", source.identifier())]); let output = cli.run(["--output=json", "get", "datasets", dataset.identifier()]); let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_info.owner.0, dataset.owner()); assert_eq!(&dataset_info.name.0, dataset.name()); assert_eq!(dataset_info.source_ids.len(), 1); let source_output = cli.run([ "--output=json", "get", "sources", &dataset_info.source_ids.first().unwrap().0, ]); let source_info: Source = serde_json::from_str(source_output.trim()).unwrap(); assert_eq!(&source_info.owner.0, source.owner()); assert_eq!(&source_info.name.0, source.name()); } #[test] fn test_create_dataset_with_gen_ai() { let cli = TestCli::get(); // Run with false ellm Flag let dataset_gen_ai_false = TestDataset::new_args(&[&format!("--gen-ai={}", false)]); let output = cli.run([ "--output=json", "get", "datasets", dataset_gen_ai_false.identifier(), ]); let dataset_gen_ai_false_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!( &dataset_gen_ai_false_info.owner.0, dataset_gen_ai_false.owner() ); assert_eq!( &dataset_gen_ai_false_info.name.0, dataset_gen_ai_false.name() ); assert!(!dataset_gen_ai_false_info .dataset_flags .contains(&DatasetFlag::Gpt4)); // Run with true gen_ai Flag let dataset_gen_ai = TestDataset::new_args(&[&format!("--gen-ai={}", true)]); let output = cli.run([ "--output=json", "get", "datasets", dataset_gen_ai.identifier(), ]); let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_info.owner.0, dataset_gen_ai.owner()); assert_eq!(&dataset_info.name.0, dataset_gen_ai.name()); assert!(dataset_info.dataset_flags.contains(&DatasetFlag::Gpt4)); } #[test] fn test_create_dataset_with_zero_shot() { let cli = TestCli::get(); // Run with false ellm Flag let dataset_zero_shot_false = TestDataset::new_args(&[&format!("--zero-shot={}", false)]); let output = cli.run([ "--output=json", "get", "datasets", dataset_zero_shot_false.identifier(), ]); let dataset_zero_shot_false_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!( &dataset_zero_shot_false_info.owner.0, dataset_zero_shot_false.owner() ); assert_eq!( &dataset_zero_shot_false_info.name.0, dataset_zero_shot_false.name() ); assert!(!dataset_zero_shot_false_info .dataset_flags .contains(&DatasetFlag::ZeroShotLabels)); // Run with true zero_shot Flag let dataset_zero_shot = TestDataset::new_args(&[&format!("--zero-shot={}", true), "--gen-ai=true"]); let output = cli.run([ "--output=json", "get", "datasets", dataset_zero_shot.identifier(), ]); let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_info.owner.0, dataset_zero_shot.owner()); assert_eq!(&dataset_info.name.0, dataset_zero_shot.name()); assert!(dataset_info .dataset_flags .contains(&DatasetFlag::ZeroShotLabels)); } #[test] fn test_create_dataset_with_external_llm() { let cli = TestCli::get(); // Run with false ellm Flag let dataset_ellm_false = TestDataset::new_args(&[&format!("--external-llm={}", false)]); let output = cli.run([ "--output=json", "get", "datasets", dataset_ellm_false.identifier(), ]); let dataset_ellm_false_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_ellm_false_info.owner.0, dataset_ellm_false.owner()); assert_eq!(&dataset_ellm_false_info.name.0, dataset_ellm_false.name()); assert!(!dataset_ellm_false_info .dataset_flags .contains(&DatasetFlag::ExternalMoonLlm)); // Run with true ellm Flag let dataset_ellm = TestDataset::new_args(&[&format!("--external-llm={}", true), "--gen-ai=true"]); let output = cli.run([ "--output=json", "get", "datasets", dataset_ellm.identifier(), ]); let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_info.owner.0, dataset_ellm.owner()); assert_eq!(&dataset_info.name.0, dataset_ellm.name()); assert!(dataset_info .dataset_flags .contains(&DatasetFlag::ExternalMoonLlm)); } #[test] fn test_create_dataset_with_no_flags() { let cli = TestCli::get(); // Run with no QoS Flag let dataset_qos_none = TestDataset::new_args(&[]); let output = cli.run([ "--output=json", "get", "datasets", dataset_qos_none.identifier(), ]); let dataset_qos_none_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_qos_none_info.owner.0, dataset_qos_none.owner()); assert_eq!(&dataset_qos_none_info.name.0, dataset_qos_none.name()); assert!(!dataset_qos_none_info .dataset_flags .contains(&DatasetFlag::Qos)); assert!(!dataset_qos_none_info .dataset_flags .contains(&DatasetFlag::ExternalMoonLlm)); assert!(!dataset_qos_none_info .dataset_flags .contains(&DatasetFlag::Gpt4)); assert!(!dataset_qos_none_info .dataset_flags .contains(&DatasetFlag::ZeroShotLabels)); } #[test] fn test_create_dataset_with_qos() { let cli = TestCli::get(); // Run with false QoS Flag let dataset_qos_false = TestDataset::new_args(&[&format!("--qos={}", false)]); let output = cli.run([ "--output=json", "get", "datasets", dataset_qos_false.identifier(), ]); let dataset_qos_false_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_qos_false_info.owner.0, dataset_qos_false.owner()); assert_eq!(&dataset_qos_false_info.name.0, dataset_qos_false.name()); assert!(!dataset_qos_false_info .dataset_flags .contains(&DatasetFlag::Qos)); // Run with true QoS Flag let dataset_qos = TestDataset::new_args(&[&format!("--qos={}", true)]); let output = cli.run(["--output=json", "get", "datasets", dataset_qos.identifier()]); let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_info.owner.0, dataset_qos.owner()); assert_eq!(&dataset_info.name.0, dataset_qos.name()); assert!(dataset_info.dataset_flags.contains(&DatasetFlag::Qos)); } #[test] fn test_create_dataset_requires_owner() { let cli = TestCli::get(); let output = cli .command() .args(["create", "dataset", "dataset-without-owner"]) .output() .unwrap(); assert!(!output.status.success()); } #[test] fn test_create_dataset_model_family() { let cli = TestCli::get(); let dataset = TestDataset::new_args(&["--model-family==german"]); let output = cli.run(["--output=json", "get", "datasets", dataset.identifier()]); let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); assert_eq!(&dataset_info.owner.0, dataset.owner()); assert_eq!(&dataset_info.name.0, dataset.name()); assert_eq!(&dataset_info.model_family.0, "german"); } #[test] fn test_create_dataset_wrong_model_family() { let cli = TestCli::get(); let output = cli .command() .args([ "create", "dataset", "--model-family==non-existent-family", &format!("{}/test-dataset-{}", TestCli::project(), Uuid::new_v4()), ]) .output() .unwrap(); assert!(!output.status.success()); assert!(String::from_utf8_lossy(&output.stderr) .contains("API request failed with 400 Bad Request: Invalid request - Unsupported model family: non-existent-family")) } #[test] fn test_create_dataset_copy_annotations() { let cli = TestCli::get(); let dataset1 = TestDataset::new(); let dataset1_output = cli.run(["--output=json", "get", "datasets", dataset1.identifier()]); let dataset1_info: Dataset = serde_json::from_str(dataset1_output.trim()).unwrap(); let output = cli .command() .args([ "create", "dataset", &format!("--copy-annotations-from={}", dataset1_info.id.0), &format!("{}/test-dataset-{}", TestCli::project(), Uuid::new_v4()), ]) .output() .unwrap(); assert!(output.status.success()); }