Generic Callbacks in Rust

In a recent project, we developed an ingestion pipeline in Rust. We wanted to execute asynchronous tasks in response to messages received from SQS, as a subscriber. Our goal was to leverage Rust generics to allow the subscriber to take ownership of the payload schema and the pipeline implementation details. This approach enables the separation of the SQS subscription logic from the subscriber’s responsibilities, promoting of cleaner separation of concerns. It also facilitates independent testing of both components, ensuring they can be assessed in isolation.

Here’s what we came up with:

use aws_sdk_sqs::Client;
use serde::{DeserializeOwned, Serialize};
use serde_json::Error as JsonError;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum QueueError {
  #[error("Error receiving message from SQS queue: {0}")]
  ReceiveError(#[from] ReceiveMessageError),

  #[error("Error deleting message from SQS queue: {0}")]
  DeleteError(#[from] DeleteMessageError),

  #[error("Error sending message to SQS queue: {0}")]
  SendError(#[from] SendMessageError),

  #[error("Error serializing/deserializing JSON: {0}")]
  JsonSerializationError(JsonError),

  #[error("Callback failed: {0}")]
  CallbackError(#[source] Box<dyn std::error::Error + Send + Sync>),
}

pub async fn subscribe<T, F, Fut>(client: &Client, queue: &str, callback: F) -> Result<(), QueueError>
where
  T: DeserializeOwned,
  F: Fn(T) -> Fut + Send + Sync,
  Fut: Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send,
{
  loop {
    match client
      .receive_messages()
      .queue_url(queue)
      .send()
      .await
      {
        Ok(response) => {
          for message in response.messages.unwrap_or_default() {
            log::info!("Received SQS Message: {:?}, from Queue: {}", message, queue);
            let cloned_message = message.clone();
            let receipt_handle = cloned_message.receipt_handle.unwrap();
            if let Some(body) = message.body {
              match serde_json::from_str::<T>(&body) {
                Ok(message) => {
                  match callback(message).await {
                    Ok(_) => log::info!("Callback completed successfully for SQS queue: {}", queue),
                    Err(e) => {
                      log::error!("Callback failed for SQS Queue: {}, error: {:?}", queue, e);
                      return Err(QueueError::CallbackError(e.into()));
                    }
                  }

                  match client
                    .delete_message()
                    .queue_url(queue)
                    .receipt_handle(receipt_handle)
                    .send()
                    .await
                    {
                      Ok(_) => log::info!("Message deleted from SQS queue: {}", queue),
                      Err(e) => {
                        log::error!("Failed to delete message from queue: {}, error: {:?}", queue, e);
                        return Err(QueueError::DeleteError(e));
                      }
                    }
                }
                Err(e) => {
                  log::error!("Error decoding message from SQS queue: {}, error: {:?}", queue, e);
                  return Err(QueueError::JsonSerializationError(e));
                }
              }
            }
          }
        }
        Err(e) => {
          log::error!("Error receive messages from SQS queue: {}, error: {:?}", queue, e);
          return Err(QueueError::ReceiveError(e));
        }
      }
  }
}

The subscribe method is generic on 3 things :

A similar approach for the publisher:

use serde::Serialize;
use aws_sdk_sqs::Client;

#[derive(Debug, Serialize)]
pub struct Message<T: Serialize> {
  pub body: T,
}

pub async fn publish<T>(client: &Client, queue: &str, message: &Message<T>) -> Result<(), QueueError>
where
  T: Serialize,
{
  client
    .send_message()
    .queue_url(queue)
    .message_body(serde_json::to_string(&message.body)?)
    .send()
  .await?;
  Ok(())
}

The publish function is generic on the type T. It must implement the Serialize trait. It is passed as the body field of the Message struct argument. This message body is then serialized to a string when calling the SQS send_message API.

comments powered by Disqus