use serde_json::value::Value; use std::cell::Cell; use std::collections::HashMap; use std::fs; use std::path::PathBuf; use asl::asl::execution::{ ExecutionStatus, ExecutionStatusDiscriminants, StateMachineExecutionError, }; use asl::asl::execution::{StateExecutionHandler, StateExecutionOutput}; use asl::asl::state_machine::StateMachine; use similar_asserts::assert_eq; struct TestStateExecutionHandler { resource_name_to_output: HashMap, times_called: HashMap>, } impl TestStateExecutionHandler { fn new() -> TestStateExecutionHandler { TestStateExecutionHandler { resource_name_to_output: hash_map![], times_called: hash_map![], } } fn with_map( resource_name_to_output: HashMap, ) -> TestStateExecutionHandler { let map = resource_name_to_output .keys() .map(|k| (k.to_string(), Cell::new(0))) .collect(); TestStateExecutionHandler { resource_name_to_output, times_called: map, } } } #[derive(Error, Debug)] enum MyTaskExecutionError { #[error("{0}")] ForwardedError(String), } impl StateExecutionHandler for TestStateExecutionHandler { type TaskExecutionError = MyTaskExecutionError; fn execute_task( &self, resource: &str, input: &Value, _credentials: Option<&Value>, ) -> Result, Self::TaskExecutionError> { let option = self.resource_name_to_output.get(resource); match option { None => Ok(Some(input.clone())), // resource is not mapped, so just forwards the input Some(desired_outputs) => { let times_called = self.times_called.get(resource).unwrap(); let index = times_called.get(); let desired_output = match desired_outputs { TaskResults::Repeat(o) => o, TaskResults::List(vec) => vec.get(index).unwrap_or_else(|| { panic!( "Task called {} times, but there are only {} expected results.", index + 1, vec.len() ) }), }; times_called.set(index + 1); match desired_output { TaskBehavior::Output(val) => Ok(Some(val.to_owned())), // resource is mapped, so returns the desired output TaskBehavior::Error(err) => { Err(MyTaskExecutionError::ForwardedError(err.clone())) } } } } } fn wait(&self, _seconds: f64) { //nop on purpose (sleeping a thread is bad for tests). } } use asl::asl::execution::ExecutionStatus::FinishedWithFailure; use asl::asl::states::error_handling::StateMachineExecutionPredefinedErrors; use asl::asl::types::execution::{EmptyContext, StateMachineContext}; use itertools::Itertools; use map_macro::hash_map; use rstest::*; use serde_with::serde_derive::Deserialize; use testresult::TestResult; use thiserror::Error; use wildmatch::WildMatch; #[rstest] fn execute_hello_world_succeed_state() -> TestResult { let definition = include_str!("test-data/hello-world-succeed-state.json"); let state_machine = StateMachine::parse(definition)?; let input = serde_json::from_str( r#" "Hello world" "#, )?; let val = Value::from("Hello world"); let mut execution = state_machine.start(&input, TestStateExecutionHandler::new(), EmptyContext {}); assert_eq!(ExecutionStatus::Executing, execution.status); // Advance state let state_output = execution.next(); assert_eq!( state_output, Some(StateExecutionOutput { status: ExecutionStatus::Executing, state_name: Some("Hello World".to_string()), result: Some(val.clone()) }) ); assert_eq!(ExecutionStatus::Executing, execution.status); // Advance state let state_output = execution.next(); assert_eq!( state_output, Some(StateExecutionOutput { status: ExecutionStatus::FinishedWithSuccess(Some(val.clone())), state_name: Some("Succeed State".to_string()), result: Some(val.clone()) }) ); assert_eq!( ExecutionStatus::FinishedWithSuccess(Some(val.clone())), execution.status ); // Iterator is exhausted assert_eq!(None, execution.next()); assert_eq!( ExecutionStatus::FinishedWithSuccess(Some(val.clone())), execution.status ); Ok(()) } #[rstest] fn execute_hello_world_fail_state() -> TestResult { let definition = include_str!("test-data/hello-world-fail-state.json"); let state_machine = StateMachine::parse(definition)?; let val = serde_json::from_str( r#" "Hello world" "#, )?; let mut execution = state_machine.start(&val, TestStateExecutionHandler::new(), EmptyContext {}); assert_eq!(ExecutionStatus::Executing, execution.status); // Advance state let state_output = execution.next(); assert_eq!( state_output, Some(StateExecutionOutput { status: ExecutionStatus::Executing, state_name: Some("Hello World".to_string()), result: Some(val.clone()) }) ); assert_eq!(ExecutionStatus::Executing, execution.status); // Advance state let state_output = execution.next(); let expected_status = with_error_and_cause("ErrorA", "Kaiju attack"); assert_eq!( state_output, Some(StateExecutionOutput { status: expected_status.clone(), state_name: Some("Fail State".to_string()), result: None, }) ); assert_eq!(expected_status, execution.status); // Iterator is exhausted assert_eq!(None, execution.next()); assert_eq!(expected_status, execution.status); Ok(()) } pub fn with_error_and_cause(error: &str, cause: &str) -> ExecutionStatus { FinishedWithFailure(StateMachineExecutionError { error: StateMachineExecutionPredefinedErrors::Custom(error.to_string()), cause: Some(String::from(cause)), }) } pub fn with_success_and_output(output: &str) -> ExecutionStatus { let val = serde_json::from_str(output).expect("Invalid json specified"); ExecutionStatus::FinishedWithSuccess(val) } #[derive(Deserialize, Clone, PartialEq, Debug)] #[serde(rename_all = "snake_case")] enum ExpectedFinalStatus { Output(Value), #[serde(rename_all = "PascalCase")] Error { error: StateMachineExecutionPredefinedErrors, cause: Option, }, } /// Controls what the Task will do in the test cases. /// If it finds an `Output` key, then it will forward the output and *succeed* the task /// If it finds an `Error` key, then it will error with the string provided and *fail* the task. #[derive(Deserialize, Clone, PartialEq, Debug)] #[serde(rename_all = "snake_case")] enum TaskBehavior { Output(Value), Error(String), } #[derive(Deserialize, Clone, PartialEq, Debug)] #[serde(untagged)] enum TaskResults { Repeat(TaskBehavior), List(Vec), } #[derive(Deserialize, Debug)] struct ExpectedExecution { input: Value, #[serde(flatten)] final_status: ExpectedFinalStatus, states: Vec, task_behavior: Option>, context: Option>, } impl From for ExecutionStatus { fn from(value: ExpectedFinalStatus) -> Self { match value { ExpectedFinalStatus::Output(val) => ExecutionStatus::FinishedWithSuccess(Some(val)), ExpectedFinalStatus::Error { error, cause } => { FinishedWithFailure(StateMachineExecutionError { error, cause }) } } } } #[derive(Debug)] pub struct MapContext { map: HashMap, current_state_name: String, } impl MapContext { fn new(map: HashMap) -> Self { MapContext { map, current_state_name: "".to_string(), } } } impl StateMachineContext for MapContext { fn as_value(&self) -> Value { self.map .get(&self.current_state_name) .cloned() .unwrap_or(Value::Null) } fn transition_to_state(&mut self, state: &str) { self.current_state_name = String::from(state); } } #[rstest] fn execute_all( #[files("**/test-data/expected-executions-valid-cases/valid-*.json5")] // TODO: Support Map state #[exclude("valid-map.*")] // TODO: Support Parallel state #[exclude("valid-parallel.*")] #[exclude("valid-parameters-resultSelector.*")] // TODO: Support Intrinsic Functions #[exclude("valid-fail-paths\\.json.*")] #[exclude("valid-intrinsic-functions.*")] // TODO: Support negative index #[exclude("valid-pass-negativeIndex.*")] // TODO: Support retries #[exclude("valid-retry-failure.*")] path: PathBuf, ) -> TestResult { let all_expected_executions: Vec = json5::from_str(&fs::read_to_string(&path)?)?; let state_machine_definition_filename = path.with_extension("json"); let state_machine_definition_filename = state_machine_definition_filename.file_name().unwrap(); let definition = fs::read_to_string(format!( "tests/test-data/asl-validator/{}", state_machine_definition_filename.to_str().unwrap() ))?; let state_machine = StateMachine::parse(&definition)?; dbg!("Parsed state machine:", &state_machine); // The loop is for each test case contained within the JSON file for (i, execution_expected_input) in all_expected_executions.iter().enumerate() { let input = &execution_expected_input.input; let map = execution_expected_input .task_behavior .clone() .unwrap_or(hash_map![]); let handler = TestStateExecutionHandler::with_map(map); let context = MapContext::new( execution_expected_input .context .clone() .unwrap_or(HashMap::new()), ); dbg!("State machine execution", i); dbg!("Input:", input); dbg!("Context:", &context); let execution = state_machine.start(input, handler, context); // Execute the state machine and collect all steps so they can be compared let execution_steps = execution.collect_vec(); for step in &execution_steps { println!("{:?}", step); } // Compare the states seen by the execution let actual_states = &execution_steps .iter() .map(|e| e.state_name.as_ref().unwrap_or(&String::new()).clone()) .collect_vec(); let expected_states = &execution_expected_input.states; assert_eq!( expected_states, actual_states, "States are different for test case {i}." ); // Compare final status let expected_status = &ExecutionStatus::from(execution_expected_input.final_status.clone()); let actual_status = &execution_steps.last().unwrap().status; assert_eq!( ExecutionStatusDiscriminants::from(expected_status), ExecutionStatusDiscriminants::from(actual_status), "Final execution status are different for test case {i}." ); // When it's a failure, then we treat the "Cause" as a wildcard so that error messages can // be easily matched. if let FinishedWithFailure(expected) = expected_status { // TODO: use if-let chains, but needs https://github.com/rust-lang/rust/issues/53667 first if let FinishedWithFailure(actual) = actual_status { assert_eq!(expected.error, actual.error); assert_eq!(expected.cause.is_some(), actual.cause.is_some()); if expected.cause.is_some() { let expected_cause = expected.cause.as_ref().unwrap(); let actual_cause = actual.cause.as_ref().unwrap(); let matches = WildMatch::new(&expected_cause).matches(&actual_cause); assert!( matches, "Expected wildcard expression for Cause '{}' to match the actual Cause: {}", expected_cause, actual_cause ); } } else { panic!("wtf"); } } else { assert_eq!( expected_status, actual_status, "Status are different for test case {i}." ); } } Ok(()) }