diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..2b01626 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + branches: + - '*' + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.12",] + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + - name: Install + run: python -m pip install --upgrade pip && pip install .[dev] -c constraints + - name: Lint + run: ruff check --output-format=github ./src ./tests + - name: Test + run: pytest --rootdir= ./tests --doctest-modules --junitxml=junit/test-results.xml + + \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cd2572f --- /dev/null +++ b/.gitignore @@ -0,0 +1,83 @@ +## macOS +.DS_Store +.AppleDouble +.LSOverride +._* +.fseventsd +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +## Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +*.manifest +*.spec +pip-log.txt +pip-delete-this-directory.txt +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ +*.mo +*.pot +.ipynb_checkpoints +.python-version +.env +.venv +.dev_venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.ropeproject +/site +.mypy_cache/ +.dmypy.json +dmypy.json +.venv/ +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +## Editors +.vscode/ +.idea/ +local_cache/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6d1c2a8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Miquido + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5027b93 --- /dev/null +++ b/Makefile @@ -0,0 +1,78 @@ +SHELL := sh +.ONESHELL: +.SHELLFLAGS := -eu -c +.DELETE_ON_ERROR: + +SOURCES_PATH := src +TESTS_PATH := tests + +# load environment config from .env if able +-include .env + +ifndef PYTHON_ALIAS + PYTHON_ALIAS := python +endif + +ifndef INSTALL_OPTIONS + INSTALL_OPTIONS := .[dev] +endif + +ifndef UV_VERSION + UV_VERSION := 0.4.22 +endif + +.PHONY: venv sync lock update format lint test release + +# Setup virtual environment for local development. +venv: + @echo '# Preparing development environment...' + @echo '...preparing git hooks...' + @cp -n ./config/pre-push ./.git/hooks/pre-push || : + @echo '...installing uv...' + @curl -LsSf https://github.com/astral-sh/uv/releases/download/$(UV_VERSION)/uv-installer.sh | sh + @echo '...preparing venv...' + @$(PYTHON_ALIAS) -m venv .venv --prompt="VENV[DEV]" --clear --upgrade-deps + @. ./.venv/bin/activate && pip install --upgrade pip && uv pip install --editable $(INSTALL_OPTIONS) --constraint constraints + @echo '...development environment ready! Activate venv using `. ./.venv/bin/activate`.' + +# Sync environment with uv based on constraints +sync: + @echo '# Synchronizing dependencies...' + @$(if $(findstring $(UV_VERSION), $(shell uv --version | head -n1 | cut -d" " -f2)), , @echo '...updating uv...' && curl -LsSf https://github.com/astral-sh/uv/releases/download/$(UV_VERSION)/uv-installer.sh | sh) + @uv pip install --editable $(INSTALL_OPTIONS) --constraint constraints + @echo '...finished!' + +# Generate a set of locked dependencies from pyproject.toml +lock: + @echo '# Locking dependencies...' + @uv pip compile pyproject.toml -o constraints --all-extras + @echo '...finished!' + +# Update and lock dependencies from pyproject.toml +update: + @echo '# Updating dependencies...' + @$(if $(shell printf '%s\n%s\n' "$(UV_VERSION)" "$$(uv --version | head -n1 | cut -d' ' -f2)" | sort -V | head -n1 | grep -q "$(UV_VERSION)"), , @echo '...updating uv...' && curl -LsSf https://github.com/astral-sh/uv/releases/download/$(UV_VERSION)/uv-installer.sh | sh) + # @$(if $(findstring $(UV_VERSION), $(shell uv --version | head -n1 | cut -d" " -f2)), , @echo '...updating uv...' && curl -LsSf https://github.com/astral-sh/uv/releases/download/$(UV_VERSION)/uv-installer.sh | sh) + @uv --no-cache pip compile pyproject.toml -o constraints --all-extras --upgrade + @uv pip install --editable $(INSTALL_OPTIONS) --constraint constraints + @echo '...finished!' + +# Run formatter. +format: + @ruff check --quiet --fix $(SOURCES_PATH) $(TESTS_PATH) + @ruff format --quiet $(SOURCES_PATH) $(TESTS_PATH) + +# Run linters and code checks. +lint: + @bandit -r $(SOURCES_PATH) + @ruff check $(SOURCES_PATH) $(TESTS_PATH) + @pyright --project ./ + +# Run tests suite. +test: + @$(PYTHON_ALIAS) -B -m pytest -vv --cov=$(SOURCES_PATH) --rootdir=$(TESTS_PATH) + +release: lint test + @echo '# Preparing release...' + @python -m build && python -m twine upload --skip-existing dist/* + @echo '...finished!' diff --git a/README.md b/README.md new file mode 100644 index 0000000..34e943f --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +# πŸš— haiway πŸš• 🚚 πŸš™ + +haiway is a framework helping to build better project codebase by leveraging concepts of structured concurrency and functional programming. + +## πŸ–₯️ Install + +With pip: + +```bash +pip install haiway +``` + +## πŸ‘· Contributing + +As an open-source project in a rapidly evolving field, we welcome all contributions. Whether you can add a new feature, enhance our infrastructure, or improve our documentation, your input is valuable to us. + +We welcome any feedback and suggestions! Feel free to open an issue or pull request. + +## βš–οΈ License + +MIT License + +Copyright (c) 2024 Miquido + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/config/pre-push b/config/pre-push new file mode 100755 index 0000000..a096fa4 --- /dev/null +++ b/config/pre-push @@ -0,0 +1,57 @@ +#!/bin/sh + +if git rev-parse --verify HEAD >/dev/null 2>&1 +then + against=HEAD +else + against=$(git hash-object -t tree /dev/null) +fi + +remote="$1" +url="$2" + +zero=$(git hash-object --stdin &2 "Found WIP commit in $local_ref, not pushing" + exit 1 + fi + fi +done + +. ./.venv/bin/activate + +make lint + +if test $? != 0 +then + cat <<\EOF + +Error: Linting failed. + +Ensure project quality and make all linting rules pass! + +EOF + exit 1 +fi + +exec git diff-index --check --cached $against -- diff --git a/constraints b/constraints new file mode 100644 index 0000000..3c15e4f --- /dev/null +++ b/constraints @@ -0,0 +1,43 @@ +# This file was autogenerated by uv via the following command: +# uv --no-cache pip compile pyproject.toml -o constraints --all-extras +bandit==1.7.10 + # via haiway (pyproject.toml) +coverage==7.6.3 + # via pytest-cov +iniconfig==2.0.0 + # via pytest +markdown-it-py==3.0.0 + # via rich +mdurl==0.1.2 + # via markdown-it-py +nodeenv==1.9.1 + # via pyright +packaging==24.1 + # via pytest +pbr==6.1.0 + # via stevedore +pluggy==1.5.0 + # via pytest +pygments==2.18.0 + # via rich +pyright==1.1.384 + # via haiway (pyproject.toml) +pytest==7.4.4 + # via + # haiway (pyproject.toml) + # pytest-asyncio + # pytest-cov +pytest-asyncio==0.23.8 + # via haiway (pyproject.toml) +pytest-cov==4.1.0 + # via haiway (pyproject.toml) +pyyaml==6.0.2 + # via bandit +rich==13.9.2 + # via bandit +ruff==0.5.7 + # via haiway (pyproject.toml) +stevedore==5.3.0 + # via bandit +typing-extensions==4.12.2 + # via pyright diff --git a/guidelines/packages.md b/guidelines/packages.md new file mode 100644 index 0000000..78efe37 --- /dev/null +++ b/guidelines/packages.md @@ -0,0 +1,345 @@ +## Organizing packages + +haiway is a framework designed to help developers organize their code, manage state propagation, and handle dependencies effectively. While the framework does not strictly enforce its proposed package structure, adhering to these guidelines can significantly enhance the maintainability and scalability of your projects. + +The core philosophy behind haiway's package organization is to create a clear separation of concerns, allowing developers to build modular and easily extensible applications. By following these principles, you'll be able to create software that is not only easier to understand and maintain but also more resilient to changes and growth over time. + +### Package structure + +In software development, especially in large-scale projects, proper organization is crucial. It helps developers navigate the codebase, understand the relationships between different components, and make changes with confidence. haiway's package organization strategy is designed to address these needs by providing a clear structure that scales well with project complexity. + +haiway defines five distinct package types, each serving a specific purpose in the overall architecture of your application. Package types are organized by their high level role in building application layers from the most basic and common elements to the most specific and complex functionalities to finally form an application entrypoint. + +Here is a high level overview of the project packages structure which will be explained in detail below. + +``` +src/project +β”‚ +β”œβ”€β”€ ... +β”‚ +β”œβ”€β”€ entrypoint_a/ # application entrypoints +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ __main__.py +β”‚ └── ... +β”œβ”€β”€ entrypoint_b/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ __main__.py +β”‚ └── ... +β”‚ +β”œβ”€β”€ features/ # high level functionalities +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ feature_a/ +β”‚ β”‚ β”œβ”€β”€ __init__.py +β”‚ β”‚ └── ... +β”‚ └── feature_b/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ └── ... +β”‚ +β”œβ”€β”€ solutions/ # low level functionalities +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ solution_a/ +β”‚ β”‚ β”œβ”€β”€ __init__.py +β”‚ β”‚ └── ... +β”‚ └── solution_b/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ └── ... +β”‚ +β”œβ”€β”€ integrations/ # third party services integrations +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ integration_a/ +β”‚ β”‚ β”œβ”€β”€ __init__.py +β”‚ β”‚ └── ... +β”‚ └── integration_b/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ └── ... +β”‚ +└── commons/ # common utilities and language extensions + β”œβ”€β”€ __init__.py + └── ... +``` + +#### Entrypoints + +Entrypoint packages serve as the starting points for your application. They define how your application is invoked and interacted with from the outside world. Examples of entrypoints include command-line interfaces (CLIs), HTTP servers, or even graphical user interfaces (GUIs). + +``` +src/project +β”‚ +β”œβ”€β”€ entrypoint_a/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ __main__.py +β”‚ └── ... +β”œβ”€β”€ entrypoint_b/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ __main__.py +β”‚ └── ... +└── ... +``` + +Entrypoints are top-level packages within your project's source directory. Your project can have multiple entrypoints, allowing for various ways to interact with your application. No other packages should depend on entrypoint packages. They are the outermost layer of your application architecture. Each entrypoint should be isolated in its own package, promoting a clear separation between different ways of invoking your application. + +By keeping entrypoints separate, you maintain flexibility in how your application can be used while ensuring that the core functionality remains independent of any specific interface. + +#### Features + +Feature packages encapsulate the highest-level functions provided by your application. They represent the main capabilities or services that your application offers to its users. Examples of features could be user registration, chat handling or data processing pipelines. + +``` +src/project +β”‚ +β”œβ”€β”€ ... +β”‚ +└── features/ + β”œβ”€β”€ __init__.py + β”œβ”€β”€ feature_a/ + β”‚ β”œβ”€β”€ __init__.py + β”‚ └── ... + └── feature_b/ + β”œβ”€β”€ __init__.py + └── ... +``` + +Feature packages are designed to be consumed by multiple entrypoints, allowing the same functionality to be accessed through different interfaces. All feature packages should be located within a top-level "features" package in your project's source directory. The top-level "features" package itself should not export any symbols. It serves purely as an organizational container. Feature packages should focus on high-level business logic and orchestration of lower-level components. However it should not directly depend on any of integrations prioritizing solution packages usage instead. + +By organizing your core application capabilities into feature packages, you create a clear delineation of what your application does, making it easier to understand, extend, and maintain the overall functionality. + +#### Solutions + +Solution packages provide smaller, more focused utilities and partial functionalities. They serve as the building blocks for your features, offering reusable components that can be combined to create more complex behaviors. While features implement a complete and complex functionalities, the solutions aim for simple, single purpose helpers that allow build numerous features on top. Examples of solutions include storage mechanism, user management or encryption helpers. + +``` +src/project +β”‚ +β”œβ”€β”€ ... +β”‚ +└── solutions/ + β”œβ”€β”€ __init__.py + β”œβ”€β”€ solution_a/ + β”‚ β”œβ”€β”€ __init__.py + β”‚ └── ... + └── solution_b/ + β”œβ”€β”€ __init__.py + └── ... +``` + +Solution packages deliver low-level functionalities that are common across multiple features. Solution packages cannot depend on any feature and entrypoint packages, maintaining a clear hierarchical structure. All solution packages should be located within a top-level "solutions" package in your project's source directory. Like the features package, the top-level "solutions" package itself should not export any symbols. Solutions should be project specific and abstract away direct integrations with third parties and implement algorithms laying foundations for features. + +By breaking down common functionalities into solution packages, you promote code reuse and maintain a clear separation between high-level features and their underlying implementations. + +#### Integrations + +Integration packages are responsible for implementing connections to third-party services, external APIs, or system resources. They serve as the bridge between your application and the outside world. Examples of integrations may be api clients or database connectors. + +``` +src/project +β”‚ +β”œβ”€β”€ ... +β”‚ +└── integrations/ + β”œβ”€β”€ __init__.py + β”œβ”€β”€ integration_a/ + β”‚ β”œβ”€β”€ __init__.py + β”‚ └── ... + └── integration_b/ + β”œβ”€β”€ __init__.py + └── ... +``` + +Each integration package should focus on a single integration or external service. They should not depend on other packages except for the commons package. All integration packages should be located within a top-level "integrations" package in your project's source directory. The top-level "integrations" package, like features and solutions, should not export any symbols. + +By isolating integrations in their own packages, you make it easier to manage external dependencies, update or replace integrations, and maintain a clear boundary between your application's core logic and its interactions with external systems. + +#### Commons + +The commons package is a special package that provides shared utilities, extensions, and helper functions used throughout your application. It serves as a foundation for all other packages and may be used to resolve circular dependencies caused by type imports in some cases. + +``` +src/project +β”‚ +β”œβ”€β”€ ... +β”‚ +└── commons/ + β”œβ”€β”€ __init__.py + └── ... +``` + +Commons package cannot depend on any other package in your application. It should contain only truly common and widely used functionalities. Care should be taken not to overload the commons package with too many responsibilities. + +The commons package helps reduce code duplication and provides a centralized location for shared utilities, promoting consistency across your application. + +### Internal Package Structure + +To maintain consistency and improve code organization, haiway recommends a specific internal structure for packages. This structure varies slightly depending on the package type, but generally follows a similar pattern. + +#### Structure for Features and Solutions Packages + +Features and solutions packages should adhere to the following internal structure: + +``` +solution_or_feature/ +β”‚ +β”œβ”€β”€ __init__.py +β”œβ”€β”€ calls.py +β”œβ”€β”€ config.py +β”œβ”€β”€ state.py +β”œβ”€β”€ ... +└── types.py +``` + +`__init__.py`: This file is responsible for exporting the package's public symbols. It's crucial to only export what is intended to be used outside the package. Anything not exported is considered internal and should not be accessed from outside the package. The `__init__.py` file should not import any internal or private elements of the package, especially direct implementations. Typically it would export the calls, state and types contents. + +`types.py`: This file contains definitions for data types, interfaces, and errors used within the package. It should not depend on any other file within the package and should not contain any logicβ€”only type declarations. Types defined here can be partially or fully exported to allow for type annotations and checks in other parts of your application. + +`config.py`: This file holds configuration and constants used by the package, including any relevant environment variables. It should only depend on types.py and not on any other module within the package. + +`state.py`: This file contains state declarations for the package, used for dependency and data injection. It can use types.py, config.py, and optionally default implementations from other modules. State types should be exported to allow defining and updating implementations and contextual data or configuration. The state should provide default values and/or factory methods for easy configuration. + +`calls.py`: This file defines the public interface functions that utilize the package's state and allow access to the package's functionalities. These functions should be exported within `__init__.py`. + +Other: Any additional files needed for internal implementation details. These files should be treated as internal and not exported. + +#### Structure for Integration Packages + +Integration packages follow a slightly different structure: + +``` +integration/ +β”‚ +β”œβ”€β”€ __init__.py +β”œβ”€β”€ config.py +β”œβ”€β”€ client.py +β”œβ”€β”€ ... +└── types.py +``` + +`__init__.py`: Similar to feature and solution packages, this file exports the integration's public symbols. + +`types.py`: Contains data types, interfaces, and errors specific to the integration. + +`config.py`: Holds configuration and constants for the integration, including relevant environment variables. + +`client.py`: Defines the integration client, which should be exported as public and provide all the functionalities of the integration. + +Other: These may include sets of mixins providing parts of the implementation separated by functionalities or topics. It may also include a session base type for managing connections to services. All mixins should be merged within the client type to provide the full functionality of the integration. + +#### Structure for Commons Package + +The commons package has a more flexible structure, as it contains various utility functions and shared components. However, it should still maintain a clear organization: + +``` +commons/ +β”‚ +β”œβ”€β”€ __init__.py +β”œβ”€β”€ config.py +β”œβ”€β”€ types.py +└── ... +``` + +`__init__.py`: Exports the public API of the commons package. + +`config.py`: Contains global configuration settings and constants. + +`types.py`: Defines common types used throughout the application including errors. + +Other: there are possible multiple additional files within this package according to your project needs. Additional, nested packages are highly recommended for splitting complex and long files though. + +#### Structure for Entrypoint Packages + +Entrypoint packages have a structure that reflects their role as the application's entry point: + +``` +entrypoint/ +β”‚ +β”œβ”€β”€ __init__.py +β”œβ”€β”€ __main__.py +β”œβ”€β”€ config.py +└── ... +``` + +`__init__.py`: Typically empty as entrypoints are usually not imported by other packages. + +`__main__.py`: The entry point for the application, containing the code that runs when the package is executed. + +`config.py`: Configuration specific to this entrypoint. + +Other: despite feature packages containing high level functionalities, each entrypoint should also internally organize its code including usage of nested packages according to the application needs. + +### Circular dependencies + +When splitting your code into multiple small packages, you may encounter circular dependencies. While the base package organization of haiway prevents some of these issues, they can still occur within the same package type group. There are two recommended solutions to address this situation: + +#### Contained packages + +This approach involves creating an additional common package that contains the conflicting packages. This strategy allows you to resolve conflicts while keeping linked functionalities together. It is helpful to merge few (at most three) packages that are linked together and commonly providing functionalities within that link i.e. database storage of specific data and some linked service relaying on that data. + +``` +src/project +β”‚ +β”œβ”€β”€ package_group/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ └── package_merged/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ package_a/ +β”‚ β”‚ β”œβ”€β”€ __init__.py +β”‚ β”‚ └── ... +β”‚ β”œβ”€β”€ package_b/ +β”‚ β”‚ β”œβ”€β”€ __init__.py +β”‚ β”‚ └── ... +β”‚ └── ... +└── ... +``` + +`package_group`: represents the broader category (e.g., features, solutions) where the packages are located. + +`package_merged`: represents the package merging conflicting packages, it should represent the merged functionalities designated functionality. It should export all public symbols from encapsulated packages. + +`package_a`, `package_b`: are the original packages that had circular dependencies between them + +By placing linked packages within the common package, you create a new scope that can resolve the circular dependency issues. The `__init__.py` file in package_merged can then expose a unified interface, managing the interactions between package_a and package_b internally and exposing all of required symbols. + +#### Shared packages + +When the contained packages strategy can't be applied due to multiple dependencies spread across multiple packages, you can create an additional, shared package within the same package group. This shared package declares all required interfaces. + +``` +src/project +β”‚ +β”œβ”€β”€ package_group/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ package_a/ +β”‚ β”‚ β”œβ”€β”€ __init__.py +β”‚ β”‚ └── ... +β”‚ β”œβ”€β”€ package_b/ +β”‚ β”‚ β”œβ”€β”€ __init__.py +β”‚ β”‚ └── ... +β”‚ └── package_shared/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ └── ... +└── ... +``` + +`package_group`: remains the broader category where the packages are located as previously. + +`package_a`, `package_b`: are kept separate within the group as normal + +`package_shared`: is introduced as a new package that contains shared interfaces and types to resolve circular dependencies between conflicting packages. + +The shared package acts as an intermediary, defining interfaces that both package_a and package_b can depend on. This breaks the direct circular dependency between them. The shared package should only contain interface definitions and types, not implementations. + +### Best Practices + +To make the most of haiway's package organization strategy, consider the following best practices: + +- Package Focus: Ensure that each package contains only modules (files) associated with its specific functionality. If you find a package growing too large or handling multiple concerns, consider breaking it down into smaller, more focused packages. + +- Clear Public/Internal Separation: Maintain a clear distinction between public and internal elements of your packages. Only export what is necessary for other parts of your application to use. This helps prevent unintended dependencies and makes it easier to refactor internal implementations without affecting other parts of your codebase. + +- Avoid Circular Dependencies: Be vigilant about preventing circular dependencies between packages. This can lead to complex and hard-to-maintain code. If you find yourself needing to create a circular dependency, it's often a sign that your package boundaries need to be reconsidered. + +- Use Type Annotations: Leverage type annotations throughout your codebase. This not only improves readability but also helps catch potential errors early in the development process. Strict type checking is strongly recommended for each project using haiway framework. + +- State Configuration: When defining states in state.py, provide default values and factory methods. This makes it easier to configure and use your packages in different contexts. + +- Consistent Naming: Use consistent naming conventions across your packages. This helps developers quickly understand the purpose and content of different files and modules. + +- Project Specific Rules: Some projects may benefit from additional rules applied to their codebase. You may introduce additional requirements but keep it consistent across the project and ensure that all project members will know them and apply correctly. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1a4fa55 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,68 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "haiway" +description = "Framework for dependency injection and state management within structured concurrency model." +version = "0.1.0" +readme = "README.md" +maintainers = [ + { name = "Kacper KaliΕ„ski", email = "kacper.kalinski@miquido.com" }, +] +requires-python = ">=3.12" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Intended Audience :: Developers", + "Programming Language :: Python", + "Typing :: Typed", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Application Frameworks", +] +license = { file = "LICENSE" } +dependencies = [] + +[project.urls] +Homepage = "https://miquido.com" +Repository = "https://github.com/miquido/haiway.git" + +[project.optional-dependencies] +dev = [ + "ruff~=0.5.0", + "pyright~=1.1", + "bandit~=1.7", + "pytest~=7.4", + "pytest-cov~=4.1", + "pytest-asyncio~=0.23.0", +] + +[tool.ruff] +target-version = "py312" +line-length = 100 +extend-exclude = [".venv", ".git", ".cache"] +lint.select = ["E", "F", "A", "I", "B", "PL", "W", "C", "RUF", "UP"] +lint.ignore = [] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401", "E402"] +"./tests/*.py" = ["PLR2004"] + +[tool.pyright] +pythonVersion = "3.12" +venvPath = "." +venv = ".venv" +include = ["./src"] +exclude = ["**/node_modules", "**/__pycache__"] +ignore = [] +stubPath = "./stubs" +reportMissingImports = true +reportMissingTypeStubs = true +typeCheckingMode = "strict" +userFileIndexingLimit = -1 +useLibraryCodeForTypes = true + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/src/haiway/__init__.py b/src/haiway/__init__.py new file mode 100644 index 0000000..7847227 --- /dev/null +++ b/src/haiway/__init__.py @@ -0,0 +1,75 @@ +from haiway.context import ( + Dependencies, + Dependency, + MissingContext, + MissingDependency, + MissingState, + ScopeMetrics, + ctx, +) +from haiway.helpers import ( + asynchronous, + auto_retry, + cached, + throttle, + with_timeout, +) +from haiway.state import Structure +from haiway.types import ( + MISSING, + Missing, + frozenlist, + is_missing, + not_missing, + when_missing, +) +from haiway.utils import ( + AsyncQueue, + always, + async_always, + async_noop, + freeze, + getenv_bool, + getenv_float, + getenv_int, + getenv_str, + load_env, + mimic_function, + noop, + setup_logging, +) + +__all__ = [ + "always", + "async_always", + "async_noop", + "asynchronous", + "AsyncQueue", + "auto_retry", + "cached", + "ctx", + "Dependencies", + "Dependency", + "freeze", + "frozenlist", + "getenv_bool", + "getenv_float", + "getenv_int", + "getenv_str", + "is_missing", + "load_env", + "mimic_function", + "Missing", + "MISSING", + "MissingContext", + "MissingDependency", + "MissingState", + "noop", + "not_missing", + "ScopeMetrics", + "setup_logging", + "Structure", + "throttle", + "when_missing", + "with_timeout", +] diff --git a/src/haiway/context/__init__.py b/src/haiway/context/__init__.py new file mode 100644 index 0000000..95045eb --- /dev/null +++ b/src/haiway/context/__init__.py @@ -0,0 +1,14 @@ +from haiway.context.access import ctx +from haiway.context.dependencies import Dependencies, Dependency +from haiway.context.metrics import ScopeMetrics +from haiway.context.types import MissingContext, MissingDependency, MissingState + +__all__ = [ + "ctx", + "Dependencies", + "Dependency", + "MissingContext", + "MissingDependency", + "MissingState", + "ScopeMetrics", +] diff --git a/src/haiway/context/access.py b/src/haiway/context/access.py new file mode 100644 index 0000000..128c7ab --- /dev/null +++ b/src/haiway/context/access.py @@ -0,0 +1,416 @@ +from asyncio import ( + Task, + current_task, +) +from collections.abc import ( + Callable, + Coroutine, +) +from logging import Logger +from types import TracebackType +from typing import Any, final + +from haiway.context.dependencies import Dependencies, Dependency +from haiway.context.metrics import MetricsContext, ScopeMetrics +from haiway.context.state import StateContext +from haiway.context.tasks import TaskGroupContext +from haiway.state import Structure +from haiway.utils import freeze + +__all__ = [ + "ctx", +] + + +@final +class ScopeContext: + def __init__( + self, + task_group: TaskGroupContext, + state: StateContext, + metrics: MetricsContext, + completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None, + ) -> None: + self._task_group: TaskGroupContext = task_group + self._state: StateContext = state + self._metrics: MetricsContext = metrics + self._completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = completion + + freeze(self) + + def __enter__(self) -> None: + assert self._completion is None, "Can't enter synchronous context with completion" # nosec: B101 + + self._state.__enter__() + self._metrics.__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self._metrics.__exit__( + exc_type=exc_type, + exc_val=exc_val, + exc_tb=exc_tb, + ) + + self._state.__exit__( + exc_type=exc_type, + exc_val=exc_val, + exc_tb=exc_tb, + ) + + async def __aenter__(self) -> None: + self._state.__enter__() + self._metrics.__enter__() + await self._task_group.__aenter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._task_group.__aexit__( + exc_type=exc_type, + exc_val=exc_val, + exc_tb=exc_tb, + ) + + self._metrics.__exit__( + exc_type=exc_type, + exc_val=exc_val, + exc_tb=exc_tb, + ) + + self._state.__exit__( + exc_type=exc_type, + exc_val=exc_val, + exc_tb=exc_tb, + ) + + if completion := self._completion: + await completion(self._metrics._metrics) # pyright: ignore[reportPrivateUsage] + + +@final +class ctx: + @staticmethod + def scope( + name: str, + /, + *state: Structure, + logger: Logger | None = None, + trace_id: str | None = None, + completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = None, + ) -> ScopeContext: + """ + Access scope context with given parameters. When called within an existing context\ + it becomes nested with current context as its predecessor. + + Parameters + ---------- + name: Value + name of the scope context + + *state: Structure + state propagated within the scope context, will be merged with current if any\ + by replacing current with provided on conflict + + logger: Logger | None + logger used within the scope context, when not provided current logger will be used\ + if any, otherwise the logger with the scope name will be requested. + + trace_id: str | None = None + tracing identifier included in logs produced within the scope context, when not\ + provided current identifier will be used if any, otherwise it random id will\ + be generated + + completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = None + completion callback called on exit from the scope granting access to finished\ + scope metrics. Completion is called outside of the context when its metrics is\ + already finished. Make sure to avoid any long operations within the completion. + + Returns + ------- + ScopeContext + context object intended to enter context manager with it + """ + + return ScopeContext( + task_group=TaskGroupContext(), + metrics=MetricsContext.scope( + name, + logger=logger, + trace_id=trace_id, + ), + state=StateContext.updated(state), + completion=completion, + ) + + @staticmethod + def updated( + *state: Structure, + ) -> StateContext: + """ + Update scope context with given state. When called within an existing context\ + it becomes nested with current context as its predecessor. + + Parameters + ---------- + *state: Structure + state propagated within the updated scope context, will be merged with current if any\ + by replacing current with provided on conflict + + Returns + ------- + StateContext + state part of context object intended to enter context manager with it + """ + + return StateContext.updated(state) + + @staticmethod + def spawn[Result, **Arguments]( + function: Callable[Arguments, Coroutine[None, None, Result]], + /, + *args: Arguments.args, + **kwargs: Arguments.kwargs, + ) -> Task[Result]: + """ + Spawn an async task within current scope context task group. When called outside of context\ + it will spawn detached task instead. + + Parameters + ---------- + function: Callable[Arguments, Coroutine[None, None, Result]] + function to be called within the task group + + *args: Arguments.args + positional arguments passed to function call + + **kwargs: Arguments.kwargs + keyword arguments passed to function call + + Returns + ------- + Task[Result] + task for tracking function execution and result + """ + + return TaskGroupContext.run(function, *args, **kwargs) + + @staticmethod + def cancel() -> None: + """ + Cancel current asyncio task + """ + + if task := current_task(): + task.cancel() + + else: + raise RuntimeError("Attempting to cancel context out of asyncio task") + + @staticmethod + async def dependency[DependencyType: Dependency]( + dependency: type[DependencyType], + /, + ) -> DependencyType: + """ + Access current dependency by its type. + + Parameters + ---------- + dependency: type[DependencyType] + type of requested dependency + + Returns + ------- + DependencyType + resolved dependency instance + """ + + return await Dependencies.dependency(dependency) + + @staticmethod + def state[StateType: Structure]( + state: type[StateType], + /, + default: StateType | None = None, + ) -> StateType: + """ + Access current scope context state by its type. If there is no matching state defined\ + default value will be created if able, an exception will raise otherwise. + + Parameters + ---------- + state: type[StateType] + type of requested state + + Returns + ------- + StateType + resolved state instance + """ + return StateContext.current( + state, + default=default, + ) + + @staticmethod + def record[Metric: Structure]( + metric: Metric, + /, + merge: Callable[[Metric, Metric], Metric] = lambda lhs, rhs: rhs, + ) -> None: + """ + Record metric within current scope context. + + Parameters + ---------- + metric: MetricType + value of metric to be recorded + + merge: Callable[[MetricType, MetricType], MetricType] = lambda lhs, rhs: rhs + merge method used on to resolve conflicts when a metric of the same type\ + was already recorded. When not provided value will be override current if any. + + Returns + ------- + None + """ + + MetricsContext.record( + metric, + merge=merge, + ) + + @staticmethod + def log_error( + message: str, + /, + *args: Any, + exception: BaseException | None = None, + ) -> None: + """ + Log using ERROR level within current scope context. When there is no current scope\ + root logger will be used without additional details. + + Parameters + ---------- + message: str + message to be written to log + + *args: Any + message format arguments + + exception: BaseException | None = None + exception associated with log, when provided full stack trace will be recorded + + Returns + ------- + None + """ + + MetricsContext.log_error( + message, + *args, + exception=exception, + ) + + @staticmethod + def log_warning( + message: str, + /, + *args: Any, + exception: Exception | None = None, + ) -> None: + """ + Log using WARNING level within current scope context. When there is no current scope\ + root logger will be used without additional details. + + Parameters + ---------- + message: str + message to be written to log + + *args: Any + message format arguments + + exception: BaseException | None = None + exception associated with log, when provided full stack trace will be recorded + + Returns + ------- + None + """ + + MetricsContext.log_warning( + message, + *args, + exception=exception, + ) + + @staticmethod + def log_info( + message: str, + /, + *args: Any, + ) -> None: + """ + Log using INFO level within current scope context. When there is no current scope\ + root logger will be used without additional details. + + Parameters + ---------- + message: str + message to be written to log + + *args: Any + message format arguments + + Returns + ------- + None + """ + + MetricsContext.log_info( + message, + *args, + ) + + @staticmethod + def log_debug( + message: str, + /, + *args: Any, + exception: Exception | None = None, + ) -> None: + """ + Log using DEBUG level within current scope context. When there is no current scope\ + root logger will be used without additional details. + + Parameters + ---------- + message: str + message to be written to log + + *args: Any + message format arguments + + exception: BaseException | None = None + exception associated with log, when provided full stack trace will be recorded + + Returns + ------- + None + """ + + MetricsContext.log_debug( + message, + *args, + exception=exception, + ) diff --git a/src/haiway/context/dependencies.py b/src/haiway/context/dependencies.py new file mode 100644 index 0000000..5fdd526 --- /dev/null +++ b/src/haiway/context/dependencies.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from asyncio import Lock, gather, shield +from typing import ClassVar, Self, cast, final + +__all__ = [ + "Dependencies", + "Dependency", +] + + +class Dependency(ABC): + @classmethod + @abstractmethod + async def prepare(cls) -> Self: ... + + async def dispose(self) -> None: # noqa: B027 + pass + + +@final +class Dependencies: + _lock: ClassVar[Lock] = Lock() + _dependencies: ClassVar[dict[type[Dependency], Dependency]] = {} + + def __init__(self) -> None: + raise NotImplementedError("Can't instantiate Dependencies") + + @classmethod + async def dependency[Requested: Dependency]( + cls, + dependency: type[Requested], + /, + ) -> Requested: + async with cls._lock: + if dependency not in cls._dependencies: + cls._dependencies[dependency] = await dependency.prepare() + + return cast(Requested, cls._dependencies[dependency]) + + @classmethod + async def register( + cls, + dependency: Dependency, + /, + ) -> None: + async with cls._lock: + if current := cls._dependencies.get(dependency.__class__): + await current.dispose() + + cls._dependencies[dependency.__class__] = dependency + + @classmethod + async def dispose(cls) -> None: + async with cls._lock: + await shield( + gather( + *[dependency.dispose() for dependency in cls._dependencies.values()], + return_exceptions=False, + ) + ) + cls._dependencies.clear() diff --git a/src/haiway/context/metrics.py b/src/haiway/context/metrics.py new file mode 100644 index 0000000..a445792 --- /dev/null +++ b/src/haiway/context/metrics.py @@ -0,0 +1,329 @@ +from asyncio import Future, gather, get_event_loop +from collections.abc import Callable +from contextvars import ContextVar, Token +from copy import copy +from itertools import chain +from logging import DEBUG, ERROR, INFO, WARNING, Logger, getLogger +from time import monotonic +from types import TracebackType +from typing import Any, Self, cast, final, overload +from uuid import uuid4 + +from haiway.state import Structure +from haiway.utils import freeze + +__all__ = [ + "ScopeMetrics", + "MetricsContext", +] + + +@final +class ScopeMetrics: + def __init__( + self, + *, + trace_id: str | None, + scope: str, + logger: Logger | None, + ) -> None: + self.trace_id: str = trace_id or uuid4().hex + self._label: str = f"{self.trace_id}|{scope}" if scope else self.trace_id + self._logger: Logger = logger or getLogger(name=scope) + self._metrics: dict[type[Structure], Structure] = {} + self._nested: list[ScopeMetrics] = [] + self._timestamp: float = monotonic() + self._completed: Future[float] = get_event_loop().create_future() + + freeze(self) + + def __del__(self) -> None: + self._complete() # ensure completion on deinit + + def __str__(self) -> str: + return self._label + + def metrics( + self, + *, + merge: Callable[[Structure, Structure], Structure] = lambda lhs, rhs: lhs, + ) -> list[Structure]: + metrics: dict[type[Structure], Structure] = copy(self._metrics) + for metric in chain.from_iterable(nested.metrics(merge=merge) for nested in self._nested): + metric_type: type[Structure] = type(metric) + if current := metrics.get(metric_type): + metrics[metric_type] = merge(current, metric) + + else: + metrics[metric_type] = metric + + return list(metrics.values()) + + @overload + def read[Metric: Structure]( + self, + metric: type[Metric], + /, + ) -> Metric | None: ... + + @overload + def read[Metric: Structure]( + self, + metric: type[Metric], + /, + default: Metric, + ) -> Metric: ... + + def read[Metric: Structure]( + self, + metric: type[Metric], + /, + default: Metric | None = None, + ) -> Metric | None: + return cast(Metric | None, self._metrics.get(metric, default)) + + def record[Metric: Structure]( + self, + metric: Metric, + /, + *, + merge: Callable[[Metric, Metric], Metric] = lambda lhs, rhs: rhs, + ) -> None: + assert not self._completed.done(), "Can't record using completed metrics scope" # nosec: B101 + metric_type: type[Metric] = type(metric) + if current := self._metrics.get(metric_type): + self._metrics[metric_type] = merge(cast(Metric, current), metric) + + else: + self._metrics[metric_type] = metric + + @property + def completed(self) -> bool: + return self._completed.done() and all(nested.completed for nested in self._nested) + + @property + def time(self) -> float: + if self._completed.done(): + return self._completed.result() + + else: + return monotonic() - self._timestamp + + async def wait(self) -> None: + await gather( + self._completed, + *[nested.wait() for nested in self._nested], + return_exceptions=False, + ) + + def _complete(self) -> None: + if self._completed.done(): + return # already completed + + self._completed.set_result(monotonic() - self._timestamp) + + def scope( + self, + name: str, + /, + ) -> Self: + nested: Self = self.__class__( + scope=name, + logger=self._logger, + trace_id=self.trace_id, + ) + self._nested.append(nested) + return nested + + def log( + self, + level: int, + message: str, + /, + *args: Any, + exception: BaseException | None = None, + ) -> None: + self._logger.log( + level, + f"[{self}] {message}", + *args, + exc_info=exception, + ) + + +@final +class MetricsContext: + _context = ContextVar[ScopeMetrics]("MetricsContext") + + @classmethod + def scope( + cls, + name: str, + /, + *, + trace_id: str | None = None, + logger: Logger | None = None, + ) -> Self: + try: + context: ScopeMetrics = cls._context.get() + if trace_id is None or context.trace_id == trace_id: + return cls(context.scope(name)) + + else: + return cls( + ScopeMetrics( + trace_id=trace_id, + scope=name, + logger=logger or context._logger, # pyright: ignore[reportPrivateUsage] + ) + ) + except LookupError: # create metrics scope when missing yet + return cls( + ScopeMetrics( + trace_id=trace_id, + scope=name, + logger=logger, + ) + ) + + @classmethod + def record[Metric: Structure]( + cls, + metric: Metric, + /, + *, + merge: Callable[[Metric, Metric], Metric] = lambda lhs, rhs: rhs, + ) -> None: + try: # catch exceptions - we don't wan't to blow up on metrics + cls._context.get().record(metric, merge=merge) + + except Exception as exc: + cls.log_error( + "Failed to record metric: %s", + type(metric).__qualname__, + exception=exc, + ) + + # - LOGS - + + @classmethod + def log_error( + cls, + message: str, + /, + *args: Any, + exception: BaseException | None = None, + ) -> None: + try: + cls._context.get().log( + ERROR, + message, + *args, + exception=exception, + ) + + except LookupError: + getLogger().log( + ERROR, + message, + *args, + exc_info=exception, + ) + + @classmethod + def log_warning( + cls, + message: str, + /, + *args: Any, + exception: Exception | None = None, + ) -> None: + try: + cls._context.get().log( + WARNING, + message, + *args, + exception=exception, + ) + + except LookupError: + getLogger().log( + WARNING, + message, + *args, + exc_info=exception, + ) + + @classmethod + def log_info( + cls, + message: str, + /, + *args: Any, + ) -> None: + try: + cls._context.get().log( + INFO, + message, + *args, + ) + + except LookupError: + getLogger().log( + INFO, + message, + *args, + ) + + @classmethod + def log_debug( + cls, + message: str, + /, + *args: Any, + exception: Exception | None = None, + ) -> None: + try: + cls._context.get().log( + DEBUG, + message, + *args, + exception=exception, + ) + + except LookupError: + getLogger().log( + DEBUG, + message, + *args, + exc_info=exception, + ) + + def __init__( + self, + metrics: ScopeMetrics, + ) -> None: + self._metrics: ScopeMetrics = metrics + self._token: Token[ScopeMetrics] | None = None + self._started: float | None = None + self._finished: float | None = None + + def __enter__(self) -> None: + assert ( # nosec: B101 + self._token is None and self._started is None + ), "MetricsContext reentrance is not allowed" + self._token = MetricsContext._context.set(self._metrics) + self._started = monotonic() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + assert ( # nosec: B101 + self._token is not None and self._started is not None and self._finished is None + ), "Unbalanced MetricsContext context enter/exit" + self._finished = monotonic() + MetricsContext._context.reset(self._token) + self._token = None diff --git a/src/haiway/context/state.py b/src/haiway/context/state.py new file mode 100644 index 0000000..8d77319 --- /dev/null +++ b/src/haiway/context/state.py @@ -0,0 +1,115 @@ +from collections.abc import Iterable +from contextvars import ContextVar, Token +from types import TracebackType +from typing import Self, cast, final + +from haiway.context.types import MissingContext, MissingState +from haiway.state import Structure +from haiway.utils import freeze + +__all__ = [ + "ScopeState", + "StateContext", +] + + +@final +class ScopeState: + def __init__( + self, + state: Iterable[Structure], + ) -> None: + self._state: dict[type[Structure], Structure] = { + type(element): element for element in state + } + freeze(self) + + def state[State: Structure]( + self, + state: type[State], + /, + default: State | None = None, + ) -> State: + if state in self._state: + return cast(State, self._state[state]) + + elif default is not None: + return default + + else: + try: + initialized: State = state() + self._state[state] = initialized + return initialized + + except Exception as exc: + raise MissingState( + f"{state.__qualname__} is not defined in current scope" + " and failed to provide a default value" + ) from exc + + def updated( + self, + state: Iterable[Structure], + ) -> Self: + if state: + return self.__class__( + [ + *self._state.values(), + *state, + ] + ) + + else: + return self + + +@final +class StateContext: + _context = ContextVar[ScopeState]("StateContext") + + @classmethod + def current[State: Structure]( + cls, + state: type[State], + /, + default: State | None = None, + ) -> State: + try: + return cls._context.get().state(state, default=default) + + except LookupError as exc: + raise MissingContext("StateContext requested but not defined!") from exc + + @classmethod + def updated( + cls, + state: Iterable[Structure], + /, + ) -> Self: + try: + return cls(state=cls._context.get().updated(state=state)) + + except LookupError: # create new context as a fallback + return cls(state=ScopeState(state)) + + def __init__( + self, + state: ScopeState, + ) -> None: + self._state: ScopeState = state + self._token: Token[ScopeState] | None = None + + def __enter__(self) -> None: + assert self._token is None, "StateContext reentrance is not allowed" # nosec: B101 + self._token = StateContext._context.set(self._state) + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + assert self._token is not None, "Unbalanced StateContext context exit" # nosec: B101 + StateContext._context.reset(self._token) + self._token = None diff --git a/src/haiway/context/tasks.py b/src/haiway/context/tasks.py new file mode 100644 index 0000000..cef3186 --- /dev/null +++ b/src/haiway/context/tasks.py @@ -0,0 +1,65 @@ +from asyncio import Task, TaskGroup, get_event_loop +from collections.abc import Callable, Coroutine +from contextvars import ContextVar, Token, copy_context +from types import TracebackType +from typing import final + +__all__ = [ + "TaskGroupContext", +] + + +@final +class TaskGroupContext: + _context = ContextVar[TaskGroup]("TaskGroupContext") + + @classmethod + def run[Result, **Arguments]( + cls, + function: Callable[Arguments, Coroutine[None, None, Result]], + /, + *args: Arguments.args, + **kwargs: Arguments.kwargs, + ) -> Task[Result]: + try: + return cls._context.get().create_task( + function(*args, **kwargs), + context=copy_context(), + ) + + except LookupError: # spawn task out of group as a fallback + return get_event_loop().create_task( + function(*args, **kwargs), + context=copy_context(), + ) + + def __init__( + self, + ) -> None: + self._group: TaskGroup = TaskGroup() + self._token: Token[TaskGroup] | None = None + + async def __aenter__(self) -> None: + assert self._token is None, "TaskGroupContext reentrance is not allowed" # nosec: B101 + await self._group.__aenter__() + self._token = TaskGroupContext._context.set(self._group) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + assert self._token is not None, "Unbalanced TaskGroupContext context exit" # nosec: B101 + TaskGroupContext._context.reset(self._token) + self._token = None + + try: + await self._group.__aexit__( + et=exc_type, + exc=exc_val, + tb=exc_tb, + ) + + except BaseException: + pass # silence TaskGroup exceptions, if there was exception already we will get it diff --git a/src/haiway/context/types.py b/src/haiway/context/types.py new file mode 100644 index 0000000..0084cf2 --- /dev/null +++ b/src/haiway/context/types.py @@ -0,0 +1,17 @@ +__all__ = [ + "MissingContext", + "MissingDependency", + "MissingState", +] + + +class MissingContext(Exception): + pass + + +class MissingDependency(Exception): + pass + + +class MissingState(Exception): + pass diff --git a/src/haiway/helpers/__init__.py b/src/haiway/helpers/__init__.py new file mode 100644 index 0000000..5af43d4 --- /dev/null +++ b/src/haiway/helpers/__init__.py @@ -0,0 +1,13 @@ +from haiway.helpers.asynchronous import asynchronous +from haiway.helpers.cache import cached +from haiway.helpers.retry import auto_retry +from haiway.helpers.throttling import throttle +from haiway.helpers.timeout import with_timeout + +__all__ = [ + "asynchronous", + "auto_retry", + "cached", + "throttle", + "with_timeout", +] diff --git a/src/haiway/helpers/asynchronous.py b/src/haiway/helpers/asynchronous.py new file mode 100644 index 0000000..ff32929 --- /dev/null +++ b/src/haiway/helpers/asynchronous.py @@ -0,0 +1,226 @@ +from asyncio import AbstractEventLoop, get_running_loop, iscoroutinefunction +from collections.abc import Callable, Coroutine +from concurrent.futures import Executor +from contextvars import Context, copy_context +from functools import partial +from typing import Any, Literal, cast, overload + +from haiway.types.missing import MISSING, Missing, not_missing + +__all__ = [ + "asynchronous", +] + + +@overload +def asynchronous[**Args, Result]() -> ( + Callable[ + [Callable[Args, Result]], + Callable[Args, Coroutine[None, None, Result]], + ] +): ... + + +@overload +def asynchronous[**Args, Result]( + *, + loop: AbstractEventLoop | None = None, + executor: Executor | Literal["default"], +) -> Callable[ + [Callable[Args, Result]], + Callable[Args, Coroutine[None, None, Result]], +]: ... + + +@overload +def asynchronous[**Args, Result]( + function: Callable[Args, Result], + /, +) -> Callable[Args, Coroutine[None, None, Result]]: ... + + +def asynchronous[**Args, Result]( + function: Callable[Args, Result] | None = None, + /, + loop: AbstractEventLoop | None = None, + executor: Executor | Literal["default"] | Missing = MISSING, +) -> ( + Callable[ + [Callable[Args, Result]], + Callable[Args, Coroutine[None, None, Result]], + ] + | Callable[Args, Coroutine[None, None, Result]] +): + """\ + Wrapper for a sync function to convert it to an async function. \ + When specified an executor, it can be used to wrap long running or blocking synchronous \ + operations within coroutines system. + + Parameters + ---------- + function: Callable[Args, Result] + function to be wrapped as running in loop executor. + loop: AbstractEventLoop | None + loop used to call the function. When None was provided the loop currently running while \ + executing the function will be used. Default is None. + executor: Executor | Literal["default"] | Missing + executor used to run the function. Specifying "default" uses a default loop executor. + When not provided (Missing) no executor will be used \ + (function will by just wrapped as an async function without any executor) + + Returns + ------- + Callable[_Args, _Result] + function wrapped to async using loop executor. + """ + + def wrap( + wrapped: Callable[Args, Result], + ) -> Callable[Args, Coroutine[None, None, Result]]: + assert not iscoroutinefunction(wrapped), "Cannot wrap async function in executor" # nosec: B101 + + if not_missing(executor): + return _ExecutorWrapper( + wrapped, + loop=loop, + executor=cast(Executor | None, None if executor == "default" else executor), + ) + + else: + + async def wrapper( + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + return wrapped( + *args, + **kwargs, + ) + + _mimic_async(wrapped, within=wrapper) + return wrapper + + if function := function: + return wrap(wrapped=function) + + else: + return wrap + + +class _ExecutorWrapper[**Args, Result]: + def __init__( + self, + function: Callable[Args, Result], + /, + loop: AbstractEventLoop | None, + executor: Executor | None, + ) -> None: + self._function: Callable[Args, Result] = function + self._loop: AbstractEventLoop | None = loop + self._executor: Executor | None = executor + + # mimic function attributes if able + _mimic_async(function, within=self) + + async def __call__( + self, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + context: Context = copy_context() + return await (self._loop or get_running_loop()).run_in_executor( + self._executor, + context.run, + partial(self._function, *args, **kwargs), + ) + + def __get__( + self, + instance: object, + owner: type | None = None, + /, + ) -> Callable[Args, Coroutine[None, None, Result]]: + if owner is None: + return self + + else: + return _mimic_async( + self._function, + within=partial( + self.__method_call__, + instance, + ), + ) + + async def __method_call__( + self, + __method_self: object, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + return await (self._loop or get_running_loop()).run_in_executor( + self._executor, + partial(self._function, __method_self, *args, **kwargs), + ) + + +def _mimic_async[**Args, Result]( + function: Callable[Args, Result], + /, + within: Callable[..., Coroutine[None, None, Result]], +) -> Callable[Args, Coroutine[None, None, Result]]: + try: + annotations: Any = getattr( # noqa: B009 + function, + "__annotations__", + ) + setattr( # noqa: B010 + within, + "__annotations__", + { + **annotations, + "return": Coroutine[None, None, annotations.get("return", Any)], + }, + ) + + except AttributeError: + pass + + for attribute in ( + "__module__", + "__name__", + "__qualname__", + "__doc__", + "__type_params__", + "__defaults__", + "__kwdefaults__", + "__globals__", + ): + try: + setattr( + within, + attribute, + getattr( + function, + attribute, + ), + ) + + except AttributeError: + pass + try: + within.__dict__.update(function.__dict__) + + except AttributeError: + pass + + setattr( # noqa: B010 - mimic functools.wraps behavior for correct signature checks + within, + "__wrapped__", + function, + ) + + return cast( + Callable[Args, Coroutine[None, None, Result]], + within, + ) diff --git a/src/haiway/helpers/cache.py b/src/haiway/helpers/cache.py new file mode 100644 index 0000000..95fd814 --- /dev/null +++ b/src/haiway/helpers/cache.py @@ -0,0 +1,326 @@ +from asyncio import AbstractEventLoop, Task, get_running_loop, iscoroutinefunction, shield +from collections import OrderedDict +from collections.abc import Callable, Coroutine, Hashable +from functools import _make_key, partial # pyright: ignore[reportPrivateUsage] +from time import monotonic +from typing import NamedTuple, cast, overload +from weakref import ref + +from haiway.utils.mimic import mimic_function + +__all__ = [ + "cached", +] + + +@overload +def cached[**Args, Result]( + function: Callable[Args, Result], + /, +) -> Callable[Args, Result]: ... + + +@overload +def cached[**Args, Result]( + *, + limit: int = 1, + expiration: float | None = None, +) -> Callable[[Callable[Args, Result]], Callable[Args, Result]]: ... + + +def cached[**Args, Result]( + function: Callable[Args, Result] | None = None, + *, + limit: int = 1, + expiration: float | None = None, +) -> Callable[[Callable[Args, Result]], Callable[Args, Result]] | Callable[Args, Result]: + """\ + Simple lru function result cache with optional expire time. \ + Works for both sync and async functions. \ + It is not allowed to be used on class methods. \ + This wrapper is not thread safe. + + Parameters + ---------- + function: Callable[_Args, _Result] + function to wrap in cache, either sync or async + limit: int + limit of cache entries to keep, default is 1 + expiration: float | None + entries expiration time in seconds, default is None (not expiring) + + Returns + ------- + Callable[[Callable[_Args, _Result]], Callable[_Args, _Result]] | Callable[_Args, _Result] + provided function wrapped in cache + """ + + def _wrap(function: Callable[Args, Result]) -> Callable[Args, Result]: + if iscoroutinefunction(function): + return cast( + Callable[Args, Result], + _AsyncCache( + function, + limit=limit, + expiration=expiration, + ), + ) + + else: + return cast( + Callable[Args, Result], + _SyncCache( + function, + limit=limit, + expiration=expiration, + ), + ) + + if function := function: + return _wrap(function) + + else: + return _wrap + + +class _CacheEntry[Entry](NamedTuple): + value: Entry + expire: float | None + + +class _SyncCache[**Args, Result]: + def __init__( + self, + function: Callable[Args, Result], + /, + limit: int, + expiration: float | None, + ) -> None: + self._function: Callable[Args, Result] = function + self._cached: OrderedDict[Hashable, _CacheEntry[Result]] = OrderedDict() + self._limit: int = limit + if expiration := expiration: + + def next_expire_time() -> float | None: + return monotonic() + expiration + + else: + + def next_expire_time() -> float | None: + return None + + self._next_expire_time: Callable[[], float | None] = next_expire_time + + # mimic function attributes if able + mimic_function(function, within=self) + + def __get__( + self, + instance: object | None, + owner: type | None = None, + /, + ) -> Callable[Args, Result]: + if owner is None or instance is None: + return self + + else: + return mimic_function( + self._function, + within=partial( + self.__method_call__, + instance, + ), + ) + + def __call__( + self, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + key: Hashable = _make_key( + args=args, + kwds=kwargs, + typed=True, + ) + + match self._cached.get(key): + case None: + pass + + case entry: + if (expire := entry[1]) and expire < monotonic(): + # if still running let it complete if able + del self._cached[key] # continue the same way as if empty + + else: + self._cached.move_to_end(key) + return entry[0] + + result: Result = self._function(*args, **kwargs) + self._cached[key] = _CacheEntry( + value=result, + expire=self._next_expire_time(), + ) + + if len(self._cached) > self._limit: + # if still running let it complete if able + self._cached.popitem(last=False) + + return result + + def __method_call__( + self, + __method_self: object, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + key: Hashable = _make_key( + args=(ref(__method_self), *args), + kwds=kwargs, + typed=True, + ) + + match self._cached.get(key): + case None: + pass + + case entry: + if (expire := entry[1]) and expire < monotonic(): + # if still running let it complete if able + del self._cached[key] # continue the same way as if empty + + else: + self._cached.move_to_end(key) + return entry[0] + + result: Result = self._function(__method_self, *args, **kwargs) # pyright: ignore[reportUnknownVariableType, reportCallIssue] + self._cached[key] = _CacheEntry( + value=result, # pyright: ignore[reportUnknownArgumentType] + expire=self._next_expire_time(), + ) + if len(self._cached) > self._limit: + # if still running let it complete if able + self._cached.popitem(last=False) + + return result # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] + + +class _AsyncCache[**Args, Result]: + def __init__( + self, + function: Callable[Args, Coroutine[None, None, Result]], + /, + limit: int, + expiration: float | None, + ) -> None: + self._function: Callable[Args, Coroutine[None, None, Result]] = function + self._cached: OrderedDict[Hashable, _CacheEntry[Task[Result]]] = OrderedDict() + self._limit: int = limit + if expiration := expiration: + + def next_expire_time() -> float | None: + return monotonic() + expiration + + else: + + def next_expire_time() -> float | None: + return None + + self._next_expire_time: Callable[[], float | None] = next_expire_time + + # mimic function attributes if able + mimic_function(function, within=self) + + def __get__( + self, + instance: object | None, + owner: type | None = None, + /, + ) -> Callable[Args, Coroutine[None, None, Result]]: + if owner is None or instance is None: + return self + + else: + return mimic_function( + self._function, + within=partial( + self.__method_call__, + instance, + ), + ) + + async def __call__( + self, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + loop: AbstractEventLoop = get_running_loop() + key: Hashable = _make_key( + args=args, + kwds=kwargs, + typed=True, + ) + + match self._cached.get(key): + case None: + pass + + case entry: + if (expire := entry[1]) and expire < monotonic(): + # if still running let it complete if able + del self._cached[key] # continue the same way as if empty + + else: + self._cached.move_to_end(key) + return await shield(entry[0]) + + task: Task[Result] = loop.create_task(self._function(*args, **kwargs)) # pyright: ignore[reportCallIssue] + self._cached[key] = _CacheEntry( + value=task, + expire=self._next_expire_time(), + ) + if len(self._cached) > self._limit: + # if still running let it complete if able + self._cached.popitem(last=False) + + return await shield(task) + + async def __method_call__( + self, + __method_self: object, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + loop: AbstractEventLoop = get_running_loop() + key: Hashable = _make_key( + args=(ref(__method_self), *args), + kwds=kwargs, + typed=True, + ) + + match self._cached.get(key): + case None: + pass + + case entry: + if (expire := entry[1]) and expire < monotonic(): + # if still running let it complete if able + del self._cached[key] # continue the same way as if empty + + else: + self._cached.move_to_end(key) + return await shield(entry[0]) + + task: Task[Result] = loop.create_task( + self._function(__method_self, *args, **kwargs), # pyright: ignore[reportCallIssue, reportUnknownArgumentType] + ) + self._cached[key] = _CacheEntry( + value=task, + expire=self._next_expire_time(), + ) + + if len(self._cached) > self._limit: + # if still running let it complete if able + self._cached.popitem(last=False) + + return await shield(task) diff --git a/src/haiway/helpers/retry.py b/src/haiway/helpers/retry.py new file mode 100644 index 0000000..218fe14 --- /dev/null +++ b/src/haiway/helpers/retry.py @@ -0,0 +1,210 @@ +from asyncio import CancelledError, iscoroutinefunction, sleep +from collections.abc import Callable, Coroutine +from typing import cast, overload + +from haiway.context import ctx +from haiway.utils import mimic_function + +__all__ = [ + "auto_retry", +] + + +@overload +def auto_retry[**Args, Result]( + function: Callable[Args, Result], + /, +) -> Callable[Args, Result]: + """\ + Function wrapper retrying the wrapped function again on fail. \ + Works for both sync and async functions. \ + It is not allowed to be used on class methods. \ + This wrapper is not thread safe. + + Parameters + ---------- + function: Callable[_Args_T, _Result_T] + function to wrap in auto retry, either sync or async. + + Returns + ------- + Callable[_Args_T, _Result_T] + provided function wrapped in auto retry with default configuration. + """ + + +@overload +def auto_retry[**Args, Result]( + *, + limit: int = 1, + delay: Callable[[int, Exception], float] | float | None = None, + catching: set[type[Exception]] | tuple[type[Exception], ...] | type[Exception] = Exception, +) -> Callable[[Callable[Args, Result]], Callable[Args, Result]]: + """\ + Function wrapper retrying the wrapped function again on fail. \ + Works for both sync and async functions. \ + It is not allowed to be used on class methods. \ + This wrapper is not thread safe. + + Parameters + ---------- + limit: int + limit of retries, default is 1 + delay: Callable[[int, Exception], float] | float | None + retry delay time in seconds, either concrete value or a function producing it, \ + default is None (no delay) + catching: set[type[Exception]] | type[Exception] | None + Exception types that are triggering auto retry. Retry will trigger only when \ + exceptions of matching types (including subclasses) will occur. CancelledError \ + will be always propagated even if specified explicitly. + Default is Exception - all subclasses of Exception will be handled. + + Returns + ------- + Callable[[Callable[_Args_T, _Result_T]], Callable[_Args_T, _Result_T]] + function wrapper for adding auto retry + """ + + +def auto_retry[**Args, Result]( + function: Callable[Args, Result] | None = None, + *, + limit: int = 1, + delay: Callable[[int, Exception], float] | float | None = None, + catching: set[type[Exception]] | tuple[type[Exception], ...] | type[Exception] = Exception, +) -> Callable[[Callable[Args, Result]], Callable[Args, Result]] | Callable[Args, Result]: + """\ + Function wrapper retrying the wrapped function again on fail. \ + Works for both sync and async functions. \ + It is not allowed to be used on class methods. \ + This wrapper is not thread safe. + + Parameters + ---------- + function: Callable[_Args_T, _Result_T] + function to wrap in auto retry, either sync or async. + limit: int + limit of retries, default is 1 + delay: Callable[[int, Exception], float] | float | None + retry delay time in seconds, either concrete value or a function producing it, \ + default is None (no delay) + catching: set[type[Exception]] | type[Exception] | None + Exception types that are triggering auto retry. Retry will trigger only when \ + exceptions of matching types (including subclasses) will occur. CancelledError \ + will be always propagated even if specified explicitly. + Default is Exception - all subclasses of Exception will be handled. + + Returns + ------- + Callable[[Callable[_Args_T, _Result_T]], Callable[_Args_T, _Result_T]] | \ + Callable[_Args_T, _Result_T] + function wrapper for adding auto retry or a wrapped function + """ + + def _wrap( + function: Callable[Args, Result], + /, + ) -> Callable[Args, Result]: + if iscoroutinefunction(function): + return cast( + Callable[Args, Result], + _wrap_async( + function, + limit=limit, + delay=delay, + catching=catching if isinstance(catching, set | tuple) else {catching}, + ), + ) + else: + assert delay is None, "Delay is not supported in sync wrapper" # nosec: B101 + return _wrap_sync( + function, + limit=limit, + catching=catching if isinstance(catching, set | tuple) else {catching}, + ) + + if function := function: + return _wrap(function) + else: + return _wrap + + +def _wrap_sync[**Args, Result]( + function: Callable[Args, Result], + *, + limit: int, + catching: set[type[Exception]] | tuple[type[Exception], ...], +) -> Callable[Args, Result]: + assert limit > 0, "Limit has to be greater than zero" # nosec: B101 + + @mimic_function(function) + def wrapped( + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + attempt: int = 0 + while True: + try: + return function(*args, **kwargs) + except CancelledError as exc: + raise exc + + except Exception as exc: + if attempt < limit and any(isinstance(exc, exception) for exception in catching): + attempt += 1 + ctx.log_error( + "Attempting to retry %s which failed due to an error: %s", + function.__name__, + exc, + ) + + else: + raise exc + + return wrapped + + +def _wrap_async[**Args, Result]( + function: Callable[Args, Coroutine[None, None, Result]], + *, + limit: int, + delay: Callable[[int, Exception], float] | float | None, + catching: set[type[Exception]] | tuple[type[Exception], ...], +) -> Callable[Args, Coroutine[None, None, Result]]: + assert limit > 0, "Limit has to be greater than zero" # nosec: B101 + + @mimic_function(function) + async def wrapped( + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + attempt: int = 0 + while True: + try: + return await function(*args, **kwargs) + except CancelledError as exc: + raise exc + + except Exception as exc: + if attempt < limit and any(isinstance(exc, exception) for exception in catching): + attempt += 1 + ctx.log_error( + "Attempting to retry %s which failed due to an error", + function.__name__, + exception=exc, + ) + + match delay: + case None: + continue + + case float(strict): + await sleep(delay=strict) + + case make_delay: # type: Callable[[], float] + await sleep(delay=make_delay(attempt, exc)) # pyright: ignore[reportCallIssue, reportUnknownArgumentType] + + else: + raise exc + + return wrapped diff --git a/src/haiway/helpers/throttling.py b/src/haiway/helpers/throttling.py new file mode 100644 index 0000000..d550276 --- /dev/null +++ b/src/haiway/helpers/throttling.py @@ -0,0 +1,133 @@ +from asyncio import ( + Lock, + iscoroutinefunction, + sleep, +) +from collections import deque +from collections.abc import Callable, Coroutine +from datetime import timedelta +from time import monotonic +from typing import cast, overload + +from haiway.utils.mimic import mimic_function + +__all__ = [ + "throttle", +] + + +@overload +def throttle[**Args, Result]( + function: Callable[Args, Coroutine[None, None, Result]], + /, +) -> Callable[Args, Coroutine[None, None, Result]]: ... + + +@overload +def throttle[**Args, Result]( + *, + limit: int = 1, + period: timedelta | float = 1, +) -> Callable[ + [Callable[Args, Coroutine[None, None, Result]]], Callable[Args, Coroutine[None, None, Result]] +]: ... + + +def throttle[**Args, Result]( + function: Callable[Args, Coroutine[None, None, Result]] | None = None, + *, + limit: int = 1, + period: timedelta | float = 1, +) -> ( + Callable[ + [Callable[Args, Coroutine[None, None, Result]]], + Callable[Args, Coroutine[None, None, Result]], + ] + | Callable[Args, Coroutine[None, None, Result]] +): + """\ + Throttle for function calls with custom limit and period time. \ + Works only for async functions by waiting desired time before execution. \ + It is not allowed to be used on class or instance methods. \ + This wrapper is not thread safe. + + Parameters + ---------- + function: Callable[Args, Coroutine[None, None, Result]] + function to wrap in throttle + limit: int + limit of executions in given period, if no period was specified + it is number of concurrent executions instead, default is 1 + period: timedelta | float | None + period time (in seconds by default) during which the limit resets, default is 1 second + + Returns + ------- + Callable[[Callable[Args, Coroutine[None, None, Result]]], Callable[Args, Coroutine[None, None, Result]]] \ + | Callable[Args, Coroutine[None, None, Result]] + provided function wrapped in throttle + """ # noqa: E501 + + def _wrap( + function: Callable[Args, Coroutine[None, None, Result]], + ) -> Callable[Args, Coroutine[None, None, Result]]: + assert iscoroutinefunction(function) # nosec: B101 + return cast( + Callable[Args, Coroutine[None, None, Result]], + _AsyncThrottle( + function, + limit=limit, + period=period, + ), + ) + + if function := function: + return _wrap(function) + + else: + return _wrap + + +class _AsyncThrottle[**Args, Result]: + def __init__( + self, + function: Callable[Args, Coroutine[None, None, Result]], + /, + limit: int, + period: timedelta | float, + ) -> None: + self._function: Callable[Args, Coroutine[None, None, Result]] = function + self._entries: deque[float] = deque() + self._lock: Lock = Lock() + self._limit: int = limit + self._period: float + match period: + case timedelta() as delta: + self._period = delta.total_seconds() + + case period_seconds: + self._period = period_seconds + + # mimic function attributes if able + mimic_function(function, within=self) + + async def __call__( + self, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + async with self._lock: + time_now: float = monotonic() + while self._entries: # cleanup old entries + if self._entries[0] + self._period <= time_now: + self._entries.popleft() + + else: + break + + if len(self._entries) >= self._limit: + await sleep(self._entries[0] - time_now) + + self._entries.append(monotonic()) + + return await self._function(*args, **kwargs) diff --git a/src/haiway/helpers/timeout.py b/src/haiway/helpers/timeout.py new file mode 100644 index 0000000..16c1c72 --- /dev/null +++ b/src/haiway/helpers/timeout.py @@ -0,0 +1,112 @@ +from asyncio import AbstractEventLoop, Future, Task, TimerHandle, get_running_loop +from collections.abc import Callable, Coroutine + +from haiway.utils.mimic import mimic_function + +__all__ = [ + "with_timeout", +] + + +def with_timeout[**Args, Result]( + timeout: float, + /, +) -> Callable[ + [Callable[Args, Coroutine[None, None, Result]]], + Callable[Args, Coroutine[None, None, Result]], +]: + """\ + Timeout wrapper for a function call. \ + When the timeout time will pass before function returns function execution will be \ + cancelled and TimeoutError exception will raise. Make sure that wrapped \ + function handles cancellation properly. + This wrapper is not thread safe. + + Parameters + ---------- + timeout: float + timeout time in seconds + + Returns + ------- + Callable[[Callable[_Args, _Result]], Callable[_Args, _Result]] | Callable[_Args, _Result] + function wrapper adding timeout + """ + + def _wrap( + function: Callable[Args, Coroutine[None, None, Result]], + ) -> Callable[Args, Coroutine[None, None, Result]]: + return _AsyncTimeout( + function, + timeout=timeout, + ) + + return _wrap + + +class _AsyncTimeout[**Args, Result]: + def __init__( + self, + function: Callable[Args, Coroutine[None, None, Result]], + /, + timeout: float, + ) -> None: + self._function: Callable[Args, Coroutine[None, None, Result]] = function + self._timeout: float = timeout + + # mimic function attributes if able + mimic_function(function, within=self) + + async def __call__( + self, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + loop: AbstractEventLoop = get_running_loop() + future: Future[Result] = loop.create_future() + task: Task[Result] = loop.create_task( + self._function( + *args, + **kwargs, + ), + ) + + def on_timeout( + future: Future[Result], + ) -> None: + if future.done(): + return # ignore if already finished + + # result future on its completion will ensure that task will complete + future.set_exception(TimeoutError()) + + timeout_handle: TimerHandle = loop.call_later( + self._timeout, + on_timeout, + future, + ) + + def on_completion( + task: Task[Result], + ) -> None: + timeout_handle.cancel() # at this stage we no longer need timeout to trigger + + if future.done(): + return # ignore if already finished + + try: + future.set_result(task.result()) + + except Exception as exc: + future.set_exception(exc) + + task.add_done_callback(on_completion) + + def on_result( + future: Future[Result], + ) -> None: + task.cancel() # when result future completes make sure that task also completes + + future.add_done_callback(on_result) + + return await future diff --git a/src/haiway/py.typed b/src/haiway/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/haiway/state/__init__.py b/src/haiway/state/__init__.py new file mode 100644 index 0000000..16bf5a4 --- /dev/null +++ b/src/haiway/state/__init__.py @@ -0,0 +1,8 @@ +from haiway.state.attributes import AttributeAnnotation, attribute_annotations +from haiway.state.structure import Structure + +__all__ = [ + "attribute_annotations", + "AttributeAnnotation", + "Structure", +] diff --git a/src/haiway/state/attributes.py b/src/haiway/state/attributes.py new file mode 100644 index 0000000..74f4b77 --- /dev/null +++ b/src/haiway/state/attributes.py @@ -0,0 +1,360 @@ +import sys +import types +import typing +from collections.abc import Mapping +from types import NoneType, UnionType +from typing import ( + Any, + ClassVar, + ForwardRef, + Generic, + Literal, + TypeAliasType, + TypeVar, + get_args, + get_origin, + get_type_hints, +) + +__all__ = [ + "attribute_annotations", + "AttributeAnnotation", +] + + +class AttributeAnnotation: + def __init__( + self, + *, + origin: Any, + arguments: list[Any], + ) -> None: + self.origin: Any = origin + self.arguments: list[Any] = arguments + + def __eq__( + self, + other: Any, + ) -> bool: + return self is other or ( + isinstance(other, self.__class__) + and self.origin == other.origin + and self.arguments == other.arguments + ) + + +def attribute_annotations( + cls: type[Any], + /, + type_parameters: dict[str, Any] | None = None, +) -> dict[str, AttributeAnnotation]: + type_parameters = type_parameters or {} + + self_annotation = AttributeAnnotation( + origin=cls, + arguments=[], # ignore self arguments here, Structure will have them resolved at this stage + ) + localns: dict[str, Any] = {cls.__name__: cls} + recursion_guard: dict[Any, AttributeAnnotation] = {cls: self_annotation} + attributes: dict[str, AttributeAnnotation] = {} + + for key, annotation in get_type_hints(cls, localns=localns).items(): + # do not include ClassVars, private or dunder items + if ((get_origin(annotation) or annotation) is ClassVar) or key.startswith("_"): + continue + + attributes[key] = _resolve_attribute_annotation( + annotation, + self_annotation=self_annotation, + type_parameters=type_parameters, + module=cls.__module__, + localns=localns, + recursion_guard=recursion_guard, + ) + + return attributes + + +def _resolve_attribute_annotation( # noqa: C901, PLR0911, PLR0912, PLR0913 + annotation: Any, + /, + self_annotation: AttributeAnnotation | None, + type_parameters: dict[str, Any], + module: str, + localns: dict[str, Any], + recursion_guard: Mapping[Any, AttributeAnnotation], # TODO: verify recursion! +) -> AttributeAnnotation: + # resolve annotation directly if able + match annotation: + # None + case types.NoneType | types.NoneType(): + return AttributeAnnotation( + origin=NoneType, + arguments=[], + ) + + # forward reference through string + case str() as forward_ref: + return _resolve_attribute_annotation( + ForwardRef(forward_ref, module=module)._evaluate( + globalns=None, + localns=localns, + recursive_guard=frozenset(), + ), + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, # we might need to update it somehow? + ) + + # forward reference directly + case typing.ForwardRef() as reference: + return _resolve_attribute_annotation( + reference._evaluate( + globalns=None, + localns=localns, + recursive_guard=frozenset(), + ), + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, # we might need to update it somehow? + ) + + # generic alias aka parametrized type + case types.GenericAlias() as generic_alias: + match get_origin(generic_alias): + # check for an alias with parameters + case typing.TypeAliasType() as alias: # pyright: ignore[reportUnnecessaryComparison] + type_alias: AttributeAnnotation = AttributeAnnotation( + origin=TypeAliasType, + arguments=[], + ) + resolved: AttributeAnnotation = _resolve_attribute_annotation( + alias.__value__, + self_annotation=None, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + type_alias.origin = resolved.origin + type_alias.arguments = resolved.arguments + return type_alias + + # check if we can resolve it as generic + case parametrized if issubclass(parametrized, Generic): + parametrized_type: Any = parametrized.__class_getitem__( # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + *( + type_parameters.get( + arg.__name__, + arg.__bound__ or Any, + ) + if isinstance(arg, TypeVar) + else arg + for arg in get_args(generic_alias) + ) + ) + + match parametrized_type: + # verify if we got any specific type or generic alias again + case types.GenericAlias(): + return AttributeAnnotation( + origin=parametrized, + arguments=[ + _resolve_attribute_annotation( + argument, + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(generic_alias) + ], + ) + + # use resolved type if it is not an alias again + case _: + return AttributeAnnotation( + origin=parametrized_type, + arguments=[], + ) + + # anything else - try to resolve a concrete type or use as is + case origin: + return AttributeAnnotation( + origin=origin, + arguments=[ + _resolve_attribute_annotation( + argument, + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(generic_alias) + ], + ) + + # type alias + case typing.TypeAliasType() as alias: + type_alias: AttributeAnnotation = AttributeAnnotation( + origin=TypeAliasType, + arguments=[], + ) + resolved: AttributeAnnotation = _resolve_attribute_annotation( + alias.__value__, + self_annotation=None, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + type_alias.origin = resolved.origin + type_alias.arguments = resolved.arguments + return type_alias + + # type parameter + case typing.TypeVar(): + return _resolve_attribute_annotation( + # try to resolve it from current parameters if able + type_parameters.get( + annotation.__name__, + # use bound as default or Any otherwise + annotation.__bound__ or Any, + ), + self_annotation=None, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + + case typing.ParamSpec(): + sys.stderr.write( + "ParamSpec is not supported for attribute annotations," + " ignoring with Any type - it might incorrectly validate types\n" + ) + return AttributeAnnotation( + origin=Any, + arguments=[], + ) + + case typing.TypeVarTuple(): + sys.stderr.write( + "TypeVarTuple is not supported for attribute annotations," + " ignoring with Any type - it might incorrectly validate types\n" + ) + return AttributeAnnotation( + origin=Any, + arguments=[], + ) + + case _: + pass # proceed to resolving based on origin + + # resolve based on origin if any + match get_origin(annotation) or annotation: + case types.UnionType | typing.Union: + return AttributeAnnotation( + origin=UnionType, # pyright: ignore[reportArgumentType] + arguments=[ + recursion_guard.get( + argument, + _resolve_attribute_annotation( + argument, + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ), + ) + for argument in get_args(annotation) + ], + ) + + case typing.Callable: # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + return AttributeAnnotation( + origin=typing.Callable, + arguments=[ + _resolve_attribute_annotation( + argument, + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(annotation) + ], + ) + + case typing.Self: # pyright: ignore[reportUnknownMemberType] + if not self_annotation: + sys.stderr.write( + "Unresolved Self attribute annotation," + " ignoring with Any type - it might incorrectly validate types\n" + ) + return AttributeAnnotation( + origin=Any, + arguments=[], + ) + + return self_annotation + + # unwrap from irrelevant type wrappers + case typing.Annotated | typing.Final | typing.Required | typing.NotRequired: + return _resolve_attribute_annotation( + get_args(annotation)[0], + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + + case typing.Optional: # optional is a Union[Value, None] + return AttributeAnnotation( + origin=UnionType, # pyright: ignore[reportArgumentType] + arguments=[ + _resolve_attribute_annotation( + get_args(annotation)[0], + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ), + AttributeAnnotation( + origin=NoneType, + arguments=[], + ), + ], + ) + + case typing.Literal: + return AttributeAnnotation( + origin=Literal, + arguments=list(get_args(annotation)), + ) + + case other: # finally use whatever there was + return AttributeAnnotation( + origin=other, + arguments=[ + _resolve_attribute_annotation( + argument, + self_annotation=self_annotation, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(other) + ], + ) diff --git a/src/haiway/state/structure.py b/src/haiway/state/structure.py new file mode 100644 index 0000000..dc5f14f --- /dev/null +++ b/src/haiway/state/structure.py @@ -0,0 +1,254 @@ +from collections.abc import Callable +from copy import deepcopy +from types import GenericAlias +from typing import ( + Any, + ClassVar, + Generic, + Self, + TypeVar, + cast, + dataclass_transform, + final, + get_origin, +) +from weakref import WeakValueDictionary + +from haiway.state.attributes import AttributeAnnotation, attribute_annotations +from haiway.state.validation import attribute_type_validator +from haiway.types.missing import MISSING, Missing + +__all__ = [ + "Structure", +] + + +@final +class StructureAttribute[Value]: + def __init__( + self, + annotation: AttributeAnnotation, + default: Value | Missing, + validator: Callable[[Any], Value], + ) -> None: + self.annotation: AttributeAnnotation = annotation + self.default: Value | Missing = default + self.validator: Callable[[Any], Value] = validator + + def validated( + self, + value: Any | Missing, + /, + ) -> Value: + return self.validator(self.default if value is MISSING else value) + + +@dataclass_transform( + kw_only_default=True, + frozen_default=True, + field_specifiers=(), +) +class StructureMeta(type): + def __new__( + cls, + /, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + type_parameters: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Any: + structure_type = type.__new__( + cls, + name, + bases, + namespace, + **kwargs, + ) + + attributes: dict[str, StructureAttribute[Any]] = {} + + if bases: # handle base class + for key, annotation in attribute_annotations( + structure_type, + type_parameters=type_parameters, + ).items(): + # do not include ClassVars and dunder items + if ((get_origin(annotation) or annotation) is ClassVar) or key.startswith("__"): + continue + + attributes[key] = StructureAttribute( + annotation=annotation, + default=getattr(structure_type, key, MISSING), + validator=attribute_type_validator(annotation), + ) + + structure_type.__ATTRIBUTES__ = attributes # pyright: ignore[reportAttributeAccessIssue] + structure_type.__slots__ = frozenset(attributes.keys()) # pyright: ignore[reportAttributeAccessIssue] + structure_type.__match_args__ = structure_type.__slots__ # pyright: ignore[reportAttributeAccessIssue] + + return structure_type + + +_types_cache: WeakValueDictionary[ + tuple[ + Any, + tuple[Any, ...], + ], + Any, +] = WeakValueDictionary() + + +class Structure(metaclass=StructureMeta): + """ + Base class for immutable data structures. + """ + + __ATTRIBUTES__: ClassVar[dict[str, StructureAttribute[Any]]] + + def __class_getitem__( + cls, + type_argument: tuple[type[Any], ...] | type[Any], + ) -> type[Self]: + assert Generic in cls.__bases__, "Can't specialize non generic type!" # nosec: B101 + + type_arguments: tuple[type[Any], ...] + match type_argument: + case [*arguments]: + type_arguments = tuple(arguments) + + case argument: + type_arguments = (argument,) + + if any(isinstance(argument, TypeVar) for argument in type_arguments): # pyright: ignore[reportUnnecessaryIsInstance] + # if we got unfinished type treat it as an alias instead of resolving + return cast(type[Self], GenericAlias(cls, type_arguments)) + + assert len(type_arguments) == len( # nosec: B101 + cls.__type_params__ + ), "Type arguments count has to match type parameters count" + + if cached := _types_cache.get((cls, type_arguments)): + return cached + + type_parameters: dict[str, Any] = { + parameter.__name__: argument + for (parameter, argument) in zip( + cls.__type_params__ or (), + type_arguments or (), + strict=False, + ) + } + + parameter_names: str = ",".join( + getattr( + argument, + "__name__", + str(argument), + ) + for argument in type_arguments + ) + name: str = f"{cls.__name__}[{parameter_names}]" + bases: tuple[type[Self]] = (cls,) + + parametrized_type: type[Self] = StructureMeta.__new__( + cls.__class__, + name=name, + bases=bases, + namespace={"__module__": cls.__module__}, + type_parameters=type_parameters, + ) + _types_cache[(cls, type_arguments)] = parametrized_type + return parametrized_type + + def __init__( + self, + **kwargs: Any, + ) -> None: + for name, attribute in self.__ATTRIBUTES__.items(): + object.__setattr__( + self, # pyright: ignore[reportUnknownArgumentType] + name, + attribute.validated( + kwargs.get( + name, + MISSING, + ), + ), + ) + + def updated( + self, + **kwargs: Any, + ) -> Self: + return self.__replace__(**kwargs) + + def as_dict(self) -> dict[str, Any]: + return vars(self) + + def __str__(self) -> str: + attributes: str = ", ".join([f"{key}: {value}" for key, value in vars(self).items()]) + return f"{self.__class__.__name__}({attributes})" + + def __repr__(self) -> str: + return str(self) + + def __eq__( + self, + other: Any, + ) -> bool: + if not issubclass(other.__class__, self.__class__): + return False + + return all( + getattr(self, key, MISSING) == getattr(other, key, MISSING) + for key in self.__ATTRIBUTES__.keys() + ) + + def __setattr__( + self, + name: str, + value: Any, + ) -> Any: + raise AttributeError( + f"Can't modify immutable structure {self.__class__.__qualname__}," + f" attribute - '{name}' cannot be modified" + ) + + def __delattr__( + self, + name: str, + ) -> None: + raise AttributeError( + f"Can't modify immutable structure {self.__class__.__qualname__}," + f" attribute - '{name}' cannot be deleted" + ) + + def __copy__(self) -> Self: + return self.__class__(**vars(self)) + + def __deepcopy__( + self, + memo: dict[int, Any] | None, + ) -> Self: + copy: Self = self.__class__( + **{ + key: deepcopy( + value, + memo, + ) + for key, value in vars(self).items() + } + ) + return copy + + def __replace__( + self, + **kwargs: Any, + ) -> Self: + return self.__class__( + **{ + **vars(self), + **kwargs, + } + ) diff --git a/src/haiway/state/validation.py b/src/haiway/state/validation.py new file mode 100644 index 0000000..251784e --- /dev/null +++ b/src/haiway/state/validation.py @@ -0,0 +1,125 @@ +import types +import typing +from collections.abc import Callable, Sequence +from typing import Any + +from haiway import types as _types +from haiway.state.attributes import AttributeAnnotation + +__all__ = [ + "attribute_type_validator", +] + + +def attribute_type_validator( + annotation: AttributeAnnotation, + /, +) -> Callable[[Any], Any]: + match annotation.origin: + case types.NoneType: + return _none_validator + + case _types.Missing: + return _missing_validator + + case types.UnionType: + return _prepare_union_validator(annotation.arguments) + + case typing.Literal: + return _prepare_literal_validator(annotation.arguments) + + case typing.Any: + return _any_validator + + case type() as other_type: + return _prepare_type_validator(other_type) + + case other: + raise TypeError(f"Unsupported type annotation: {other}") + + +def _none_validator( + value: Any, +) -> Any: + match value: + case None: + return None + + case _: + raise TypeError(f"Type '{type(value)}' is not matching expected type 'None'") + + +def _missing_validator( + value: Any, +) -> Any: + match value: + case _types.Missing(): + return _types.MISSING + + case _: + raise TypeError(f"Type '{type(value)}' is not matching expected type 'Missing'") + + +def _any_validator( + value: Any, +) -> Any: + return value # any is always valid + + +def _prepare_union_validator( + elements: Sequence[AttributeAnnotation], + /, +) -> Callable[[Any], Any]: + validators: list[Callable[[Any], Any]] = [ + attribute_type_validator(alternative) for alternative in elements + ] + + def union_validator( + value: Any, + ) -> Any: + errors: list[Exception] = [] + for validator in validators: + try: + return validator(value) + + except Exception as exc: + errors.append(exc) + + raise ExceptionGroup("Multiple errors", errors) + + return union_validator + + +def _prepare_literal_validator( + elements: Sequence[Any], + /, +) -> Callable[[Any], Any]: + def literal_validator( + value: Any, + ) -> Any: + if value in elements: + return value + + else: + raise ValueError(f"Value '{value}' is not matching expected '{elements}'") + + return literal_validator + + +def _prepare_type_validator( + validated_type: type[Any], + /, +) -> Callable[[Any], Any]: + def type_validator( + value: Any, + ) -> Any: + match value: + case value if isinstance(value, validated_type): + return value + + case _: + raise TypeError( + f"Type '{type(value)}' is not matching expected type '{validated_type}'" + ) + + return type_validator diff --git a/src/haiway/types/__init__.py b/src/haiway/types/__init__.py new file mode 100644 index 0000000..c810568 --- /dev/null +++ b/src/haiway/types/__init__.py @@ -0,0 +1,11 @@ +from haiway.types.frozen import frozenlist +from haiway.types.missing import MISSING, Missing, is_missing, not_missing, when_missing + +__all__ = [ + "frozenlist", + "is_missing", + "Missing", + "MISSING", + "not_missing", + "when_missing", +] diff --git a/src/haiway/types/frozen.py b/src/haiway/types/frozen.py new file mode 100644 index 0000000..116f57a --- /dev/null +++ b/src/haiway/types/frozen.py @@ -0,0 +1,5 @@ +__all__ = [ + "frozenlist", +] + +type frozenlist[Value] = tuple[Value, ...] diff --git a/src/haiway/types/missing.py b/src/haiway/types/missing.py new file mode 100644 index 0000000..5d6371f --- /dev/null +++ b/src/haiway/types/missing.py @@ -0,0 +1,91 @@ +from typing import Any, Final, TypeGuard, cast, final + +__all__ = [ + "MISSING", + "Missing", + "is_missing", + "not_missing", + "when_missing", +] + + +class MissingType(type): + _instance: Any = None + + def __call__(cls) -> Any: + if cls._instance is None: + cls._instance = super().__call__() + return cls._instance + + else: + return cls._instance + + +@final +class Missing(metaclass=MissingType): + """ + Type representing absence of a value. Use MISSING constant for its value. + """ + + def __bool__(self) -> bool: + return False + + def __eq__( + self, + value: object, + ) -> bool: + return value is MISSING + + def __str__(self) -> str: + return "MISSING" + + def __repr__(self) -> str: + return "MISSING" + + def __getattribute__( + self, + name: str, + ) -> Any: + raise RuntimeError("Missing has no attributes") + + def __setattr__( + self, + __name: str, + __value: Any, + ) -> None: + raise RuntimeError("Missing can't be modified") + + def __delattr__( + self, + __name: str, + ) -> None: + raise RuntimeError("Missing can't be modified") + + +MISSING: Final[Missing] = Missing() + + +def is_missing( + check: Any | Missing, + /, +) -> TypeGuard[Missing]: + return check is MISSING + + +def not_missing[Value]( + check: Value | Missing, + /, +) -> TypeGuard[Value]: + return check is not MISSING + + +def when_missing[Value]( + check: Value | Missing, + /, + value: Value, +) -> Value: + if check is MISSING: + return value + + else: + return cast(Value, check) diff --git a/src/haiway/utils/__init__.py b/src/haiway/utils/__init__.py new file mode 100644 index 0000000..6b0dd45 --- /dev/null +++ b/src/haiway/utils/__init__.py @@ -0,0 +1,23 @@ +from haiway.utils.always import always, async_always +from haiway.utils.env import getenv_bool, getenv_float, getenv_int, getenv_str, load_env +from haiway.utils.immutable import freeze +from haiway.utils.logs import setup_logging +from haiway.utils.mimic import mimic_function +from haiway.utils.noop import async_noop, noop +from haiway.utils.queue import AsyncQueue + +__all__ = [ + "always", + "async_always", + "async_noop", + "AsyncQueue", + "freeze", + "getenv_bool", + "getenv_float", + "getenv_int", + "getenv_str", + "load_env", + "mimic_function", + "noop", + "setup_logging", +] diff --git a/src/haiway/utils/always.py b/src/haiway/utils/always.py new file mode 100644 index 0000000..cd61a6d --- /dev/null +++ b/src/haiway/utils/always.py @@ -0,0 +1,61 @@ +from collections.abc import Callable, Coroutine +from typing import Any + +__all__ = [ + "always", + "async_always", +] + + +def always[Value]( + value: Value, + /, +) -> Callable[..., Value]: + """ + Factory method creating functions returning always the same value. + + Parameters + ---------- + value: Value + value to be always returned from prepared function + + Returns + ------- + Callable[..., Value] + function ignoring arguments and always returning the provided value. + """ + + def always_value( + *args: Any, + **kwargs: Any, + ) -> Value: + return value + + return always_value + + +def async_always[Value]( + value: Value, + /, +) -> Callable[..., Coroutine[None, None, Value]]: + """ + Factory method creating async functions returning always the same value. + + Parameters + ---------- + value: Value + value to be always returned from prepared function + + Returns + ------- + Callable[..., Coroutine[None, None, Value]] + async function ignoring arguments and always returning the provided value. + """ + + async def always_value( + *args: Any, + **kwargs: Any, + ) -> Value: + return value + + return always_value diff --git a/src/haiway/utils/env.py b/src/haiway/utils/env.py new file mode 100644 index 0000000..2de303f --- /dev/null +++ b/src/haiway/utils/env.py @@ -0,0 +1,164 @@ +from os import environ, getenv +from typing import overload + +__all__ = [ + "getenv_bool", + "getenv_int", + "getenv_float", + "getenv_str", + "load_env", +] + + +@overload +def getenv_bool( + key: str, + /, +) -> bool | None: ... + + +@overload +def getenv_bool( + key: str, + /, + default: bool, +) -> bool: ... + + +def getenv_bool( + key: str, + /, + default: bool | None = None, +) -> bool | None: + if value := getenv(key=key): + return value.lower() in ("true", "1", "t") + else: + return default + + +@overload +def getenv_int( + key: str, + /, +) -> int | None: ... + + +@overload +def getenv_int( + key: str, + /, + default: int, +) -> int: ... + + +def getenv_int( + key: str, + /, + default: int | None = None, +) -> int | None: + if value := getenv(key=key): + return int(value) + + else: + return default + + +@overload +def getenv_float( + key: str, + /, +) -> float | None: ... + + +@overload +def getenv_float( + key: str, + /, + default: float, +) -> float: ... + + +def getenv_float( + key: str, + /, + default: float | None = None, +) -> float | None: + if value := getenv(key=key): + return float(value) + + else: + return default + + +@overload +def getenv_str( + key: str, + /, +) -> str | None: ... + + +@overload +def getenv_str( + key: str, + /, + default: str, +) -> str: ... + + +def getenv_str( + key: str, + /, + default: str | None = None, +) -> str | None: + if value := getenv(key=key): + return value + else: + return default + + +def load_env( + path: str | None = None, + override: bool = True, +) -> None: + """\ + Minimalist implementation of environment variables file loader. \ + When the file is not available configuration won't be loaded. + Allows only subset of formatting: + - lines starting with '#' are ignored + - other comments are not allowed + - each element is in a new line + - each element must be a `key=value` pair without whitespaces or additional characters + - keys without values are ignored + + Parameters + ---------- + path: str + custom path to load environment variables, default is '.env' + override: bool + override existing variables on conflict if True, otherwise keep existing + """ + + try: + with open(file=path or ".env") as file: + for line in file.readlines(): + if line.startswith("#"): + continue # ignore commented + + idx: int # find where key ends + for element in enumerate(line): + if element[1] == "=": + idx: int = element[0] + break + else: # ignore keys without assignment + continue + + if idx >= len(line): + continue # ignore keys without values + + key: str = line[0:idx] + value: str = line[idx + 1 :].strip() + if value and (override or key not in environ): + environ[key] = value + + except FileNotFoundError: + pass # ignore loading if no .env available diff --git a/src/haiway/utils/immutable.py b/src/haiway/utils/immutable.py new file mode 100644 index 0000000..df5000a --- /dev/null +++ b/src/haiway/utils/immutable.py @@ -0,0 +1,28 @@ +from typing import Any + +__all__ = [ + "freeze", +] + + +def freeze( + instance: object, + /, +) -> None: + """ + Freeze object instance by replacing __delattr__ and __setattr__ to raising Exceptions. + """ + + def frozen_set( + __name: str, + __value: Any, + ) -> None: + raise RuntimeError(f"{instance.__class__.__qualname__} is frozen and can't be modified") + + def frozen_del( + __name: str, + ) -> None: + raise RuntimeError(f"{instance.__class__.__qualname__} is frozen and can't be modified") + + instance.__delattr__ = frozen_del + instance.__setattr__ = frozen_set diff --git a/src/haiway/utils/logs.py b/src/haiway/utils/logs.py new file mode 100644 index 0000000..5259922 --- /dev/null +++ b/src/haiway/utils/logs.py @@ -0,0 +1,57 @@ +from logging.config import dictConfig + +from haiway.utils.env import getenv_bool + +__all__ = [ + "setup_logging", +] + + +def setup_logging( + *loggers: str, + debug: bool = getenv_bool("DEBUG_LOGGING", __debug__), +) -> None: + """\ + Setup logging configuration and prepare specified loggers. + + Parameters + ---------- + *loggers: str + names of additional loggers to configure. + + NOTE: this function should be run only once on application start + """ + + dictConfig( + config={ + "version": 1, + "disable_existing_loggers": True, + "formatters": { + "standard": { + "format": "%(asctime)s [%(levelname)-4s] [%(name)s] %(message)s", + "datefmt": "%d/%b/%Y:%H:%M:%S +0000", + }, + }, + "handlers": { + "console": { + "level": "DEBUG" if debug else "INFO", + "formatter": "standard", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + name: { + "handlers": ["console"], + "level": "DEBUG" if debug else "INFO", + "propagate": False, + } + for name in loggers + }, + "root": { # root logger + "handlers": ["console"], + "level": "DEBUG" if debug else "INFO", + "propagate": False, + }, + }, + ) diff --git a/src/haiway/utils/mimic.py b/src/haiway/utils/mimic.py new file mode 100644 index 0000000..b3ce75c --- /dev/null +++ b/src/haiway/utils/mimic.py @@ -0,0 +1,77 @@ +from collections.abc import Callable +from typing import Any, cast, overload + +__all__ = [ + "mimic_function", +] + + +@overload +def mimic_function[**Args, Result]( + function: Callable[Args, Result], + /, + within: Callable[..., Any], +) -> Callable[Args, Result]: ... + + +@overload +def mimic_function[**Args, Result]( + function: Callable[Args, Result], + /, +) -> Callable[[Callable[..., Any]], Callable[Args, Result]]: ... + + +def mimic_function[**Args, Result]( + function: Callable[Args, Result], + /, + within: Callable[..., Result] | None = None, +) -> Callable[[Callable[..., Result]], Callable[Args, Result]] | Callable[Args, Result]: + def mimic( + target: Callable[..., Result], + ) -> Callable[Args, Result]: + # mimic function attributes if able + for attribute in ( + "__module__", + "__name__", + "__qualname__", + "__doc__", + "__annotations__", + "__type_params__", + "__defaults__", + "__kwdefaults__", + "__globals__", + ): + try: + setattr( + target, + attribute, + getattr( + function, + attribute, + ), + ) + + except AttributeError: + pass + try: + target.__dict__.update(function.__dict__) + + except AttributeError: + pass + + setattr( # noqa: B010 - mimic functools.wraps behavior for correct signature checks + target, + "__wrapped__", + function, + ) + + return cast( + Callable[Args, Result], + target, + ) + + if target := within: + return mimic(target) + + else: + return mimic diff --git a/src/haiway/utils/noop.py b/src/haiway/utils/noop.py new file mode 100644 index 0000000..21c2dd7 --- /dev/null +++ b/src/haiway/utils/noop.py @@ -0,0 +1,24 @@ +from typing import Any + +__all__ = [ + "async_noop", + "noop", +] + + +def noop( + *args: Any, + **kwargs: Any, +) -> None: + """ + Placeholder function doing nothing (no operation). + """ + + +async def async_noop( + *args: Any, + **kwargs: Any, +) -> None: + """ + Placeholder async function doing nothing (no operation). + """ diff --git a/src/haiway/utils/queue.py b/src/haiway/utils/queue.py new file mode 100644 index 0000000..00d46ff --- /dev/null +++ b/src/haiway/utils/queue.py @@ -0,0 +1,89 @@ +from asyncio import AbstractEventLoop, CancelledError, Future, get_running_loop +from collections import deque +from collections.abc import AsyncIterator +from typing import Self + +from haiway.utils.immutable import freeze + +__all__ = [ + "AsyncQueue", +] + + +class AsyncQueue[Element](AsyncIterator[Element]): + """ + Asynchronous queue supporting iteration and finishing. + Cannot be concurrently consumed by multiple readers. + """ + + def __init__( + self, + loop: AbstractEventLoop | None = None, + ) -> None: + self._loop: AbstractEventLoop = loop or get_running_loop() + self._queue: deque[Element] = deque() + self._waiting: Future[Element] | None = None + self._finish_reason: BaseException | None = None + + freeze(self) + + def __del__(self) -> None: + self.finish() + + @property + def finished(self) -> bool: + return self._finish_reason is not None + + def enqueue( + self, + element: Element, + /, + *elements: Element, + ) -> None: + if self.finished: + raise RuntimeError("AsyncQueue is already finished") + + if self._waiting is not None and not self._waiting.done(): + self._waiting.set_result(element) + + else: + self._queue.append(element) + + self._queue.extend(elements) + + def finish( + self, + exception: BaseException | None = None, + ) -> None: + if self.finished: + return # already finished, ignore + + self._finish_reason = exception or StopAsyncIteration() + + if self._waiting is not None and not self._waiting.done(): + self._waiting.set_exception(self._finish_reason) + + def cancel(self) -> None: + self.finish(exception=CancelledError()) + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> Element: + assert self._waiting is None, "Only a single queue iterator is supported!" # nosec: B101 + + if self._queue: # check the queue, let it finish + return self._queue.popleft() + + if self._finish_reason is not None: # check if is finished + raise self._finish_reason + + try: + # create a new future to wait for next + self._waiting = self._loop.create_future() + # wait for the result + return await self._waiting + + finally: + # cleanup + self._waiting = None diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_async_queue.py b/tests/test_async_queue.py new file mode 100644 index 0000000..d4503ed --- /dev/null +++ b/tests/test_async_queue.py @@ -0,0 +1,110 @@ +from asyncio import CancelledError + +from haiway import AsyncQueue, ctx +from pytest import mark, raises + + +class FakeException(Exception): + pass + + +@mark.asyncio +async def test_fails_when_stream_fails(): + stream: AsyncQueue[int] = AsyncQueue() + stream.enqueue(0) + stream.finish(exception=FakeException()) + elements: int = 0 + with raises(FakeException): + async for _ in stream: + elements += 1 + + assert elements == 1 + + +@mark.asyncio +async def test_cancels_when_iteration_cancels(): + stream: AsyncQueue[int] = AsyncQueue() + elements: int = 0 + with raises(CancelledError): + ctx.cancel() + async for _ in stream: + elements += 1 + + assert elements == 0 + + +@mark.asyncio +async def test_ends_when_stream_ends(): + stream: AsyncQueue[int] = AsyncQueue() + stream.finish() + elements: int = 0 + async for _ in stream: + elements += 1 + + assert elements == 0 + + +@mark.asyncio +async def test_buffers_values_when_not_reading(): + stream: AsyncQueue[int] = AsyncQueue() + stream.enqueue(0) + stream.enqueue(1) + stream.enqueue(2) + stream.enqueue(3) + stream.finish() + elements: int = 0 + + async for _ in stream: + elements += 1 + + assert elements == 4 + + +@mark.asyncio +async def test_delivers_buffer_when_streaming_fails(): + stream: AsyncQueue[int] = AsyncQueue() + stream.enqueue(0) + stream.enqueue(1) + stream.enqueue(2) + stream.enqueue(3) + stream.finish(exception=FakeException()) + elements: int = 0 + + with raises(FakeException): + async for _ in stream: + elements += 1 + + assert elements == 4 + + +@mark.asyncio +async def test_delivers_updates_when_sending(): + stream: AsyncQueue[int] = AsyncQueue() + stream.enqueue(0) + + elements: list[int] = [] + + async for element in stream: + elements.append(element) + if len(elements) < 10: + stream.enqueue(element + 1) + else: + stream.finish() + + assert elements == list(range(0, 10)) + + +@mark.asyncio +async def test_fails_when_sending_to_finished(): + stream: AsyncQueue[int] = AsyncQueue() + stream.finish() + + with raises(RuntimeError): + stream.enqueue(42) + + +@mark.asyncio +async def test_ignores_when_finishing_when_finished(): + stream: AsyncQueue[int] = AsyncQueue() + stream.finish() + stream.finish() # should not raise diff --git a/tests/test_auto_retry.py b/tests/test_auto_retry.py new file mode 100644 index 0000000..d9392c9 --- /dev/null +++ b/tests/test_auto_retry.py @@ -0,0 +1,296 @@ +from asyncio import CancelledError, Task, sleep +from time import time +from unittest import TestCase + +from haiway import auto_retry +from pytest import mark, raises + + +class FakeException(Exception): + pass + + +@mark.asyncio +async def test_returns_value_without_errors(): + executions: int = 0 + + @auto_retry + def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + return value + + assert compute("expected") == "expected" + assert executions == 1 + + +@mark.asyncio +async def test_retries_with_errors(): + executions: int = 0 + + @auto_retry + def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + if executions == 1: + raise FakeException() + else: + return value + + assert compute("expected") == "expected" + assert executions == 2 + + +@mark.asyncio +async def test_logs_issue_with_errors(): + executions: int = 0 + test_case = TestCase() + + @auto_retry + def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + if executions == 1: + raise FakeException("fake") + else: + return value + + with test_case.assertLogs() as logs: + compute("expected") + assert executions == 2 + assert logs.output == [ + f"ERROR:root:Attempting to retry {compute.__name__}" + f" which failed due to an error: {FakeException("fake")}" + ] + + +@mark.asyncio +async def test_fails_with_exceeding_errors(): + executions: int = 0 + + @auto_retry(limit=1) + def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + raise FakeException() + + with raises(FakeException): + compute("expected") + assert executions == 2 + + +@mark.asyncio +async def test_fails_with_cancellation(): + executions: int = 0 + + @auto_retry(limit=1) + def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + raise CancelledError() + + with raises(CancelledError): + compute("expected") + assert executions == 1 + + +@mark.asyncio +async def test_retries_with_selected_errors(): + executions: int = 0 + + @auto_retry + def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + if executions == 1: + raise FakeException() + else: + return value + + assert compute("expected") == "expected" + assert executions == 2 + + +@mark.asyncio +async def test_fails_with_not_selected_errors(): + executions: int = 0 + + @auto_retry(catching={ValueError}) + def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + raise FakeException() + + with raises(FakeException): + compute("expected") + + assert executions == 1 + + +@mark.asyncio +async def test_async_returns_value_without_errors(): + executions: int = 0 + + @auto_retry + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + return value + + assert await compute("expected") == "expected" + assert executions == 1 + + +@mark.asyncio +async def test_async_retries_with_errors(): + executions: int = 0 + + @auto_retry + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + if executions == 1: + raise FakeException() + else: + return value + + assert await compute("expected") == "expected" + assert executions == 2 + + +@mark.asyncio +async def test_async_fails_with_exceeding_errors(): + executions: int = 0 + + @auto_retry(limit=1) + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + raise FakeException() + + with raises(FakeException): + await compute("expected") + assert executions == 2 + + +@mark.asyncio +async def test_async_fails_with_cancellation(): + executions: int = 0 + + @auto_retry(limit=1) + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + raise CancelledError() + + with raises(CancelledError): + await compute("expected") + assert executions == 1 + + +@mark.asyncio +async def test_async_fails_when_cancelled(): + executions: int = 0 + + @auto_retry(limit=1) + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + await sleep(1) + return value + + with raises(CancelledError): + task = Task(compute("expected")) + await sleep(0.02) + task.cancel() + await task + assert executions == 1 + + +@mark.asyncio +async def test_async_uses_delay_with_errors(): + executions: int = 0 + + @auto_retry(limit=2, delay=0.05) + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + raise FakeException() + + time_start: float = time() + with raises(FakeException): + await compute("expected") + assert (time() - time_start) >= 0.1 + assert executions == 3 + + +@mark.asyncio +async def test_async_uses_computed_delay_with_errors(): + executions: int = 0 + + @auto_retry(limit=2, delay=lambda attempt, _: attempt * 0.035) + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + raise FakeException() + + time_start: float = time() + with raises(FakeException): + await compute("expected") + assert (time() - time_start) >= 0.1 + assert executions == 3 + + +@mark.asyncio +async def test_async_logs_issue_with_errors(): + executions: int = 0 + test_case = TestCase() + + @auto_retry + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + if executions == 1: + raise FakeException("fake") + else: + return value + + with test_case.assertLogs() as logs: + await compute("expected") + assert executions == 2 + assert logs.output[0].startswith( + f"ERROR:root:Attempting to retry {compute.__name__}" " which failed due to an error" + ) + + +@mark.asyncio +async def test_async_retries_with_selected_errors(): + executions: int = 0 + + @auto_retry + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + if executions == 1: + raise FakeException() + else: + return value + + assert await compute("expected") == "expected" + assert executions == 2 + + +@mark.asyncio +async def test_async_fails_with_not_selected_errors(): + executions: int = 0 + + @auto_retry(catching={ValueError}) + async def compute(value: str, /) -> str: + nonlocal executions + executions += 1 + raise FakeException() + + with raises(FakeException): + await compute("expected") + + assert executions == 1 diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..e65d160 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,191 @@ +from asyncio import CancelledError, Task, sleep +from collections.abc import Callable, Generator +from time import sleep as sync_sleep + +from haiway import cached +from pytest import fixture, mark, raises + + +class FakeException(Exception): + pass + + +@fixture +def fake_random() -> Callable[[], Generator[int, None, None]]: + def random_next() -> Generator[int, None, None]: + yield from range(0, 65536) + + return random_next + + +def test_returns_cached_value_with_same_argument(fake_random: Callable[[], int]): + @cached + def randomized(_: str, /) -> int: + return fake_random() + + expected: int = randomized("expected") + assert randomized("expected") == expected + + +def test_returns_fresh_value_with_different_argument(fake_random: Callable[[], int]): + @cached + def randomized(_: str, /) -> int: + return fake_random() + + expected: int = randomized("expected") + assert randomized("checked") != expected + + +def test_returns_fresh_value_with_limit_exceed(fake_random: Callable[[], int]): + @cached(limit=1) + def randomized(_: str, /) -> int: + return fake_random() + + expected: int = randomized("expected") + randomized("different") + assert randomized("expected") != expected + + +def test_returns_same_value_with_repeating_argument(fake_random: Callable[[], int]): + @cached(limit=2) + def randomized(_: str, /) -> int: + return fake_random() + + expected: int = randomized("expected") + randomized("different") + randomized("expected") + randomized("more_different") + randomized("expected") + randomized("final_different") + assert randomized("expected") == expected + + +def test_fails_with_error(): + @cached(expiration=0.02) + def randomized(_: str, /) -> int: + raise FakeException() + + with raises(FakeException): + randomized("expected") + + +def test_returns_fresh_value_with_expiration_time_exceed(fake_random: Callable[[], int]): + @cached(expiration=0.01) + def randomized(_: str, /) -> int: + return fake_random() + + expected: int = randomized("expected") + sync_sleep(0.02) + assert randomized("expected") != expected + + +@mark.asyncio +async def test_async_returns_cached_value_with_same_argument(fake_random: Callable[[], int]): + @cached + async def randomized(_: str, /) -> int: + return fake_random() + + expected: int = await randomized("expected") + assert await randomized("expected") == expected + + +@mark.asyncio +async def test_async_returns_fresh_value_with_different_argument(fake_random: Callable[[], int]): + @cached + async def randomized(_: str, /) -> int: + return fake_random() + + expected: int = await randomized("expected") + assert await randomized("checked") != expected + + +@mark.asyncio +async def test_async_returns_fresh_value_with_limit_exceed(fake_random: Callable[[], int]): + @cached(limit=1) + async def randomized(_: str, /) -> int: + return fake_random() + + expected: int = await randomized("expected") + await randomized("different") + assert await randomized("expected") != expected + + +@mark.asyncio +async def test_async_returns_same_value_with_repeating_argument(fake_random: Callable[[], int]): + @cached(limit=2) + async def randomized(_: str, /) -> int: + return fake_random() + + expected: int = await randomized("expected") + await randomized("different") + await randomized("expected") + await randomized("more_different") + await randomized("expected") + await randomized("final_different") + assert await randomized("expected") == expected + + +@mark.asyncio +async def test_async_returns_fresh_value_with_expiration_time_exceed( + fake_random: Callable[[], int], +): + @cached(expiration=0.01) + async def randomized(_: str, /) -> int: + return fake_random() + + expected: int = await randomized("expected") + await sleep(0.02) + assert await randomized("expected") != expected + + +@mark.asyncio +async def test_async_cancel_waiting_does_not_cancel_task(): + @cached + async def randomized(_: str, /) -> int: + try: + await sleep(0.5) + return 0 + except CancelledError: + return 42 + + expected: int = await randomized("expected") + cancelled = Task(randomized("expected")) + + async def delayed_cancel() -> None: + cancelled.cancel() + + Task(delayed_cancel()) + assert await randomized("expected") == expected + + +@mark.asyncio +async def test_async_expiration_does_not_cancel_task(): + @cached(expiration=0.01) + async def randomized(_: str, /) -> int: + try: + await sleep(0.02) + return 0 + except CancelledError: + return 42 + + assert await randomized("expected") == 0 + + +@mark.asyncio +async def test_async_expiration_creates_new_task(fake_random: Callable[[], int]): + @cached(expiration=0.01) + async def randomized(_: str, /) -> int: + await sleep(0.02) + return fake_random() + + assert await randomized("expected") != await randomized("expected") + + +@mark.asyncio +async def test_async_fails_with_error(): + @cached(expiration=0.02) + async def randomized(_: str, /) -> int: + raise FakeException() + + with raises(FakeException): + await randomized("expected") diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..6b5f5e5 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,101 @@ +from haiway import MissingContext, ScopeMetrics, Structure, ctx +from pytest import mark, raises + + +class ExampleState(Structure): + state: str = "default" + + +class FakeException(Exception): + pass + + +@mark.asyncio +async def test_state_is_available_according_to_context(): + with raises(MissingContext): + assert ctx.state(ExampleState).state == "default" + + async with ctx.scope("default"): + assert ctx.state(ExampleState).state == "default" + + async with ctx.scope("specified", ExampleState(state="specified")): + assert ctx.state(ExampleState).state == "specified" + + async with ctx.scope("modified", ExampleState(state="modified")): + assert ctx.state(ExampleState).state == "modified" + + assert ctx.state(ExampleState).state == "specified" + + assert ctx.state(ExampleState).state == "default" + + with raises(MissingContext): + assert ctx.state(ExampleState).state == "default" + + +@mark.asyncio +async def test_state_update_updates_local_context(): + with raises(MissingContext): + assert ctx.state(ExampleState).state == "default" + + async with ctx.scope("default"): + assert ctx.state(ExampleState).state == "default" + + with ctx.updated(ExampleState(state="updated")): + assert ctx.state(ExampleState).state == "updated" + + with ctx.updated(ExampleState(state="modified")): + assert ctx.state(ExampleState).state == "modified" + + assert ctx.state(ExampleState).state == "updated" + + assert ctx.state(ExampleState).state == "default" + + with raises(MissingContext): + assert ctx.state(ExampleState).state == "default" + + +@mark.asyncio +async def test_exceptions_are_propagated(): + with raises(FakeException): + async with ctx.scope("outer"): + async with ctx.scope("inner"): + raise FakeException() + + +@mark.asyncio +async def test_completions_are_called_according_to_context_exits(): + executions: int = 0 + + async def completion(metrics: ScopeMetrics): + nonlocal executions + executions += 1 + + async with ctx.scope("outer", completion=completion): + assert executions == 0 + + async with ctx.scope("inner", completion=completion): + assert executions == 0 + + assert executions == 1 + + assert executions == 2 + + +@mark.asyncio +async def test_metrics_are_recorded_within_context(): + def verify_example_metrics(state: str): + async def completion(metrics: ScopeMetrics): + assert metrics.read(ExampleState, default=ExampleState()).state == state + + return completion + + async with ctx.scope("outer", completion=verify_example_metrics("outer-in-out")): + ctx.record(ExampleState(state="outer-in")) + + async with ctx.scope("inner", completion=verify_example_metrics("inner")): + ctx.record(ExampleState(state="inner")) + + ctx.record( + ExampleState(state="-out"), + merge=lambda lhs, rhs: ExampleState(state=lhs.state + rhs.state), + ) diff --git a/tests/test_structure.py b/tests/test_structure.py new file mode 100644 index 0000000..3f679b9 --- /dev/null +++ b/tests/test_structure.py @@ -0,0 +1,115 @@ +from collections.abc import Callable +from typing import Literal, Protocol, Self, runtime_checkable + +from haiway import Structure, frozenlist + + +def test_basic_initializes_with_arguments() -> None: + @runtime_checkable + class Proto(Protocol): + def __call__(self) -> None: ... + + class Basics(Structure): + string: str + literal: Literal["A", "B"] + sequence: list[str] + frozen: frozenlist[int] + integer: int + union: str | int + optional: str | None + none: None + function: Callable[[], None] + proto: Proto + + basic = Basics( + string="string", + literal="A", + sequence=["a", "b", "c"], + frozen=(1, 2, 3), + integer=0, + union="union", + optional="optional", + none=None, + function=lambda: None, + proto=lambda: None, + ) + assert basic.string == "string" + assert basic.literal == "A" + assert basic.sequence == ["a", "b", "c"] + assert basic.frozen == (1, 2, 3) + assert basic.integer == 0 + assert basic.union == "union" + assert basic.optional == "optional" + assert basic.none is None + assert callable(basic.function) + assert isinstance(basic.proto, Proto) + + +def test_basic_initializes_with_defaults() -> None: + class Basics(Structure): + string: str = "" + integer: int = 0 + optional: str | None = None + + basic = Basics() + assert basic.string == "" + assert basic.integer == 0 + assert basic.optional is None + + +def test_basic_initializes_with_arguments_and_defaults() -> None: + class Basics(Structure): + string: str + integer: int = 0 + optional: str | None = None + + basic = Basics( + string="string", + integer=42, + ) + assert basic.string == "string" + assert basic.integer == 42 + assert basic.optional is None + + +def test_parametrized_initializes_with_proper_parameters() -> None: + class Parametrized[T](Structure): + value: T + + parametrized_string = Parametrized( + value="string", + ) + assert parametrized_string.value == "string" + + parametrized_int = Parametrized( + value=42, + ) + assert parametrized_int.value == 42 + + assert parametrized_string != parametrized_int + + +def test_nested_initializes_with_proper_arguments() -> None: + class Nested(Structure): + string: str + + class Recursive(Structure): + nested: Nested + recursion: "Recursive | None" + self_recursion: Self | None + + recursive = Recursive( + nested=Nested(string="one"), + recursion=Recursive( + nested=Nested(string="two"), + recursion=None, + self_recursion=None, + ), + self_recursion=None, + ) + assert recursive.nested == Nested(string="one") + assert recursive.recursion == Recursive( + nested=Nested(string="two"), + recursion=None, + self_recursion=None, + ) diff --git a/tests/test_timeout.py b/tests/test_timeout.py new file mode 100644 index 0000000..4a56085 --- /dev/null +++ b/tests/test_timeout.py @@ -0,0 +1,52 @@ +from asyncio import CancelledError, Task, sleep + +from haiway import with_timeout +from pytest import mark, raises + + +class FakeException(Exception): + pass + + +@mark.asyncio +async def test_returns_result_when_returning_value(): + @with_timeout(3) + async def long_running() -> int: + return 42 + + assert await long_running() == 42 + + +@mark.asyncio +async def test_raises_with_error(): + @with_timeout(3) + async def long_running() -> int: + raise FakeException() + + with raises(FakeException): + await long_running() + + +@mark.asyncio +async def test_raises_with_cancel(): + @with_timeout(3) + async def long_running() -> int: + await sleep(1) + raise RuntimeError("Invalid state") + + task = Task(long_running()) + with raises(CancelledError): + await sleep(0.01) + task.cancel() + await task + + +@mark.asyncio +async def test_raises_with_timeout(): + @with_timeout(0.01) + async def long_running() -> int: + await sleep(0.03) + raise RuntimeError("Invalid state") + + with raises(TimeoutError): + await long_running()