use assert_matches::assert_matches; use futures::future::{join_all, ready}; use futures::prelude::*; use lrcall::client::{self}; use lrcall::context; use lrcall::server::incoming::Incoming; use lrcall::server::{BaseChannel, Channel}; use lrcall::transport::channel; use std::time::{Duration, Instant}; use tokio::join; #[lrcall_macro::service] trait Service { async fn add(x: i32, y: i32) -> i32; async fn hey(name: String) -> String; } #[derive(Clone)] struct Server; impl Service for Server { async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { x + y } async fn hey(self, _: context::Context, name: String) -> String { format!("Hey, {name}.") } } #[tokio::test] async fn sequential() { let (tx, rx) = lrcall::transport::channel::unbounded(); let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); tokio::spawn(channel.execute(lrcall::server::serve(|_, i: u32| async move { Ok(i + 1) })).for_each(|response| response)); assert_eq!(client.call(context::rpc_current(), 1).await.unwrap(), 2); } #[tokio::test] async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { #[lrcall_macro::service] trait Loop { async fn r#loop(); } #[derive(Clone)] struct LoopServer; impl Loop for LoopServer { async fn r#loop(self, _: context::Context) { loop { futures::pending!(); } } } let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. tokio::spawn(async move { let client = LoopClient::::rpc_client((client::Config::default(), tx).into()); let mut ctx = context::rpc_current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(ctx).await; }); let mut requests = BaseChannel::with_defaults(rx).requests(); // Reading a request should trigger the request being registered with BaseChannel. let first_request = requests.next().await.unwrap()?; // Dropping the channel should trigger cleanup of outstanding requests. drop(requests); // In-flight requests should be aborted by channel cleanup. // The first and only request sent by the client is `loop`, which is an infinite loop // on the server side, so if cleanup was not triggered, this line should hang indefinitely. first_request.execute(LoopServer.serve()).await; Ok(()) } #[cfg(all(feature = "serde-transport", feature = "tcp"))] #[tokio::test] async fn serde_tcp() -> anyhow::Result<()> { use lrcall::serde_transport; use tokio_serde::formats::Json; let _ = tracing_subscriber::fmt::try_init(); let transport = lrcall::serde_transport::tcp::listen("localhost:56789", Json::default).await?; let addr = transport.local_addr(); tokio::spawn( transport .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::::rpc_client(ServiceChannel::spawn(client::Config::default(), transport)); assert_matches!(client.add(context::rpc_current(), 1, 2).await, Ok(3)); assert_matches!( client.hey(context::rpc_current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); Ok(()) } #[cfg(all(feature = "serde-transport", feature = "unix", unix))] #[tokio::test] async fn serde_uds() -> anyhow::Result<()> { use lrcall::serde_transport; use tokio_serde::formats::Json; let _ = tracing_subscriber::fmt::try_init(); let sock = lrcall::serde_transport::unix::TempPathBuf::with_random("uds"); let transport = lrcall::serde_transport::unix::listen(&sock, Json::default).await?; tokio::spawn( transport .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); let tx = serde_transport::unix::connect(&sock, Json::default).await?; let client = ServiceClient::::rpc_client(ServiceChannel::spawn(client::Config::default(), tx)); // Save results using socket so we can clean the socket even if our test assertions fail let res1 = client.add(context::rpc_current(), 1, 2).await; let res2 = client.hey(context::rpc_current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); Ok(()) } #[tokio::test] async fn conrpc_current() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); let client = ServiceClient::::rpc_client(ServiceChannel::spawn(client::Config::default(), tx)); let req1 = client.add(context::rpc_current(), 1, 2); let req2 = client.add(context::rpc_current(), 3, 4); let req3 = client.hey(context::rpc_current(), "Tim".to_string()); assert_matches!(req1.await, Ok(3)); assert_matches!(req2.await, Ok(7)); assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); Ok(()) } #[tokio::test] async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); let client = ServiceClient::::rpc_client(ServiceChannel::spawn(client::Config::default(), tx)); let req1 = client.add(context::rpc_current(), 1, 2); let req2 = client.add(context::rpc_current(), 3, 4); let req3 = client.hey(context::rpc_current(), "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); assert_matches!(resp2, Ok(7)); assert_matches!(resp3, Ok(ref s) if s == "Hey, Tim."); Ok(()) } #[cfg(test)] async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } #[tokio::test] async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn(BaseChannel::with_defaults(rx).execute(Server.serve()).for_each(spawn)); let client = ServiceClient::::rpc_client(ServiceChannel::spawn(client::Config::default(), tx)); let req1 = client.add(context::rpc_current(), 1, 2); let req2 = client.add(context::rpc_current(), 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3)); assert_matches!(responses[1], Ok(7)); Ok(()) } #[tokio::test] async fn counter() -> anyhow::Result<()> { #[lrcall::service] trait Counter { async fn count() -> u32; } struct CountService(u32); impl Counter for &mut CountService { async fn count(self, _: context::Context) -> u32 { self.0 += 1; self.0 } } let (tx, rx) = channel::unbounded(); tokio::spawn(async { let mut requests = BaseChannel::with_defaults(rx).requests(); let mut counter = CountService(0); while let Some(Ok(request)) = requests.next().await { request.execute(counter.serve()).await; } }); let client = CounterClient::::rpc_client(CounterChannel::spawn(client::Config::default(), tx)); assert_matches!(client.count(context::rpc_current()).await, Ok(1)); assert_matches!(client.count(context::rpc_current()).await, Ok(2)); Ok(()) }