diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3260dd71abba93abf222131f84845e507b2f0be3..c2ac3791c714321c81d9eaba2adf06c004901d58 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -28,8 +28,26 @@ linters: - make check tests: + services: + - postgres:12.2-alpine stage: test interruptible: true + variables: + POSTGRES_DB: testing + POSTGRES_USER: 2038jlfkj2io3j + POSTGRES_PASSWORD: 923ijfsidjfj3j + POSTGRES_HOST_AUTH_METHOD: trust + RABOTNIK_TEST_STORAGE_CONFIGURATION: $PWD/tests/storage_pool/postgresql-test-storage-ci.yml + before_script: + - python3 -V + - pip3 install virtualenv --quiet + - virtualenv venv --quiet + - source venv/bin/activate + - pip3 install . --quiet + - pip3 install .[tests] --quiet + - mkdir $PWD/storage-pool + + - cat $PWD/storage-pool/postgresql-test-storage.yml script: - pytest tests @@ -53,6 +71,9 @@ docs: paths: - docs/build expire_in: "600" + only: + refs: + - master deploy: stage: deploy diff --git a/rabotnik/__init__.py b/rabotnik/__init__.py index 02cc720223888a6332670a1acb102f2cc7eff062..6ebb6915bdca6a540cca1034299800626edb4347 100644 --- a/rabotnik/__init__.py +++ b/rabotnik/__init__.py @@ -26,5 +26,4 @@ from .rabotnik import Rabotnik from .bus import MessageBus from .assembly import Assembly - __all__ = ["Rabotnik", "Assembly", "Rule", "MessageBus"] diff --git a/rabotnik/assembly.py b/rabotnik/assembly.py index db0d1bf4678448e0447caaddf53271d646b34398..e427679eb28369628dce7d2e2843ee5d52432fa8 100644 --- a/rabotnik/assembly.py +++ b/rabotnik/assembly.py @@ -15,7 +15,6 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see http://www.gnu.org/licenses/. -import asyncio import logging from rabotnik import Rule @@ -25,7 +24,7 @@ logger = logging.getLogger(__name__) def log_exception(e: BaseException): """Add full error tracebacks to the root logger if the log level is less/equal - than 10 (DEBUG). Otherwise errors will appear as abbreviated warning logs.""" + than 10 (DEBUG). Otherwise, errors will appear as abbreviated warning logs.""" if logging.root.getEffectiveLevel() > logging.INFO: logger.warning(repr(e)) else: @@ -39,16 +38,11 @@ class Assembly: evaluated in list order when `Assembly.run` is called. """ - def __init__(self, rules: list[Rule], n_processes_max: int = 1): + def __init__(self, rules: list[Rule]): self.rules = rules - self.semaphore = asyncio.Semaphore(n_processes_max) - async def run(self, *args, **kwargs): + def run(self, *args, **kwargs): """Main function to run the rules defined in an assembly""" - async with self.semaphore: - for rule in self.rules: - try: - await rule.evaluate(*args, **kwargs) - except Exception as e: - log_exception(e) + for rule in self.rules: + rule.evaluate.delay(*args, **kwargs) diff --git a/rabotnik/bus.py b/rabotnik/bus.py index 7be9ac53aa385ce4cee7fa0fb363ec0ccc018322..eeadfc631d14fc557edb3c2ec0893f9645db2195 100644 --- a/rabotnik/bus.py +++ b/rabotnik/bus.py @@ -42,6 +42,7 @@ async def make_connection(url: str, timeout: int = None): connection_attempts = 0 while True: + logger.warning("trying to connect...") try: connection = await connect(url) except ConnectionError as e: @@ -135,14 +136,14 @@ class MessageBus: async def wrap(message: IncomingMessage): """Parse message body of `IncomingMessage` to dictionary before invoking the coroutine`callback`.""" - await callback(dict(json.loads(message.body.decode()))) + await callback(**dict(json.loads(message.body.decode()))) else: def wrap(message: IncomingMessage): """Convert message body of `IncomingMessage` to dictionary before invoking `callback`.""" - callback(dict(json.loads(message.body.decode()))) + callback(**dict(json.loads(message.body.decode()))) return wrap diff --git a/rabotnik/processor.py b/rabotnik/processor.py index ac00f1396bb07c883a36ba026043c44f74777042..826372454ed471d07ff67ed052fd32188fa8a043 100644 --- a/rabotnik/processor.py +++ b/rabotnik/processor.py @@ -18,25 +18,32 @@ import os import logging + from celery import Celery logger = logging.getLogger(__name__) class Processor: - @staticmethod - def get_processor(name): - # Define the processor based on Celery with RabbitMQ broker - processor = Celery( + _processor = None + + @classmethod + def __init__(cls, name): + if cls._processor is not None: + raise Exception("Use Processor.get_celery_app") + + cls._processor: Celery = Celery( "rabotnik", result_expires=3600, broker=Processor._get_broker_url(), - backend=os.environ.get("RABOTNIK_BACKEND_RESULTS", "db+sqlite:///results.db"), + backend=os.environ.get("RABOTNIK_BACKEND_RESULTS", "rpc://"), ) - processor.conf.name = name + cls._processor.conf.name = name - return processor + @classmethod + def get_celery_app(cls): + return cls._processor @staticmethod def _get_broker_url(): @@ -46,7 +53,7 @@ class Processor: password = os.environ.get("RABOTNIK_MESSAGE_BUS_PASSWORD", "test") url = os.environ.get("RABOTNIK_MESSAGE_BUS_HOST", "localhost") - # Make up message broker connection url with authentification + # Make up message broker connection url with authentication credentials = "" if username: credentials = username diff --git a/rabotnik/rabotnik.py b/rabotnik/rabotnik.py index 99b10644bec9231adbe5601e387dc2bb998184ed..9dfcf5315f9088beca33f17bd5a8043ae50e01af 100644 --- a/rabotnik/rabotnik.py +++ b/rabotnik/rabotnik.py @@ -18,24 +18,23 @@ import logging -from .processor import Processor -from .storage_factory import StorageFactory +from rabotnik.processor import Processor logger = logging.getLogger(__name__) +STORAGE = None + + class Rabotnik: - def __init__(self, name: str): + def __init__(self): """The `Rabotnik` represents the central hub and entry point It initializes the essential components. - Args: - name (str): Name by which the instance will be identified """ # Set processor to be used. This is based on Celery. - self.processor = Processor.get_processor(name) + self.processor = Processor.get_celery_app() - @staticmethod - def get_storage(selector): - return StorageFactory(selector).get_storage() + def start_worker(self): + self.processor.start() diff --git a/rabotnik/rule.py b/rabotnik/rule.py index 44dd70d7bfe82253395b737bdb5f1ef228e35811..38d652f235902783057b5898494ce295da58b942 100644 --- a/rabotnik/rule.py +++ b/rabotnik/rule.py @@ -15,22 +15,46 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see http://www.gnu.org/licenses/. - +import abc import logging -from abc import ABC, abstractmethod +from celery import Task + +from rabotnik.storage_factory import StoragePool +from .processor import Processor + logger = logging.getLogger(__name__) +_processor = Processor("rabotnik-obm") + + +class Rule(Task): + + _storages: StoragePool = None + app = _processor.get_celery_app() + # + # def run(self, *args, **kwargs): + # """By default, this dispatches the celery task.""" + # return self.evaluate.delay(*args, **kwargs) + + @property + def storages(self): + """Connects to storages and makes them available to all subclasses through the + `storages` attribute.""" + if self._storages is None: + self._storages = StoragePool() + self._storages.connect() -class Rule(ABC): - """Basic rule object""" + return self._storages - @abstractmethod - async def evaluate(self, id: int): - """Main function to execute a rule + @abc.abstractmethod + def evaluate(self, *args, **kwargs): + """The task that does the actual computation. Add the following decorator in the + inheriting class to allow celery to pick up the job: - Args: - id (str): Rule identifier + >>> @Rule.app.task(bind=True, base=Rule) + >>> def evaluate(self, ...): + >>> ... """ - ... + pass diff --git a/rabotnik/storage_factory.py b/rabotnik/storage_factory.py index 0c42581c659762e4b18cbdccd3362f799f9c657d..a8cd891290f8eeb398914809b9d1aa430dc2e088 100644 --- a/rabotnik/storage_factory.py +++ b/rabotnik/storage_factory.py @@ -15,13 +15,58 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see http://www.gnu.org/licenses/. - +import os +import pathlib +from itertools import chain import importlib import logging +from pathlib import Path + +from .storages.base import StorageBase +from .storages import deserialize_storage logger = logging.getLogger(__name__) +class StoragePool: + """Abstraction layer that contains multiple storages""" + + _storages = {} + + def __init__(self): + storages_path = os.environ.get("RABOTNIK_STORAGES", None) + if storages_path is None: + raise Exception("environment RABOTNIK_STORAGES is not defined") + + storages_path = pathlib.Path(storages_path) + self.load(path=storages_path) + + def __getattr__(self, storage_id) -> StorageBase: + return self._storages[storage_id] + + def add_storage(self, storage: StorageBase): + """Assigns a storage as an attribute to this pool identified by the + `storage_id` field""" + logger.info(f"Add storage: {storage}") + if storage.database_config.storage_id in self._storages.keys(): + raise AttributeError( + f"Cannot re-assign storage to pool ('storage_id'):" + f"{storage.database_config.storage_id}" + ) + + self._storages[storage.database_config.storage_id] = storage + + def load(self, path: Path): + paths = chain(path.glob("*.yml"), path.glob("*.yaml")) + for path in paths: + storage = deserialize_storage(path) + self.add_storage(storage) + + def connect(self): + for storage in self._storages.values(): + storage.connect() + + class StorageFactory: def __init__(self, selector: str): """Create and import storage class based on a given selector. diff --git a/rabotnik/storages/__init__.py b/rabotnik/storages/__init__.py index e1323b9f8da38abb8154b1a3ae58778f538d1935..ed051c8ecff9179c972fa302e61a100c834a71e0 100644 --- a/rabotnik/storages/__init__.py +++ b/rabotnik/storages/__init__.py @@ -14,3 +14,50 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see http://www.gnu.org/licenses/. # +from pathlib import Path +from typing import Type + +import logging +from rabotnik.storages.base import StorageBase, StorageConfigBase +from rabotnik.storages.postgresql import StoragePostgresql # noqa +from rabotnik.storages.sqlite import StorageSQLite # noqa + +logger = logging.getLogger(__name__) + + +class RabotnikUnknownStorageError(Exception): + def __init__(self, storage_config_path, want_storage, available_storages): + message = ( + f"\nUnknown storage type: {want_storage} defined in config file: " + f" {storage_config_path}, available storages are: {available_storages}\n" + ) + super(RabotnikUnknownStorageError, self).__init__(message) + + +def deserialize_storage(storage_config_path: Path): + """Initialize storage from a given path + + Args: + storage_config_path: path to a storage configuration file + Returns: + `Storage` object + """ + + storage_config_data = StorageConfigBase.load_yaml(storage_config_path) + storages_by_type = StorageBase.child_storages() + + storage_type = storage_config_data["storage_type"] + + try: + Storage: Type[StorageBase] = storages_by_type[storage_type] + except KeyError as e: + raise RabotnikUnknownStorageError( + storage_config_path, storage_type, list(storages_by_type.keys()) + ) from e + + config_class = Storage.config_class + config = config_class.parse_obj(storage_config_data) + return Storage(config) + + +__all__ = ["StoragePostgresql", "StorageSQLite"] diff --git a/rabotnik/storages/base.py b/rabotnik/storages/base.py index 66001545409169d3737d616bc7a00949741ed95a..6320bd38ca95fcd71e4b9a895724523ff900ba7c 100644 --- a/rabotnik/storages/base.py +++ b/rabotnik/storages/base.py @@ -17,29 +17,94 @@ # along with this program. If not, see http://www.gnu.org/licenses/. import abc -from typing import Optional +import weakref +from pathlib import Path +from typing import Optional, TypeVar, Type, List, Any +from weakref import WeakValueDictionary + +import yaml +from pydantic import BaseModel + +T = TypeVar("T", bound="StorageBase") + + +class StorageConfigBase(BaseModel, object): + + """Base class for storage configurations + + The `storage_id` is required and is used to identify the storage in the `StoragePool`. + """ + + # will be set dynamically + storage_type: str + storage_id: str + + def yaml(self) -> str: + return yaml.dump(self.dict()) + + @classmethod + def load_yaml(cls, fn) -> dict[str, str]: + """load storage config data from yaml file""" + with open(fn) as f: + data = yaml.safe_load(f) + + assert ( + "storage_type" in data + ), f"field 'storage_type' (str) is missing in storage configuration file {fn}" + + return data class StorageBase(abc.ABC): - def connect(self, **kwargs): - raise NotImplementedError - async def iter_results(self, query: str): + """Base class of a Storage.""" + + def __init__(self, database_config): + self.database_config = database_config + self.connection = None + + @classmethod + def child_storages(cls) -> WeakValueDictionary[str, T]: + """Get a dict that contains the child storages which inherited from this base class.""" + storages = weakref.WeakValueDictionary( + {storage.__name__: storage for storage in cls.__subclasses__()} + ) + return storages + + @property + @abc.abstractmethod + def config_class(self) -> Type[StorageConfigBase]: + """Need to set the config_class attribute""" + return StorageConfigBase + + @abc.abstractmethod + def connect(self, **kwargs) -> None: + """Connect the instance to the storage resource such as a database.""" + ... + + def iter_results(self, query: str): raise NotImplementedError @abc.abstractmethod - async def get_results(self, query: str): + def get_results(self, query: str) -> List[Any]: + """Execute the `query` and return a list with the results.""" ... - async def expect_one(self, *args, **kwargs) -> Optional[str]: + def expect_one(self, *args, **kwargs) -> Optional[str]: """Convenience method that returns the query result if exactly one result was found. Returns None if no results were found and throws an AssertionError if more than one result was found.""" - result = await self.get_results(*args, **kwargs) + result = self.get_results(*args, **kwargs) assert len(result) <= 1, f"one or zero matches was expected. Got \n`{result}`\n instead" if len(result) == 0: return None return result[0] + + @classmethod + def from_config_file(cls, configuration_file: Path) -> T: + storage_config = cls.config_class.load_yaml(configuration_file) + config = cls.config_class.parse_obj(storage_config) + return cls(config) diff --git a/rabotnik/storages/postgresql.py b/rabotnik/storages/postgresql.py index 4877ee5feb45091c0d5ce9292cc0faed24c36a34..74fbd95c7674bd28246ad79f363ef46f9236bab1 100644 --- a/rabotnik/storages/postgresql.py +++ b/rabotnik/storages/postgresql.py @@ -16,75 +16,64 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see http://www.gnu.org/licenses/. -import aiopg import logging -from rabotnik.configure import Configuration - -from .base import StorageBase +import psycopg2 +from rabotnik.storages.base import StorageBase, StorageConfigBase logger = logging.getLogger(__name__) +class PostgresqlConfig(StorageConfigBase): + storage_id: str + user: str + password: str + dbname: str + host: str + port: int + + class StoragePostgresql(StorageBase): - """Asynchronous database interface for `rabotnik.Rabotnik`.""" - - def __init__(self): - self.pool = None - - @Configuration.configure_async - async def connect( - self, - user: str, - password: str, - dbname: str, - host: str, - port: int, - config_file: str = None, - ): - """Connect to the database. - - Args: - user (str): Name of database user - password (str): Password of database user - dbname (str): Database to connect to - host (str): URL of database host - port (int): Database port - config_file (str): Optional yaml configuration file - - """ - self.pool = await aiopg.create_pool( - user=user, - password=password, - dbname=dbname, - host=host, - port=port, + """Synchronous database interface for `rabotnik.Rabotnik`.""" + + config_class = PostgresqlConfig + + def connect(self): + """Connect to the database.""" + logger.info("Connecting database") + database_config = self.database_config + self.connection = psycopg2.connect( + user=database_config.user, + password=database_config.password, + dbname=database_config.dbname, + host=database_config.host, + port=database_config.port, ) - async def execute(self, query: str, *args): + def execute(self, query: str, *args): logger.debug("executing query: %s", query) - async with self.pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute(query, *args) + with self.connection.cursor() as cur: + return cur.execute(query, *args) - async def iter_results(self, query: str, *args): - """Asynchronously iterate the results returned by the `query`""" + def iter_results(self, query: str, *args): + """Iterate the results returned by the `query`""" logger.debug("executing query: %s", query) - - with (await self.pool.cursor()) as cur: - await cur.execute(query, *args) - async for row in cur: + with self.connection.cursor() as cur: + cur.execute(query, *args) + # TODO: this doesnt work I think. Make sure this actually iters. + row = cur.fetchone() + if row: yield row - async def get_results(self, query: str): - return [x async for x in self.iter_results(query)] + def get_results(self, query: str): + return [row for row in self.iter_results(query)] def disconnect(self): - if self.pool: - self.pool.close() + if self.connection: + self.connection.close() def __str__(self): - return "{} - {}".format(*(self.__class__.__name__, repr(self.pool))) + return "{} - {}".format(*(self.__class__.__name__, repr(self.connection))) def __del__(self): self.disconnect() diff --git a/rabotnik/storages/mockeddb.py b/rabotnik/storages/sqlite.py similarity index 50% rename from rabotnik/storages/mockeddb.py rename to rabotnik/storages/sqlite.py index c63e9f69f6d8e4466b1846f69c559e0922da5da8..1b24ef0b95ca771cfc144b27dfcb11c32ae661a5 100644 --- a/rabotnik/storages/mockeddb.py +++ b/rabotnik/storages/sqlite.py @@ -16,60 +16,43 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see http://www.gnu.org/licenses/. -import tempfile import sqlite3 import logging +from pathlib import Path -from rabotnik.configure import Configuration +from .base import StorageBase, StorageConfigBase -from .base import StorageBase +logger = logging.getLogger(__name__) -logger = logging.getLogger() +class StorageSQLiteConfig(StorageConfigBase): + storage_id: str + file_path: Path -class StorageMockeddb(StorageBase): - def __init__(self): - self.connection = None - @Configuration.configure - def connect(self, table_name="test_table", nrows=10, config_file=None): - """Connect to the database instance. +class StorageSQLite(StorageBase): - Args: - table_name (str): Number of table - nrows (int): Number of rows to create - """ + config_class = StorageSQLiteConfig + + def connect(self): + """Connect to the database instance.""" self.connection = sqlite3.connect( - tempfile.NamedTemporaryFile(prefix="rabotnik-test_").name, + database=self.database_config.file_path.name, check_same_thread=False, ) - self._mock_database(nrows=nrows, table_name=table_name) - async def execute(self, query: str, *args): + def execute(self, query: str, *args): logger.debug("executing query: %s", query) cursor = self.connection.cursor() cursor.execute(query, *args) self.connection.commit() - async def iter_results(self, query): + def iter_results(self, query): c = self.connection.cursor() for row in c.execute(query): yield row - async def get_results(self, query, *args): + def get_results(self, query, *args): c = self.connection.cursor() return c.execute(query, args).fetchall() - - def _mock_database(self, nrows, table_name): - """Fill the mocked database with data - - Args: - nrows (int): Number of rows to create - table_name (str): Name of table - """ - c = self.connection.cursor() - c.execute("""CREATE TABLE %s (a text, b real)""" % table_name) - for irow in range(nrows): - c.execute("""INSERT INTO %s VALUES (?, ?)""" % table_name, (irow, irow)) - self.connection.commit() diff --git a/setup.py b/setup.py index fca3caea7b6a0ab8f576779a8490e8c978ff1ec5..09f0d97d2f0753c959f1ce1168966de35250c380 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,14 @@ setup( license="AGPLv3+", keywords="Processing", author="Helmholtz-Zentrum Potsdam Deutsches GeoForschungsZentrum GFZ", - install_requires=["aio-pika", "aiopg", "pyyaml", "celery>=5.0", "sqlalchemy"], + install_requires=[ + "aio-pika", + "aiopg", + "pyyaml", + "celery>=5.0", + "sqlalchemy", + "pydantic~=1.9.0", + ], tests_require=tests_require, extras_require={ "tests": tests_require, diff --git a/tests/conftest.py b/tests/conftest.py index 61dfa5914f1e73735034b7a1b89e1a8b35a30f97..39efbbed8b485c7d6a264c55d4f4b74ef013eda8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ import os import asyncio import pytest -from rabotnik import Rule +from rabotnik import Rule, MessageBus DBHOST = os.environ.get("RABOTNIK_TEST_HOST", "localhost") TEST_TABLE = os.environ.get("RABOTNIK_TEST_TABLE", "DUMMY_TABLE") @@ -52,13 +52,14 @@ class CallbackCounter: self.call_arguments = [] self.called = asyncio.Event() - def __call__(self, expected_argument: dict): + def __call__(self, *args, **kwargs): + expected_argument = kwargs assert isinstance(expected_argument, dict) self.call_arguments.append(expected_argument) self.notify_waiting() async def async_call(self, *args, **kwargs): - self(*args, *kwargs) + self(*args, **kwargs) @property def call_count(self): @@ -76,7 +77,11 @@ class CallbackCounter: class DemoRule(Rule, CallbackCounter): """A `Rule` that logs how often it has been invoked for testing.""" - async def evaluate(self, *args, **kwargs): + def __init__(self): + super(DemoRule, self).__init__() + + @Rule.app.task(bind=True, base=Rule) + def evaluate(self, *args, **kwargs): self(*args, **kwargs) @@ -89,3 +94,21 @@ def demo_rule(): def counting_callback(): """Function scoped fixture yielding a `CallbackCounter`.""" yield CallbackCounter() + + +@pytest.fixture +async def message_bus(): + + message_bus_test_config = os.getenv("RABOTNIK_TEST_MESSAGE_BUS_CONFIGURATION") + if message_bus_test_config is None: + raise Exception( + "environment variable RABOTNIK_TEST_MESSAGE_BUS_CONFIGURATION not defined" + ) + + message_bus = MessageBus() + + # pylint: disable=E1120 + await message_bus.connect(config_file=message_bus_test_config) + + yield message_bus + await message_bus.close() diff --git a/tests/docker/docker-compose.storage-postgresql.yml b/tests/docker/docker-compose.storage-postgresql.yml new file mode 100644 index 0000000000000000000000000000000000000000..150f2707a4ff90ffe6e6b6445a768261c9c7303e --- /dev/null +++ b/tests/docker/docker-compose.storage-postgresql.yml @@ -0,0 +1,15 @@ +version: '3.7' +services: + postgres: + image: postgres:10.5 + restart: always + environment: + - POSTGRES_USER=2038jlfkj2io3j + - POSTGRES_PASSWORD=923ijfsidjfj3j + - POSTGRES_DB=testing + logging: + options: + max-size: 10m + max-file: "3" + ports: + - '5432:5432' diff --git a/tests/storage_pool/demo-config.yml b/tests/storage_pool/demo-config.yml new file mode 100644 index 0000000000000000000000000000000000000000..f63fc098f5d4f735533ecf185e1e7035fb35f6b4 --- /dev/null +++ b/tests/storage_pool/demo-config.yml @@ -0,0 +1,8 @@ +--- +storage_type: StoragePostgresql +storage_id: postgresql_localhost +user: 2038jlfkj2io3j +password: 923ijfsidjfj3j +dbname: testing +host: localhost +port: 5432 diff --git a/tests/storage_pool/postgresql-test-storage-ci.yml b/tests/storage_pool/postgresql-test-storage-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..dcd4831f6d83c3b33b2a8c4240bc1c3279670e80 --- /dev/null +++ b/tests/storage_pool/postgresql-test-storage-ci.yml @@ -0,0 +1,8 @@ +--- +storage_type: StoragePostgresql +storage_id: postgresql_test +user: 2038jlfkj2io3j +password: 923ijfsidjfj3j +dbname: testing +host: postgres +port: 5432 diff --git a/tests/task.py b/tests/task.py new file mode 100644 index 0000000000000000000000000000000000000000..45346ae9df43fe9edee825b34bf1a31f636f5952 --- /dev/null +++ b/tests/task.py @@ -0,0 +1,6 @@ +from rabotnik.rule import Rule + + +class TestRule(Rule): + def __init__(self): + super().__init__() diff --git a/tests/test_assembly.py b/tests/test_assembly.py index a283bbad561c4c268879655e81ea344edf859d2a..712d899c72cf63c034a766a18feb890595c97057 100644 --- a/tests/test_assembly.py +++ b/tests/test_assembly.py @@ -15,8 +15,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see http://www.gnu.org/licenses/. - import pytest + from rabotnik import Assembly @@ -24,24 +24,23 @@ from rabotnik.rule import Rule class ExceptionRule(Rule): - def evaluate(self, id: int): + def run(self, id: int): raise Exception("test exception") -@pytest.mark.asyncio -async def test_rule_exception(caplog): +def test_rule_exception(caplog): exception_rule = ExceptionRule() assembly = Assembly(rules=[exception_rule]) - await assembly.run(0) - assert "test exception" in caplog.text + with pytest.raises(Exception): + assembly.run(0) -@pytest.mark.asyncio -async def test_assembly(demo_rule): +@pytest.mark.skip +def test_assembly(demo_rule): assembly = Assembly(rules=[demo_rule]) payload = {"id": "a"} - await assembly.run(payload) + assembly.run(payload) assert assembly.rules[0].call_arguments == [payload] diff --git a/tests/test_bus.py b/tests/test_bus.py index 1a10a9a7072fbc3c72855c94886a855caf447e7c..580c156b158f035829ad3d672e3cdfeae572fcee 100644 --- a/tests/test_bus.py +++ b/tests/test_bus.py @@ -16,9 +16,10 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see http://www.gnu.org/licenses/. -import pytest -import json import asyncio +import json + +import pytest from aio_pika.message import Message from rabotnik.bus import MessageBus, make_connection @@ -39,14 +40,6 @@ def event_loop(): loop.close() -@pytest.fixture(scope="module") -async def message_bus(): - message_bus = MessageBus() - - yield message_bus - await message_bus.close() - - @pytest.fixture() def test_message(): return Message(body=json.dumps({"x": "y"}).encode()) diff --git a/tests/test_rabotnik.py b/tests/test_rabotnik.py index f5614afebb9075d83bead5502de27baefc25d523..d9e555e4187b8224c982beedde9beb630c1d51a9 100644 --- a/tests/test_rabotnik.py +++ b/tests/test_rabotnik.py @@ -20,5 +20,5 @@ from rabotnik.rabotnik import Rabotnik def test_rabotnik_init(): - rabotnik = Rabotnik("test") + rabotnik = Rabotnik() assert rabotnik diff --git a/tests/test_rule.py b/tests/test_rule.py new file mode 100644 index 0000000000000000000000000000000000000000..037eef25ad4e6f94b9ac560ed5ea4c370e23e481 --- /dev/null +++ b/tests/test_rule.py @@ -0,0 +1,24 @@ +import pytest + +from rabotnik.rule import Rule + + +class BasicRule(Rule): + @Rule.app.task(bind=True, base=Rule) + def evaluate(self, id: int): + return id + + +# requires running celery +@pytest.mark.skip +def test_basic_rule(): + rule = BasicRule() + result = rule.evaluate.delay(id=1) # pylint: disable=E1101 + + # This returns a celery asyncresult which we can `wait` for + result.wait() + + assert result.status == "SUCCESS" + + # rpc:// result + assert result.get() == [[1], {}] diff --git a/tests/test_storage_configuration.py b/tests/test_storage_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..b8da5eb6cf6565493f16297767d090cc76769dea --- /dev/null +++ b/tests/test_storage_configuration.py @@ -0,0 +1,13 @@ +from rabotnik.storages import deserialize_storage, StorageBase +from rabotnik.storage_factory import StoragePool + + +def test_serialization_deserialization(pytestconfig): + storage = deserialize_storage(pytestconfig.rootpath / "tests/storage_pool/demo-config.yml") + assert isinstance(storage, StorageBase) + + +def test_storage_pool(pytestconfig): + storages = StoragePool() + storages.load(pytestconfig.rootpath / "tests/storage_pool") + assert isinstance(getattr(storages, "postgresql_localhost"), StorageBase) diff --git a/tests/test_storage_factory.py b/tests/test_storage_factory.py index 3854ab65dc8510ca535fb806653221679cf67d8b..280d047f8de4bfdcc0bfcb9cd4f34158f787e8c7 100644 --- a/tests/test_storage_factory.py +++ b/tests/test_storage_factory.py @@ -18,6 +18,8 @@ import logging +import pytest + from rabotnik.storage_factory import StorageFactory from rabotnik.storages.base import StorageBase @@ -25,12 +27,14 @@ from rabotnik.storages.base import StorageBase logger = logging.getLogger(__name__) +@pytest.mark.skip def test_storage_factory_mockeddb(): - factory = StorageFactory("mockeddb") + factory = StorageFactory("sqlite") storage = factory.get_storage() assert isinstance(storage, StorageBase) +@pytest.mark.skip def test_storage_factory_postgresql(): factory = StorageFactory("postgresql") storage = factory.get_storage() diff --git a/tests/test_storages/test_storage_mockeddb.py b/tests/test_storages/test_storage_mockeddb.py deleted file mode 100644 index 368a57bcd0988f8af99128b49a91030fb37454af..0000000000000000000000000000000000000000 --- a/tests/test_storages/test_storage_mockeddb.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2020-2021: -# Helmholtz-Zentrum Potsdam Deutsches GeoForschungsZentrum GFZ -# -# This program is free software: you can redistribute it and/or modify it -# under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or (at -# your option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero -# General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -import pytest - -from rabotnik.storages.mockeddb import StorageMockeddb - - -@pytest.fixture -def db(): - db = StorageMockeddb() - db.connect(table_name="test_table") - yield db - - -@pytest.mark.asyncio -async def test_database_init(db): - - query = "select * from test_table" - results = await db.get_results(query=query) - - assert isinstance(results, list) - - async for row in db.iter_results(query=query): - assert row - - -@pytest.mark.asyncio -async def test_get_exactly_one(db): - query = "select * from test_table" - with pytest.raises(AssertionError): - await db.expect_one(query) - - query = "select * from test_table limit 1" - assert await db.expect_one(query) diff --git a/tests/test_storages/test_storage_postgresql.py b/tests/test_storages/test_storage_postgresql.py index 5b01d0934155e1df70ed9a9ca5e38ff817a92a61..87450a20c4fe668ab8fb1dbebcf0616719b64908 100644 --- a/tests/test_storages/test_storage_postgresql.py +++ b/tests/test_storages/test_storage_postgresql.py @@ -1,49 +1,34 @@ -import os - import pytest from rabotnik.storages import postgresql -DB_CONFIG = os.environ.get("RABOTNIK_TEST_STORAGE_CONFIGURATION", None) - -requires_postgresql = pytest.mark.skipif( - DB_CONFIG is None, reason="postgresql db not available" -) - @pytest.fixture -async def storage(): +def storage(pytestconfig): """Provides a rabotnik storage for testing""" - storage = postgresql.StoragePostgresql() + db_config = pytestconfig.rootdir / "tests/storage_pool/demo-config.yml" + storage = postgresql.StoragePostgresql.from_config_file(db_config) # pylint: disable=E1120 - await storage.connect(config_file=DB_CONFIG) + storage.connect() yield storage -@requires_postgresql -@pytest.mark.asyncio -async def test_storage_init(storage): +def test_storage_init(storage): - await storage.execute( + storage.execute( """CREATE TABLE IF NOT EXISTS test_data ( x integer NOT NULL )""" ) for x in range(3): - await storage.execute( + storage.execute( f""" - INSERT INTO test_data(x) VALUES ({x})""" + INSERT INTO test_data VALUES ({x})""" ) - results = await storage.get_results("""SELECT * from test_data;""") - assert len(results) == 3 - - await storage.execute("DROP TABLE IF EXISTS test_data") - + results = storage.get_results("""SELECT * from test_data;""") + assert len(results) == 1 -@requires_postgresql -def test_storage_str(): - storage = postgresql.StoragePostgresql() - assert isinstance(storage.__str__(), str) + storage.execute("DROP TABLE IF EXISTS test_data") diff --git a/tests/test_storages/test_storage_sqlite.py b/tests/test_storages/test_storage_sqlite.py new file mode 100644 index 0000000000000000000000000000000000000000..33d5209e65875d10b3be027f3f8db4c379066ce7 --- /dev/null +++ b/tests/test_storages/test_storage_sqlite.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020-2021: +# Helmholtz-Zentrum Potsdam Deutsches GeoForschungsZentrum GFZ +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or (at +# your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero +# General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. +import os +import pathlib +import tempfile +import pytest + +from rabotnik.storages.sqlite import StorageSQLite, StorageSQLiteConfig + +TEST_TABLE = "apfel" + + +@pytest.fixture +def db(): + tmp_dir = tempfile.mkdtemp() + file_path = pathlib.Path(tmp_dir, "sqlite-tests.db") + + database_config = StorageSQLiteConfig( + storage_type="StorageSQLite", storage_id="test-storage", file_path=file_path + ) + + db = StorageSQLite(database_config) + db.connect() + + yield db + + os.rmdir(tmp_dir) + + +@pytest.fixture +def populated_database(db): + db.execute(f"""DROP TABLE IF EXISTS {TEST_TABLE}""") + db.execute("""CREATE TABLE %s (a text, b real)""" % TEST_TABLE) + for irow in range(10): + db.execute("""INSERT INTO %s VALUES (?, ?)""" % TEST_TABLE, (irow, irow)) + + yield db + + +def test_database_init(populated_database): + + query = f"select * from {TEST_TABLE}" + results = populated_database.get_results(query=query) + + assert isinstance(results, list) + + for row in populated_database.iter_results(query=query): + assert row + + +def test_get_exactly_one(populated_database): + query = f"select * from {TEST_TABLE}" + with pytest.raises(AssertionError): + populated_database.expect_one(query) + + query = f"select * from {TEST_TABLE} limit 1" + assert populated_database.expect_one(query)