use futures::{select, FutureExt}; use krossbar_common_rpc::{request::Body, rpc::Rpc}; use tokio::{io::AsyncWriteExt, net::UnixStream}; const ENDPOINT_NAME: &str = "test_function"; #[tokio::test] async fn test_simple_call() { let _ = pretty_env_logger::formatted_builder() .filter_level(log::LevelFilter::Debug) .try_init(); let (stream1, stream2) = UnixStream::pair().unwrap(); let mut rpc1 = Rpc::new(stream1, "rpc1"); let mut rpc2 = Rpc::new(stream2, "rpc2"); let call = rpc1.call::(ENDPOINT_NAME, &42).await.unwrap(); // Poll the stream to receive the request let mut request: krossbar_common_rpc::request::RpcRequest = rpc2.poll().await.unwrap(); assert_eq!(request.endpoint(), ENDPOINT_NAME); if let Some(Body::Call(bson)) = request.take_body() { let request_body: u32 = bson::from_bson(bson).unwrap(); assert_eq!(request_body, 42); assert_eq!(request.peer_name(), "rpc2"); } else { panic!("Invalid message type") } assert!(request.respond(Ok(420)).await); select! { response = call.fuse() => { assert_eq!(response.unwrap(), 420); }, _ = rpc1.poll().fuse() => {} } } #[tokio::test] async fn test_call_reconnect() { let (stream1, stream2) = UnixStream::pair().unwrap(); let mut rpc1 = Rpc::new(stream1, "rpc"); { let mut rpc2 = Rpc::new(stream2, "rpc"); let call = rpc1.call::(ENDPOINT_NAME, &42).await.unwrap(); // Poll the stream to receive the request let mut request = rpc2.poll().await.unwrap(); if let Some(Body::Call(bson)) = request.take_body() { let request_body: u32 = bson::from_bson(bson).unwrap(); assert_eq!(request_body, 42); } else { panic!("Invalid message type") } assert!(request.respond(Ok(420)).await); select! { response = call.fuse() => { assert_eq!(response.unwrap(), 420); }, _ = rpc1.poll().fuse() => {} } } let (stream1, stream3) = UnixStream::pair().unwrap(); { rpc1.on_reconnected(Rpc::new(stream1, "rpc")).await; let mut rpc3 = Rpc::new(stream3, "rpc"); let call = rpc1.call::(ENDPOINT_NAME, &42).await.unwrap(); // Poll the stream to receive the request let mut request = rpc3.poll().await.unwrap(); assert_eq!(request.endpoint(), ENDPOINT_NAME); if let Some(Body::Call(bson)) = request.take_body() { let request_body: u32 = bson::from_bson(bson).unwrap(); assert_eq!(request_body, 42); } else { panic!("Invalid message type") } assert!(request.respond(Ok(421)).await); select! { response = call.fuse() => { assert_eq!(response.unwrap(), 421); }, _ = rpc1.poll().fuse() => {} } } } #[tokio::test] async fn test_bson_param_error() { let (stream1, stream2) = UnixStream::pair().unwrap(); let rpc1 = Rpc::new(stream1, "rpc"); let _ = Rpc::new(stream2, "rpc"); // Try to send u64::MAX, which BSON doesn't support. It has only i64 let call = rpc1.call::(ENDPOINT_NAME, &u64::MAX).await; assert!(matches!( call, Err(krossbar_common_rpc::Error::ParamsTypeError(_)) )) } #[tokio::test] async fn test_result_type_error() { let (stream1, stream2) = UnixStream::pair().unwrap(); let mut rpc1 = Rpc::new(stream1, "rpc"); let mut rpc2 = Rpc::new(stream2, "rpc"); let call = rpc1.call::(ENDPOINT_NAME, &42).await.unwrap(); // Poll the stream to receive the request let request = rpc2.poll().await.unwrap(); assert!(request.respond(Ok(420)).await); select! { response = call.fuse() => { assert!(matches!( response, Err(krossbar_common_rpc::Error::ResultTypeError(_)) )) }, _ = rpc1.poll().fuse() => {} } } #[tokio::test] async fn test_client_disconnected_error() { let (mut stream1, mut stream2) = UnixStream::pair().unwrap(); stream1.shutdown().await.unwrap(); stream2.shutdown().await.unwrap(); let rpc1 = Rpc::new(stream1, "rpc"); // Try to send u64::MAX, which BSON doesn't support. It has only i64 let call = rpc1.call::(ENDPOINT_NAME, &42).await; assert!(matches!( call, Err(krossbar_common_rpc::Error::PeerDisconnected) )) } #[tokio::test] async fn test_user_error() { let _ = pretty_env_logger::formatted_builder() .filter_level(log::LevelFilter::Debug) .try_init(); let (stream1, stream2) = UnixStream::pair().unwrap(); let mut rpc1 = Rpc::new(stream1, "rpc"); let mut rpc2 = Rpc::new(stream2, "rpc"); let call = rpc1.call::(ENDPOINT_NAME, &42).await.unwrap(); // Poll the stream to receive the request let mut request = rpc2.poll().await.unwrap(); assert_eq!(request.endpoint(), ENDPOINT_NAME); // Test formatting here assert_eq!( format!("{request:?}"), "RpcRequest { message_id: 1, endpoint: \"test_function\", body: Some(Call(Int64(42))) }" ); if let Some(Body::Call(bson)) = request.take_body() { let request_body: u32 = bson::from_bson(bson).unwrap(); assert_eq!(request_body, 42); } else { panic!("Invalid message type") } assert!( request .respond::(Err(krossbar_common_rpc::Error::ClientError( "Test error".to_owned() ))) .await ); select! { response = call.fuse() => { assert!(matches!(response, Err(krossbar_common_rpc::Error::ClientError(_)))); }, _ = rpc1.poll().fuse() => {} } }