Compare commits

...

21 commits
0.1.0 ... main

Author SHA1 Message Date
Rafael Caricio 912127857f
Update readme 2023-03-22 11:00:14 +01:00
Rafael Caricio f198d3983a
Update logo with turbofish 2023-03-22 10:52:45 +01:00
Rafael Caricio 7a9eddf9e4
Merge pull request #1 from rafaelcaricio/logo-test
Introduce the Backie mascot and logo design
2023-03-21 16:39:29 +01:00
Rafael Caricio 5c43dacb5b
Add Backie logo, so cute!!1 :) 2023-03-21 15:39:46 +01:00
Rafael Caricio 617bd71bd1
Update readme with execution diagram 2023-03-18 16:26:12 +01:00
Rafael Caricio ed6a784e02
Release backie version 0.6.0 2023-03-14 16:41:46 +01:00
Rafael Caricio e1a8eeb7de
Use Queue type directly 2023-03-14 15:06:22 +01:00
Rafael Caricio c93d38de01
Release backie version 0.5.0 2023-03-13 17:59:15 +01:00
Rafael Caricio aa1144e54f
Allow definition of custom error type 2023-03-13 17:46:59 +01:00
Rafael Caricio 2b42a27b72
Release backie version 0.4.0 2023-03-13 14:11:58 +01:00
Rafael Caricio 64e2315999
Review Queue signature 2023-03-13 14:11:19 +01:00
Rafael Caricio 253a82fecf
Release backie version 0.3.0 2023-03-13 13:19:59 +01:00
Rafael Caricio 979294296e
Configure release.toml 2023-03-13 13:17:16 +01:00
Rafael Caricio c99486eaa6
Make TaskStore trait object safe 2023-03-13 13:08:54 +01:00
Rafael Caricio c07781a79b
Maximum of 5 keywords per crate 2023-03-12 18:42:12 +01:00
Rafael Caricio 042de9261f
Release 0.2.0 2023-03-12 18:38:54 +01:00
Rafael Caricio 716eeae4b1
Handle tasks that panic 2023-03-12 18:33:00 +01:00
Rafael Caricio 10e01390b8
Allow customization of the pulling interval per queue 2023-03-12 17:15:40 +01:00
Rafael Caricio 82e6ef6dac
Remove submodules 2023-03-12 15:55:27 +01:00
Rafael Caricio 0f0a9c2238
Tasks are let run until completion 2023-03-12 15:52:13 +01:00
Rafael Caricio 2964dc2b88
Wait all workers to stop gracefully 2023-03-12 00:18:15 +01:00
16 changed files with 772 additions and 338 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@
Cargo.lock
docs/content/docs/CHANGELOG.md
docs/content/docs/README.md
.DS_Store

3
.gitmodules vendored
View file

@ -1,3 +0,0 @@
[submodule "docs/themes/adidoks"]
path = docs/themes/adidoks
url = https://github.com/aaranxu/adidoks.git

View file

@ -1,25 +1,22 @@
[package]
name = "backie"
version = "0.1.0"
version = "0.6.0"
authors = [
"Rafael Caricio <rafael@caricio.com>",
]
description = "Async persistent background task processing for Rust applications with Tokio and PostgreSQL."
repository = "https://code.caric.io/rafaelcaricio/backie"
description = "Background task processing for Rust applications with Tokio, Diesel, and PostgreSQL."
keywords = ["async", "background", "task", "jobs", "queue"]
repository = "https://github.com/rafaelcaricio/backie"
edition = "2021"
license = "MIT"
readme = "README.md"
rust-version = "1.67"
[lib]
doctest = false
[dependencies]
chrono = "0.4"
log = "0.4"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
anyhow = "1"
thiserror = "1"
uuid = { version = "1.1", features = ["v4", "serde"] }
async-trait = "0.1"

View file

