Source code for workflow.states

"""
Action classes to be called when receiving specific messages.

To add an action for a specific queue, add a StateAction class
with the name of the queue in lower-case, replacing periods with underscores.
"""

import importlib
import inspect
import json
import logging
import re

from .database import transactions
from .settings import CATALOG_DATA_READY, POSTPROCESS_ERROR, REDUCTION_CATALOG_DATA_READY, REDUCTION_DATA_READY
from .state_utilities import logged_action


[docs] class StateAction: """ Base class for processing messages """ _send_connection = None def __init__(self, connection=None, use_db_task=False): """ Initialization :param connection: AMQ connection to use to send messages :param use_db_task: if True, a task definition will be looked for in the DB when executing the action """ self._user_db_task = use_db_task self._send_connection = connection def _call_default_task(self, headers, message): """ Find a default task for the given message header :param headers: message headers :param message: JSON-encoded message content """ # Convert the message queue name into a class name destination = headers["destination"].replace("/queue/", "") destination = destination.replace(".", "_") destination = destination.capitalize() # Find a custom action for this message if destination in globals(): action_cls = globals()[destination] action_cls(connection=self._send_connection)(headers, message) def _get_class_from_path(self, class_path: str): """ Returns the class given by the class path :param class_path: the class, e.g. "module_name.ClassName" :return: class or None """ # check that the string is in the format "package_name.module_name.class_name" pattern = r"^[a-zA-Z0-9_\.]+\.[a-zA-Z0-9_]+$" if not re.match(pattern, class_path): logging.error(f"task_class {class_path} does not match pattern module_name.ClassName") return None module_name, class_name = class_path.rsplit(".", 1) # try importing the class try: module = importlib.import_module(module_name) cls = getattr(module, class_name) if not inspect.isclass(cls): raise ValueError return cls except (ModuleNotFoundError, AttributeError, ValueError): logging.error(f"task_class {class_path} cannot be imported") return None def _call_db_task(self, task_data, headers, message): """ :param task_data: JSON-encoded task definition :param headers: message headers :param message: JSON-encoded message content """ task_def = json.loads(task_data) if ( "task_class" in task_def and (task_def["task_class"] is not None) and len(task_def["task_class"].strip()) > 0 ): action_cls = self._get_class_from_path(task_def["task_class"]) if action_cls: try: action_cls(connection=self._send_connection)(headers, message) # noqa: F821 except: # noqa: E722 logging.exception("Task [%s] failed:", headers["destination"]) if "task_queues" in task_def: for item in task_def["task_queues"]: destination = "/queue/%s" % item self.send(destination=destination, message=message, persistent="true") headers = {"destination": destination, "message-id": ""} @logged_action def __call__(self, headers, message): """ Called to process a message :param headers: message headers :param message: JSON-encoded message content """ # Find task definition in DB if available if self._user_db_task: task_data = transactions.get_task(headers, message) if task_data is not None: self._call_db_task(task_data, headers, message) return # If we made it here we need to use default tasks self._call_default_task(headers, message)
[docs] def send(self, destination, message, persistent="true"): """ Send a message to a queue :param destination: name of the queue :param message: message content """ logging.debug("Send: %s" % destination) if self._send_connection is not None: self._send_connection.send(destination, message, persistent=persistent) headers = {"destination": destination, "message-id": ""} transactions.add_status_entry(headers, message) else: logging.error("No AMQ connection to send to %s" % destination) headers = {"destination": "/queue/%s" % POSTPROCESS_ERROR, "message-id": ""} data_dict = json.loads(message) data_dict["error"] = "No AMQ connection: Could not send to %s" % destination message = json.dumps(data_dict) transactions.add_status_entry(headers, message)
[docs] class Postprocess_data_ready(StateAction): """ Default action for POSTPROCESS.DATA_READY messages """ def __call__(self, headers, message): """ Called to process a message :param headers: message headers :param message: JSON-encoded message content """ # Tell workers for start processing self.send( destination="/queue/%s" % CATALOG_DATA_READY, message=message, persistent="true", ) self.send( destination="/queue/%s" % REDUCTION_DATA_READY, message=message, persistent="true", )
[docs] class Reduction_request(StateAction): """ Default action for REDUCTION.REQUEST messages """ def __call__(self, headers, message): """ Called to process a message :param headers: message headers :param message: JSON-encoded message content """ # Tell workers for start reduction self.send( destination="/queue/%s" % REDUCTION_DATA_READY, message=message, persistent="true", )
[docs] class Catalog_request(StateAction): """ Default action for CATALOG.REQUEST messages """ def __call__(self, headers, message): """ Called to process a message :param headers: message headers :param message: JSON-encoded message content """ # Tell workers for start cataloging self.send( destination="/queue/%s" % CATALOG_DATA_READY, message=message, persistent="true", )
[docs] class Reduction_complete(StateAction): """ Default action for REDUCTION.COMPLETE messages """ def __call__(self, headers, message): """ Called to process a message :param headers: message headers :param message: JSON-encoded message content """ # Tell workers to catalog the output self.send( destination="/queue/%s" % REDUCTION_CATALOG_DATA_READY, message=message, persistent="true", )