use crate::errors::AsyncQueueError; use crate::runnable::RunnableTask; use crate::task::{Task, TaskHash, TaskId, TaskType}; use async_trait::async_trait; use diesel::result::Error::QueryBuilderError; use diesel_async::scoped_futures::ScopedFutureExt; use diesel_async::AsyncConnection; use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool}; /// This trait defines operations for an asynchronous queue. /// The trait can be implemented for different storage backends. /// For now, the trait is only implemented for PostgreSQL. More backends are planned to be implemented in the future. #[async_trait] pub trait Queueable: Send { /// Pull pending tasks from the queue to execute them. /// /// This method returns one task of the `task_type` type. If `task_type` is `None` it will try to /// fetch a task of the type `common`. The returned task is marked as running and must be executed. async fn pull_next_task( &mut self, kind: Option, ) -> Result, AsyncQueueError>; /// Enqueue a task to the queue, The task will be executed as soon as possible by the worker of the same type /// created by an AsyncWorkerPool. async fn create_task(&mut self, task: &dyn RunnableTask) -> Result; /// Retrieve a task by its `id`. async fn find_task_by_id(&mut self, id: TaskId) -> Result; /// Update the state of a task to failed and set an error_message. async fn set_task_failed( &mut self, id: TaskId, error_message: &str, ) -> Result; /// Update the state of a task to done. async fn set_task_done(&mut self, id: TaskId) -> Result; /// Update the state of a task to inform that it's still in progress. async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError>; /// Remove a task by its id. async fn remove_task(&mut self, id: TaskId) -> Result; /// The method will remove all tasks from the queue async fn remove_all_tasks(&mut self) -> Result; /// Remove all tasks that are scheduled in the future. async fn remove_all_scheduled_tasks(&mut self) -> Result; /// Remove a task by its metadata (struct fields values) async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result; /// Removes all tasks that have the specified `task_type`. async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result; async fn schedule_task_retry( &mut self, id: TaskId, backoff_seconds: u32, error: &str, ) -> Result; } /// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage. #[derive(Debug, Clone)] pub struct PgAsyncQueue { pool: Pool, } impl PgAsyncQueue { pub fn new(pool: Pool) -> Self { PgAsyncQueue { pool } } } #[async_trait] impl Queueable for PgAsyncQueue { async fn pull_next_task( &mut self, task_type: Option, ) -> Result, AsyncQueueError> { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; connection .transaction::, AsyncQueueError, _>(|conn| { async move { let Some(pending_task) = Task::fetch_next_pending(conn, task_type.unwrap_or_default()).await else { return Ok(None); }; Task::set_running(conn, pending_task).await.map(Some) } .scope_boxed() }) .await } async fn create_task(&mut self, runnable: &dyn RunnableTask) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; Ok(Task::insert(&mut connection, runnable).await?) } async fn find_task_by_id(&mut self, id: TaskId) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; Task::find_by_id(&mut connection, id).await } async fn set_task_failed( &mut self, id: TaskId, error_message: &str, ) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; Task::fail_with_message(&mut connection, id, error_message).await } async fn set_task_done(&mut self, id: TaskId) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; Task::set_done(&mut connection, id).await } async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError> { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; connection .transaction::<(), AsyncQueueError, _>(|conn| { async move { let task = Task::find_by_id(conn, id).await?; Task::set_running(conn, task).await?; Ok(()) } .scope_boxed() }) .await } async fn remove_task(&mut self, id: TaskId) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; let result = Task::remove(&mut connection, id).await?; Ok(result) } async fn remove_all_tasks(&mut self) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; Task::remove_all(&mut connection).await } async fn remove_all_scheduled_tasks(&mut self) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; let result = Task::remove_all_scheduled(&mut connection).await?; Ok(result) } async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; Task::remove_by_hash(&mut connection, task_hash).await } async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; let result = Task::remove_by_type(&mut connection, task_type).await?; Ok(result) } async fn schedule_task_retry( &mut self, id: TaskId, backoff_seconds: u32, error: &str, ) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; let task = Task::schedule_retry(&mut connection, id, backoff_seconds, error).await?; Ok(task) } } #[cfg(test)] mod async_queue_tests { use super::*; use crate::task::TaskState; use crate::Scheduled; use async_trait::async_trait; use chrono::DateTime; use chrono::Utc; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::AsyncPgConnection; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize)] struct AsyncTask { pub number: u16, } #[typetag::serde] #[async_trait] impl RunnableTask for AsyncTask { async fn run( &self, _queueable: &mut dyn Queueable, ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { Ok(()) } } #[derive(Serialize, Deserialize)] struct AsyncUniqTask { pub number: u16, } #[typetag::serde] #[async_trait] impl RunnableTask for AsyncUniqTask { async fn run( &self, _queueable: &mut dyn Queueable, ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { Ok(()) } fn uniq(&self) -> Option { TaskHash::default_for_task(self).ok() } } #[derive(Serialize, Deserialize)] struct AsyncTaskSchedule { pub number: u16, pub datetime: String, } #[typetag::serde] #[async_trait] impl RunnableTask for AsyncTaskSchedule { async fn run( &self, _queueable: &mut dyn Queueable, ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { Ok(()) } fn cron(&self) -> Option { let datetime = self.datetime.parse::>().ok()?; Some(Scheduled::ScheduleOnce(datetime)) } } #[tokio::test] async fn insert_task_creates_new_task() { let pool = pool().await; let mut test = PgAsyncQueue::new(pool); let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); test.remove_all_tasks().await.unwrap(); } #[tokio::test] async fn update_task_state_test() { let pool = pool().await; let mut test = PgAsyncQueue::new(pool); let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); let id = task.id; assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); let finished_task = test.set_task_done(task.id).await.unwrap(); assert_eq!(id, finished_task.id); assert_eq!(TaskState::Done, finished_task.state()); test.remove_all_tasks().await.unwrap(); } #[tokio::test] async fn failed_task_query_test() { let pool = pool().await; let mut test = PgAsyncQueue::new(pool); let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); let id = task.id; assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); let failed_task = test.set_task_failed(task.id, "Some error").await.unwrap(); assert_eq!(id, failed_task.id); assert_eq!(Some("Some error"), failed_task.error_message.as_deref()); assert_eq!(TaskState::Failed, failed_task.state()); test.remove_all_tasks().await.unwrap(); } #[tokio::test] async fn remove_all_tasks_test() { let pool = pool().await; let mut test = PgAsyncQueue::new(pool); let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(2), number); assert_eq!(Some("AsyncTask"), type_task); let result = test.remove_all_tasks().await.unwrap(); assert_eq!(2, result); } // #[tokio::test] // async fn schedule_task_test() { // let pool = pool().await; // let mut test = PgAsyncQueue::new(pool); // // let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0); // // let task = &AsyncTaskSchedule { // number: 1, // datetime: datetime.to_string(), // }; // // let task = test.schedule_task(task).await.unwrap(); // // let metadata = task.payload.as_object().unwrap(); // let number = metadata["number"].as_u64(); // let type_task = metadata["type"].as_str(); // // assert_eq!(Some(1), number); // assert_eq!(Some("AsyncTaskSchedule"), type_task); // assert_eq!(task.scheduled_at, datetime); // // test.remove_all_tasks().await.unwrap(); // } // // #[tokio::test] // async fn remove_all_scheduled_tasks_test() { // let pool = pool().await; // let mut test = PgAsyncQueue::new(pool); // // let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0); // // let task1 = &AsyncTaskSchedule { // number: 1, // datetime: datetime.to_string(), // }; // // let task2 = &AsyncTaskSchedule { // number: 2, // datetime: datetime.to_string(), // }; // // test.schedule_task(task1).await.unwrap(); // test.schedule_task(task2).await.unwrap(); // // let number = test.remove_all_scheduled_tasks().await.unwrap(); // // assert_eq!(2, number); // // test.remove_all_tasks().await.unwrap(); // } #[tokio::test] async fn pull_next_task_test() { let pool = pool().await; let mut test = PgAsyncQueue::new(pool); let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(2), number); assert_eq!(Some("AsyncTask"), type_task); let task = test.pull_next_task(None).await.unwrap().unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); let task = test.pull_next_task(None).await.unwrap().unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(2), number); assert_eq!(Some("AsyncTask"), type_task); test.remove_all_tasks().await.unwrap(); } #[tokio::test] async fn remove_tasks_type_test() { let pool = pool().await; let mut test = PgAsyncQueue::new(pool); let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(2), number); assert_eq!(Some("AsyncTask"), type_task); let result = test .remove_tasks_type(TaskType::from("nonexistentType")) .await .unwrap(); assert_eq!(0, result); let result = test.remove_tasks_type(TaskType::default()).await.unwrap(); assert_eq!(2, result); test.remove_all_tasks().await.unwrap(); } #[tokio::test] async fn remove_tasks_by_metadata() { let pool = pool().await; let mut test = PgAsyncQueue::new(pool); let task = test .create_task(&AsyncUniqTask { number: 1 }) .await .unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(1), number); assert_eq!(Some("AsyncUniqTask"), type_task); let task = test .create_task(&AsyncUniqTask { number: 2 }) .await .unwrap(); let metadata = task.payload.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); assert_eq!(Some(2), number); assert_eq!(Some("AsyncUniqTask"), type_task); let result = test .remove_task_by_hash(AsyncUniqTask { number: 0 }.uniq().unwrap()) .await .unwrap(); assert!(!result, "Should **not** remove task"); let result = test .remove_task_by_hash(AsyncUniqTask { number: 1 }.uniq().unwrap()) .await .unwrap(); assert!(result, "Should remove task"); test.remove_all_tasks().await.unwrap(); } async fn pool() -> Pool { let manager = AsyncDieselConnectionManager::::new( option_env!("DATABASE_URL").expect("DATABASE_URL must be set"), ); Pool::builder() .max_size(1) .min_idle(Some(1)) .build(manager) .await .unwrap() } }