@ -1,4 +1,4 @@
# Backie 🚲
<p align="center"><img src="logo.png" alt="Backie" width="400"></p>
Async persistent background task processing for Rust applications with Tokio. Queue asynchronous tasks
to be processed by workers. It's designed to be easy to use and horizontally scalable. It uses Postgres as
@ -17,19 +17,47 @@ Backie started as a fork of
Here are some of the Backie's key features:
- Async workers: Workers are started as [Tokio](https://tokio.rs/) tasks
- Application context: Tasks can access an shared user-provided application context
- Single-purpose workers: Tasks are stored together but workers are configured to execute only tasks of a specific queue
- Retries: Tasks are retried with a custom backoff mode
- Graceful shutdown: provide a future to gracefully shutdown the workers, on-the-fly tasks are not interrupted
- Recovery of unfinished tasks: Tasks that were not finished are retried on the next worker start
- Unique tasks: Tasks are not duplicated in the queue if they provide a unique hash
- **Guaranteed execution**: at least one execution of a task
- **Async workers**: Workers are started as [Tokio](https://tokio.rs/) tasks
- **Application context**: Tasks can access an shared user-provided application context
- **Single-purpose workers**: Tasks are stored together but workers are configured to execute only tasks of a specific queue
- **Retries**: Tasks are retried with a custom backoff mode
- **Graceful shutdown**: provide a future to gracefully shutdown the workers, on-the-fly tasks are not interrupted
- **Recovery of unfinished tasks**: Tasks that were not finished are retried on the next worker start
- **Unique tasks**: Tasks are not duplicated in the queue if they provide a unique hash
## Other planned features
- Task timeout: Tasks are retried if they are not completed in time
- Scheduling of tasks: Tasks can be scheduled to be executed at a specific time
## Task execution protocol
The following diagram shows the protocol used to execute tasks:
```mermaid
stateDiagram-v2
[*] --> Ready
Ready --> Running: Task is picked up by a worker
Running --> Done: Task is finished
Running --> Failed: Task failed
Failed --> Ready: Task is retried
Failed --> [*]: Task is not retried anymore, max retries reached
Done --> [*]
```
When a task goes from `Running` to `Failed` it is retried. The number of retries is controlled by the
[`BackgroundTask::MAX_RETRIES`] attribute. The default implementation uses `3` retries.
## Safety
This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust.
## Minimum supported Rust version
Backie's MSRV is 1.68.
## Installation
1. Add this to your `Cargo.toml`
@ -52,8 +80,6 @@ diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
Those dependencies are required to use the `#[async_trait]` and `#[derive(Serialize, Deserialize)]` attributes
in your task definitions and to connect to the Postgres database.
*Supports rustc 1.68+*
2. Create the `backie_tasks` table in the Postgres database. The migration can be found in [the migrations directory](https://github.com/rafaelcaricio/backie/blob/master/migrations/2023-03-06-151907_create_backie_tasks/up.sql).
## Usage
@ -67,6 +93,9 @@ the whole application. This attribute is critical for reconstructing the task ba
The [`BackgroundTask::AppData`] can be used to argument the task with your application specific contextual information.
This is useful for example to pass a database connection pool to the task or other application configuration.
The [`BackgroundTask::Error`] is the error type that will be returned by the [`BackgroundTask::run`] method. You can
use this to define your own error type for your tasks.
The [`BackgroundTask::run`] method is where you define the behaviour of your background task execution. This method
will be called by the task queue workers.
@ -84,8 +113,9 @@ pub struct MyTask {
impl BackgroundTask for MyTask {
const TASK_NAME: &'static str = "my_task_unique_name";
type AppData = ();
type Error = ();
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
// Do something
Ok(())
}
@ -98,44 +128,23 @@ First, we need to create a [`TaskStore`] trait instance. This is the object resp
tasks from a database. Backie currently only supports Postgres as a storage backend via the provided
[`PgTaskStore`]. You can implement other storage backends by implementing the [`TaskStore`] trait.
```rust
let connection_url = "postgres://postgres:password@localhost/backie";
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(connection_url);
let pool = Pool::builder()
.max_size(3)
.build(manager)
.await
.unwrap();
let task_store = PgTaskStore::new(pool);
```
Then, we can use the `task_store` to start a worker pool using the [`WorkerPool`]. The [`WorkerPool`] is responsible
for starting the workers and managing their lifecycle.
```rust
// Register the task types I want to use and start the worker pool
let (_, queue) = WorkerPool::new(task_store, |_|())
.register_task_type::<MyTask>()
.configure_queue("default", 1, RetentionMode::default())
.start(futures::future::pending::<()>())
.await
.unwrap();
```
With that, we are defining that we want to execute instances of `MyTask` and that the `default` queue should
have 1 worker running using the default [`RetentionMode`] (remove from the database only successfully finished tasks).
We also defined in the `start` method that the worker pool should run forever.
A full example of starting a worker pool can be found in the [examples directory](https://github.com/rafaelcaricio/backie/blob/main/examples/simple_worker/src/main.rs).
### Queueing tasks
After stating the workers we get an instance of [`Queue`] which we can use to enqueue tasks:
After stating the workers, we get an instance of [`Queue`] which we can use to enqueue tasks. It is also possible
to directly create a [`Queue`] instance from with a [`TaskStore`] instance.
```rust
let task = MyTask { info: "Hello world!".to_string() };
queue.enqueue(task).await.unwrap();
```
This will enqueue the task and whenever a worker is available it will start processing. Workers don't need to be
started before enqueuing tasks. Workers don't need to be in the same process as the queue as long as the workers have
access to the same underlying storage system. This enables horizontal scaling of the workers.
## License
This project is licensed under the [MIT license][license].
## Contributing
@ -145,7 +154,7 @@ queue.enqueue(task).await.unwrap();
4. Push to the branch (`git push origin my-new-feature`)
5. Create a new Pull Request
## Thanks to related crates authors
## Acknowledgements
I would like to thank the authors of the [Fang](https://github.com/ayrat555/fang) and [background_job](https://git.asonix.dog/asonix/background-jobs.git) crates which were the main inspiration for this project.

View file

@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
backie = { path = "../../" }
anyhow = "1"
env_logger = "0.9.0"
env_logger = "0.10"
log = "0.4.0"
tokio = { version = "1", features = ["full"] }
diesel-async = { version = "0.2", features = ["postgres", "bb8"] }

View file

@ -1,97 +0,0 @@
use async_trait::async_trait;
use backie::{BackgroundTask, CurrentTask};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct MyApplicationContext {
app_name: String,
}
impl MyApplicationContext {
pub fn new(app_name: &str) -> Self {
Self {
app_name: app_name.to_string(),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct MyTask {
pub number: u16,
}
impl MyTask {
pub fn new(number: u16) -> Self {
Self { number }
}
}
#[derive(Serialize, Deserialize)]
pub struct MyFailingTask {
pub number: u16,
}
impl MyFailingTask {
pub fn new(number: u16) -> Self {
Self { number }
}
}
#[async_trait]
impl BackgroundTask for MyTask {
const TASK_NAME: &'static str = "my_task";
type AppData = MyApplicationContext;
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> {
// let new_task = MyTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
log::info!(
"[{}] Hello from {}! the current number is {}",
task.id(),
ctx.app_name,
self.number
);
tokio::time::sleep(Duration::from_secs(3)).await;
log::info!("[{}] done..", task.id());
Ok(())
}
}
#[async_trait]
impl BackgroundTask for MyFailingTask {
const TASK_NAME: &'static str = "my_failing_task";
type AppData = MyApplicationContext;
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> {
// let new_task = MyFailingTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
// task.id();
// task.keep_alive().await?;
// task.previous_error();
// task.retry_count();
log::info!("[{}] the current number is {}", task.id(), self.number);
tokio::time::sleep(Duration::from_secs(3)).await;
log::info!("[{}] done..", task.id());
//
// let b = true;
//
// if b {
// panic!("Hello!");
// } else {
// Ok(())
// }
Ok(())
}
}

View file

@ -1,9 +1,97 @@
use backie::{PgTaskStore, RetentionMode, WorkerPool};
use async_trait::async_trait;
use backie::{BackgroundTask, CurrentTask};
use backie::{PgTaskStore, Queue, WorkerPool};
use diesel_async::pg::AsyncPgConnection;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use simple_worker::MyApplicationContext;
use simple_worker::MyFailingTask;
use simple_worker::MyTask;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct MyApplicationContext {
app_name: String,
}
impl MyApplicationContext {
pub fn new(app_name: &str) -> Self {
Self {
app_name: app_name.to_string(),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct MyTask {
pub number: u16,
}
impl MyTask {
pub fn new(number: u16) -> Self {
Self { number }
}
}
#[async_trait]
impl BackgroundTask for MyTask {
const TASK_NAME: &'static str = "my_task";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> {
// let new_task = MyTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
log::info!(
"[{}] Hello from {}! the current number is {}",
task.id(),
ctx.app_name,
self.number
);
tokio::time::sleep(Duration::from_secs(3)).await;
log::info!("[{}] done..", task.id());
Ok(())
}
}
#[derive(Serialize, Deserialize)]
pub struct MyFailingTask {
pub number: u16,
}
impl MyFailingTask {
pub fn new(number: u16) -> Self {
Self { number }
}
}
#[async_trait]
impl BackgroundTask for MyFailingTask {
const TASK_NAME: &'static str = "my_failing_task";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> {
// let new_task = MyFailingTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
// task.id();
// task.keep_alive().await?;
// task.previous_error();
// task.retry_count();
log::info!("[{}] the current number is {}", task.id(), self.number);
tokio::time::sleep(Duration::from_secs(3)).await;
log::info!("[{}] done..", task.id());
Ok(())
}
}
#[tokio::main]
async fn main() {
@ -30,15 +118,16 @@ async fn main() {
let my_app_context = MyApplicationContext::new("Backie Example App");
// Register the task types I want to use and start the worker pool
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
.register_task_type::<MyTask>()
.register_task_type::<MyFailingTask>()
.configure_queue("default", 3, RetentionMode::RemoveDone)
.start(async move {
let _ = rx.changed().await;
})
.await
.unwrap();
let join_handle =
WorkerPool::new(task_store.clone(), move || my_app_context.clone())
.register_task_type::<MyTask>()
.register_task_type::<MyFailingTask>()
.configure_queue("default".into())
.start(async move {
let _ = rx.changed().await;
})
.await
.unwrap();
log::info!("Workers started ...");
@ -46,6 +135,7 @@ async fn main() {
let task2 = MyTask::new(20_000);
let task3 = MyFailingTask::new(50_000);
let queue = Queue::new(task_store);
queue.enqueue(task1).await.unwrap();
queue.enqueue(task2).await.unwrap();
queue.enqueue(task3).await.unwrap();

BIN
logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

23
release.toml Normal file
View file

@ -0,0 +1,23 @@
allow-branch = [
"*",
"!HEAD",
]
sign-commit = true
sign-tag = true
push-remote = "origin"
release = true
publish = true
verify = true
owners = []
push = true
push-options = []
consolidate-commits = false
pre-release-commit-message = "Release {{crate_name}} version {{version}}"
pre-release-replacements = []
tag-message = "Release {{version}}"
tag-name = "{{version}}"
tag = true
enable-features = []
enable-all-features = false
dependent-version = "upgrade"
metadata = "optional"

46
src/catch_unwind.rs Normal file
View file

@ -0,0 +1,46 @@
use crate::worker::TaskExecError;
use futures::future::BoxFuture;
use futures::FutureExt;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
pub(crate) struct CatchUnwindFuture<F: Future + Send + 'static> {
inner: BoxFuture<'static, F::Output>,
}
impl<F: Future + Send + 'static> CatchUnwindFuture<F> {
pub fn create(f: F) -> CatchUnwindFuture<F> {
Self { inner: f.boxed() }
}
}
impl<F: Future + Send + 'static> Future for CatchUnwindFuture<F> {
type Output = Result<F::Output, TaskExecError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = &mut self.inner;
match catch_unwind(move || inner.poll_unpin(cx)) {
Ok(Poll::Pending) => Poll::Pending,
Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
Err(cause) => Poll::Ready(Err(cause)),
}
}
}
fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> Result<R, TaskExecError> {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(res) => Ok(res),
Err(cause) => match cause.downcast_ref::<&'static str>() {
None => match cause.downcast_ref::<String>() {
None => Err(TaskExecError::Panicked(
"Sorry, unknown panic message".to_string(),
)),
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
},
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
},
}
}

View file

@ -5,7 +5,7 @@
/// All possible options for retaining tasks in the db after their execution.
///
/// The default mode is [`RetentionMode::RemoveAll`]
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
pub enum RetentionMode {
/// Keep all tasks
KeepAll,
@ -26,10 +26,11 @@ impl Default for RetentionMode {
pub use queue::Queue;
pub use runnable::BackgroundTask;
pub use store::{PgTaskStore, TaskStore};
pub use task::{CurrentTask, Task, TaskId, TaskState};
pub use task::{CurrentTask, NewTask, Task, TaskId, TaskState};
pub use worker::Worker;
pub use worker_pool::WorkerPool;
pub use worker_pool::{QueueConfig, WorkerPool};
mod catch_unwind;
pub mod errors;
mod queries;
mod queue;

View file

@ -2,22 +2,20 @@ use crate::errors::BackieError;
use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
use crate::task::NewTask;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct Queue<S>
where
S: TaskStore,
{
task_store: Arc<S>,
task_store: S,
}
impl<S> Queue<S>
where
S: TaskStore,
{
pub fn new(task_store: Arc<S>) -> Self {
pub fn new(task_store: S) -> Self {
Queue { task_store }
}
@ -25,9 +23,21 @@ where
where
BT: BackgroundTask,
{
// TODO: Add option to specify the timeout of a task
self.task_store
.create_task(NewTask::new(background_task, Duration::from_secs(10))?)
.await?;
Ok(())
}
}
impl<S> Clone for Queue<S>
where
S: TaskStore + Clone,
{
fn clone(&self) -> Self {
Self {
task_store: self.task_store.clone(),
}
}
}

View file

@ -1,6 +1,7 @@
use crate::task::{CurrentTask, TaskHash};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, ser::Serialize};
use std::fmt::Debug;
/// The [`BackgroundTask`] trait is used to define the behaviour of a task. You must implement this
/// trait for all tasks you want to execute.
@ -17,7 +18,7 @@ use serde::{de::DeserializeOwned, ser::Serialize};
///
///
/// # Example
/// ```rust
/// ```
/// use async_trait::async_trait;
/// use backie::{BackgroundTask, CurrentTask};
/// use serde::{Deserialize, Serialize};
@ -25,11 +26,13 @@ use serde::{de::DeserializeOwned, ser::Serialize};
/// #[derive(Serialize, Deserialize)]
/// pub struct MyTask {}
///
/// #[async_trait]
/// impl BackgroundTask for MyTask {
/// const TASK_NAME: &'static str = "my_task_unique_name";
/// type AppData = ();
/// type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
///
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> {
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
/// // Do something
/// Ok(())
/// }
@ -50,14 +53,17 @@ pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
/// Number of retries for tasks.
///
/// By default, it is set to 5.
const MAX_RETRIES: i32 = 5;
/// By default, it is set to 3.
const MAX_RETRIES: i32 = 3;
/// The application data provided to this task at runtime.
type AppData: Clone + Send + 'static;
/// An application custom error type.
type Error: Debug + Send + 'static;
/// Execute the task. This method should define its logic
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>;
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error>;
/// If set to true, no new tasks with the same metadata will be inserted
/// By default it is set to false.

View file

@ -106,7 +106,7 @@ pub mod test_store {
#[derive(Default, Clone)]
pub struct MemoryTaskStore {
tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
pub tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
}
#[async_trait::async_trait]
@ -197,7 +197,7 @@ pub mod test_store {
}
#[async_trait::async_trait]
pub trait TaskStore: Clone + Send + Sync + 'static {
pub trait TaskStore: Send + Sync + 'static {
async fn pull_next_task(
&self,
queue_name: &str,
@ -213,3 +213,15 @@ pub trait TaskStore: Clone + Send + Sync + 'static {
error: &str,
) -> Result<Task, AsyncQueueError>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::test_store::MemoryTaskStore;
#[test]
fn task_store_trait_is_object_safe() {
let store = MemoryTaskStore::default();
let _object = &store as &dyn TaskStore;
}
}

View file

@ -1,3 +1,4 @@
use crate::catch_unwind::CatchUnwindFuture;
use crate::errors::{AsyncQueueError, BackieError};
use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
@ -9,7 +10,7 @@ use std::collections::BTreeMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
use std::time::Duration;
pub type ExecuteTaskFn<AppData> = Arc<
dyn Fn(
@ -23,13 +24,16 @@ pub type ExecuteTaskFn<AppData> = Arc<
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
#[derive(Debug, Error)]
#[derive(Debug, thiserror::Error)]
pub enum TaskExecError {
#[error("Task execution failed: {0}")]
ExecutionFailed(#[from] anyhow::Error),
#[error("Task deserialization failed: {0}")]
TaskDeserializationFailed(#[from] serde_json::Error),
#[error("Task execution failed: {0}")]
ExecutionFailed(String),
#[error("Task panicked with: {0}")]
Panicked(String),
}
pub(crate) fn runnable<BT>(
@ -42,8 +46,10 @@ where
{
Box::pin(async move {
let background_task: BT = serde_json::from_value(payload)?;
background_task.run(task_info, app_context).await?;
Ok(())
match background_task.run(task_info, app_context).await {
Ok(_) => Ok(()),
Err(err) => Err(TaskExecError::ExecutionFailed(format!("{:?}", err))),
}
})
}
@ -51,14 +57,16 @@ where
pub struct Worker<AppData, S>
where
AppData: Clone + Send + 'static,
S: TaskStore,
S: TaskStore + Clone,
{
store: Arc<S>,
store: S,
queue_name: String,
retention_mode: RetentionMode,
pull_interval: Duration,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
app_data_fn: StateFn<AppData>,
@ -70,12 +78,13 @@ where
impl<AppData, S> Worker<AppData, S>
where
AppData: Clone + Send + 'static,
S: TaskStore,
S: TaskStore + Clone,
{
pub(crate) fn new(
store: Arc<S>,
store: S,
queue_name: String,
retention_mode: RetentionMode,
pull_interval: Duration,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
app_data_fn: StateFn<AppData>,
shutdown: Option<tokio::sync::watch::Receiver<()>>,
@ -84,6 +93,7 @@ where
store,
queue_name,
retention_mode,
pull_interval,
task_registry,
app_data_fn,
shutdown,
@ -120,11 +130,11 @@ where
log::info!("Shutting down worker");
return Ok(());
}
_ = tokio::time::sleep(std::time::Duration::from_secs(1)).fuse() => {}
_ = tokio::time::sleep(self.pull_interval).fuse() => {}
}
}
None => {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
tokio::time::sleep(self.pull_interval).await;
}
};
}
@ -139,9 +149,18 @@ where
.get(&task.task_name)
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
// TODO: catch panics
let result: Result<(), TaskExecError> =
runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
// catch panics
let result: Result<(), TaskExecError> = CatchUnwindFuture::create({
let task_payload = task.payload.clone();
let app_data = (self.app_data_fn)();
let runnable_task_caller = runnable_task_caller.clone();
async move { runnable_task_caller(task_info, task_payload, app_data).await }
})
.await
.and_then(|result| {
result?;
Ok(())
});
match &result {
Ok(_) => self.finalize_task(task, result).await?,
@ -154,7 +173,9 @@ where
task.id,
backoff_seconds
);
let error_message = format!("{}", error);
self.store
.schedule_task_retry(task.id, backoff_seconds, &error_message)
.await?;
@ -231,8 +252,9 @@ mod async_worker_tests {
impl BackgroundTask for WorkerAsyncTask {
const TASK_NAME: &'static str = "WorkerAsyncTask";
type AppData = ();
type Error = ();
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), ()> {
Ok(())
}
}
@ -246,8 +268,9 @@ mod async_worker_tests {
impl BackgroundTask for WorkerAsyncTaskSchedule {
const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
Ok(())
}
@ -265,11 +288,12 @@ mod async_worker_tests {
impl BackgroundTask for AsyncFailedTask {
const TASK_NAME: &'static str = "AsyncFailedTask";
type AppData = ();
type Error = TaskError;
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), TaskError> {
let message = format!("number {} is wrong :(", self.number);
Err(TaskError::Custom(message).into())
Err(TaskError::Custom(message))
}
fn max_retries(&self) -> i32 {
@ -284,9 +308,10 @@ mod async_worker_tests {
impl BackgroundTask for AsyncRetryTask {
const TASK_NAME: &'static str = "AsyncRetryTask";
type AppData = ();
type Error = TaskError;
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Err(TaskError::SomethingWrong.into())
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
Err(TaskError::SomethingWrong)
}
}
@ -297,8 +322,9 @@ mod async_worker_tests {
impl BackgroundTask for AsyncTaskType1 {
const TASK_NAME: &'static str = "AsyncTaskType1";
type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
Ok(())
}
}
@ -310,8 +336,9 @@ mod async_worker_tests {
impl BackgroundTask for AsyncTaskType2 {
const TASK_NAME: &'static str = "AsyncTaskType2";
type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
Ok(())
}
}

View file

@ -1,26 +1,24 @@
use crate::errors::BackieError;
use crate::queue::Queue;
use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
use crate::worker::{runnable, ExecuteTaskFn};
use crate::worker::{StateFn, Worker};
use crate::RetentionMode;
use futures::future::join_all;
use std::collections::BTreeMap;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
#[derive(Clone)]
pub struct WorkerPool<AppData, S>
where
AppData: Clone + Send + 'static,
S: TaskStore,
S: TaskStore + Clone,
{
/// Storage of tasks.
task_store: Arc<S>,
/// Queue used to spawn tasks.
queue: Queue<S>,
task_store: S,
/// Make possible to load the application data.
///
@ -36,28 +34,21 @@ where
queue_tasks: BTreeMap<String, Vec<String>>,
/// Number of workers that will be spawned per queue.
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
worker_queues: BTreeMap<String, QueueConfig>,
}
impl<AppData, S> WorkerPool<AppData, S>
where
AppData: Clone + Send + 'static,
S: TaskStore,
S: TaskStore + Clone,
{
/// Create a new worker pool.
pub fn new<A>(task_store: S, application_data_fn: A) -> Self
where
A: Fn(Queue<S>) -> AppData + Send + Sync + 'static,
A: Fn() -> AppData + Send + Sync + 'static,
{
let queue_store = Arc::new(task_store);
let queue = Queue::new(queue_store.clone());
let application_data_fn = {
let queue = queue.clone();
move || application_data_fn(queue.clone())
};
Self {
task_store: queue_store,
queue,
task_store,
application_data_fn: Arc::new(application_data_fn),
task_registry: BTreeMap::new(),
queue_tasks: BTreeMap::new(),
@ -79,21 +70,12 @@ where
self
}
pub fn configure_queue(
mut self,
queue_name: impl ToString,
num_workers: u32,
retention_mode: RetentionMode,
) -> Self {
self.worker_queues
.insert(queue_name.to_string(), (retention_mode, num_workers));
pub fn configure_queue(mut self, config: QueueConfig) -> Self {
self.worker_queues.insert(config.name.clone(), config);
self
}
pub async fn start<F>(
self,
graceful_shutdown: F,
) -> Result<(JoinHandle<()>, Queue<S>), BackieError>
pub async fn start<F>(self, graceful_shutdown: F) -> Result<JoinHandle<()>, BackieError>
where
F: Future<Output = ()> + Send + 'static,
{
@ -106,39 +88,125 @@ where
let (tx, rx) = tokio::sync::watch::channel(());
let mut worker_handles = Vec::new();
// Spawn all individual workers per queue
for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() {
for idx in 0..*num_workers {
for (queue_name, queue_config) in self.worker_queues.iter() {
for idx in 0..queue_config.num_workers {
let mut worker: Worker<AppData, S> = Worker::new(
self.task_store.clone(),
queue_name.clone(),
*retention_mode,
queue_config.retention_mode,
queue_config.pull_interval,
self.task_registry.clone(),
self.application_data_fn.clone(),
Some(rx.clone()),
);
let worker_name = format!("worker-{queue_name}-{idx}");
// TODO: grab the join handle for every worker for graceful shutdown
tokio::spawn(async move {
// grabs the join handle for every worker for graceful shutdown
let join_handle = tokio::spawn(async move {
match worker.run_tasks().await {
Ok(()) => log::info!("Worker {worker_name} stopped successfully"),
Err(err) => log::error!("Worker {worker_name} stopped due to error: {err}"),
}
});
worker_handles.push(join_handle);
}
}
Ok((
tokio::spawn(async move {
graceful_shutdown.await;
if let Err(err) = tx.send(()) {
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
Ok(tokio::spawn(async move {
graceful_shutdown.await;
if let Err(err) = tx.send(()) {
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
} else {
// Wait for all workers to finish processing
let results = join_all(worker_handles)
.await
.into_iter()
.filter(Result::is_err)
.map(Result::unwrap_err)
.collect::<Vec<_>>();
if !results.is_empty() {
log::error!("Worker pool stopped with errors: {:?}", results);
} else {
log::info!("Worker pool stopped gracefully");
}
}),
self.queue,
))
}
}))
}
}
/// Configuration for a queue.
///
/// This is used to configure the number of workers, the retention mode, and the pulling interval
/// for a queue.
///
/// # Examples
///
/// Example of configuring a queue with all options:
/// ```
/// # use backie::QueueConfig;
/// # use backie::RetentionMode;
/// # use std::time::Duration;
/// let config = QueueConfig::new("default")
/// .num_workers(5)
/// .retention_mode(RetentionMode::KeepAll)
/// .pull_interval(Duration::from_secs(1));
/// ```
/// Example of queue configuration with default options:
/// ```
/// # use backie::QueueConfig;
/// let config = QueueConfig::new("default");
/// // Also possible to use the `From` trait:
/// let config: QueueConfig = "default".into();
/// ```
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct QueueConfig {
name: String,
num_workers: u32,
retention_mode: RetentionMode,
pull_interval: Duration,
}
impl QueueConfig {
/// Create a new queue configuration.
pub fn new(name: impl ToString) -> Self {
Self {
name: name.to_string(),
num_workers: 1,
retention_mode: RetentionMode::default(),
pull_interval: Duration::from_secs(1),
}
}
/// Set the number of workers for this queue.
pub fn num_workers(mut self, num_workers: u32) -> Self {
self.num_workers = num_workers;
self
}
/// Set the retention mode for this queue.
pub fn retention_mode(mut self, retention_mode: RetentionMode) -> Self {
self.retention_mode = retention_mode;
self
}
/// Set the pull interval for this queue.
///
/// This is the interval at which the queue will be checking for new tasks by calling
/// the backend storage.
pub fn pull_interval(mut self, pull_interval: Duration) -> Self {
self.pull_interval = pull_interval;
self
}
}
impl<S> From<S> for QueueConfig
where
S: ToString,
{
fn from(name: S) -> Self {
Self::new(name.to_string())
}
}
@ -148,13 +216,16 @@ mod tests {
use crate::store::test_store::MemoryTaskStore;
use crate::store::PgTaskStore;
use crate::task::CurrentTask;
use crate::Queue;
use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
use futures::FutureExt;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::Mutex;
#[derive(Clone, Debug)]
struct ApplicationContext {
pub struct ApplicationContext {
app_name: String,
}
@ -175,17 +246,50 @@ mod tests {
person: String,
}
/// This tests that one can customize the task parameters for the application.
#[async_trait]
impl BackgroundTask for GreetingTask {
const TASK_NAME: &'static str = "my_task";
trait MyAppTask {
const TASK_NAME: &'static str;
const QUEUE: &'static str = "default";
async fn run(
&self,
task_info: CurrentTask,
app_context: ApplicationContext,
) -> Result<(), ()>;
}
#[async_trait]
impl<T> BackgroundTask for T
where
T: MyAppTask + serde::de::DeserializeOwned + serde::ser::Serialize + Sync + Send + 'static,
{
const TASK_NAME: &'static str = T::TASK_NAME;
const QUEUE: &'static str = T::QUEUE;
type AppData = ApplicationContext;
type Error = ();
async fn run(
&self,
task_info: CurrentTask,
app_context: Self::AppData,
) -> Result<(), anyhow::Error> {
) -> Result<(), Self::Error> {
self.run(task_info, app_context).await
}
}
#[async_trait]
impl MyAppTask for GreetingTask {
const TASK_NAME: &'static str = "my_task";
async fn run(
&self,
task_info: CurrentTask,
app_context: ApplicationContext,
) -> Result<(), ()> {
println!(
"[{}] Hello {}! I'm {}.",
task_info.id(),
@ -206,12 +310,9 @@ mod tests {
const QUEUE: &'static str = "other_queue";
type AppData = ApplicationContext;
type Error = ();
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
println!(
"[{}] Other task with {}!",
task.id(),
@ -221,41 +322,11 @@ mod tests {
}
}
#[derive(Clone)]
struct NotifyFinishedContext {
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct NotifyFinished;
#[async_trait]
impl BackgroundTask for NotifyFinished {
const TASK_NAME: &'static str = "notify_finished";
type AppData = NotifyFinishedContext;
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
match context.tx.lock().await.take() {
None => println!("Cannot notify, already done that!"),
Some(tx) => {
tx.send(()).unwrap();
println!("[{}] Notify finished did it's job!", task.id())
}
};
Ok(())
}
}
#[tokio::test]
async fn validate_all_registered_tasks_queues_are_configured() {
let my_app_context = ApplicationContext::new();
let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
let result = WorkerPool::new(memory_store(), move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.start(futures::future::ready(()))
.await;
@ -273,14 +344,16 @@ mod tests {
async fn test_worker_pool_with_task() {
let my_app_context = ApplicationContext::new();
let (join_handle, queue) =
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
.start(futures::future::ready(()))
.await
.unwrap();
let task_store = memory_store();
let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(<GreetingTask as MyAppTask>::QUEUE.into())
.start(futures::future::ready(()))
.await
.unwrap();
let queue = Queue::new(task_store);
queue
.enqueue(GreetingTask {
person: "Rafael".to_string(),
@ -295,16 +368,17 @@ mod tests {
async fn test_worker_pool_with_multiple_task_types() {
let my_app_context = ApplicationContext::new();
let (join_handle, queue) =
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.register_task_type::<OtherTask>()
.configure_queue("default", 1, RetentionMode::default())
.configure_queue("other_queue", 1, RetentionMode::default())
.start(futures::future::ready(()))
.await
.unwrap();
let task_store = memory_store();
let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.register_task_type::<OtherTask>()
.configure_queue("default".into())
.configure_queue("other_queue".into())
.start(futures::future::ready(()))
.await
.unwrap();
let queue = Queue::new(task_store.clone());
queue
.enqueue(GreetingTask {
person: "Rafael".to_string(),
@ -319,23 +393,56 @@ mod tests {
#[tokio::test]
async fn test_worker_pool_stop_after_task_execute() {
#[derive(Clone)]
struct NotifyFinishedContext {
/// Used to notify the task ran
notify_finished: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
}
/// A task that notifies the test that it ran
#[derive(serde::Serialize, serde::Deserialize)]
struct NotifyFinished;
#[async_trait]
impl BackgroundTask for NotifyFinished {
const TASK_NAME: &'static str = "notify_finished";
type AppData = NotifyFinishedContext;
type Error = ();
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> {
// Notify the test that the task ran
match context.notify_finished.lock().await.take() {
None => println!("Cannot notify, already done that!"),
Some(tx) => {
tx.send(()).unwrap();
println!("[{}] Notify finished did it's job!", task.id())
}
};
Ok(())
}
}
let (tx, rx) = tokio::sync::oneshot::channel();
let my_app_context = NotifyFinishedContext {
tx: Arc::new(Mutex::new(Some(tx))),
notify_finished: Arc::new(Mutex::new(Some(tx))),
};
let (join_handle, queue) =
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<NotifyFinished>()
.configure_queue("default", 1, RetentionMode::default())
.start(async move {
rx.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let memory_store = memory_store();
let join_handle = WorkerPool::new(memory_store.clone(), move || my_app_context.clone())
.register_task_type::<NotifyFinished>()
.configure_queue("default".into())
.start(async move {
rx.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let queue = Queue::new(memory_store);
// Notifies the worker pool to stop after the task is executed
queue.enqueue(NotifyFinished).await.unwrap();
@ -347,6 +454,44 @@ mod tests {
#[tokio::test]
async fn test_worker_pool_try_to_run_unknown_task() {
#[derive(Clone)]
struct NotifyUnknownRanContext {
/// Notify that application should stop
should_stop: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
/// Used to mark if the unknown task ran
unknown_task_ran: Arc<AtomicBool>,
}
/// A task that notifies the test that it ran
#[derive(serde::Serialize, serde::Deserialize)]
struct NotifyStopDuringRun;
#[async_trait]
impl BackgroundTask for NotifyStopDuringRun {
const TASK_NAME: &'static str = "notify_finished";
type AppData = NotifyUnknownRanContext;
type Error = ();
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), Self::Error> {
// Notify the test that the task ran
match context.should_stop.lock().await.take() {
None => println!("Cannot notify, already done that!"),
Some(tx) => {
tx.send(()).unwrap();
println!("[{}] Notify finished did it's job!", task.id())
}
};
Ok(())
}
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct UnknownTask;
@ -354,46 +499,211 @@ mod tests {
impl BackgroundTask for UnknownTask {
const TASK_NAME: &'static str = "unknown_task";
type AppData = NotifyFinishedContext;
type AppData = NotifyUnknownRanContext;
async fn run(
&self,
task: CurrentTask,
_context: Self::AppData,
) -> Result<(), anyhow::Error> {
type Error = ();
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> {
println!("[{}] Unknown task ran!", task.id());
context.unknown_task_ran.store(true, Ordering::Relaxed);
Ok(())
}
}
let (tx, rx) = tokio::sync::oneshot::channel();
let my_app_context = NotifyFinishedContext {
tx: Arc::new(Mutex::new(Some(tx))),
let my_app_context = NotifyUnknownRanContext {
should_stop: Arc::new(Mutex::new(Some(tx))),
unknown_task_ran: Arc::new(AtomicBool::new(false)),
};
let task_store = memory_store().await;
let task_store = memory_store();
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
.register_task_type::<NotifyFinished>()
.configure_queue("default", 1, RetentionMode::default())
.start(async move {
rx.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let join_handle = WorkerPool::new(task_store.clone(), {
let my_app_context = my_app_context.clone();
move || my_app_context.clone()
})
.register_task_type::<NotifyStopDuringRun>()
.configure_queue("default".into())
.start(async move {
rx.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let queue = Queue::new(task_store);
// Enqueue a task that is not registered
queue.enqueue(UnknownTask).await.unwrap();
// Notifies the worker pool to stop for this test
queue.enqueue(NotifyFinished).await.unwrap();
queue.enqueue(NotifyStopDuringRun).await.unwrap();
join_handle.await.unwrap();
assert!(
!my_app_context.unknown_task_ran.load(Ordering::Relaxed),
"Unknown task ran but it is not registered in the worker pool!"
);
}
async fn memory_store() -> MemoryTaskStore {
#[tokio::test]
async fn task_can_panic_and_not_affect_worker() {
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct BrokenTask;
#[async_trait]
impl BackgroundTask for BrokenTask {
const TASK_NAME: &'static str = "panic_me";
type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _context: Self::AppData) -> Result<(), ()> {
panic!("Oh no!");
}
}
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
let task_store = memory_store();
let worker_pool_finished = WorkerPool::new(task_store.clone(), || ())
.register_task_type::<BrokenTask>()
.configure_queue("default".into())
.start(async move {
should_stop.await.unwrap();
})
.await
.unwrap();
let queue = Queue::new(task_store.clone());
// Enqueue a task that will panic
queue.enqueue(BrokenTask).await.unwrap();
notify_stop_worker_pool.send(()).unwrap();
worker_pool_finished.await.unwrap();
let raw_task = task_store
.tasks
.lock()
.await
.first_entry()
.unwrap()
.remove();
assert_eq!(
serde_json::to_string(&raw_task.error_info.unwrap()).unwrap(),
"{\"error\":\"Task panicked with: Oh no!\"}"
);
}
/// This test will make sure that the worker pool will only stop after all workers are done.
/// We create a KeepAliveTask that will keep running until we notify it to stop.
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
/// Then we notify the KeepAliveTask to stop and make sure that the worker pool stops.
#[tokio::test]
async fn tasks_only_stop_running_when_finished() {
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
enum PingPongGame {
Ping,
Pong,
StopThisNow,
}
#[derive(Clone)]
struct PlayerContext {
/// Used to communicate with the running task
pong_tx: Arc<tokio::sync::mpsc::Sender<PingPongGame>>,
ping_rx: Arc<Mutex<tokio::sync::mpsc::Receiver<PingPongGame>>>,
}
/// Task that will respond to the ping pong game and keep alive as long as we need
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct KeepAliveTask;
#[async_trait]
impl BackgroundTask for KeepAliveTask {
const TASK_NAME: &'static str = "keep_alive_task";
type AppData = PlayerContext;
type Error = ();
async fn run(
&self,
_task: CurrentTask,
context: Self::AppData,
) -> Result<(), Self::Error> {
loop {
let msg = context.ping_rx.lock().await.recv().await.unwrap();
match msg {
PingPongGame::Ping => {
println!("Pong!");
context.pong_tx.send(PingPongGame::Pong).await.unwrap();
}
PingPongGame::Pong => {
context.pong_tx.send(PingPongGame::Ping).await.unwrap();
}
PingPongGame::StopThisNow => {
println!("Got stop signal, stopping the ping pong game now!");
break;
}
}
}
Ok(())
}
}
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
let (pong_tx, mut pong_rx) = tokio::sync::mpsc::channel(1);
let (ping_tx, ping_rx) = tokio::sync::mpsc::channel(1);
let player_context = PlayerContext {
pong_tx: Arc::new(pong_tx),
ping_rx: Arc::new(Mutex::new(ping_rx)),
};
let task_store = memory_store();
let worker_pool_finished = WorkerPool::new(task_store.clone(), {
let player_context = player_context.clone();
move || player_context.clone()
})
.register_task_type::<KeepAliveTask>()
.configure_queue("default".into())
.start(async move {
should_stop.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let queue = Queue::new(task_store);
queue.enqueue(KeepAliveTask).await.unwrap();
// Make sure task is running
println!("Ping!");
ping_tx.send(PingPongGame::Ping).await.unwrap();
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
// Notify to stop the worker pool
notify_stop_worker_pool.send(()).unwrap();
// Make sure task is still running
println!("Ping!");
ping_tx.send(PingPongGame::Ping).await.unwrap();
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
// is_none() means that the worker pool is still waiting for tasks to finish, which is what we want!
assert!(
worker_pool_finished.now_or_never().is_none(),
"Worker pool finished before task stopped!"
);
// Notify to stop the task, which will stop the worker pool
ping_tx.send(PingPongGame::StopThisNow).await.unwrap();
}
fn memory_store() -> MemoryTaskStore {
MemoryTaskStore::default()
}
@ -402,13 +712,15 @@ mod tests {
async fn test_worker_pool_with_pg_store() {
let my_app_context = ApplicationContext::new();
let (join_handle, _queue) =
WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
.start(futures::future::ready(()))
.await
.unwrap();
let join_handle = WorkerPool::new(pg_task_store().await, move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(
QueueConfig::new(<GreetingTask as MyAppTask>::QUEUE)
.retention_mode(RetentionMode::RemoveDone),
)
.start(futures::future::ready(()))
.await
.unwrap();
join_handle.await.unwrap();
}