use crate::{
config::*,
job::Job,
logging::{configure_logger, SharedJobId},
mcai_worker::{McaiWorker, McaiWorkerDescription},
message_exchange::{
message::{OrderMessage, ResponseMessage},
ExternalExchange, ExternalLocalExchange, LocalExchange, RabbitmqExchange,
},
worker::{docker, SystemInformation, WorkerConfiguration},
Processor,
};
use futures_executor::LocalPool;
use log::kv::Value;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use std::{
fs,
str::FromStr,
sync::{Arc, Mutex},
thread, time,
};
/// Function to start a worker
pub fn start_worker<
P: DeserializeOwned + JsonSchema,
D: McaiWorkerDescription,
W: 'static + McaiWorker
,
>(
mut worker: W,
) where
W: Sync + Send,
{
let amqp_queue = get_amqp_queue();
let instance_id = docker::get_instance_id();
let container_id = instance_id.clone();
let (config, shared_job_id) = configure_logger(&container_id, &get_amqp_queue());
log4rs::init_config(config).unwrap();
let worker_configuration = WorkerConfiguration::new(&amqp_queue, &worker, &instance_id);
if let Err(configuration_error) = worker_configuration {
log::error!("{:?}", configuration_error);
return;
}
let worker_configuration = worker_configuration.unwrap();
log::info!(
"Worker: {}, version: {} (MCAI Worker SDK {})",
worker_configuration.get_worker_name(),
worker_configuration.get_worker_version(),
worker_configuration.get_sdk_version(),
);
if let Ok(enabled) = std::env::var("DESCRIBE") {
if enabled == "1" || bool::from_str(&enabled.to_lowercase()).unwrap_or(false) {
match serde_json::to_string_pretty(&worker_configuration) {
Ok(serialized_configuration) => {
println!("{serialized_configuration}");
return;
}
Err(error) => log::error!("Could not serialize worker configuration: {:?}", error),
}
}
}
if let Err(message) = worker.init() {
log::error!("{:?}", message);
return;
}
let shared_worker = Arc::new(Mutex::new(worker));
if let Some(e) = SystemInformation::enable_accounting_on_gpu().err() {
log::warn!("Could not enable accounting on GPU {:?}", e);
}
log::info!("Worker initialized, ready to receive jobs");
if let Some(source_orders) = get_source_orders() {
log::warn!("Worker will process source orders");
handle_source_orders(
shared_worker,
shared_job_id,
worker_configuration,
source_orders,
);
} else {
handle_remote_orders(shared_worker, shared_job_id, worker_configuration)
}
}
pub fn handle_remote_orders<
P: DeserializeOwned + JsonSchema,
D: McaiWorkerDescription,
W: 'static + McaiWorker
,
>(
shared_worker: Arc>,
shared_job_id: SharedJobId,
worker_configuration: WorkerConfiguration,
) where
W: Sync + Send,
{
loop {
let mut executor = LocalPool::new();
executor.run_until(async {
let exchange = RabbitmqExchange::new(&worker_configuration).await.unwrap();
let exchange = Arc::new(Mutex::new(exchange));
let processor = Processor::new(exchange, worker_configuration.clone());
processor
.run(shared_worker.clone(), shared_job_id.clone())
.unwrap();
});
let sleep_duration = time::Duration::new(1, 0);
thread::sleep(sleep_duration);
log::info!("Reconnection...");
}
}
pub fn handle_source_orders<
P: DeserializeOwned + JsonSchema,
D: McaiWorkerDescription,
W: 'static + McaiWorker,
>(
shared_worker: Arc>,
shared_job_id: SharedJobId,
worker_configuration: WorkerConfiguration,
source_orders: Vec,
) where
W: Sync + Send,
{
let (internal_exchange, mut external_exchange) = LocalExchange::create();
let shared_internal_exchange = Arc::new(Mutex::new(internal_exchange));
let cloned_shared_job_id = shared_job_id.clone();
async_std::task::spawn(async move {
let processor = Processor::new(shared_internal_exchange, worker_configuration.clone());
processor
.run(shared_worker.clone(), cloned_shared_job_id)
.unwrap();
});
{
if let Ok(message) = external_exchange.next_response() {
log::warn!("{:?}", message);
let Some(ResponseMessage::WorkerCreated(_)) = message else {
panic!("Bad message received, expected Worker created");
};
}
}
for source_order in &source_orders {
log::info!("Start to process order: {:?}", source_order);
let message_data = fs::read_to_string(source_order).unwrap();
handle_source_order(shared_job_id.clone(), &mut external_exchange, message_data);
}
}
pub fn handle_source_order(
shared_job_id: SharedJobId,
external_exchange: &mut ExternalLocalExchange,
message: String,
) {
let job = Job::new(&message).unwrap();
*shared_job_id.lock().unwrap() = Some(job.job_id);
log::debug!("received message: {:?}", job);
{
external_exchange
.send_order(OrderMessage::InitProcess(job.clone()))
.unwrap();
if let Ok(message) = external_exchange.next_response() {
log::info!(json = Value::from_serde(&message); "{:?}", message);
let Some(ResponseMessage::WorkerInitialized(_)) = message else {
panic!("Bad message received");
};
}
external_exchange
.send_order(OrderMessage::StartProcess(job))
.unwrap();
}
loop {
if let Ok(message) = external_exchange.next_response() {
log::info!(json = Value::from_serde(&message); "{:?}", message);
match message {
Some(ResponseMessage::Completed(_)) | Some(ResponseMessage::Error(_)) => {
*shared_job_id.lock().unwrap() = None;
break;
}
_ => {}
}
}
}
}