Generic Worker Pool (#55)

* generic async worker pool !!

* cfg tests
This commit is contained in:
Pmarquez 2022-08-02 14:32:58 +00:00 committed by GitHub
parent 549f5a1c4b
commit 140b19e6e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 30 deletions

View file

@ -21,7 +21,7 @@ async fn main() {
queue.connect(NoTls).await.unwrap();
log::info!("Queue connected...");
let mut pool = AsyncWorkerPool::builder()
let mut pool: AsyncWorkerPool<AsyncQueue<NoTls>> = AsyncWorkerPool::builder()
.number_of_workers(max_pool_size)
.queue(queue.clone())
.build();

View file

@ -10,9 +10,12 @@ use std::time::Duration;
use typed_builder::TypedBuilder;
#[derive(TypedBuilder)]
pub struct AsyncWorker<'a> {
pub struct AsyncWorker<AQueue>
where
AQueue: AsyncQueueable + Clone + Sync + 'static,
{
#[builder(setter(into))]
pub queue: &'a mut dyn AsyncQueueable,
pub queue: AQueue,
#[builder(default=DEFAULT_TASK_TYPE.to_string(), setter(into))]
pub task_type: String,
#[builder(default, setter(into))]
@ -21,7 +24,10 @@ pub struct AsyncWorker<'a> {
pub retention_mode: RetentionMode,
}
impl<'a> AsyncWorker<'a> {
impl<AQueue> AsyncWorker<AQueue>
where
AQueue: AsyncQueueable + Clone + Sync + 'static,
{
pub async fn run(&mut self, task: Task) -> Result<(), Error> {
let result = self.execute_task(task).await;
self.finalize_task(result).await
@ -31,7 +37,7 @@ impl<'a> AsyncWorker<'a> {
let actual_task: Box<dyn AsyncRunnable> =
serde_json::from_value(task.metadata.clone()).unwrap();
let task_result = actual_task.run(self.queue).await;
let task_result = actual_task.run(&mut self.queue).await;
match task_result {
Ok(()) => Ok(task),
Err(error) => Err((task, error.description)),
@ -104,8 +110,81 @@ impl<'a> AsyncWorker<'a> {
};
}
}
}
#[cfg(test)]
#[cfg(test)]
#[derive(TypedBuilder)]
pub struct AsyncWorkerTest<'a> {
#[builder(setter(into))]
pub queue: &'a mut dyn AsyncQueueable,
#[builder(default=DEFAULT_TASK_TYPE.to_string(), setter(into))]
pub task_type: String,
#[builder(default, setter(into))]
pub sleep_params: SleepParams,
#[builder(default, setter(into))]
pub retention_mode: RetentionMode,
}
#[cfg(test)]
impl<'a> AsyncWorkerTest<'a> {
pub async fn run(&mut self, task: Task) -> Result<(), Error> {
let result = self.execute_task(task).await;
self.finalize_task(result).await
}
async fn execute_task(&mut self, task: Task) -> Result<Task, (Task, String)> {
let actual_task: Box<dyn AsyncRunnable> =
serde_json::from_value(task.metadata.clone()).unwrap();
let task_result = actual_task.run(self.queue).await;
match task_result {
Ok(()) => Ok(task),
Err(error) => Err((task, error.description)),
}
}
async fn finalize_task(&mut self, result: Result<Task, (Task, String)>) -> Result<(), Error> {
match self.retention_mode {
RetentionMode::KeepAll => match result {
Ok(task) => {
self.queue
.update_task_state(task, FangTaskState::Finished)
.await?;
Ok(())
}
Err((task, error)) => {
self.queue.fail_task(task, &error).await?;
Ok(())
}
},
RetentionMode::RemoveAll => match result {
Ok(task) => {
self.queue.remove_task(task).await?;
Ok(())
}
Err((task, _error)) => {
self.queue.remove_task(task).await?;
Ok(())
}
},
RetentionMode::RemoveFinished => match result {
Ok(task) => {
self.queue.remove_task(task).await?;
Ok(())
}
Err((task, error)) => {
self.queue.fail_task(task, &error).await?;
Ok(())
}
},
}
}
pub async fn sleep(&mut self) {
self.sleep_params.maybe_increase_sleep_period();
tokio::time::sleep(Duration::from_secs(self.sleep_params.sleep_period)).await;
}
pub async fn run_tasks_until_none(&mut self) -> Result<(), Error> {
loop {
match self
@ -132,7 +211,7 @@ impl<'a> AsyncWorker<'a> {
#[cfg(test)]
mod async_worker_tests {
use super::AsyncWorker;
use super::AsyncWorkerTest;
use crate::asynk::async_queue::AsyncQueueTest;
use crate::asynk::async_queue::AsyncQueueable;
use crate::asynk::async_queue::FangTaskState;
@ -215,7 +294,7 @@ mod async_worker_tests {
let task = insert_task(&mut test, &WorkerAsyncTask { number: 1 }).await;
let id = task.id;
let mut worker = AsyncWorker::builder()
let mut worker = AsyncWorkerTest::builder()
.queue(&mut test as &mut dyn AsyncQueueable)
.retention_mode(RetentionMode::KeepAll)
.build();
@ -237,7 +316,7 @@ mod async_worker_tests {
let task = insert_task(&mut test, &AsyncFailedTask { number: 1 }).await;
let id = task.id;
let mut worker = AsyncWorker::builder()
let mut worker = AsyncWorkerTest::builder()
.queue(&mut test as &mut dyn AsyncQueueable)
.retention_mode(RetentionMode::KeepAll)
.build();
@ -269,7 +348,7 @@ mod async_worker_tests {
let id12 = task12.id;
let id2 = task2.id;
let mut worker = AsyncWorker::builder()
let mut worker = AsyncWorkerTest::builder()
.queue(&mut test as &mut dyn AsyncQueueable)
.task_type("type1".to_string())
.retention_mode(RetentionMode::KeepAll)
@ -304,7 +383,7 @@ mod async_worker_tests {
let _id12 = task12.id;
let id2 = task2.id;
let mut worker = AsyncWorker::builder()
let mut worker = AsyncWorkerTest::builder()
.queue(&mut test as &mut dyn AsyncQueueable)
.task_type("type1".to_string())
.build();

View file

@ -1,26 +1,19 @@
use crate::asynk::async_queue::AsyncQueue;
use crate::asynk::async_queue::AsyncQueueable;
use crate::asynk::async_worker::AsyncWorker;
use crate::asynk::Error;
use crate::{RetentionMode, SleepParams};
use async_recursion::async_recursion;
use bb8_postgres::tokio_postgres::tls::MakeTlsConnect;
use bb8_postgres::tokio_postgres::tls::TlsConnect;
use bb8_postgres::tokio_postgres::Socket;
use log::error;
use std::time::Duration;
use typed_builder::TypedBuilder;
#[derive(TypedBuilder, Clone)]
pub struct AsyncWorkerPool<Tls>
pub struct AsyncWorkerPool<AQueue>
where
Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
<Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
<Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
<<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
AQueue: AsyncQueueable + Clone + Sync + 'static,
{
#[builder(setter(into))]
pub queue: AsyncQueue<Tls>,
pub queue: AQueue,
#[builder(default, setter(into))]
pub sleep_params: SleepParams,
#[builder(default, setter(into))]
@ -39,12 +32,9 @@ pub struct WorkerParams {
pub task_type: Option<String>,
}
impl<Tls> AsyncWorkerPool<Tls>
impl<AQueue> AsyncWorkerPool<AQueue>
where
Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
<Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
<Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
<<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
AQueue: AsyncQueueable + Clone + Sync + 'static,
{
pub async fn start(&mut self) {
for _idx in 0..self.number_of_workers {
@ -60,7 +50,7 @@ where
#[async_recursion]
pub async fn supervise_worker(
queue: AsyncQueue<Tls>,
queue: AQueue,
sleep_params: SleepParams,
retention_mode: RetentionMode,
) -> Result<(), Error> {
@ -82,12 +72,12 @@ where
}
pub async fn run_worker(
mut queue: AsyncQueue<Tls>,
queue: AQueue,
sleep_params: SleepParams,
retention_mode: RetentionMode,
) -> Result<(), Error> {
let mut worker = AsyncWorker::builder()
.queue(&mut queue as &mut dyn AsyncQueueable)
let mut worker: AsyncWorker<AQueue> = AsyncWorker::builder()
.queue(queue)
.sleep_params(sleep_params)
.retention_mode(retention_mode)
.build();