From 8871aa47445d8c6d1c4a73eb85b2ecff6491069b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Goran=20Meki=C4=87?= Date: Thu, 1 Feb 2024 01:43:34 +0100 Subject: [PATCH] Use lifespan function instead of event --- .github/workflows/test-package.yml | 26 ++- benchmarks/conftest.py | 24 +-- docs/fastapi/index.md | 45 ++--- docs_src/fastapi/docs001.py | 32 +--- examples/fastapi_quick_start.py | 21 +-- ormar/models/ormar_config.py | 4 + tests/lifespan.py | 33 ++++ tests/settings.py | 14 ++ .../test_deferred/test_forward_cross_refs.py | 30 +--- tests/test_deferred/test_forward_refs.py | 81 +++------ .../test_more_same_table_joins.py | 53 ++---- tests/test_deferred/test_same_table_joins.py | 59 ++----- .../test_encryption/test_encrypted_columns.py | 30 +--- .../test_complex_relation_tree_performance.py | 22 +-- .../test_dumping_model_to_dict.py | 16 +- .../test_excludable_items.py | 17 +- .../test_excluding_fields_in_fastapi.py | 47 +---- .../test_excluding_fields_with_default.py | 33 +--- .../test_excluding_subset_of_columns.py | 37 ++-- .../test_pydantic_dict_params.py | 34 ++-- tests/test_fastapi/test_binary_fields.py | 41 +---- ...est_docs_with_multiple_relations_to_one.py | 17 +- tests/test_fastapi/test_enum_schema.py | 18 +- .../test_excludes_with_get_pydantic.py | 56 +++--- tests/test_fastapi/test_excluding_fields.py | 45 +---- .../test_extra_ignore_parameter.py | 39 +---- tests/test_fastapi/test_fastapi_docs.py | 38 +--- tests/test_fastapi/test_fastapi_usage.py | 26 ++- .../test_inheritance_concrete_fastapi.py | 162 ++++++++++++++---- .../test_inheritance_mixins_fastapi.py | 63 +++---- tests/test_fastapi/test_json_field_fastapi.py | 44 +---- tests/test_fastapi/test_m2m_forwardref.py | 39 +---- .../test_more_reallife_fastapi.py | 44 +---- tests/test_fastapi/test_nested_saving.py | 52 ++---- tests/test_fastapi/test_recursion_error.py | 52 +----- .../test_relations_with_nested_defaults.py | 40 +---- .../test_schema_not_allowed_params.py | 13 +- .../test_fastapi/test_skip_reverse_models.py | 39 +---- tests/test_fastapi/test_wekref_exclusion.py | 44 +---- ...est_excluding_parent_fields_inheritance.py | 51 ++---- .../test_geting_pydantic_models.py | 14 +- .../test_inheritance_concrete.py | 98 ++++------- .../test_inheritance_mixins.py | 46 ++--- .../test_inheritance_of_property_fields.py | 31 +--- .../test_inheritance_with_default.py | 23 +-- ...erited_class_is_not_abstract_by_default.py | 27 +-- .../test_nested_models_pydantic.py | 16 +- .../test_pydantic_fields_order.py | 27 +-- .../test_validators_are_inherited.py | 18 +- .../test_validators_in_generated_pydantic.py | 22 +-- .../test_check_constraints.py | 25 +-- .../test_index_constraints.py | 25 +-- .../test_unique_constraints.py | 25 +-- tests/test_model_definition/test_aliases.py | 45 ++--- tests/test_model_definition/test_columns.py | 42 ++--- .../test_create_uses_init_for_consistency.py | 22 +-- .../test_dates_with_timezone.py | 51 ++---- .../test_equality_and_hash.py | 29 +--- .../test_extra_ignore_parameter.py | 16 +- .../test_fields_access.py | 23 +-- ...oreign_key_value_used_for_related_model.py | 28 +-- tests/test_model_definition/test_iterate.py | 87 ++++------ .../test_model_construct.py | 53 ++---- .../test_model_definition.py | 59 ++----- tests/test_model_definition/test_models.py | 158 ++++++----------- .../test_models_are_pickable.py | 31 +--- .../test_overwriting_pydantic_field_type.py | 25 +-- .../test_overwriting_sql_nullable.py | 24 +-- .../test_pk_field_is_always_not_null.py | 16 +- .../test_model_definition/test_properties.py | 25 +-- .../test_pydantic_fields.py | 26 +-- .../test_pydantic_only_fields.py | 27 +-- .../test_pydantic_private_attributes.py | 16 +- .../test_model_definition/test_save_status.py | 73 +++----- .../test_saving_nullable_fields.py | 33 +--- .../test_server_default.py | 27 +-- .../test_setting_comments_in_db.py | 21 +-- .../test_excludes_in_load_all.py | 25 +-- tests/test_model_methods/test_load_all.py | 41 ++--- .../test_populate_default_values.py | 18 +- tests/test_model_methods/test_save_related.py | 61 ++----- .../test_save_related_from_dict.py | 77 +++------ .../test_save_related_uuid.py | 32 +--- tests/test_model_methods/test_update.py | 31 +--- tests/test_model_methods/test_upsert.py | 29 +--- .../test_ordering/test_default_model_order.py | 29 +--- .../test_default_relation_order.py | 31 +--- .../test_default_through_relation_order.py | 25 +-- .../test_proper_order_of_sorting_apply.py | 25 +-- tests/test_queries/test_adding_related.py | 29 +--- tests/test_queries/test_aggr_functions.py | 35 ++-- .../test_deep_relations_select_all.py | 81 ++------- tests/test_queries/test_filter_groups.py | 16 +- .../test_indirect_relations_to_self.py | 31 +--- tests/test_queries/test_isnull_filter.py | 31 +--- .../test_nested_reverse_relations.py | 22 +-- .../test_non_relation_fields_not_merged.py | 23 +-- tests/test_queries/test_or_filters.py | 23 +-- tests/test_queries/test_order_by.py | 73 ++------ tests/test_queries/test_pagination.py | 37 ++-- .../test_queryproxy_on_m2m_models.py | 54 ++---- .../test_queryset_level_methods.py | 91 ++++------ ...t_quoting_table_names_in_on_join_clause.py | 37 +--- .../test_reserved_sql_keywords_escaped.py | 25 +-- .../test_queries/test_reverse_fk_queryset.py | 55 ++---- .../test_selecting_subset_of_columns.py | 55 ++---- .../test_values_and_values_list.py | 54 +++--- tests/test_relations/test_cascades.py | 57 ++---- ...ustomizing_through_model_relation_names.py | 31 +--- .../test_database_fk_creation.py | 49 ++---- tests/test_relations/test_foreign_keys.py | 105 ++++-------- .../test_relations/test_m2m_through_fields.py | 47 ++--- tests/test_relations/test_many_to_many.py | 54 ++---- ...est_postgress_select_related_with_limit.py | 25 +-- tests/test_relations/test_prefetch_related.py | 85 +++------ ...efetch_related_multiple_models_relation.py | 37 +--- .../test_python_style_relations.py | 38 ++-- .../test_relations_default_exception.py | 36 ++-- .../test_replacing_models_with_copy.py | 29 +--- tests/test_relations/test_saving_related.py | 37 ++-- .../test_select_related_with_limit.py | 43 ++--- ...select_related_with_m2m_and_pk_name_set.py | 22 +-- .../test_selecting_proper_table_prefix.py | 37 +--- tests/test_relations/test_skipping_reverse.py | 30 +--- .../test_through_relations_fail.py | 17 +- tests/test_relations/test_weakref_checking.py | 24 +-- tests/test_signals/test_signals.py | 65 +++---- .../test_signals_for_relations.py | 47 ++--- tests/test_utils/test_queryset_utils.py | 20 +-- 129 files changed, 1596 insertions(+), 3470 deletions(-) create mode 100644 tests/lifespan.py diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index 8f1d7bdc1..9456a4300 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -39,7 +39,7 @@ jobs: POSTGRES_DB: testsuite ports: - 5432:5432 - options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 --name postgres steps: - name: Checkout @@ -71,6 +71,30 @@ jobs: DATABASE_URL: "mysql://username:password@127.0.0.1:3306/testsuite" run: bash scripts/test.sh + - name: Install postgresql-client + run: | + sudo apt-get update + sudo apt-get install --yes postgresql-client + + - name: Connect to PostgreSQL with CLI + run: env PGPASSWORD=password psql -h localhost -U username -c 'SELECT VERSION();' testsuite + + - name: Show max connections + run: env PGPASSWORD=password psql -h localhost -U username -c 'SHOW max_connections;' testsuite + + - name: Alter max connections + run: | + + docker exec -i postgres bash << EOF + sed -i -e 's/max_connections = 100/max_connections = 1000/' /var/lib/postgresql/data/postgresql.conf + sed -i -e 's/shared_buffers = 128MB/shared_buffers = 512MB/' /var/lib/postgresql/data/postgresql.conf + EOF + docker restart --time 0 postgres + sleep 5 + + - name: Show max connections + run: env PGPASSWORD=password psql -h localhost -U username -c 'SHOW max_connections;' testsuite + - name: Run postgres env: DATABASE_URL: "postgresql://username:password@localhost:5432/testsuite" diff --git a/benchmarks/conftest.py b/benchmarks/conftest.py index 99cb8438e..083ab4c5c 100644 --- a/benchmarks/conftest.py +++ b/benchmarks/conftest.py @@ -3,28 +3,20 @@ import string import time -import databases import nest_asyncio import ormar import pytest import pytest_asyncio -import sqlalchemy -from tests.settings import DATABASE_URL -nest_asyncio.apply() +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() +base_ormar_config = create_config() +nest_asyncio.apply() pytestmark = pytest.mark.asyncio -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) - - class Author(ormar.Model): ormar_config = base_ormar_config.copy(tablename="authors") @@ -57,13 +49,7 @@ class Book(ormar.Model): year: int = ormar.Integer(nullable=True) -@pytest.fixture(autouse=True, scope="function") # TODO: fix this to be module -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config, scope="function") @pytest_asyncio.fixture diff --git a/docs/fastapi/index.md b/docs/fastapi/index.md index fe5b8abd3..25521a73c 100644 --- a/docs/fastapi/index.md +++ b/docs/fastapi/index.md @@ -26,7 +26,8 @@ Here you can find a very simple sample application code. ### Imports and initialization -First take care of the imports and initialization +Define startup and shutdown procedures using FastAPI lifespan and use is in the +application. ```python from typing import List, Optional @@ -36,29 +37,29 @@ from fastapi import FastAPI import ormar -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database("sqlite:///test.db") -app.state.database = database -``` +from contextlib import asynccontextmanager +from fastapi import FastAPI -### Database connection -Next define startup and shutdown events (or use middleware) -- note that this is `databases` specific setting not the ormar one -```python -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() +@asynccontextmanager +async def lifespan(_: FastAPI) -> AsyncIterator[None]: + if not config.database.is_connected: + await config.database.connect() + config.metadata.drop_all(config.engine) + config.metadata.create_all(config.engine) + + yield + + if config.database.is_connected: + config.metadata.drop_all(config.engine) + await config.database.disconnect() + + +base_ormar_config = ormar.OrmarConfig( + metadata=sqlalchemy.MetaData(), + database=databases.Database("sqlite:///test.db"), +) +app = FastAPI(lifespan=lifespan(base_ormar_config)) ``` !!!info diff --git a/docs_src/fastapi/docs001.py b/docs_src/fastapi/docs001.py index b19912691..68ee72dba 100644 --- a/docs_src/fastapi/docs001.py +++ b/docs_src/fastapi/docs001.py @@ -1,45 +1,25 @@ from typing import List, Optional -import databases import ormar -import sqlalchemy from fastapi import FastAPI -DATABASE_URL = "sqlite:///test.db" +from tests.settings import create_config +from tests.lifespan import lifespan -ormar_base_config = ormar.OrmarConfig( - database=databases.Database(DATABASE_URL), metadata=sqlalchemy.MetaData() -) -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database("sqlite:///test.db") -app.state.database = database - - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class Category(ormar.Model): - ormar_config = ormar_base_config.copy(tablename="categories") + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Item(ormar.Model): - ormar_config = ormar_base_config.copy(tablename="items") + ormar_config = base_ormar_config.copy(tablename="items") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) diff --git a/examples/fastapi_quick_start.py b/examples/fastapi_quick_start.py index 8757cfb04..5883b8e8a 100644 --- a/examples/fastapi_quick_start.py +++ b/examples/fastapi_quick_start.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from typing import List, Optional import databases @@ -13,25 +14,21 @@ ) -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database("sqlite:///test.db") -app.state.database = database - - -@app.on_event("startup") -async def startup() -> None: +@asynccontextmanager +async def lifespan(app: FastAPI): database_ = app.state.database if not database_.is_connected: await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: + yield database_ = app.state.database if database_.is_connected: await database_.disconnect() +app = FastAPI(lifespan=lifespan) +metadata = sqlalchemy.MetaData() +database = databases.Database("sqlite:///test.db") +app.state.database = database + class Category(ormar.Model): ormar_config = ormar_base_config.copy(tablename="categories") diff --git a/ormar/models/ormar_config.py b/ormar/models/ormar_config.py index c2a825415..e1895377e 100644 --- a/ormar/models/ormar_config.py +++ b/ormar/models/ormar_config.py @@ -28,6 +28,7 @@ def __init__( self, metadata: Optional[sqlalchemy.MetaData] = None, database: Optional[databases.Database] = None, + engine: Optional[sqlalchemy.engine.Engine] = None, tablename: Optional[str] = None, order_by: Optional[List[str]] = None, abstract: bool = False, @@ -39,6 +40,7 @@ def __init__( self.pkname = None # type: ignore self.metadata = metadata self.database = database # type: ignore + self.engine = engine # type: ignore self.tablename = tablename # type: ignore self.orders_by = order_by or [] self.columns: List[sqlalchemy.Column] = [] @@ -60,6 +62,7 @@ def copy( self, metadata: Optional[sqlalchemy.MetaData] = None, database: Optional[databases.Database] = None, + engine: Optional[sqlalchemy.engine.Engine] = None, tablename: Optional[str] = None, order_by: Optional[List[str]] = None, abstract: Optional[bool] = None, @@ -71,6 +74,7 @@ def copy( return OrmarConfig( metadata=metadata or self.metadata, database=database or self.database, + engine=engine or self.engine, tablename=tablename, order_by=order_by, abstract=abstract or self.abstract, diff --git a/tests/lifespan.py b/tests/lifespan.py new file mode 100644 index 000000000..684f358fd --- /dev/null +++ b/tests/lifespan.py @@ -0,0 +1,33 @@ +import pytest +import sqlalchemy + +from contextlib import asynccontextmanager +from fastapi import FastAPI +from typing import AsyncIterator + + +def lifespan(config): + @asynccontextmanager + async def do_lifespan(_: FastAPI) -> AsyncIterator[None]: + if not config.database.is_connected: + await config.database.connect() + + yield + + if config.database.is_connected: + await config.database.disconnect() + + return do_lifespan + + +def init_tests(config, scope="module"): + @pytest.fixture(autouse=True, scope=scope) + def create_database(): + config.engine = sqlalchemy.create_engine(config.database.url._url) + config.metadata.create_all(config.engine) + + yield + + config.metadata.drop_all(config.engine) + + return create_database diff --git a/tests/settings.py b/tests/settings.py index be1bed2a8..b2b7c9e19 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,9 +1,23 @@ import os import databases +import ormar +import sqlalchemy DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") database_url = databases.DatabaseURL(DATABASE_URL) if database_url.scheme == "postgresql+aiopg": # pragma no cover DATABASE_URL = str(database_url.replace(driver=None)) print("USED DB:", DATABASE_URL) + + +def create_config(**args): + database_ = databases.Database(DATABASE_URL, **args) + metadata_ = sqlalchemy.MetaData() + engine_ = sqlalchemy.create_engine(DATABASE_URL) + + return ormar.OrmarConfig( + metadata=metadata_, + database=database_, + engine=engine_, + ) diff --git a/tests/test_deferred/test_forward_cross_refs.py b/tests/test_deferred/test_forward_cross_refs.py index 3e1aa747f..56af4d776 100644 --- a/tests/test_deferred/test_forward_cross_refs.py +++ b/tests/test_deferred/test_forward_cross_refs.py @@ -1,25 +1,17 @@ # type: ignore from typing import ForwardRef, List, Optional -import databases import ormar import pytest -import sqlalchemy as sa -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sa.MetaData() -db = databases.Database(DATABASE_URL) -engine = create_engine(DATABASE_URL) -TeacherRef = ForwardRef("Teacher") +base_ormar_config = create_config() -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, -) +TeacherRef = ForwardRef("Teacher") class Student(ormar.Model): @@ -76,17 +68,13 @@ class City(ormar.Model): Country.update_forward_refs() -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_double_relations(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): t1 = await Teacher.objects.create(name="Mr. Jones") t2 = await Teacher.objects.create(name="Ms. Smith") t3 = await Teacher.objects.create(name="Mr. Quibble") @@ -143,8 +131,8 @@ async def test_double_relations(): @pytest.mark.asyncio async def test_auto_through_model(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): england = await Country(name="England").save() france = await Country(name="France").save() london = await City(name="London", country=england).save() diff --git a/tests/test_deferred/test_forward_refs.py b/tests/test_deferred/test_forward_refs.py index 0d2d31981..7842d37af 100644 --- a/tests/test_deferred/test_forward_refs.py +++ b/tests/test_deferred/test_forward_refs.py @@ -1,28 +1,24 @@ # type: ignore from typing import ForwardRef, List, Optional -import databases import ormar import pytest import pytest_asyncio import sqlalchemy as sa from ormar.exceptions import ModelError -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests + + +base_ormar_config = create_config() -metadata = sa.MetaData() -db = databases.Database(DATABASE_URL) -engine = create_engine(DATABASE_URL) PersonRef = ForwardRef("Person") class Person(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -37,10 +33,7 @@ class Person(ormar.Model): class Child(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -54,17 +47,11 @@ class Child(ormar.Model): class ChildFriend(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() class Game(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -73,17 +60,13 @@ class Game(ormar.Model): Child.update_forward_refs() -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(scope="function") async def cleanup(): yield - async with db: + async with base_ormar_config.database: await ChildFriend.objects.delete(each=True) await Child.objects.delete(each=True) await Game.objects.delete(each=True) @@ -95,10 +78,7 @@ async def test_not_updated_model_raises_errors(): Person2Ref = ForwardRef("Person2") class Person2(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -119,16 +99,10 @@ async def test_not_updated_model_m2m_raises_errors(): Person3Ref = ForwardRef("Person3") class PersonFriend(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() class Person3(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -151,19 +125,13 @@ async def test_not_updated_model_m2m_through_raises_errors(): PersonPetRef = ForwardRef("PersonPet") class Pet(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Person4(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -172,10 +140,7 @@ class Person4(ormar.Model): ) class PersonPet(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() with pytest.raises(ModelError): await Person4.objects.create(name="Test") @@ -205,8 +170,8 @@ def test_proper_field_init(): @pytest.mark.asyncio async def test_self_relation(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): sam = await Person.objects.create(name="Sam") joe = await Person(name="Joe", supervisor=sam).save() assert joe.supervisor.name == "Sam" @@ -223,8 +188,8 @@ async def test_self_relation(): @pytest.mark.asyncio async def test_other_forwardref_relation(cleanup): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): checkers = await Game.objects.create(name="checkers") uno = await Game(name="Uno").save() @@ -250,8 +215,8 @@ async def test_other_forwardref_relation(cleanup): @pytest.mark.asyncio async def test_m2m_self_forwardref_relation(cleanup): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): checkers = await Game.objects.create(name="Checkers") uno = await Game(name="Uno").save() jenga = await Game(name="Jenga").save() diff --git a/tests/test_deferred/test_more_same_table_joins.py b/tests/test_deferred/test_more_same_table_joins.py index 38c73bd38..5cebb79d6 100644 --- a/tests/test_deferred/test_more_same_table_joins.py +++ b/tests/test_deferred/test_more_same_table_joins.py @@ -1,44 +1,31 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config(force_rollback=True) class Department(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="departments", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="departments") id: int = ormar.Integer(primary_key=True, autoincrement=False) name: str = ormar.String(max_length=100) class SchoolClass(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="schoolclasses", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="schoolclasses") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -46,11 +33,7 @@ class Category(ormar.Model): class Student(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="students", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="students") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -59,11 +42,7 @@ class Student(ormar.Model): class Teacher(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="teachers", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="teachers") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -71,13 +50,7 @@ class Teacher(ormar.Model): category: Optional[Category] = ormar.ForeignKey(Category, nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) async def create_data(): @@ -95,7 +68,7 @@ async def create_data(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): - async with database: + async with base_ormar_config.database: await create_data() classes = await SchoolClass.objects.select_related( ["teachers__category__department", "students__category__department"] @@ -109,7 +82,7 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio async def test_load_all_multiple_instances_of_same_table_in_schema(): - async with database: + async with base_ormar_config.database: await create_data() math_class = await SchoolClass.objects.get(name="Math") assert math_class.name == "Math" @@ -123,7 +96,7 @@ async def test_load_all_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio async def test_filter_groups_with_instances_of_same_table_in_schema(): - async with database: + async with base_ormar_config.database: await create_data() math_class = ( await SchoolClass.objects.select_related( diff --git a/tests/test_deferred/test_same_table_joins.py b/tests/test_deferred/test_same_table_joins.py index 17667db9c..548fded6b 100644 --- a/tests/test_deferred/test_same_table_joins.py +++ b/tests/test_deferred/test_same_table_joins.py @@ -1,33 +1,24 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Department(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="departments", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="departments") id: int = ormar.Integer(primary_key=True, autoincrement=False) name: str = ormar.String(max_length=100) class SchoolClass(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="schoolclasses", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="schoolclasses") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -35,22 +26,14 @@ class SchoolClass(ormar.Model): class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Student(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="students", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="students") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -59,11 +42,7 @@ class Student(ormar.Model): class Teacher(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="teachers", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="teachers") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -71,13 +50,7 @@ class Teacher(ormar.Model): category: Optional[Category] = ormar.ForeignKey(Category, nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) async def create_data(): @@ -95,8 +68,8 @@ async def create_data(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await create_data() classes = await SchoolClass.objects.select_related( ["teachers__category", "students__schoolclass"] @@ -123,8 +96,8 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio async def test_right_tables_join(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await create_data() classes = await SchoolClass.objects.select_related( ["teachers__category", "students"] @@ -138,8 +111,8 @@ async def test_right_tables_join(): @pytest.mark.asyncio async def test_multiple_reverse_related_objects(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await create_data() classes = await SchoolClass.objects.select_related( ["teachers__category", "students__category"] diff --git a/tests/test_encryption/test_encrypted_columns.py b/tests/test_encryption/test_encrypted_columns.py index 65f341ec9..589886fd8 100644 --- a/tests/test_encryption/test_encrypted_columns.py +++ b/tests/test_encryption/test_encrypted_columns.py @@ -6,24 +6,16 @@ import uuid from typing import Any -import databases import ormar import pytest -import sqlalchemy from ormar import ModelDefinitionError, NoMatch from ormar.fields.sqlalchemy_encrypted import EncryptedString -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() default_fernet = dict( encrypt_secret="asd123", encrypt_backend=ormar.EncryptBackends.FERNET ) @@ -108,13 +100,7 @@ class Report(ormar.Model): filters = ormar.ManyToMany(Filter) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_error_on_encrypted_pk(): @@ -178,7 +164,7 @@ def test_db_structure(): @pytest.mark.asyncio async def test_save_and_retrieve(): - async with database: + async with base_ormar_config.database: test_uuid = uuid.uuid4() await Author( name="Test", @@ -222,7 +208,7 @@ async def test_save_and_retrieve(): @pytest.mark.asyncio async def test_fernet_filters_nomatch(): - async with database: + async with base_ormar_config.database: await Filter(name="test1").save() await Filter(name="test1").save() @@ -237,7 +223,7 @@ async def test_fernet_filters_nomatch(): @pytest.mark.asyncio async def test_hash_filters_works(): - async with database: + async with base_ormar_config.database: await Hash(name="test1").save() await Hash(name="test2").save() @@ -254,7 +240,7 @@ async def test_hash_filters_works(): @pytest.mark.asyncio async def test_related_model_fields_properly_decrypted(): - async with database: + async with base_ormar_config.database: hash1 = await Hash(name="test1").save() report = await Report.objects.create(name="Report1") await report.filters.create(name="test1", hash=hash1) diff --git a/tests/test_exclude_include_dict/test_complex_relation_tree_performance.py b/tests/test_exclude_include_dict/test_complex_relation_tree_performance.py index 2cc5a076c..a70c74c58 100644 --- a/tests/test_exclude_include_dict/test_complex_relation_tree_performance.py +++ b/tests/test_exclude_include_dict/test_complex_relation_tree_performance.py @@ -7,16 +7,10 @@ import pytest import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - - -base_ormar_config = orm.OrmarConfig( - database=database, - metadata=metadata, -) +base_ormar_config = create_config() class ChagenlogRelease(orm.Model): @@ -304,18 +298,12 @@ class Webhook(orm.Model): error: str = orm.Text(default="") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_very_complex_relation_map(): - async with database: + async with base_ormar_config.database: tags = [ {"id": 18, "name": "name-18", "ref": "ref-18"}, {"id": 17, "name": "name-17", "ref": "ref-17"}, diff --git a/tests/test_exclude_include_dict/test_dumping_model_to_dict.py b/tests/test_exclude_include_dict/test_dumping_model_to_dict.py index 4a5eb0108..060b531c5 100644 --- a/tests/test_exclude_include_dict/test_dumping_model_to_dict.py +++ b/tests/test_exclude_include_dict/test_dumping_model_to_dict.py @@ -1,20 +1,13 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Role(ormar.Model): @@ -58,6 +51,9 @@ class Item(ormar.Model): created_by: Optional[User] = ormar.ForeignKey(User) +create_test_database = init_tests(base_ormar_config) + + @pytest.fixture(autouse=True, scope="module") def sample_data(): role = Role(name="User", id=1) diff --git a/tests/test_exclude_include_dict/test_excludable_items.py b/tests/test_exclude_include_dict/test_excludable_items.py index c93f6a691..e1fa142ff 100644 --- a/tests/test_exclude_include_dict/test_excludable_items.py +++ b/tests/test_exclude_include_dict/test_excludable_items.py @@ -1,20 +1,12 @@ from typing import List, Optional -import databases import ormar -import sqlalchemy from ormar.models.excludable import ExcludableItems -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class NickNames(ormar.Model): @@ -58,6 +50,9 @@ class Car(ormar.Model): aircon_type: str = ormar.String(max_length=20, nullable=True) +create_test_database = init_tests(base_ormar_config) + + def compare_results(excludable): car_excludable = excludable.get(Car) assert car_excludable.exclude == {"year", "gearbox_type", "gears", "aircon_type"} diff --git a/tests/test_exclude_include_dict/test_excluding_fields_in_fastapi.py b/tests/test_exclude_include_dict/test_excluding_fields_in_fastapi.py index 00bfe8ed8..50120301e 100644 --- a/tests/test_exclude_include_dict/test_excluding_fields_in_fastapi.py +++ b/tests/test_exclude_include_dict/test_excluding_fields_in_fastapi.py @@ -13,26 +13,12 @@ from ormar import post_save from pydantic import ConfigDict, computed_field -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -app.state.database = database - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) # note that you can set orm_mode here @@ -60,11 +46,7 @@ def gen_pass(): class RandomModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="random_users", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="random_users") id: int = ormar.Integer(primary_key=True) password: str = ormar.String(max_length=255, default=gen_pass) @@ -80,11 +62,7 @@ def full_name(self) -> str: class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users") id: int = ormar.Integer(primary_key=True) email: str = ormar.String(max_length=255) @@ -95,11 +73,7 @@ class User(ormar.Model): class User2(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users2", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users2") id: int = ormar.Integer(primary_key=True) email: str = ormar.String(max_length=255, nullable=False) @@ -110,12 +84,7 @@ class User2(ormar.Model): timestamp: datetime.datetime = pydantic.Field(default=datetime.datetime.now) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @app.post("/users/", response_model=User, response_model_exclude={"password"}) diff --git a/tests/test_exclude_include_dict/test_excluding_fields_with_default.py b/tests/test_exclude_include_dict/test_excluding_fields_with_default.py index b754672a9..23c5f87be 100644 --- a/tests/test_exclude_include_dict/test_excluding_fields_with_default.py +++ b/tests/test_exclude_include_dict/test_excluding_fields_with_default.py @@ -1,15 +1,14 @@ import random from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() def get_position() -> int: @@ -17,11 +16,7 @@ def get_position() -> int: class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -29,11 +24,7 @@ class Album(ormar.Model): class Track(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tracks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tracks") id: int = ormar.Integer(primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album) @@ -42,19 +33,13 @@ class Track(ormar.Model): play_count: int = ormar.Integer(nullable=True, default=0) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_excluding_field_with_default(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): album = await Album.objects.create(name="Miami") await Track.objects.create(title="Vice City", album=album, play_count=10) await Track.objects.create(title="Beach Sand", album=album, play_count=20) diff --git a/tests/test_exclude_include_dict/test_excluding_subset_of_columns.py b/tests/test_exclude_include_dict/test_excluding_subset_of_columns.py index 3d13942b5..d3f3aa845 100644 --- a/tests/test_exclude_include_dict/test_excluding_subset_of_columns.py +++ b/tests/test_exclude_include_dict/test_excluding_subset_of_columns.py @@ -1,23 +1,18 @@ from typing import Optional -import databases import ormar import pydantic import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Company(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="companies", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="companies") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False) @@ -25,11 +20,7 @@ class Company(ormar.Model): class Car(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="cars", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="cars") id: int = ormar.Integer(primary_key=True) manufacturer: Optional[Company] = ormar.ForeignKey(Company) @@ -40,19 +31,13 @@ class Car(ormar.Model): aircon_type: str = ormar.String(max_length=20, nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_selecting_subset(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): toyota = await Company.objects.create(name="Toyota", founded=1937) await Car.objects.create( manufacturer=toyota, @@ -184,8 +169,8 @@ async def test_selecting_subset(): @pytest.mark.asyncio async def test_excluding_nested_lists_in_dump(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): toyota = await Company.objects.create(name="Toyota", founded=1937) car1 = await Car.objects.create( manufacturer=toyota, diff --git a/tests/test_exclude_include_dict/test_pydantic_dict_params.py b/tests/test_exclude_include_dict/test_pydantic_dict_params.py index a58bdc635..a16d97e3a 100644 --- a/tests/test_exclude_include_dict/test_pydantic_dict_params.py +++ b/tests/test_exclude_include_dict/test_pydantic_dict_params.py @@ -1,22 +1,17 @@ from typing import List -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) + +base_ormar_config = create_config() class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, default="Test", nullable=True) @@ -24,11 +19,7 @@ class Category(ormar.Model): class Item(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="items", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="items") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -36,17 +27,12 @@ class Item(ormar.Model): categories: List[Category] = ormar.ManyToMany(Category) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_exclude_default(): - async with database: + async with base_ormar_config.database: category = Category() assert category.model_dump() == { "id": None, @@ -70,7 +56,7 @@ async def test_exclude_default(): @pytest.mark.asyncio async def test_exclude_none(): - async with database: + async with base_ormar_config.database: category = Category(id=2, name=None) assert category.model_dump() == { "id": 2, @@ -105,7 +91,7 @@ async def test_exclude_none(): @pytest.mark.asyncio async def test_exclude_unset(): - async with database: + async with base_ormar_config.database: category = Category(id=3, name="Test 2") assert category.model_dump() == { "id": 3, diff --git a/tests/test_fastapi/test_binary_fields.py b/tests/test_fastapi/test_binary_fields.py index cd387f2a2..3142cbd62 100644 --- a/tests/test_fastapi/test_binary_fields.py +++ b/tests/test_fastapi/test_binary_fields.py @@ -1,35 +1,21 @@ import base64 import uuid -from contextlib import asynccontextmanager from enum import Enum -from typing import AsyncIterator, List +from typing import List -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import lifespan, init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() headers = {"content-type": "application/json"} - - -@asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncIterator[None]: - if not database.is_connected: - await database.connect() - yield - if database.is_connected: - await database.disconnect() - - -app = FastAPI(lifespan=lifespan) +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) blob3 = b"\xc3\x83\x28" @@ -38,12 +24,6 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: blob6 = b"\xff" -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) - - class BinaryEnum(Enum): blob3 = blob3 blob4 = blob4 @@ -59,6 +39,9 @@ class BinaryThing(ormar.Model): bt: str = ormar.LargeBinary(represent_as_base64_str=True, max_length=100) +create_test_database = init_tests(base_ormar_config) + + @app.get("/things", response_model=List[BinaryThing]) async def read_things(): return await BinaryThing.objects.order_by("name").all() @@ -70,14 +53,6 @@ async def create_things(thing: BinaryThing): return thing -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) - - @pytest.mark.asyncio async def test_read_main(): client = AsyncClient(app=app, base_url="http://testserver") diff --git a/tests/test_fastapi/test_docs_with_multiple_relations_to_one.py b/tests/test_fastapi/test_docs_with_multiple_relations_to_one.py index c7f285fa5..943653f77 100644 --- a/tests/test_fastapi/test_docs_with_multiple_relations_to_one.py +++ b/tests/test_fastapi/test_docs_with_multiple_relations_to_one.py @@ -1,24 +1,18 @@ from typing import Optional from uuid import UUID, uuid4 -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -app = FastAPI() -DATABASE_URL = "sqlite:///db.sqlite" -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() +from tests.settings import create_config +from tests.lifespan import lifespan, init_tests -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class CA(ormar.Model): @@ -44,6 +38,9 @@ class CB2(ormar.Model): ca2: Optional[CA] = ormar.ForeignKey(CA, nullable=True) +create_test_database = init_tests(base_ormar_config) + + @app.get("/ca", response_model=CA) async def get_ca(): # pragma: no cover return None diff --git a/tests/test_fastapi/test_enum_schema.py b/tests/test_fastapi/test_enum_schema.py index fa1ac6b5b..a6431128c 100644 --- a/tests/test_fastapi/test_enum_schema.py +++ b/tests/test_fastapi/test_enum_schema.py @@ -1,13 +1,12 @@ from enum import Enum -import databases import ormar -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class MyEnum(Enum): @@ -16,16 +15,15 @@ class MyEnum(Enum): class EnumExample(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="enum_example", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="enum_example") id: int = ormar.Integer(primary_key=True) size: MyEnum = ormar.Enum(enum_class=MyEnum, default=MyEnum.SMALL) +create_test_database = init_tests(base_ormar_config) + + def test_proper_schema(): schema = EnumExample.model_json_schema() assert {"MyEnum": {"title": "MyEnum", "enum": [1, 2], "type": "integer"}} == schema[ diff --git a/tests/test_fastapi/test_excludes_with_get_pydantic.py b/tests/test_fastapi/test_excludes_with_get_pydantic.py index 0db526ced..a06a6e365 100644 --- a/tests/test_fastapi/test_excludes_with_get_pydantic.py +++ b/tests/test_fastapi/test_excludes_with_get_pydantic.py @@ -1,41 +1,45 @@ +import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient +from typing import ForwardRef, Optional + +from tests.settings import create_config +from tests.lifespan import lifespan, init_tests + + +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) + + +class SelfRef(ormar.Model): + ormar_config = base_ormar_config.copy(tablename="self_refs") + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, default="selfref") + parent = ormar.ForeignKey(ForwardRef("SelfRef"), related_name="children") + + +SelfRef.update_forward_refs() -from tests.settings import DATABASE_URL -from tests.test_inheritance_and_pydantic_generation.test_geting_pydantic_models import ( - Category, - SelfRef, - database, - metadata, -) -app = FastAPI() -app.state.database = database +class Category(ormar.Model): + ormar_config = base_ormar_config.copy(tablename="categories") + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() +class Item(ormar.Model): + ormar_config = base_ormar_config.copy() -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, default="test") + category: Optional[Category] = ormar.ForeignKey(Category, nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) async def create_category(category: Category): diff --git a/tests/test_fastapi/test_excluding_fields.py b/tests/test_fastapi/test_excluding_fields.py index 6a11e7b53..6c33a1711 100644 --- a/tests/test_fastapi/test_excluding_fields.py +++ b/tests/test_fastapi/test_excluding_fields.py @@ -1,64 +1,35 @@ from typing import List -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -app.state.database = database - - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Item(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="items", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="items") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) categories: List[Category] = ormar.ManyToMany(Category) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) + @app.post("/items/", response_model=Item) diff --git a/tests/test_fastapi/test_extra_ignore_parameter.py b/tests/test_fastapi/test_extra_ignore_parameter.py index 1d1447a99..71b1219db 100644 --- a/tests/test_fastapi/test_extra_ignore_parameter.py +++ b/tests/test_fastapi/test_extra_ignore_parameter.py @@ -1,51 +1,26 @@ -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient from ormar import Extra -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -app.state.database = database - - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class Item(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - extra=Extra.ignore, - ) + ormar_config = base_ormar_config.copy(extra=Extra.ignore) id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) + @app.post("/item/", response_model=Item) diff --git a/tests/test_fastapi/test_fastapi_docs.py b/tests/test_fastapi/test_fastapi_docs.py index a3735c967..698a4ae35 100644 --- a/tests/test_fastapi/test_fastapi_docs.py +++ b/tests/test_fastapi/test_fastapi_docs.py @@ -1,42 +1,20 @@ import datetime from typing import List, Optional, Union -import databases import ormar import pydantic import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient from pydantic import Field -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -app.state.database = database - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() - - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class PTestA(pydantic.BaseModel): @@ -68,12 +46,8 @@ class Item(ormar.Model): categories = ormar.ManyToMany(Category) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) + @app.get("/items/", response_model=List[Item]) diff --git a/tests/test_fastapi/test_fastapi_usage.py b/tests/test_fastapi/test_fastapi_usage.py index d276d4c36..ed2f9046e 100644 --- a/tests/test_fastapi/test_fastapi_usage.py +++ b/tests/test_fastapi/test_fastapi_usage.py @@ -1,44 +1,38 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import lifespan, init_tests -app = FastAPI() -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Item(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="items", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="items") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) category: Optional[Category] = ormar.ForeignKey(Category, nullable=True) +create_test_database = init_tests(base_ormar_config) + + + @app.post("/items/", response_model=Item) async def create_item(item: Item): return item diff --git a/tests/test_fastapi/test_inheritance_concrete_fastapi.py b/tests/test_fastapi/test_inheritance_concrete_fastapi.py index c8f4251f0..3504a4989 100644 --- a/tests/test_fastapi/test_inheritance_concrete_fastapi.py +++ b/tests/test_fastapi/test_inheritance_concrete_fastapi.py @@ -1,43 +1,141 @@ import datetime -from typing import List +from typing import List, Optional import pytest -import sqlalchemy +import ormar from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient +from ormar.relations.relation_proxy import RelationProxy +from pydantic import computed_field -from tests.settings import DATABASE_URL -from tests.test_inheritance_and_pydantic_generation.test_inheritance_concrete import ( # noqa: E501 - Bus, - Bus2, - Category, - Person, - Subject, - Truck, - Truck2, - metadata, -) -from tests.test_inheritance_and_pydantic_generation.test_inheritance_concrete import ( - db as database, -) +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -app.state.database = database +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() +class AuditModel(ormar.Model): + ormar_config = base_ormar_config.copy(abstract=True) + + created_by: str = ormar.String(max_length=100) + updated_by: str = ormar.String(max_length=100, default="Sam") + + @computed_field + def audit(self) -> str: # pragma: no cover + return f"{self.created_by} {self.updated_by}" + + +class DateFieldsModelNoSubclass(ormar.Model): + ormar_config = base_ormar_config.copy(tablename="test_date_models") + + date_id: int = ormar.Integer(primary_key=True) + created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + + +class DateFieldsModel(ormar.Model): + ormar_config = base_ormar_config.copy( + abstract=True, + constraints=[ + ormar.fields.constraints.UniqueColumns( + "creation_date", + "modification_date", + ), + ormar.fields.constraints.CheckColumns( + "creation_date <= modification_date", + ), + ], + ) + + created_date: datetime.datetime = ormar.DateTime( + default=datetime.datetime.now, name="creation_date" + ) + updated_date: datetime.datetime = ormar.DateTime( + default=datetime.datetime.now, name="modification_date" + ) + + +class Person(ormar.Model): + ormar_config = base_ormar_config.copy() + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Car(ormar.Model): + ormar_config = base_ormar_config.copy(abstract=True) + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50) + owner: Person = ormar.ForeignKey(Person) + co_owner: Person = ormar.ForeignKey(Person, related_name="coowned") + created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + + +class Car2(ormar.Model): + ormar_config = base_ormar_config.copy(abstract=True) + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50) + owner: Person = ormar.ForeignKey(Person, related_name="owned") + co_owners: RelationProxy[Person] = ormar.ManyToMany(Person, related_name="coowned") + created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + + +class Bus(Car): + ormar_config = base_ormar_config.copy(tablename="buses") + + owner: Person = ormar.ForeignKey(Person, related_name="buses") + max_persons: int = ormar.Integer() + + +class Bus2(Car2): + ormar_config = base_ormar_config.copy(tablename="buses2") + + max_persons: int = ormar.Integer() + + +class Category(DateFieldsModel, AuditModel): + ormar_config = base_ormar_config.copy(tablename="categories") + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + code: int = ormar.Integer() + + @computed_field + def code_name(self) -> str: + return f"{self.code}:{self.name}" + + @computed_field + def audit(self) -> str: + return f"{self.created_by} {self.updated_by}" + + +class Subject(DateFieldsModel): + ormar_config = base_ormar_config.copy() + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + category: Optional[Category] = ormar.ForeignKey(Category) + + +class Truck(Car): + ormar_config = base_ormar_config.copy() + + max_capacity: int = ormar.Integer() + + +class Truck2(Car2): + ormar_config = base_ormar_config.copy(tablename="trucks2") + + max_capacity: int = ormar.Integer() + + +create_test_database = init_tests(base_ormar_config) -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() @app.post("/subjects/", response_model=Subject) @@ -113,14 +211,6 @@ async def add_truck_coowner(item_id: int, person: Person): return truck -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) - - @pytest.mark.asyncio async def test_read_main(): client = AsyncClient(app=app, base_url="http://testserver") diff --git a/tests/test_fastapi/test_inheritance_mixins_fastapi.py b/tests/test_fastapi/test_inheritance_mixins_fastapi.py index 20223c038..fa7289079 100644 --- a/tests/test_fastapi/test_inheritance_mixins_fastapi.py +++ b/tests/test_fastapi/test_inheritance_mixins_fastapi.py @@ -1,37 +1,48 @@ import datetime import pytest -import sqlalchemy +import ormar from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient +from typing import Optional -from tests.settings import DATABASE_URL -from tests.test_inheritance_and_pydantic_generation.test_inheritance_mixins import ( # noqa: E501 - Category, - Subject, - metadata, -) -from tests.test_inheritance_and_pydantic_generation.test_inheritance_mixins import ( - db as database, -) +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -app.state.database = database +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() +class AuditMixin: + created_by: str = ormar.String(max_length=100) + updated_by: str = ormar.String(max_length=100, default="Sam") + + +class DateFieldsMixins: + created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) + + +class Category(ormar.Model, DateFieldsMixins, AuditMixin): + ormar_config = base_ormar_config.copy(tablename="categories") + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + code: int = ormar.Integer() + + +class Subject(ormar.Model, DateFieldsMixins): + ormar_config = base_ormar_config.copy(tablename="subjects") + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50, unique=True, index=True) + category: Optional[Category] = ormar.ForeignKey(Category) + + +create_test_database = init_tests(base_ormar_config) -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() @app.post("/subjects/", response_model=Subject) @@ -45,14 +56,6 @@ async def create_category(category: Category): return category -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) - - @pytest.mark.asyncio async def test_read_main(): client = AsyncClient(app=app, base_url="http://testserver") diff --git a/tests/test_fastapi/test_json_field_fastapi.py b/tests/test_fastapi/test_json_field_fastapi.py index ff11f458c..394e23e38 100644 --- a/tests/test_fastapi/test_json_field_fastapi.py +++ b/tests/test_fastapi/test_json_field_fastapi.py @@ -2,42 +2,19 @@ import uuid from typing import List -import databases import ormar import pydantic import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() -app.state.database = database - - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() - - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class Thing(ormar.Model): @@ -48,6 +25,9 @@ class Thing(ormar.Model): js: pydantic.Json = ormar.JSON() +create_test_database = init_tests(base_ormar_config) + + @app.get("/things", response_model=List[Thing]) async def read_things(): return await Thing.objects.order_by("name").all() @@ -87,14 +67,6 @@ async def read_things_untyped(): return await Thing.objects.order_by("name").all() -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) - - @pytest.mark.asyncio async def test_json_is_required_if_not_nullable(): with pytest.raises(pydantic.ValidationError): @@ -115,7 +87,7 @@ class Thing2(ormar.Model): @pytest.mark.asyncio async def test_setting_values_after_init(): - async with database: + async with base_ormar_config.database: t1 = Thing(id="67a82813-d90c-45ff-b546-b4e38d7030d7", name="t1", js=["thing1"]) assert '["thing1"]' in t1.model_dump_json() await t1.save() diff --git a/tests/test_fastapi/test_m2m_forwardref.py b/tests/test_fastapi/test_m2m_forwardref.py index 9950890a2..4d5e5f1f7 100644 --- a/tests/test_fastapi/test_m2m_forwardref.py +++ b/tests/test_fastapi/test_m2m_forwardref.py @@ -1,41 +1,18 @@ from typing import ForwardRef, List, Optional -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient from starlette import status -app = FastAPI() -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() -app.state.database = database - - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() - - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) CityRef = ForwardRef("City") @@ -74,12 +51,8 @@ class City(ormar.Model): Country.update_forward_refs() -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) + @app.post("/", response_model=Country, status_code=status.HTTP_201_CREATED) diff --git a/tests/test_fastapi/test_more_reallife_fastapi.py b/tests/test_fastapi/test_more_reallife_fastapi.py index 0fcd01501..3540ad31c 100644 --- a/tests/test_fastapi/test_more_reallife_fastapi.py +++ b/tests/test_fastapi/test_more_reallife_fastapi.py @@ -1,64 +1,36 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -app.state.database = database - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Item(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="items", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="items") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) category: Optional[Category] = ormar.ForeignKey(Category, nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) + @app.get("/items", response_model=List[Item]) diff --git a/tests/test_fastapi/test_nested_saving.py b/tests/test_fastapi/test_nested_saving.py index 48d9144a8..e5827ca66 100644 --- a/tests/test_fastapi/test_nested_saving.py +++ b/tests/test_fastapi/test_nested_saving.py @@ -1,53 +1,30 @@ from typing import Any, Dict, Optional, Set, Type, Union, cast -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient from ormar.queryset.utils import translate_list_to_dict -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -app.state.database = database +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) headers = {"content-type": "application/json"} -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() - - class Department(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) department_name: str = ormar.String(max_length=100) class Course(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) course_name: str = ormar.String(max_length=100) @@ -56,25 +33,13 @@ class Course(ormar.Model): class Student(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) courses = ormar.ManyToMany(Course) -# create db and tables -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) - - to_exclude = { "id": ..., "courses": { @@ -90,6 +55,9 @@ def create_test_database(): } +create_test_database = init_tests(base_ormar_config) + + def auto_exclude_id_field(to_exclude: Any) -> Union[Dict, Set]: if isinstance(to_exclude, dict): for key in to_exclude.keys(): diff --git a/tests/test_fastapi/test_recursion_error.py b/tests/test_fastapi/test_recursion_error.py index 9eb34b07a..962d0a21e 100644 --- a/tests/test_fastapi/test_recursion_error.py +++ b/tests/test_fastapi/test_recursion_error.py @@ -2,39 +2,21 @@ from datetime import datetime from typing import List -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import Depends, FastAPI from httpx import AsyncClient from pydantic import BaseModel, Json -from tests.settings import DATABASE_URL - -router = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -router.state.database = database +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config +base_ormar_config = create_config() +router = FastAPI(lifespan=lifespan(base_ormar_config)) headers = {"content-type": "application/json"} -@router.on_event("startup") -async def startup() -> None: - database_ = router.state.database - if not database_.is_connected: - await database_.connect() - - -@router.on_event("shutdown") -async def shutdown() -> None: - database_ = router.state.database - if database_.is_connected: - await database_.disconnect() - - class User(ormar.Model): """ The user model @@ -48,11 +30,7 @@ class User(ormar.Model): verify_key: str = ormar.String(unique=True, max_length=100, nullable=True) created_at: datetime = ormar.DateTime(default=datetime.now()) - ormar_config = ormar.OrmarConfig( - tablename="users", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users") class UserSession(ormar.Model): @@ -65,11 +43,7 @@ class UserSession(ormar.Model): session_key: str = ormar.String(unique=True, max_length=64) created_at: datetime = ormar.DateTime(default=datetime.now()) - ormar_config = ormar.OrmarConfig( - tablename="user_sessions", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="user_sessions") class QuizAnswer(BaseModel): @@ -97,19 +71,11 @@ class Quiz(ormar.Model): user_id: uuid.UUID = ormar.UUID(foreign_key=User.id) questions: Json = ormar.JSON(nullable=False) - ormar_config = ormar.OrmarConfig( - tablename="quiz", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="quiz") + +create_test_database = init_tests(base_ormar_config) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) async def get_current_user(): diff --git a/tests/test_fastapi/test_relations_with_nested_defaults.py b/tests/test_fastapi/test_relations_with_nested_defaults.py index 139eb1958..6a7a4a8bd 100644 --- a/tests/test_fastapi/test_relations_with_nested_defaults.py +++ b/tests/test_fastapi/test_relations_with_nested_defaults.py @@ -1,41 +1,18 @@ from typing import Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() -app = FastAPI() -app.state.database = database - - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() - - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class Country(ormar.Model): @@ -63,17 +40,12 @@ class Book(ormar.Model): year: int = ormar.Integer(nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture async def sample_data(): - async with database: + async with base_ormar_config.database: country = await Country(id=1, name="USA").save() author = await Author(id=1, name="bug", rating=5, country=country).save() await Book( diff --git a/tests/test_fastapi/test_schema_not_allowed_params.py b/tests/test_fastapi/test_schema_not_allowed_params.py index 7384f5825..c5555802c 100644 --- a/tests/test_fastapi/test_schema_not_allowed_params.py +++ b/tests/test_fastapi/test_schema_not_allowed_params.py @@ -2,15 +2,11 @@ import ormar import sqlalchemy -DATABASE_URL = "sqlite:///db.sqlite" -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() +from tests.lifespan import init_tests +from tests.settings import create_config -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -21,6 +17,9 @@ class Author(ormar.Model): contents: str = ormar.Text() +create_test_database = init_tests(base_ormar_config) + + def test_schema_not_allowed(): schema = Author.model_json_schema() for field_schema in schema.get("properties").values(): diff --git a/tests/test_fastapi/test_skip_reverse_models.py b/tests/test_fastapi/test_skip_reverse_models.py index 4e879e2ff..5eef08573 100644 --- a/tests/test_fastapi/test_skip_reverse_models.py +++ b/tests/test_fastapi/test_skip_reverse_models.py @@ -1,43 +1,20 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -app.state.database = database +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) headers = {"content-type": "application/json"} -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() - - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) - - class Author(ormar.Model): ormar_config = base_ormar_config.copy() @@ -66,12 +43,8 @@ class Post(ormar.Model): author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) + @app.post("/categories/forbid/", response_model=Category2) diff --git a/tests/test_fastapi/test_wekref_exclusion.py b/tests/test_fastapi/test_wekref_exclusion.py index 83149e0cd..6002dc484 100644 --- a/tests/test_fastapi/test_wekref_exclusion.py +++ b/tests/test_fastapi/test_wekref_exclusion.py @@ -1,51 +1,19 @@ from typing import List, Optional from uuid import UUID, uuid4 -import databases import ormar import pydantic import pytest -import sqlalchemy from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient -from tests.settings import DATABASE_URL +from tests.lifespan import lifespan, init_tests +from tests.settings import create_config -app = FastAPI() -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -app.state.database = database - - -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() - - -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) - - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() +app = FastAPI(lifespan=lifespan(base_ormar_config)) class OtherThing(ormar.Model): @@ -65,6 +33,10 @@ class Thing(ormar.Model): other_thing: Optional[OtherThing] = ormar.ForeignKey(OtherThing, nullable=True) +create_test_database = init_tests(base_ormar_config) + + + @app.post("/test/1") async def post_test_1(): # don't split initialization and attribute assignment diff --git a/tests/test_inheritance_and_pydantic_generation/test_excluding_parent_fields_inheritance.py b/tests/test_inheritance_and_pydantic_generation/test_excluding_parent_fields_inheritance.py index e7ac3c29e..a50cd0977 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_excluding_parent_fields_inheritance.py +++ b/tests/test_inheritance_and_pydantic_generation/test_excluding_parent_fields_inheritance.py @@ -1,49 +1,38 @@ import datetime -import databases import ormar import pytest -import sqlalchemy as sa -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sa.MetaData() -db = databases.Database(DATABASE_URL) -engine = create_engine(DATABASE_URL) + +base_ormar_config = create_config() class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="users") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50, unique=True, index=True) class RelationalAuditModel(ormar.Model): - ormar_config = ormar.OrmarConfig(abstract=True) + ormar_config = base_ormar_config.copy(abstract=True) created_by: User = ormar.ForeignKey(User, nullable=False) updated_by: User = ormar.ForeignKey(User, nullable=False) class AuditModel(ormar.Model): - ormar_config = ormar.OrmarConfig(abstract=True) + ormar_config = base_ormar_config.copy(abstract=True) created_by: str = ormar.String(max_length=100) updated_by: str = ormar.String(max_length=100, default="Sam") class DateFieldsModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - abstract=True, - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(abstract=True) created_date: datetime.datetime = ormar.DateTime( default=datetime.datetime.now, name="creation_date" @@ -54,7 +43,7 @@ class DateFieldsModel(ormar.Model): class Category(DateFieldsModel, AuditModel): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( tablename="categories", exclude_parent_fields=["updated_by", "updated_date"], ) @@ -65,7 +54,7 @@ class Category(DateFieldsModel, AuditModel): class Item(DateFieldsModel, AuditModel): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( tablename="items", exclude_parent_fields=["updated_by", "updated_date"], ) @@ -77,7 +66,7 @@ class Item(DateFieldsModel, AuditModel): class Gun(RelationalAuditModel, DateFieldsModel): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( tablename="guns", exclude_parent_fields=["updated_by"], ) @@ -86,11 +75,7 @@ class Gun(RelationalAuditModel, DateFieldsModel): name: str = ormar.String(max_length=50) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_model_definition(): @@ -111,8 +96,8 @@ def test_model_definition(): @pytest.mark.asyncio async def test_model_works_as_expected(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): test = await Category(name="Cat", code=2, created_by="Joe").save() assert test.created_date is not None @@ -123,8 +108,8 @@ async def test_model_works_as_expected(): @pytest.mark.asyncio async def test_exclude_with_redefinition(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): test = await Item(name="Item", code=3, created_by="Anna").save() assert test.created_date is not None assert test.updated_by == "Bob" @@ -136,8 +121,8 @@ async def test_exclude_with_redefinition(): @pytest.mark.asyncio async def test_exclude_with_relation(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): user = await User(name="Michail Kalasznikow").save() test = await Gun(name="AK47", created_by=user).save() assert test.created_date is not None diff --git a/tests/test_inheritance_and_pydantic_generation/test_geting_pydantic_models.py b/tests/test_inheritance_and_pydantic_generation/test_geting_pydantic_models.py index fcf6c61bb..b9f1e525c 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_geting_pydantic_models.py +++ b/tests/test_inheritance_and_pydantic_generation/test_geting_pydantic_models.py @@ -6,16 +6,11 @@ import sqlalchemy from pydantic_core import PydanticUndefined -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class SelfRef(ormar.Model): @@ -62,6 +57,9 @@ class MutualB(ormar.Model): MutualA.update_forward_refs() +create_test_database = init_tests(base_ormar_config) + + def test_getting_pydantic_model(): PydanticCategory = Category.get_pydantic() assert issubclass(PydanticCategory, pydantic.BaseModel) diff --git a/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py b/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py index 4446d3451..6046e0e31 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py +++ b/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py @@ -2,7 +2,6 @@ from collections import Counter from typing import Optional -import databases import ormar import ormar.fields.constraints import pydantic @@ -13,19 +12,15 @@ from ormar.models.metaclass import get_constraint_copy from ormar.relations.relation_proxy import RelationProxy from pydantic import computed_field -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sa.MetaData() -db = databases.Database(DATABASE_URL) -engine = create_engine(DATABASE_URL) +base_ormar_config = create_config() class AuditModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - abstract=True, - ) + ormar_config = base_ormar_config.copy(abstract=True) created_by: str = ormar.String(max_length=100) updated_by: str = ormar.String(max_length=100, default="Sam") @@ -36,11 +31,7 @@ def audit(self) -> str: # pragma: no cover class DateFieldsModelNoSubclass(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="test_date_models", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="test_date_models") date_id: int = ormar.Integer(primary_key=True) created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) @@ -48,10 +39,8 @@ class DateFieldsModelNoSubclass(ormar.Model): class DateFieldsModel(ormar.Model): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( abstract=True, - metadata=metadata, - database=db, constraints=[ ormar.fields.constraints.UniqueColumns( "creation_date", @@ -72,7 +61,7 @@ class DateFieldsModel(ormar.Model): class Category(DateFieldsModel, AuditModel): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( tablename="categories", constraints=[ormar.fields.constraints.UniqueColumns("name", "code")], ) @@ -91,7 +80,7 @@ def audit(self) -> str: class Subject(DateFieldsModel): - ormar_config = ormar.OrmarConfig() + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50, unique=True, index=True) @@ -99,21 +88,14 @@ class Subject(DateFieldsModel): class Person(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Car(ormar.Model): - ormar_config = ormar.OrmarConfig( - abstract=True, - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(abstract=True) id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50) @@ -129,22 +111,14 @@ class Truck(Car): class Bus(Car): - ormar_config = ormar.OrmarConfig( - tablename="buses", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="buses") owner: Person = ormar.ForeignKey(Person, related_name="buses") max_persons: int = ormar.Integer() class Car2(ormar.Model): - ormar_config = ormar.OrmarConfig( - abstract=True, - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(abstract=True) id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50) @@ -154,13 +128,13 @@ class Car2(ormar.Model): class Truck2(Car2): - ormar_config = ormar.OrmarConfig(tablename="trucks2") + ormar_config = base_ormar_config.copy(tablename="trucks2") max_capacity: int = ormar.Integer() class Bus2(Car2): - ormar_config = ormar.OrmarConfig(tablename="buses2") + ormar_config = base_ormar_config.copy(tablename="buses2") max_persons: int = ormar.Integer() @@ -169,11 +143,7 @@ class ImmutablePerson(Person): model_config = dict(frozen=True, validate_assignment=False) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_init_of_abstract_model() -> None: @@ -193,11 +163,7 @@ class Bus3(Car2): # pragma: no cover def test_field_redefining_in_concrete_models() -> None: class RedefinedField(DateFieldsModel): - ormar_config = ormar.OrmarConfig( - tablename="redefines", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="redefines") id: int = ormar.Integer(primary_key=True) created_date: str = ormar.String( @@ -221,11 +187,7 @@ def test_model_subclassing_that_redefines_constraints_column_names() -> None: with pytest.raises(ModelDefinitionError): class WrongField2(DateFieldsModel): # pragma: no cover - ormar_config = ormar.OrmarConfig( - tablename="wrongs", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="wrongs") id: int = ormar.Integer(primary_key=True) created_date: str = ormar.String(max_length=200) # type: ignore @@ -235,18 +197,14 @@ def test_model_subclassing_non_abstract_raises_error() -> None: with pytest.raises(ModelDefinitionError): class WrongField2(DateFieldsModelNoSubclass): # pragma: no cover - ormar_config = ormar.OrmarConfig( - tablename="wrongs", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="wrongs") id: int = ormar.Integer(primary_key=True) def test_params_are_inherited() -> None: - assert Category.ormar_config.metadata == metadata - assert Category.ormar_config.database == db + assert Category.ormar_config.metadata == base_ormar_config.metadata + assert Category.ormar_config.database == base_ormar_config.database assert len(Category.ormar_config.property_fields) == 2 constraints = Counter(map(lambda c: type(c), Category.ormar_config.constraints)) @@ -265,8 +223,8 @@ def round_date_to_seconds( @pytest.mark.asyncio async def test_fields_inherited_from_mixin() -> None: - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): cat = await Category( name="Foo", code=123, created_by="Sam", updated_by="Max" ).save() @@ -282,6 +240,8 @@ async def test_fields_inherited_from_mixin() -> None: assert all( field in Subject.ormar_config.model_fields for field in mixin_columns ) + assert cat.code_name == "123:Foo" + assert cat.audit == "Sam Max" assert sub.created_date is not None assert sub.updated_date is not None @@ -293,7 +253,7 @@ async def test_fields_inherited_from_mixin() -> None: for field in mixin2_columns ) - inspector = sa.inspect(engine) + inspector = sa.inspect(base_ormar_config.engine) assert "categories" in inspector.get_table_names() table_columns = [x.get("name") for x in inspector.get_columns("categories")] assert all( @@ -342,8 +302,8 @@ async def test_fields_inherited_from_mixin() -> None: @pytest.mark.asyncio async def test_inheritance_with_relation() -> None: - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): sam = await Person(name="Sam").save() joe = await Person(name="Joe").save() await Truck( @@ -391,8 +351,8 @@ async def test_inheritance_with_relation() -> None: @pytest.mark.asyncio async def test_inheritance_with_multi_relation() -> None: - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): sam = await Person(name="Sam").save() joe = await Person(name="Joe").save() alex = await Person(name="Alex").save() diff --git a/tests/test_inheritance_and_pydantic_generation/test_inheritance_mixins.py b/tests/test_inheritance_and_pydantic_generation/test_inheritance_mixins.py index 46f7b9062..923c0c897 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_inheritance_mixins.py +++ b/tests/test_inheritance_and_pydantic_generation/test_inheritance_mixins.py @@ -1,17 +1,15 @@ import datetime from typing import Optional -import databases import ormar import pytest import sqlalchemy as sa -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sa.MetaData() -db = databases.Database(DATABASE_URL) -engine = create_engine(DATABASE_URL) + +base_ormar_config = create_config() class AuditMixin: @@ -25,11 +23,7 @@ class DateFieldsMixins: class Category(ormar.Model, DateFieldsMixins, AuditMixin): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50, unique=True, index=True) @@ -37,31 +31,19 @@ class Category(ormar.Model, DateFieldsMixins, AuditMixin): class Subject(ormar.Model, DateFieldsMixins): - ormar_config = ormar.OrmarConfig( - tablename="subjects", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="subjects") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50, unique=True, index=True) category: Optional[Category] = ormar.ForeignKey(Category) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_field_redefining() -> None: class RedefinedField(ormar.Model, DateFieldsMixins): - ormar_config = ormar.OrmarConfig( - tablename="redefined", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="redefined") id: int = ormar.Integer(primary_key=True) created_date: datetime.datetime = ormar.DateTime(name="creation_date") @@ -81,11 +63,7 @@ class RedefinedField(ormar.Model, DateFieldsMixins): def test_field_redefining_in_second() -> None: class RedefinedField2(ormar.Model, DateFieldsMixins): - ormar_config = ormar.OrmarConfig( - tablename="redefines2", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="redefines2") id: int = ormar.Integer(primary_key=True) created_date: str = ormar.String( @@ -119,8 +97,8 @@ def round_date_to_seconds( @pytest.mark.asyncio async def test_fields_inherited_from_mixin() -> None: - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): cat = await Category( name="Foo", code=123, created_by="Sam", updated_by="Max" ).save() @@ -146,7 +124,7 @@ async def test_fields_inherited_from_mixin() -> None: for field in mixin2_columns ) - inspector = sa.inspect(engine) + inspector = sa.inspect(base_ormar_config.engine) assert "categories" in inspector.get_table_names() table_columns = [x.get("name") for x in inspector.get_columns("categories")] assert all(col in table_columns for col in mixin_columns + mixin2_columns) diff --git a/tests/test_inheritance_and_pydantic_generation/test_inheritance_of_property_fields.py b/tests/test_inheritance_and_pydantic_generation/test_inheritance_of_property_fields.py index a9d0cf166..f8d9834ee 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_inheritance_of_property_fields.py +++ b/tests/test_inheritance_and_pydantic_generation/test_inheritance_of_property_fields.py @@ -1,18 +1,15 @@ -import databases import ormar -import pytest -import sqlalchemy -import sqlalchemy as sa from pydantic import computed_field -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sa.MetaData() -database = databases.Database(DATABASE_URL) + +base_ormar_config = create_config() class BaseFoo(ormar.Model): - ormar_config = ormar.OrmarConfig(abstract=True) + ormar_config = base_ormar_config.copy(abstract=True) name: str = ormar.String(max_length=100) @@ -22,10 +19,7 @@ def prefixed_name(self) -> str: class Foo(BaseFoo): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy() @computed_field def double_prefixed_name(self) -> str: @@ -35,10 +29,7 @@ def double_prefixed_name(self) -> str: class Bar(BaseFoo): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy() @computed_field def prefixed_name(self) -> str: @@ -47,13 +38,7 @@ def prefixed_name(self) -> str: id: int = ormar.Integer(primary_key=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_property_fields_are_inherited(): diff --git a/tests/test_inheritance_and_pydantic_generation/test_inheritance_with_default.py b/tests/test_inheritance_and_pydantic_generation/test_inheritance_with_default.py index 27db40dee..900b0ffe3 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_inheritance_with_default.py +++ b/tests/test_inheritance_and_pydantic_generation/test_inheritance_with_default.py @@ -1,21 +1,14 @@ import datetime import uuid -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL) - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class BaseModel(ormar.Model): @@ -35,13 +28,7 @@ class Member(BaseModel): last_name: str = ormar.String(max_length=50) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_model_structure(): @@ -58,6 +45,6 @@ def test_model_structure(): @pytest.mark.asyncio async def test_fields_inherited_with_default(): - async with database: + async with base_ormar_config.database: await Member(first_name="foo", last_name="bar").save() await Member.objects.create(first_name="foo", last_name="bar") diff --git a/tests/test_inheritance_and_pydantic_generation/test_inherited_class_is_not_abstract_by_default.py b/tests/test_inheritance_and_pydantic_generation/test_inherited_class_is_not_abstract_by_default.py index a038fab5c..44a75791d 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_inherited_class_is_not_abstract_by_default.py +++ b/tests/test_inheritance_and_pydantic_generation/test_inherited_class_is_not_abstract_by_default.py @@ -5,18 +5,15 @@ import pytest import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL) + +base_ormar_config = create_config() class TableBase(ormar.Model): - ormar_config = ormar.OrmarConfig( - abstract=True, - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(abstract=True) id: int = ormar.Integer(primary_key=True) created_by: str = ormar.String(max_length=20, default="test") @@ -28,7 +25,7 @@ class TableBase(ormar.Model): class NationBase(ormar.Model): - ormar_config = ormar.OrmarConfig(abstract=True) + ormar_config = base_ormar_config.copy(abstract=True) name: str = ormar.String(max_length=50) alpha2_code: str = ormar.String(max_length=2) @@ -37,21 +34,15 @@ class NationBase(ormar.Model): class Nation(NationBase, TableBase): - ormar_config = ormar.OrmarConfig() + ormar_config = base_ormar_config.copy() -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_model_is_not_abstract_by_default(): - async with database: + async with base_ormar_config.database: sweden = await Nation( name="Sweden", alpha2_code="SE", region="Europe", subregion="Scandinavia" ).save() diff --git a/tests/test_inheritance_and_pydantic_generation/test_nested_models_pydantic.py b/tests/test_inheritance_and_pydantic_generation/test_nested_models_pydantic.py index 08e99c9fc..87a2b7bad 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_nested_models_pydantic.py +++ b/tests/test_inheritance_and_pydantic_generation/test_nested_models_pydantic.py @@ -1,17 +1,10 @@ -import databases import ormar -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Library(ormar.Model): @@ -46,6 +39,9 @@ class TicketPackage(ormar.Model): package: Package = ormar.ForeignKey(Package, related_name="tickets") +create_test_database = init_tests(base_ormar_config) + + def test_have_proper_children(): TicketPackageOut = TicketPackage.get_pydantic(exclude={"ticket"}) assert "package" in TicketPackageOut.model_fields diff --git a/tests/test_inheritance_and_pydantic_generation/test_pydantic_fields_order.py b/tests/test_inheritance_and_pydantic_generation/test_pydantic_fields_order.py index b53aa5676..f4661338d 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_pydantic_fields_order.py +++ b/tests/test_inheritance_and_pydantic_generation/test_pydantic_fields_order.py @@ -1,25 +1,14 @@ -import databases import ormar -import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL) - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class NewTestModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() a: int = ormar.Integer(primary_key=True) b: str = ormar.String(max_length=1) @@ -29,13 +18,7 @@ class NewTestModel(ormar.Model): f: str = ormar.String(max_length=1) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_model_field_order(): diff --git a/tests/test_inheritance_and_pydantic_generation/test_validators_are_inherited.py b/tests/test_inheritance_and_pydantic_generation/test_validators_are_inherited.py index b09e2740a..7afb2cb77 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_validators_are_inherited.py +++ b/tests/test_inheritance_and_pydantic_generation/test_validators_are_inherited.py @@ -1,25 +1,18 @@ import enum -import databases import ormar import pytest -import sqlalchemy from pydantic import ValidationError, field_validator -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL) - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class BaseModel(ormar.Model): - ormar_config = ormar.OrmarConfig(abstract=True) + ormar_config = base_ormar_config.copy(abstract=True) id: int = ormar.Integer(primary_key=True) str_field: str = ormar.String(min_length=5, max_length=10, nullable=False) @@ -46,6 +39,9 @@ class ModelExample(BaseModel): ModelExampleCreate = ModelExample.get_pydantic(exclude={"id"}) +create_test_database = init_tests(base_ormar_config) + + def test_ormar_validator(): ModelExample(str_field="a aaaaaa", enum_field="A") with pytest.raises(ValidationError) as e: diff --git a/tests/test_inheritance_and_pydantic_generation/test_validators_in_generated_pydantic.py b/tests/test_inheritance_and_pydantic_generation/test_validators_in_generated_pydantic.py index e2ffe047f..96ebc17c6 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_validators_in_generated_pydantic.py +++ b/tests/test_inheritance_and_pydantic_generation/test_validators_in_generated_pydantic.py @@ -1,21 +1,14 @@ import enum -import databases import ormar import pytest -import sqlalchemy from pydantic import ValidationError, field_validator -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL) - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class EnumExample(str, enum.Enum): @@ -25,11 +18,7 @@ class EnumExample(str, enum.Enum): class ModelExample(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - tablename="examples", - ) + ormar_config = base_ormar_config.copy(tablename="examples") id: int = ormar.Integer(primary_key=True) str_field: str = ormar.String(min_length=5, max_length=10, nullable=False) @@ -45,6 +34,9 @@ def validate_str_field(cls, v): ModelExampleCreate = ModelExample.get_pydantic(exclude={"id"}) +create_test_database = init_tests(base_ormar_config) + + def test_ormar_validator(): ModelExample(str_field="a aaaaaa", enum_field="A") with pytest.raises(ValidationError) as e: diff --git a/tests/test_meta_constraints/test_check_constraints.py b/tests/test_meta_constraints/test_check_constraints.py index c55a4d9f5..4af656464 100644 --- a/tests/test_meta_constraints/test_check_constraints.py +++ b/tests/test_meta_constraints/test_check_constraints.py @@ -1,22 +1,19 @@ import sqlite3 import asyncpg # type: ignore -import databases import ormar.fields.constraints import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Product(ormar.Model): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( tablename="products", - metadata=metadata, - database=database, constraints=[ ormar.fields.constraints.CheckColumns("inventory > buffer"), ], @@ -29,20 +26,14 @@ class Product(ormar.Model): buffer: int = ormar.Integer() -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_check_columns_exclude_mysql(): if Product.ormar_config.database._backend._dialect.name != "mysql": - async with database: # pragma: no cover - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: # pragma: no cover + async with base_ormar_config.database.transaction(force_rollback=True): await Product.objects.create( name="Mars", company="Nestle", inventory=100, buffer=10 ) diff --git a/tests/test_meta_constraints/test_index_constraints.py b/tests/test_meta_constraints/test_index_constraints.py index 056b90204..291839628 100644 --- a/tests/test_meta_constraints/test_index_constraints.py +++ b/tests/test_meta_constraints/test_index_constraints.py @@ -1,19 +1,16 @@ -import databases import ormar.fields.constraints import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Product(ormar.Model): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( tablename="products", - metadata=metadata, - database=database, constraints=[ ormar.fields.constraints.IndexColumns("company", "name", name="my_index"), ormar.fields.constraints.IndexColumns("location", "company_type"), @@ -27,13 +24,7 @@ class Product(ormar.Model): company_type: str = ormar.String(max_length=200) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_table_structure(): @@ -52,8 +43,8 @@ def test_table_structure(): @pytest.mark.asyncio async def test_index_is_not_unique(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await Product.objects.create( name="Cookies", company="Nestle", location="A", company_type="B" ) diff --git a/tests/test_meta_constraints/test_unique_constraints.py b/tests/test_meta_constraints/test_unique_constraints.py index 91623c935..64da9ad5b 100644 --- a/tests/test_meta_constraints/test_unique_constraints.py +++ b/tests/test_meta_constraints/test_unique_constraints.py @@ -1,23 +1,20 @@ import sqlite3 import asyncpg # type: ignore -import databases import ormar.fields.constraints import pymysql import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Product(ormar.Model): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( tablename="products", - metadata=metadata, - database=database, constraints=[ormar.fields.constraints.UniqueColumns("name", "company")], ) @@ -26,19 +23,13 @@ class Product(ormar.Model): company: str = ormar.String(max_length=200) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_unique_columns(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await Product.objects.create(name="Cookies", company="Nestle") await Product.objects.create(name="Mars", company="Mars") await Product.objects.create(name="Mars", company="Nestle") diff --git a/tests/test_model_definition/test_aliases.py b/tests/test_model_definition/test_aliases.py index 230e3d58d..13fbb530d 100644 --- a/tests/test_model_definition/test_aliases.py +++ b/tests/test_model_definition/test_aliases.py @@ -1,22 +1,17 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Child(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="children", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="children") id: int = ormar.Integer(name="child_id", primary_key=True) first_name: str = ormar.String(name="fname", max_length=100) @@ -25,11 +20,7 @@ class Child(ormar.Model): class Artist(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="artists", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="artists") id: int = ormar.Integer(name="artist_id", primary_key=True) first_name: str = ormar.String(name="fname", max_length=100) @@ -39,24 +30,14 @@ class Artist(ormar.Model): class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="music_albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="music_albums") id: int = ormar.Integer(name="album_id", primary_key=True) name: str = ormar.String(name="album_name", max_length=100) artist: Optional[Artist] = ormar.ForeignKey(Artist, name="artist_id") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_table_structure(): @@ -69,8 +50,8 @@ def test_table_structure(): @pytest.mark.asyncio async def test_working_with_aliases(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): artist = await Artist.objects.create( first_name="Ted", last_name="Mosbey", born_year=1975 ) @@ -127,7 +108,7 @@ async def test_working_with_aliases(): @pytest.mark.asyncio async def test_bulk_operations_and_fields(): - async with database: + async with base_ormar_config.database: d1 = Child(first_name="Daughter", last_name="1", born_year=1990) d2 = Child(first_name="Daughter", last_name="2", born_year=1991) await Child.objects.bulk_create([d1, d2]) @@ -158,8 +139,8 @@ async def test_bulk_operations_and_fields(): @pytest.mark.asyncio async def test_working_with_aliases_get_or_create(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): artist, created = await Artist.objects.get_or_create( first_name="Teddy", last_name="Bear", born_year=2020 ) diff --git a/tests/test_model_definition/test_columns.py b/tests/test_model_definition/test_columns.py index 85eae2a97..94b24f862 100644 --- a/tests/test_model_definition/test_columns.py +++ b/tests/test_model_definition/test_columns.py @@ -1,17 +1,16 @@ import datetime from enum import Enum -import databases import ormar import pydantic import pytest -import sqlalchemy from ormar import ModelDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config(force_rollback=True) def time(): @@ -24,11 +23,7 @@ class MyEnum(Enum): class Example(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="example", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="example") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=200, default="aaa") @@ -42,22 +37,13 @@ class Example(ormar.Model): class EnumExample(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="enum_example", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="enum_example") id: int = ormar.Integer(primary_key=True) size: MyEnum = ormar.Enum(enum_class=MyEnum, default=MyEnum.SMALL) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_proper_enum_column_type(): @@ -75,7 +61,7 @@ class WrongEnum(Enum): @pytest.mark.asyncio async def test_enum_bulk_operations(): - async with database: + async with base_ormar_config.database: examples = [EnumExample(), EnumExample()] await EnumExample.objects.bulk_create(examples) @@ -92,7 +78,7 @@ async def test_enum_bulk_operations(): @pytest.mark.asyncio async def test_enum_filter(): - async with database: + async with base_ormar_config.database: examples = [EnumExample(), EnumExample(size=MyEnum.BIG)] await EnumExample.objects.bulk_create(examples) @@ -105,7 +91,7 @@ async def test_enum_filter(): @pytest.mark.asyncio async def test_model_crud(): - async with database: + async with base_ormar_config.database: example = Example() await example.save() @@ -133,15 +119,11 @@ async def test_model_crud(): @pytest.mark.asyncio async def test_invalid_enum_field() -> None: - async with database: + async with base_ormar_config.database: with pytest.raises(ModelDefinitionError): class Example2(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="example", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="example2") id: int = ormar.Integer(primary_key=True) size: MyEnum = ormar.Enum(enum_class=[]) # type: ignore diff --git a/tests/test_model_definition/test_create_uses_init_for_consistency.py b/tests/test_model_definition/test_create_uses_init_for_consistency.py index 6142bfce9..736248008 100644 --- a/tests/test_model_definition/test_create_uses_init_for_consistency.py +++ b/tests/test_model_definition/test_create_uses_init_for_consistency.py @@ -1,16 +1,15 @@ import uuid from typing import ClassVar -import databases import ormar import pytest -import sqlalchemy from pydantic import model_validator -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Mol(ormar.Model): @@ -19,9 +18,7 @@ class Mol(ormar.Model): "12345678-abcd-1234-abcd-123456789abc" ) - ormar_config = ormar.OrmarConfig( - database=database, metadata=metadata, tablename="mols" - ) + ormar_config = base_ormar_config.copy(tablename="mols") id: uuid.UUID = ormar.UUID(primary_key=True, index=True, uuid_format="hex") smiles: str = ormar.String(nullable=False, unique=True, max_length=256) @@ -43,17 +40,12 @@ def uuid(cls, smiles): return id_, smiles -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_json_column(): - async with database: + async with base_ormar_config.database: await Mol.objects.create(smiles="Cc1ccccc1") count = await Mol.objects.count() assert count == 1 diff --git a/tests/test_model_definition/test_dates_with_timezone.py b/tests/test_model_definition/test_dates_with_timezone.py index a524938c9..954ca350e 100644 --- a/tests/test_model_definition/test_dates_with_timezone.py +++ b/tests/test_model_definition/test_dates_with_timezone.py @@ -1,21 +1,17 @@ from datetime import date, datetime, time, timedelta, timezone -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class DateFieldsModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) created_date: datetime = ormar.DateTime( @@ -29,30 +25,21 @@ class DateFieldsModel(ormar.Model): class SampleModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) updated_at: datetime = ormar.DateTime() class TimeModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) elapsed: time = ormar.Time() class DateModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) creation_date: date = ormar.Date() @@ -62,23 +49,15 @@ class MyModel(ormar.Model): id: int = ormar.Integer(primary_key=True) created_at: datetime = ormar.DateTime(timezone=True, nullable=False) - ormar_config = ormar.OrmarConfig( - tablename="mymodels", metadata=metadata, database=database - ) + ormar_config = base_ormar_config.copy(tablename="mymodels") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_model_crud_with_timezone(): - async with database: + async with base_ormar_config.database: datemodel = await DateFieldsModel().save() assert datemodel.created_date is not None assert datemodel.updated_date is not None @@ -86,7 +65,7 @@ async def test_model_crud_with_timezone(): @pytest.mark.asyncio async def test_query_with_datetime_in_filter(): - async with database: + async with base_ormar_config.database: creation_dt = datetime(2021, 5, 18, 0, 0, 0, 0) sample = await SampleModel.objects.create(updated_at=creation_dt) @@ -100,7 +79,7 @@ async def test_query_with_datetime_in_filter(): @pytest.mark.asyncio async def test_query_with_date_in_filter(): - async with database: + async with base_ormar_config.database: sample = await TimeModel.objects.create(elapsed=time(0, 20, 20)) await TimeModel.objects.create(elapsed=time(0, 12, 0)) await TimeModel.objects.create(elapsed=time(0, 19, 55)) @@ -116,7 +95,7 @@ async def test_query_with_date_in_filter(): @pytest.mark.asyncio async def test_query_with_time_in_filter(): - async with database: + async with base_ormar_config.database: await DateModel.objects.create(creation_date=date(2021, 5, 18)) sample2 = await DateModel.objects.create(creation_date=date(2021, 5, 19)) sample3 = await DateModel.objects.create(creation_date=date(2021, 5, 20)) @@ -132,7 +111,7 @@ async def test_query_with_time_in_filter(): @pytest.mark.asyncio async def test_filtering_by_timezone_with_timedelta(): - async with database: + async with base_ormar_config.database: now_utc = datetime.now(timezone.utc) object = MyModel(created_at=now_utc) await object.save() diff --git a/tests/test_model_definition/test_equality_and_hash.py b/tests/test_model_definition/test_equality_and_hash.py index e8b52a61b..0d46dadf0 100644 --- a/tests/test_model_definition/test_equality_and_hash.py +++ b/tests/test_model_definition/test_equality_and_hash.py @@ -1,38 +1,27 @@ # type: ignore -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Song(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="songs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="songs") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_equality(): - async with database: + async with base_ormar_config.database: song1 = await Song.objects.create(name="Song") song2 = await Song.objects.create(name="Song") song3 = Song(name="Song") @@ -49,7 +38,7 @@ async def test_equality(): @pytest.mark.asyncio async def test_hash_doesnt_change_with_fields_if_pk(): - async with database: + async with base_ormar_config.database: song1 = await Song.objects.create(name="Song") prev_hash = hash(song1) @@ -59,7 +48,7 @@ async def test_hash_doesnt_change_with_fields_if_pk(): @pytest.mark.asyncio async def test_hash_changes_with_fields_if_no_pk(): - async with database: + async with base_ormar_config.database: song1 = Song(name="Song") prev_hash = hash(song1) diff --git a/tests/test_model_definition/test_extra_ignore_parameter.py b/tests/test_model_definition/test_extra_ignore_parameter.py index 9726db449..7f20cc503 100644 --- a/tests/test_model_definition/test_extra_ignore_parameter.py +++ b/tests/test_model_definition/test_extra_ignore_parameter.py @@ -1,19 +1,16 @@ -import databases import ormar -import sqlalchemy from ormar import Extra -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config(force_rollback=True) class Child(ormar.Model): - ormar_config = ormar.OrmarConfig( + ormar_config = base_ormar_config.copy( tablename="children", - metadata=metadata, - database=database, extra=Extra.ignore, ) @@ -22,6 +19,9 @@ class Child(ormar.Model): last_name: str = ormar.String(name="lname", max_length=100) +create_test_database = init_tests(base_ormar_config) + + def test_allow_extra_parameter(): child = Child(first_name="Test", last_name="Name", extra_param="Unexpected") assert child.first_name == "Test" diff --git a/tests/test_model_definition/test_fields_access.py b/tests/test_model_definition/test_fields_access.py index a4c34e534..c457f1438 100644 --- a/tests/test_model_definition/test_fields_access.py +++ b/tests/test_model_definition/test_fields_access.py @@ -4,16 +4,11 @@ import sqlalchemy from ormar import BaseField -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class PriceList(ormar.Model): @@ -40,13 +35,7 @@ class Product(ormar.Model): category = ormar.ForeignKey(Category) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_fields_access(): @@ -189,8 +178,8 @@ def test_combining_groups_together(): @pytest.mark.asyncio async def test_filtering_by_field_access(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): category = await Category(name="Toys").save() product2 = await Product( name="My Little Pony", rating=3.8, category=category diff --git a/tests/test_model_definition/test_foreign_key_value_used_for_related_model.py b/tests/test_model_definition/test_foreign_key_value_used_for_related_model.py index 9df1ef505..5ebc7f9bb 100644 --- a/tests/test_model_definition/test_foreign_key_value_used_for_related_model.py +++ b/tests/test_model_definition/test_foreign_key_value_used_for_related_model.py @@ -1,21 +1,14 @@ import uuid from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class PageLink(ormar.Model): @@ -52,18 +45,13 @@ class Course(ormar.Model): department: Optional[Department] = ormar.ForeignKey(Department) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_pass_int_values_as_fk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): link = await PageLink(id=1, value="test", country="USA").save() await Post.objects.create(title="My post", link=link.id) post_check = await Post.objects.select_related("link").get() @@ -72,7 +60,7 @@ async def test_pass_int_values_as_fk(): @pytest.mark.asyncio async def test_pass_uuid_value_as_fk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): dept = await Department(name="Department test").save() await Course(name="Test course", department=dept.id).save() diff --git a/tests/test_model_definition/test_iterate.py b/tests/test_model_definition/test_iterate.py index 2dcad639e..c5a9f78f0 100644 --- a/tests/test_model_definition/test_iterate.py +++ b/tests/test_model_definition/test_iterate.py @@ -1,34 +1,25 @@ import uuid -import databases import ormar import pytest -import sqlalchemy from ormar.exceptions import QueryDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users3", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users3") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, default="") class User2(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users4", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users4") id: uuid.UUID = ormar.UUID( uuid_format="string", primary_key=True, default=uuid.uuid4 @@ -37,11 +28,7 @@ class User2(ormar.Model): class Task(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tasks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tasks") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, default="") @@ -49,11 +36,7 @@ class Task(ormar.Model): class Task2(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tasks2", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tasks2") id: uuid.UUID = ormar.UUID( uuid_format="string", primary_key=True, default=uuid.uuid4 @@ -62,27 +45,21 @@ class Task2(ormar.Model): user: User2 = ormar.ForeignKey(to=User2) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_empty_result(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): async for user in User.objects.iterate(): pass # pragma: no cover @pytest.mark.asyncio async def test_model_iterator(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User.objects.create(name="Tom") jane = await User.objects.create(name="Jane") lucy = await User.objects.create(name="Lucy") @@ -93,8 +70,8 @@ async def test_model_iterator(): @pytest.mark.asyncio async def test_model_iterator_filter(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User.objects.create(name="Tom") await User.objects.create(name="Jane") await User.objects.create(name="Lucy") @@ -105,8 +82,8 @@ async def test_model_iterator_filter(): @pytest.mark.asyncio async def test_model_iterator_relations(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User.objects.create(name="Tom") jane = await User.objects.create(name="Jane") lucy = await User.objects.create(name="Lucy") @@ -125,8 +102,8 @@ async def test_model_iterator_relations(): @pytest.mark.asyncio async def test_model_iterator_relations_queryset_proxy(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User.objects.create(name="Tom") jane = await User.objects.create(name="Jane") @@ -151,8 +128,8 @@ async def test_model_iterator_relations_queryset_proxy(): @pytest.mark.asyncio async def test_model_iterator_uneven_number_of_relations(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User.objects.create(name="Tom") jane = await User.objects.create(name="Jane") lucy = await User.objects.create(name="Lucy") @@ -173,8 +150,8 @@ async def test_model_iterator_uneven_number_of_relations(): @pytest.mark.asyncio async def test_model_iterator_uuid_pk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User2.objects.create(name="Tom") jane = await User2.objects.create(name="Jane") lucy = await User2.objects.create(name="Lucy") @@ -185,8 +162,8 @@ async def test_model_iterator_uuid_pk(): @pytest.mark.asyncio async def test_model_iterator_filter_uuid_pk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User2.objects.create(name="Tom") await User2.objects.create(name="Jane") await User2.objects.create(name="Lucy") @@ -197,8 +174,8 @@ async def test_model_iterator_filter_uuid_pk(): @pytest.mark.asyncio async def test_model_iterator_relations_uuid_pk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User2.objects.create(name="Tom") jane = await User2.objects.create(name="Jane") lucy = await User2.objects.create(name="Lucy") @@ -217,8 +194,8 @@ async def test_model_iterator_relations_uuid_pk(): @pytest.mark.asyncio async def test_model_iterator_relations_queryset_proxy_uuid_pk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User2.objects.create(name="Tom") jane = await User2.objects.create(name="Jane") @@ -243,8 +220,8 @@ async def test_model_iterator_relations_queryset_proxy_uuid_pk(): @pytest.mark.asyncio async def test_model_iterator_uneven_number_of_relations_uuid_pk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User2.objects.create(name="Tom") jane = await User2.objects.create(name="Jane") lucy = await User2.objects.create(name="Lucy") @@ -267,7 +244,7 @@ async def test_model_iterator_uneven_number_of_relations_uuid_pk(): @pytest.mark.asyncio async def test_model_iterator_with_prefetch_raises_error(): - async with database: + async with base_ormar_config.database: with pytest.raises(QueryDefinitionError): async for user in User.objects.prefetch_related(User.tasks).iterate(): pass # pragma: no cover diff --git a/tests/test_model_definition/test_model_construct.py b/tests/test_model_definition/test_model_construct.py index 1ef8ec559..85ad4101d 100644 --- a/tests/test_model_definition/test_model_construct.py +++ b/tests/test_model_definition/test_model_construct.py @@ -1,41 +1,28 @@ from typing import List -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class NickNames(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") class NicksHq(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks_x_hq", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks_x_hq") class HQ(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="hqs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="hqs") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -43,11 +30,7 @@ class HQ(ormar.Model): class Company(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="companies", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="companies") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="company_name") @@ -55,19 +38,13 @@ class Company(ormar.Model): hq: HQ = ormar.ForeignKey(HQ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_construct_with_empty_relation(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await HQ.objects.create(name="Main") comp = Company(name="Banzai", hq=None, founded=1988) comp2 = Company.model_construct( @@ -78,8 +55,8 @@ async def test_construct_with_empty_relation(): @pytest.mark.asyncio async def test_init_and_construct_has_same_effect(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") comp = Company(name="Banzai", hq=hq, founded=1988) comp2 = Company.model_construct(**dict(name="Banzai", hq=hq, founded=1988)) @@ -93,8 +70,8 @@ async def test_init_and_construct_has_same_effect(): @pytest.mark.asyncio async def test_init_and_construct_has_same_effect_with_m2m(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): n1 = await NickNames(name="test").save() n2 = await NickNames(name="test2").save() hq = HQ(name="Main", nicks=[n1, n2]) diff --git a/tests/test_model_definition/test_model_definition.py b/tests/test_model_definition/test_model_definition.py index 0393d5111..7d841bf18 100644 --- a/tests/test_model_definition/test_model_definition.py +++ b/tests/test_model_definition/test_model_definition.py @@ -3,7 +3,6 @@ import decimal import typing -import databases import ormar import pydantic import pytest @@ -11,18 +10,15 @@ from ormar.exceptions import ModelDefinitionError from ormar.models import Model -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) + +base_ormar_config = create_config() class ExampleModel(Model): - ormar_config = ormar.OrmarConfig( - tablename="example", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="example") test: int = ormar.Integer(primary_key=True) test_string: str = ormar.String(max_length=250) @@ -53,22 +49,13 @@ class ExampleModel(Model): class ExampleModel2(Model): - ormar_config = ormar.OrmarConfig( - tablename="examples", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="example2") test: int = ormar.Integer(primary_key=True) test_string: str = ormar.String(max_length=250) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.fixture() @@ -117,7 +104,7 @@ def test_missing_metadata(): class JsonSample2(ormar.Model): ormar_config = ormar.OrmarConfig( tablename="jsons2", - database=database, + database=base_ormar_config.database, ) id: int = ormar.Integer(primary_key=True) @@ -176,11 +163,7 @@ def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): # type: ignore class ExampleModel2(Model): # type: ignore - ormar_config = ormar.OrmarConfig( - tablename="example2", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="example2") test_string: str = ormar.String(max_length=250) # type: ignore @@ -191,11 +174,7 @@ def test_two_pks_in_model_definition(): @typing.no_type_check class ExampleModel2(Model): - ormar_config = ormar.OrmarConfig( - tablename="example3", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="example3") id: int = ormar.Integer(primary_key=True) test_string: str = ormar.String(max_length=250, primary_key=True) @@ -206,11 +185,7 @@ def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): class ExampleModel2(Model): - ormar_config = ormar.OrmarConfig( - tablename="example5", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="example5") test: decimal.Decimal = ormar.Decimal(primary_key=True) @@ -220,11 +195,7 @@ def test_binary_error_without_length_model_definition(): with pytest.raises(ModelDefinitionError): class ExampleModel2(Model): - ormar_config = ormar.OrmarConfig( - tablename="example6", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="example6") test: bytes = ormar.LargeBinary(primary_key=True, max_length=-1) @@ -234,11 +205,7 @@ def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): class ExampleModel2(Model): - ormar_config = ormar.OrmarConfig( - tablename="example6", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="example6") test: str = ormar.String(primary_key=True, max_length=0) diff --git a/tests/test_model_definition/test_models.py b/tests/test_model_definition/test_models.py index 1f099100e..951edafd1 100644 --- a/tests/test_model_definition/test_models.py +++ b/tests/test_model_definition/test_models.py @@ -5,25 +5,21 @@ import uuid from enum import Enum -import databases import ormar import pydantic import pytest import sqlalchemy from ormar.exceptions import ModelError, NoMatch, QueryDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class JsonSample(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="jsons", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="jsons") id: int = ormar.Integer(primary_key=True) test_json = ormar.JSON(nullable=True) @@ -34,11 +30,7 @@ class JsonSample(ormar.Model): class LargeBinarySample(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="my_bolbs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="my_bolbs") id: int = ormar.Integer(primary_key=True) test_binary: bytes = ormar.LargeBinary(max_length=100000) @@ -49,11 +41,7 @@ class LargeBinarySample(ormar.Model): class LargeBinaryStr(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="my_str_blobs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="my_str_bolbs") id: int = ormar.Integer(primary_key=True) test_binary: str = ormar.LargeBinary( @@ -62,11 +50,7 @@ class LargeBinaryStr(ormar.Model): class LargeBinaryNullableStr(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="my_str_blobs2", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="my_str_bolbs2") id: int = ormar.Integer(primary_key=True) test_binary: str = ormar.LargeBinary( @@ -78,22 +62,14 @@ class LargeBinaryNullableStr(ormar.Model): class UUIDSample(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="uuids", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="uuids") id: uuid.UUID = ormar.UUID(primary_key=True, default=uuid.uuid4) test_text: str = ormar.Text() class UUIDSample2(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="uuids2", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="uuids2") id: uuid.UUID = ormar.UUID( primary_key=True, default=uuid.uuid4, uuid_format="string" @@ -102,33 +78,21 @@ class UUIDSample2(ormar.Model): class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, default="") class User2(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users2", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users2") id: str = ormar.String(primary_key=True, max_length=100) name: str = ormar.String(max_length=100, default="") class Product(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="product", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="product") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -152,11 +116,7 @@ class CountryCodeEnum(int, Enum): class Country(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="country", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="country") id: int = ormar.Integer(primary_key=True) name: CountryNameEnum = ormar.Enum(enum_class=CountryNameEnum, default="Canada") @@ -165,34 +125,20 @@ class Country(ormar.Model): class NullableCountry(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="country2", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="country2") id: int = ormar.Integer(primary_key=True) name: CountryNameEnum = ormar.Enum(enum_class=CountryNameEnum, nullable=True) class NotNullableCountry(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="country3", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="country3") id: int = ormar.Integer(primary_key=True) name: CountryNameEnum = ormar.Enum(enum_class=CountryNameEnum, nullable=False) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_model_class(): @@ -219,8 +165,8 @@ def test_model_pk(): @pytest.mark.asyncio async def test_json_column(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await JsonSample.objects.create(test_json=dict(aa=12)) await JsonSample.objects.create(test_json='{"aa": 12}') @@ -235,8 +181,8 @@ async def test_json_column(): @pytest.mark.asyncio async def test_binary_column(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await LargeBinarySample.objects.create(test_binary=blob) await LargeBinarySample.objects.create(test_binary=blob2) @@ -251,8 +197,8 @@ async def test_binary_column(): @pytest.mark.asyncio async def test_binary_str_column(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await LargeBinaryStr(test_binary=blob3).save() await LargeBinaryStr.objects.create(test_binary=blob4) @@ -267,8 +213,8 @@ async def test_binary_str_column(): @pytest.mark.asyncio async def test_binary_nullable_str_column(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await LargeBinaryNullableStr().save() await LargeBinaryNullableStr.objects.create() items = await LargeBinaryNullableStr.objects.all() @@ -298,8 +244,8 @@ async def test_binary_nullable_str_column(): @pytest.mark.asyncio async def test_uuid_column(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): u1 = await UUIDSample.objects.create(test_text="aa") u2 = await UUIDSample.objects.create(test_text="bb") @@ -333,8 +279,8 @@ async def test_uuid_column(): @pytest.mark.asyncio async def test_model_crud(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): users = await User.objects.all() assert users == [] @@ -360,8 +306,8 @@ async def test_model_crud(): @pytest.mark.asyncio async def test_model_get(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): with pytest.raises(ormar.NoMatch): await User.objects.get() @@ -385,8 +331,8 @@ async def test_model_get(): @pytest.mark.asyncio async def test_model_filter(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await User.objects.create(name="Tom") await User.objects.create(name="Jane") await User.objects.create(name="Lucy") @@ -441,7 +387,7 @@ async def test_model_filter(): @pytest.mark.asyncio async def test_wrong_query_contains_model(): - async with database: + async with base_ormar_config.database: with pytest.raises(QueryDefinitionError): product = Product(name="90%-Cotton", rating=2) await Product.objects.filter(name__contains=product).count() @@ -449,8 +395,8 @@ async def test_wrong_query_contains_model(): @pytest.mark.asyncio async def test_model_exists(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await User.objects.create(name="Tom") assert await User.objects.filter(name="Tom").exists() is True assert await User.objects.filter(name="Jane").exists() is False @@ -458,8 +404,8 @@ async def test_model_exists(): @pytest.mark.asyncio async def test_model_count(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await User.objects.create(name="Tom") await User.objects.create(name="Jane") await User.objects.create(name="Lucy") @@ -470,8 +416,8 @@ async def test_model_count(): @pytest.mark.asyncio async def test_model_limit(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await User.objects.create(name="Tom") await User.objects.create(name="Jane") await User.objects.create(name="Lucy") @@ -481,8 +427,8 @@ async def test_model_limit(): @pytest.mark.asyncio async def test_model_limit_with_filter(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await User.objects.create(name="Tom") await User.objects.create(name="Tom") await User.objects.create(name="Tom") @@ -494,8 +440,8 @@ async def test_model_limit_with_filter(): @pytest.mark.asyncio async def test_offset(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await User.objects.create(name="Tom") await User.objects.create(name="Jane") @@ -505,8 +451,8 @@ async def test_offset(): @pytest.mark.asyncio async def test_model_first(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): tom = await User.objects.create(name="Tom") jane = await User.objects.create(name="Jane") @@ -522,7 +468,7 @@ async def test_model_first(): @pytest.mark.asyncio async def test_model_choices(): """Test that choices work properly for various types of fields.""" - async with database: + async with base_ormar_config.database: # Test valid choices. await asyncio.gather( Country.objects.create(name="Canada", taxed=True, country_code=1), @@ -562,7 +508,7 @@ async def test_model_choices(): @pytest.mark.asyncio async def test_nullable_field_model_choices(): """Test that choices work properly for according to nullable setting""" - async with database: + async with base_ormar_config.database: c1 = await NullableCountry(name=None).save() assert c1.name is None @@ -572,8 +518,8 @@ async def test_nullable_field_model_choices(): @pytest.mark.asyncio async def test_start_and_end_filters(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await User.objects.create(name="Markos Uj") await User.objects.create(name="Maqua Bigo") await User.objects.create(name="maqo quidid") @@ -602,8 +548,8 @@ async def test_start_and_end_filters(): @pytest.mark.asyncio async def test_get_and_first(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await User.objects.create(name="Tom") await User.objects.create(name="Jane") await User.objects.create(name="Lucy") diff --git a/tests/test_model_definition/test_models_are_pickable.py b/tests/test_model_definition/test_models_are_pickable.py index 42d7adcbc..c0c90e1f0 100644 --- a/tests/test_model_definition/test_models_are_pickable.py +++ b/tests/test_model_definition/test_models_are_pickable.py @@ -1,23 +1,18 @@ import pickle from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -25,29 +20,19 @@ class User(ormar.Model): class Post(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="posts", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="posts") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) created_by: Optional[User] = ormar.ForeignKey(User) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_dumping_and_loading_model_works(): - async with database: + async with base_ormar_config.database: user = await User(name="Test", properties={"aa": "bb"}).save() post = Post(name="Test post") await user.posts.add(post) diff --git a/tests/test_model_definition/test_overwriting_pydantic_field_type.py b/tests/test_model_definition/test_overwriting_pydantic_field_type.py index cd2bb3f45..bb026415d 100644 --- a/tests/test_model_definition/test_overwriting_pydantic_field_type.py +++ b/tests/test_model_definition/test_overwriting_pydantic_field_type.py @@ -1,23 +1,18 @@ from typing import Dict, Optional -import databases import ormar import pytest -import sqlalchemy from pydantic import Json, PositiveInt, ValidationError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class OverwriteTest(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="overwrites", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="overwrites") id: int = ormar.Integer(primary_key=True) my_int: int = ormar.Integer(overwrite_pydantic_type=PositiveInt) @@ -26,13 +21,7 @@ class OverwriteTest(ormar.Model): ) # type: ignore -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_constraints(): @@ -50,7 +39,7 @@ def test_constraints(): @pytest.mark.asyncio async def test_saving(): - async with database: + async with base_ormar_config.database: await OverwriteTest(my_int=5, constraint_dict={"aa": 123}).save() test = await OverwriteTest.objects.get() diff --git a/tests/test_model_definition/test_overwriting_sql_nullable.py b/tests/test_model_definition/test_overwriting_sql_nullable.py index 08b103556..f967f05bd 100644 --- a/tests/test_model_definition/test_overwriting_sql_nullable.py +++ b/tests/test_model_definition/test_overwriting_sql_nullable.py @@ -2,23 +2,16 @@ from typing import Optional import asyncpg -import databases import ormar import pymysql import pytest -import sqlalchemy -from sqlalchemy import create_engine, text +from sqlalchemy import text -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -db = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, -) +base_ormar_config = create_config() class PrimaryModel(ormar.Model): @@ -32,17 +25,12 @@ class PrimaryModel(ormar.Model): ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_create_models(): - async with db: + async with base_ormar_config.database: primary = await PrimaryModel( name="Foo", some_text="Bar", some_other_text="Baz" ).save() diff --git a/tests/test_model_definition/test_pk_field_is_always_not_null.py b/tests/test_model_definition/test_pk_field_is_always_not_null.py index c4a37e6e4..285c82dc3 100644 --- a/tests/test_model_definition/test_pk_field_is_always_not_null.py +++ b/tests/test_model_definition/test_pk_field_is_always_not_null.py @@ -1,17 +1,10 @@ -import databases import ormar -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class AutoincrementModel(ormar.Model): @@ -32,6 +25,9 @@ class ExplicitNullableModel(ormar.Model): id: int = ormar.Integer(primary_key=True, nullable=True) +create_test_database = init_tests(base_ormar_config) + + def test_pk_field_is_not_null(): for model in [AutoincrementModel, NonAutoincrementModel, ExplicitNullableModel]: assert not model.ormar_config.table.c.get("id").nullable diff --git a/tests/test_model_definition/test_properties.py b/tests/test_model_definition/test_properties.py index 4d54bf73c..085f6e030 100644 --- a/tests/test_model_definition/test_properties.py +++ b/tests/test_model_definition/test_properties.py @@ -1,22 +1,17 @@ # type: ignore -import databases import ormar import pytest -import sqlalchemy from pydantic import PydanticUserError, computed_field -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Song(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="songs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="songs") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -35,18 +30,12 @@ def sample2(self) -> str: return "sample2" -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_sort_order_on_main_model(): - async with database: + async with base_ormar_config.database: await Song.objects.create(name="Song 3", sort_order=3) await Song.objects.create(name="Song 1", sort_order=1) await Song.objects.create(name="Song 2", sort_order=2) diff --git a/tests/test_model_definition/test_pydantic_fields.py b/tests/test_model_definition/test_pydantic_fields.py index df69a6163..ff9b1e474 100644 --- a/tests/test_model_definition/test_pydantic_fields.py +++ b/tests/test_model_definition/test_pydantic_fields.py @@ -1,22 +1,16 @@ import random from typing import Optional -import databases import ormar import pytest -import sqlalchemy from pydantic import BaseModel, Field, HttpUrl from pydantic_extra_types.payment import PaymentCardNumber -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class ModelTest(ormar.Model): @@ -70,18 +64,12 @@ def __init__(self, **kwargs): pydantic_test: PydanticTest -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_working_with_pydantic_fields(): - async with database: + async with base_ormar_config.database: test = ModelTest(name="Test") assert test.name == "Test" assert test.url == "https://www.example.com" @@ -101,7 +89,7 @@ async def test_working_with_pydantic_fields(): @pytest.mark.asyncio async def test_default_factory_for_pydantic_fields(): - async with database: + async with base_ormar_config.database: test = ModelTest2(name="Test2", number="4000000000000002") assert test.name == "Test2" assert test.url == "https://www.example2.com" @@ -121,7 +109,7 @@ async def test_default_factory_for_pydantic_fields(): @pytest.mark.asyncio async def test_init_setting_for_pydantic_fields(): - async with database: + async with base_ormar_config.database: test = ModelTest3(name="Test3") assert test.name == "Test3" assert test.url == "https://www.example3.com" diff --git a/tests/test_model_definition/test_pydantic_only_fields.py b/tests/test_model_definition/test_pydantic_only_fields.py index f7c0d676a..024ecba3c 100644 --- a/tests/test_model_definition/test_pydantic_only_fields.py +++ b/tests/test_model_definition/test_pydantic_only_fields.py @@ -1,24 +1,19 @@ import datetime -import databases import ormar import pydantic import pytest -import sqlalchemy from pydantic import computed_field -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -41,19 +36,13 @@ def name40(self) -> str: return self.name + "_40" -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_pydantic_only_fields(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): album = await Album.objects.create(name="Hitchcock") assert album.pk is not None assert album.saved diff --git a/tests/test_model_definition/test_pydantic_private_attributes.py b/tests/test_model_definition/test_pydantic_private_attributes.py index c990dcf74..591dabaea 100644 --- a/tests/test_model_definition/test_pydantic_private_attributes.py +++ b/tests/test_model_definition/test_pydantic_private_attributes.py @@ -1,20 +1,13 @@ from typing import List -import databases import ormar -import sqlalchemy from pydantic import PrivateAttr -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Subscription(ormar.Model): @@ -29,6 +22,9 @@ def add_payment(self, payment: str): self._add_payments.append(payment) +create_test_database = init_tests(base_ormar_config) + + def test_private_attribute(): sub = Subscription(stripe_subscription_id="2312312sad231") sub.add_payment("test") diff --git a/tests/test_model_definition/test_save_status.py b/tests/test_model_definition/test_save_status.py index c142a5874..52f58fb32 100644 --- a/tests/test_model_definition/test_save_status.py +++ b/tests/test_model_definition/test_save_status.py @@ -1,23 +1,18 @@ from typing import List -import databases import ormar import pytest -import sqlalchemy from ormar.exceptions import ModelPersistenceError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class NickNames(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -25,19 +20,11 @@ class NickNames(ormar.Model): class NicksHq(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks_x_hq", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks_x_hq") class HQ(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="hqs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="hqs") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -45,11 +32,7 @@ class HQ(ormar.Model): class Company(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="companies", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="companies") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="company_name") @@ -57,19 +40,13 @@ class Company(ormar.Model): hq: HQ = ormar.ForeignKey(HQ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_instantiation_false_save_true(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): comp = Company(name="Banzai", founded=1988) assert not comp.saved await comp.save() @@ -78,8 +55,8 @@ async def test_instantiation_false_save_true(): @pytest.mark.asyncio async def test_saved_edited_not_saved(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): comp = await Company.objects.create(name="Banzai", founded=1988) assert comp.saved comp.name = "Banzai2" @@ -100,8 +77,8 @@ async def test_saved_edited_not_saved(): @pytest.mark.asyncio async def test_adding_related_gets_dirty(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") comp = await Company.objects.create(name="Banzai", founded=1988) assert comp.saved @@ -127,8 +104,8 @@ async def test_adding_related_gets_dirty(): @pytest.mark.asyncio async def test_adding_many_to_many_does_not_gets_dirty(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): nick1 = await NickNames.objects.create(name="Bazinga", is_lame=False) nick2 = await NickNames.objects.create(name="Bazinga2", is_lame=True) @@ -156,8 +133,8 @@ async def test_adding_many_to_many_does_not_gets_dirty(): @pytest.mark.asyncio async def test_delete(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): comp = await Company.objects.create(name="Banzai", founded=1988) assert comp.saved await comp.delete() @@ -169,8 +146,8 @@ async def test_delete(): @pytest.mark.asyncio async def test_load(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): comp = await Company.objects.create(name="Banzai", founded=1988) assert comp.saved comp.name = "AA" @@ -183,8 +160,8 @@ async def test_load(): @pytest.mark.asyncio async def test_queryset_methods(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await Company.objects.create(name="Banzai", founded=1988) await Company.objects.create(name="Yuhu", founded=1989) await Company.objects.create(name="Konono", founded=1990) @@ -226,8 +203,8 @@ async def test_queryset_methods(): @pytest.mark.asyncio async def test_bulk_methods(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): c1 = Company(name="Banzai", founded=1988) c2 = Company(name="Yuhu", founded=1989) diff --git a/tests/test_model_definition/test_saving_nullable_fields.py b/tests/test_model_definition/test_saving_nullable_fields.py index 18cb0ef29..05a2a9405 100644 --- a/tests/test_model_definition/test_saving_nullable_fields.py +++ b/tests/test_model_definition/test_saving_nullable_fields.py @@ -1,23 +1,17 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -db = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class PrimaryModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="primary_models", - ) + ormar_config = base_ormar_config.copy(tablename="primary_models") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=255, index=True) @@ -27,11 +21,7 @@ class PrimaryModel(ormar.Model): class SecondaryModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="secondary_models", - ) + ormar_config = base_ormar_config.copy(tablename="secondary_models") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -40,18 +30,13 @@ class SecondaryModel(ormar.Model): ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_create_models(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): primary = await PrimaryModel( name="Foo", some_text="Bar", some_other_text="Baz" ).save() diff --git a/tests/test_model_definition/test_server_default.py b/tests/test_model_definition/test_server_default.py index 45e8bf6ed..089d78463 100644 --- a/tests/test_model_definition/test_server_default.py +++ b/tests/test_model_definition/test_server_default.py @@ -1,24 +1,19 @@ import time from datetime import datetime -import databases import ormar import pytest -import sqlalchemy from sqlalchemy import func, text -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Product(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="product", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="product") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -27,13 +22,7 @@ class Product(ormar.Model): created: datetime = ormar.DateTime(server_default=func.now()) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_table_defined_properly(): @@ -46,8 +35,8 @@ def test_table_defined_properly(): @pytest.mark.asyncio async def test_model_creation(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): p1 = Product(name="Test") assert p1.created is None await p1.save() diff --git a/tests/test_model_definition/test_setting_comments_in_db.py b/tests/test_model_definition/test_setting_comments_in_db.py index 213c0cf7f..08388be5c 100644 --- a/tests/test_model_definition/test_setting_comments_in_db.py +++ b/tests/test_model_definition/test_setting_comments_in_db.py @@ -1,33 +1,22 @@ -import databases import ormar import pytest -import sqlalchemy from ormar.models import Model -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) +base_ormar_config = create_config() class Comment(Model): - ormar_config = ormar.OrmarConfig( - tablename="comments", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="comments") test: int = ormar.Integer(primary_key=True, comment="primary key of comments") test_string: str = ormar.String(max_length=250, comment="test that it works") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio diff --git a/tests/test_model_methods/test_excludes_in_load_all.py b/tests/test_model_methods/test_excludes_in_load_all.py index 62238ca5e..49811e386 100644 --- a/tests/test_model_methods/test_excludes_in_load_all.py +++ b/tests/test_model_methods/test_excludes_in_load_all.py @@ -1,20 +1,13 @@ import uuid -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config(force_rollback=True) class JimmyUser(ormar.Model): @@ -45,18 +38,12 @@ class JimmyAccount(ormar.Model): user: JimmyUser = ormar.ForeignKey(to=JimmyUser) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_excluding_one_relation(): - async with database: + async with base_ormar_config.database: user = JimmyUser() await user.save() @@ -70,7 +57,7 @@ async def test_excluding_one_relation(): @pytest.mark.asyncio async def test_excluding_other_relation(): - async with database: + async with base_ormar_config.database: user = JimmyUser() await user.save() diff --git a/tests/test_model_methods/test_load_all.py b/tests/test_model_methods/test_load_all.py index 4a9f9086f..af0452654 100644 --- a/tests/test_model_methods/test_load_all.py +++ b/tests/test_model_methods/test_load_all.py @@ -1,20 +1,13 @@ from typing import List -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Language(ormar.Model): @@ -59,19 +52,13 @@ class Company(ormar.Model): hq: HQ = ormar.ForeignKey(HQ, related_name="companies") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_load_all_fk_rel(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") company = await Company.objects.create(name="Banzai", founded=1988, hq=hq) @@ -90,8 +77,8 @@ async def test_load_all_fk_rel(): @pytest.mark.asyncio async def test_load_all_many_to_many(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): nick1 = await NickName.objects.create(name="BazingaO", is_lame=False) nick2 = await NickName.objects.create(name="Bazinga20", is_lame=True) hq = await HQ.objects.create(name="Main") @@ -116,8 +103,8 @@ async def test_load_all_many_to_many(): @pytest.mark.asyncio async def test_load_all_with_order(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): nick1 = await NickName.objects.create(name="Barry", is_lame=False) nick2 = await NickName.objects.create(name="Joe", is_lame=True) hq = await HQ.objects.create(name="Main") @@ -150,8 +137,8 @@ async def test_load_all_with_order(): @pytest.mark.asyncio async def test_loading_reversed_relation(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") await Company.objects.create(name="Banzai", founded=1988, hq=hq) @@ -166,8 +153,8 @@ async def test_loading_reversed_relation(): @pytest.mark.asyncio async def test_loading_nested(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): language = await Language.objects.create(name="English") level = await CringeLevel.objects.create(name="High", language=language) level2 = await CringeLevel.objects.create(name="Low", language=language) diff --git a/tests/test_model_methods/test_populate_default_values.py b/tests/test_model_methods/test_populate_default_values.py index 1ae5ba802..820d313c1 100644 --- a/tests/test_model_methods/test_populate_default_values.py +++ b/tests/test_model_methods/test_populate_default_values.py @@ -1,18 +1,11 @@ -import databases import ormar -import sqlalchemy from sqlalchemy import text -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Task(ormar.Model): @@ -20,7 +13,7 @@ class Task(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String( - max_length=255, minimum=0, server_default=text("Default Name"), nullable=False + max_length=255, minimum=0, server_default=text("'Default Name'"), nullable=False ) points: int = ormar.Integer( default=0, minimum=0, server_default=text("0"), nullable=False @@ -28,6 +21,9 @@ class Task(ormar.Model): score: int = ormar.Integer(default=5) +create_test_database = init_tests(base_ormar_config) + + def test_populate_default_values(): new_kwargs = { "id": None, diff --git a/tests/test_model_methods/test_save_related.py b/tests/test_model_methods/test_save_related.py index c10f9d3e4..7201050af 100644 --- a/tests/test_model_methods/test_save_related.py +++ b/tests/test_model_methods/test_save_related.py @@ -5,29 +5,22 @@ import pytest import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class CringeLevel(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="levels", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="levels") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class NickName(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -36,19 +29,11 @@ class NickName(ormar.Model): class NicksHq(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks_x_hq", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks_x_hq") class HQ(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="hqs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="hqs") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -56,11 +41,7 @@ class HQ(ormar.Model): class Company(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="companies", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="companies") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="company_name") @@ -68,19 +49,13 @@ class Company(ormar.Model): hq: HQ = ormar.ForeignKey(HQ, related_name="companies") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_saving_related_fk_rel(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") comp = await Company.objects.create(name="Banzai", founded=1988, hq=hq) assert comp.saved @@ -107,8 +82,8 @@ async def test_saving_related_fk_rel(): @pytest.mark.asyncio async def test_saving_many_to_many(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): nick1 = await NickName.objects.create(name="BazingaO", is_lame=False) nick2 = await NickName.objects.create(name="Bazinga20", is_lame=True) @@ -149,8 +124,8 @@ async def test_saving_many_to_many(): @pytest.mark.asyncio async def test_saving_reversed_relation(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): hq = await HQ.objects.create(name="Main") await Company.objects.create(name="Banzai", founded=1988, hq=hq) @@ -190,8 +165,8 @@ async def test_saving_reversed_relation(): @pytest.mark.asyncio async def test_saving_nested(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): level = await CringeLevel.objects.create(name="High") level2 = await CringeLevel.objects.create(name="Low") nick1 = await NickName.objects.create( diff --git a/tests/test_model_methods/test_save_related_from_dict.py b/tests/test_model_methods/test_save_related_from_dict.py index 46013071f..089360ab5 100644 --- a/tests/test_model_methods/test_save_related_from_dict.py +++ b/tests/test_model_methods/test_save_related_from_dict.py @@ -1,33 +1,24 @@ from typing import List -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class CringeLevel(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="levels", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="levels") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class NickName(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -36,22 +27,14 @@ class NickName(ormar.Model): class NicksHq(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks_x_hq", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks_x_hq") id: int = ormar.Integer(primary_key=True) new_field: str = ormar.String(max_length=200, nullable=True) class HQ(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="hqs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="hqs") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -59,11 +42,7 @@ class HQ(ormar.Model): class Company(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="companies", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="companies") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="company_name") @@ -71,19 +50,13 @@ class Company(ormar.Model): hq: HQ = ormar.ForeignKey(HQ, related_name="companies") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_saving_related_reverse_fk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): payload = {"companies": [{"name": "Banzai"}], "name": "Main"} hq = HQ(**payload) count = await hq.save_related(follow=True, save_all=True) @@ -99,8 +72,8 @@ async def test_saving_related_reverse_fk(): @pytest.mark.asyncio async def test_saving_related_reverse_fk_multiple(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): payload = { "companies": [{"name": "Banzai"}, {"name": "Yamate"}], "name": "Main", @@ -121,8 +94,8 @@ async def test_saving_related_reverse_fk_multiple(): @pytest.mark.asyncio async def test_saving_related_fk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): payload = {"hq": {"name": "Main"}, "name": "Banzai"} comp = Company(**payload) count = await comp.save_related(follow=True, save_all=True) @@ -137,8 +110,8 @@ async def test_saving_related_fk(): @pytest.mark.asyncio async def test_saving_many_to_many_wo_through(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): payload = { "name": "Main", "nicks": [ @@ -160,9 +133,9 @@ async def test_saving_many_to_many_wo_through(): @pytest.mark.asyncio async def test_saving_many_to_many_with_through(): - async with database: - async with database.transaction(force_rollback=True): - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + async with base_ormar_config.database.transaction(force_rollback=True): payload = { "name": "Main", "nicks": [ @@ -194,8 +167,8 @@ async def test_saving_many_to_many_with_through(): @pytest.mark.asyncio async def test_saving_nested_with_m2m_and_rev_fk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): payload = { "name": "Main", "nicks": [ @@ -219,8 +192,8 @@ async def test_saving_nested_with_m2m_and_rev_fk(): @pytest.mark.asyncio async def test_saving_nested_with_m2m_and_rev_fk_and_through(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): payload = { "hq": { "name": "Yoko", diff --git a/tests/test_model_methods/test_save_related_uuid.py b/tests/test_model_methods/test_save_related_uuid.py index db7f10180..2139cb573 100644 --- a/tests/test_model_methods/test_save_related_uuid.py +++ b/tests/test_model_methods/test_save_related_uuid.py @@ -6,27 +6,22 @@ import pytest import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Department(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: uuid.UUID = ormar.UUID(primary_key=True, default=uuid.uuid4) department_name: str = ormar.String(max_length=100) class Course(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: uuid.UUID = ormar.UUID(primary_key=True, default=uuid.uuid4) course_name: str = ormar.String(max_length=100) @@ -35,28 +30,19 @@ class Course(ormar.Model): class Student(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: uuid.UUID = ormar.UUID(primary_key=True, default=uuid.uuid4) name: str = ormar.String(max_length=100) courses = ormar.ManyToMany(Course) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_uuid_pk_in_save_related(): - async with database: + async with base_ormar_config.database: to_save = { "department_name": "Ormar", "courses": [ diff --git a/tests/test_model_methods/test_update.py b/tests/test_model_methods/test_update.py index 22a81c677..fbf11b823 100644 --- a/tests/test_model_methods/test_update.py +++ b/tests/test_model_methods/test_update.py @@ -5,18 +5,15 @@ import pytest import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Director(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="directors", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="directors") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="first_name") @@ -24,11 +21,7 @@ class Director(ormar.Model): class Movie(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="movies", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="movies") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="title") @@ -37,18 +30,12 @@ class Movie(ormar.Model): director: Optional[Director] = ormar.ForeignKey(Director) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_updating_selected_columns(): - async with database: + async with base_ormar_config.database: director1 = await Director(name="Peter", last_name="Jackson").save() director2 = await Director(name="James", last_name="Cameron").save() @@ -84,7 +71,7 @@ async def test_updating_selected_columns(): @pytest.mark.asyncio async def test_not_passing_columns_or_empty_list_saves_all(): - async with database: + async with base_ormar_config.database: director = await Director(name="James", last_name="Cameron").save() terminator = await Movie( name="Terminator", year=1984, director=director, profit=0.078 diff --git a/tests/test_model_methods/test_upsert.py b/tests/test_model_methods/test_upsert.py index ad216862a..43b241f1d 100644 --- a/tests/test_model_methods/test_upsert.py +++ b/tests/test_model_methods/test_upsert.py @@ -5,18 +5,15 @@ import pytest import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Director(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="directors", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="directors") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="first_name") @@ -24,11 +21,7 @@ class Director(ormar.Model): class Movie(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="movies", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="movies") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="title") @@ -37,18 +30,12 @@ class Movie(ormar.Model): director: Optional[Director] = ormar.ForeignKey(Director) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_updating_selected_columns(): - async with database: + async with base_ormar_config.database: director1 = await Director(name="Peter", last_name="Jackson").save() await Movie( diff --git a/tests/test_ordering/test_default_model_order.py b/tests/test_ordering/test_default_model_order.py index 90f400649..20bc4a608 100644 --- a/tests/test_ordering/test_default_model_order.py +++ b/tests/test_ordering/test_default_model_order.py @@ -1,21 +1,14 @@ from typing import Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -37,26 +30,20 @@ class Book(ormar.Model): ranking: int = ormar.Integer(nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(autouse=True, scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: await Book.objects.delete(each=True) await Author.objects.delete(each=True) @pytest.mark.asyncio async def test_default_orders_is_applied(): - async with database: + async with base_ormar_config.database: tolkien = await Author(name="J.R.R. Tolkien").save() sapkowski = await Author(name="Andrzej Sapkowski").save() king = await Author(name="Stephen King").save() @@ -77,7 +64,7 @@ async def test_default_orders_is_applied(): @pytest.mark.asyncio async def test_default_orders_is_applied_on_related(): - async with database: + async with base_ormar_config.database: tolkien = await Author(name="J.R.R. Tolkien").save() silmarillion = await Book( author=tolkien, title="The Silmarillion", year=1977 @@ -100,7 +87,7 @@ async def test_default_orders_is_applied_on_related(): @pytest.mark.asyncio async def test_default_orders_is_applied_on_related_two_fields(): - async with database: + async with base_ormar_config.database: sanders = await Author(name="Brandon Sanderson").save() twok = await Book( author=sanders, title="The Way of Kings", year=2010, ranking=10 diff --git a/tests/test_ordering/test_default_relation_order.py b/tests/test_ordering/test_default_relation_order.py index 27e0a628f..6b20cc292 100644 --- a/tests/test_ordering/test_default_relation_order.py +++ b/tests/test_ordering/test_default_relation_order.py @@ -1,22 +1,15 @@ from typing import List, Optional from uuid import UUID, uuid4 -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -59,26 +52,20 @@ class Human(ormar.Model): ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(autouse=True, scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: await Book.objects.delete(each=True) await Author.objects.delete(each=True) @pytest.mark.asyncio async def test_default_orders_is_applied_from_reverse_relation(): - async with database: + async with base_ormar_config.database: tolkien = await Author(name="J.R.R. Tolkien").save() hobbit = await Book(author=tolkien, title="The Hobbit", year=1933).save() silmarillion = await Book( @@ -96,7 +83,7 @@ async def test_default_orders_is_applied_from_reverse_relation(): @pytest.mark.asyncio async def test_default_orders_is_applied_from_relation(): - async with database: + async with base_ormar_config.database: bret = await Author(name="Peter V. Bret").save() tds = await Book( author=bret, title="The Desert Spear", year=2010, ranking=9 @@ -113,7 +100,7 @@ async def test_default_orders_is_applied_from_relation(): @pytest.mark.asyncio async def test_default_orders_is_applied_from_relation_on_m2m(): - async with database: + async with base_ormar_config.database: alice = await Human(name="Alice").save() spot = await Animal(name="Spot", specie="Cat").save() @@ -132,7 +119,7 @@ async def test_default_orders_is_applied_from_relation_on_m2m(): @pytest.mark.asyncio async def test_default_orders_is_applied_from_reverse_relation_on_m2m(): - async with database: + async with base_ormar_config.database: max = await Animal(name="Max", specie="Dog").save() joe = await Human(name="Joe").save() zack = await Human(name="Zack").save() diff --git a/tests/test_ordering/test_default_through_relation_order.py b/tests/test_ordering/test_default_through_relation_order.py index b2a59b07e..1d0e1daa5 100644 --- a/tests/test_ordering/test_default_through_relation_order.py +++ b/tests/test_ordering/test_default_through_relation_order.py @@ -1,10 +1,8 @@ from typing import Any, Dict, List, Tuple, Type, cast from uuid import UUID, uuid4 -import databases import ormar import pytest -import sqlalchemy from ormar import ( Model, ModelDefinitionError, @@ -14,16 +12,11 @@ pre_update, ) -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Animal(ormar.Model): @@ -66,18 +59,12 @@ class Human2(ormar.Model): ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_ordering_by_through_fail(): - async with database: + async with base_ormar_config.database: alice = await Human2(name="Alice").save() spot = await Animal(name="Spot").save() await alice.favoriteAnimals.add(spot) @@ -252,7 +239,7 @@ async def reorder_links_on_remove( @pytest.mark.asyncio async def test_ordering_by_through_on_m2m_field(): - async with database: + async with base_ormar_config.database: def verify_order(instance, expected): field_name = ( diff --git a/tests/test_ordering/test_proper_order_of_sorting_apply.py b/tests/test_ordering/test_proper_order_of_sorting_apply.py index 96dfb1fc7..9f7c117b7 100644 --- a/tests/test_ordering/test_proper_order_of_sorting_apply.py +++ b/tests/test_ordering/test_proper_order_of_sorting_apply.py @@ -1,21 +1,14 @@ from typing import Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -37,26 +30,20 @@ class Book(ormar.Model): ranking: int = ormar.Integer(nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(autouse=True, scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: await Book.objects.delete(each=True) await Author.objects.delete(each=True) @pytest.mark.asyncio async def test_default_orders_is_applied_from_reverse_relation(): - async with database: + async with base_ormar_config.database: tolkien = await Author(name="J.R.R. Tolkien").save() hobbit = await Book(author=tolkien, title="The Hobbit", year=1933).save() silmarillion = await Book( diff --git a/tests/test_queries/test_adding_related.py b/tests/test_queries/test_adding_related.py index c6ce9e2dc..247999241 100644 --- a/tests/test_queries/test_adding_related.py +++ b/tests/test_queries/test_adding_related.py @@ -1,31 +1,24 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Department(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Course(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -33,18 +26,12 @@ class Course(ormar.Model): department: Optional[Department] = ormar.ForeignKey(Department) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_adding_relation_to_reverse_saves_the_child(): - async with database: + async with base_ormar_config.database: department = await Department(name="Science").save() course = Course(name="Math", completed=False) diff --git a/tests/test_queries/test_aggr_functions.py b/tests/test_queries/test_aggr_functions.py index c74b77fbe..c7986d783 100644 --- a/tests/test_queries/test_aggr_functions.py +++ b/tests/test_queries/test_aggr_functions.py @@ -1,22 +1,15 @@ from typing import Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy from ormar.exceptions import QueryDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -38,19 +31,13 @@ class Book(ormar.Model): ranking: int = ormar.Integer(nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(autouse=True, scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: await Book.objects.delete(each=True) await Author.objects.delete(each=True) @@ -64,7 +51,7 @@ async def sample_data(): @pytest.mark.asyncio async def test_min_method(): - async with database: + async with base_ormar_config.database: await sample_data() assert await Book.objects.min("year") == 1920 result = await Book.objects.min(["year", "ranking"]) @@ -88,7 +75,7 @@ async def test_min_method(): @pytest.mark.asyncio async def test_max_method(): - async with database: + async with base_ormar_config.database: await sample_data() assert await Book.objects.max("year") == 1930 result = await Book.objects.max(["year", "ranking"]) @@ -112,7 +99,7 @@ async def test_max_method(): @pytest.mark.asyncio async def test_sum_method(): - async with database: + async with base_ormar_config.database: await sample_data() assert await Book.objects.sum("year") == 5773 result = await Book.objects.sum(["year", "ranking"]) @@ -137,7 +124,7 @@ async def test_sum_method(): @pytest.mark.asyncio async def test_avg_method(): - async with database: + async with base_ormar_config.database: await sample_data() assert round(float(await Book.objects.avg("year")), 2) == 1924.33 result = await Book.objects.avg(["year", "ranking"]) @@ -165,7 +152,7 @@ async def test_avg_method(): @pytest.mark.asyncio async def test_queryset_method(): - async with database: + async with base_ormar_config.database: await sample_data() author = await Author.objects.select_related("books").get() assert await author.books.min("year") == 1920 @@ -179,7 +166,7 @@ async def test_queryset_method(): @pytest.mark.asyncio async def test_count_method(): - async with database: + async with base_ormar_config.database: await sample_data() count = await Author.objects.select_related("books").count() diff --git a/tests/test_queries/test_deep_relations_select_all.py b/tests/test_queries/test_deep_relations_select_all.py index 7f2caf46f..3e7bc84f9 100644 --- a/tests/test_queries/test_deep_relations_select_all.py +++ b/tests/test_queries/test_deep_relations_select_all.py @@ -1,21 +1,16 @@ -import databases import ormar import pytest -import sqlalchemy from sqlalchemy import func -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Chart(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="charts", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="charts") chart_id = ormar.Integer(primary_key=True, autoincrement=True) name = ormar.String(max_length=200, unique=True, index=True) @@ -29,11 +24,7 @@ class Chart(ormar.Model): class Report(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="reports", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="reports") report_id = ormar.Integer(primary_key=True, autoincrement=True) name = ormar.String(max_length=200, unique=True, index=True) @@ -42,11 +33,7 @@ class Report(ormar.Model): class Language(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="languages", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="languages") language_id = ormar.Integer(primary_key=True, autoincrement=True) code = ormar.String(max_length=5) @@ -54,22 +41,14 @@ class Language(ormar.Model): class TranslationNode(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="translation_nodes", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="translation_nodes") node_id = ormar.Integer(primary_key=True, autoincrement=True) node_type = ormar.String(max_length=200) class Translation(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="translations", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="translations") translation_id = ormar.Integer(primary_key=True, autoincrement=True) node_id = ormar.ForeignKey(TranslationNode, related_name="translations") @@ -78,11 +57,7 @@ class Translation(ormar.Model): class Filter(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="filters", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="filters") filter_id = ormar.Integer(primary_key=True, autoincrement=True) name = ormar.String(max_length=200, unique=True, index=True) @@ -96,11 +71,7 @@ class Filter(ormar.Model): class FilterValue(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="filter_values", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="filter_values") value_id = ormar.Integer(primary_key=True, autoincrement=True) value = ormar.String(max_length=300) @@ -110,11 +81,7 @@ class FilterValue(ormar.Model): class FilterXReport(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="filters_x_reports", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="filters_x_reports") filter_x_report_id = ormar.Integer(primary_key=True) filter = ormar.ForeignKey(Filter, name="filter_id", related_name="reports") @@ -125,11 +92,7 @@ class FilterXReport(ormar.Model): class ChartXReport(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="charts_x_reports", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="charts_x_reports") chart_x_report_id = ormar.Integer(primary_key=True) chart = ormar.ForeignKey(Chart, name="chart_id", related_name="reports") @@ -139,11 +102,7 @@ class ChartXReport(ormar.Model): class ChartColumn(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="charts_columns", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="charts_columns") column_id = ormar.Integer(primary_key=True, autoincrement=True) chart = ormar.ForeignKey(Chart, name="chart_id", related_name="columns") @@ -152,17 +111,11 @@ class ChartColumn(ormar.Model): translation = ormar.ForeignKey(TranslationNode, name="translation_node_id") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_saving_related_fk_rel(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await Report.objects.select_all(follow=True).all() diff --git a/tests/test_queries/test_filter_groups.py b/tests/test_queries/test_filter_groups.py index 1eb002b49..c0af4c8bf 100644 --- a/tests/test_queries/test_filter_groups.py +++ b/tests/test_queries/test_filter_groups.py @@ -1,19 +1,12 @@ from typing import Optional -import databases import ormar -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -32,6 +25,9 @@ class Book(ormar.Model): year: int = ormar.Integer(nullable=True) +create_test_database = init_tests(base_ormar_config) + + def test_or_group(): result = ormar.or_(name="aa", books__title="bb") result.resolve(model_cls=Author) diff --git a/tests/test_queries/test_indirect_relations_to_self.py b/tests/test_queries/test_indirect_relations_to_self.py index c9e5094be..ae4a91736 100644 --- a/tests/test_queries/test_indirect_relations_to_self.py +++ b/tests/test_queries/test_indirect_relations_to_self.py @@ -1,22 +1,17 @@ from datetime import datetime -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Node(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="node", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="node") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=120) @@ -25,11 +20,7 @@ class Node(ormar.Model): class Edge(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="edge", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="edge") id: str = ormar.String(primary_key=True, max_length=12) src_node: Node = ormar.ForeignKey(Node, related_name="next_edges") @@ -38,18 +29,12 @@ class Edge(ormar.Model): created_at: datetime = ormar.DateTime(timezone=True, default=datetime.now) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_sort_order_on_main_model(): - async with database: + async with base_ormar_config.database: node1 = await Node(name="Node 1").save() node2 = await Node(name="Node 2").save() node3 = await Node(name="Node 3").save() diff --git a/tests/test_queries/test_isnull_filter.py b/tests/test_queries/test_isnull_filter.py index e3e421459..4f7fc79ff 100644 --- a/tests/test_queries/test_isnull_filter.py +++ b/tests/test_queries/test_isnull_filter.py @@ -1,20 +1,13 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -34,11 +27,7 @@ class Book(ormar.Model): class JsonModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - tablename="jsons", - ) + ormar_config = base_ormar_config.copy(tablename="jsons") id = ormar.Integer(primary_key=True) text_field = ormar.Text(nullable=True) @@ -46,18 +35,12 @@ class JsonModel(ormar.Model): json_not_null = ormar.JSON() -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_is_null(): - async with database: + async with base_ormar_config.database: tolkien = await Author.objects.create(name="J.R.R. Tolkien") await Book.objects.create(author=tolkien, title="The Hobbit") await Book.objects.create( @@ -99,7 +82,7 @@ async def test_is_null(): @pytest.mark.asyncio async def test_isnull_json(): - async with database: + async with base_ormar_config.database: author = await JsonModel.objects.create(json_not_null=None) assert author.json_field is None non_null_text_fields = await JsonModel.objects.all(text_field__isnull=False) diff --git a/tests/test_queries/test_nested_reverse_relations.py b/tests/test_queries/test_nested_reverse_relations.py index 679b53b7c..b55f11bcf 100644 --- a/tests/test_queries/test_nested_reverse_relations.py +++ b/tests/test_queries/test_nested_reverse_relations.py @@ -1,20 +1,14 @@ from typing import Optional -import databases import ormar import pytest import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class DataSource(ormar.Model): @@ -45,18 +39,12 @@ class DataSourceTableColumn(ormar.Model): ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): # pragma: no cover - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_double_nested_reverse_relation(): - async with database: + async with base_ormar_config.database: data_source = await DataSource(name="local").save() test_tables = [ { diff --git a/tests/test_queries/test_non_relation_fields_not_merged.py b/tests/test_queries/test_non_relation_fields_not_merged.py index b64302bd2..97b3c70a7 100644 --- a/tests/test_queries/test_non_relation_fields_not_merged.py +++ b/tests/test_queries/test_non_relation_fields_not_merged.py @@ -1,20 +1,13 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Chart(ormar.Model): @@ -31,18 +24,12 @@ class Config(ormar.Model): chart: Optional[Chart] = ormar.ForeignKey(Chart) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_list_field_that_is_not_relation_is_not_merged(): - async with database: + async with base_ormar_config.database: chart = await Chart.objects.create(datasets=[{"test": "ok"}]) await Config.objects.create(chart=chart) await Config.objects.create(chart=chart) diff --git a/tests/test_queries/test_or_filters.py b/tests/test_queries/test_or_filters.py index 039d3df84..05c2e556e 100644 --- a/tests/test_queries/test_or_filters.py +++ b/tests/test_queries/test_or_filters.py @@ -1,21 +1,14 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy from ormar.exceptions import QueryDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -34,18 +27,12 @@ class Book(ormar.Model): year: int = ormar.Integer(nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_or_filters(): - async with database: + async with base_ormar_config.database: tolkien = await Author(name="J.R.R. Tolkien").save() await Book(author=tolkien, title="The Hobbit", year=1933).save() await Book(author=tolkien, title="The Lord of the Rings", year=1955).save() diff --git a/tests/test_queries/test_order_by.py b/tests/test_queries/test_order_by.py index 187381aa4..6c091805f 100644 --- a/tests/test_queries/test_order_by.py +++ b/tests/test_queries/test_order_by.py @@ -1,22 +1,17 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Song(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="songs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="songs") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -24,33 +19,21 @@ class Song(ormar.Model): class Owner(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="owners", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="owners") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class AliasNested(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="aliases_nested", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="aliases_nested") id: int = ormar.Integer(name="alias_id", primary_key=True) name: str = ormar.String(name="alias_name", max_length=100) class AliasTest(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="aliases", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="aliases") id: int = ormar.Integer(name="alias_id", primary_key=True) name: str = ormar.String(name="alias_name", max_length=100) @@ -58,11 +41,7 @@ class AliasTest(ormar.Model): class Toy(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="toys", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="toys") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -70,22 +49,14 @@ class Toy(ormar.Model): class Factory(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="factories", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="factories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Car(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="cars", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="cars") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -93,29 +64,19 @@ class Car(ormar.Model): class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="users", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="users") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) cars: List[Car] = ormar.ManyToMany(Car) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_sort_order_on_main_model(): - async with database: + async with base_ormar_config.database: await Song.objects.create(name="Song 3", sort_order=3) await Song.objects.create(name="Song 1", sort_order=1) await Song.objects.create(name="Song 2", sort_order=2) @@ -174,7 +135,7 @@ async def test_sort_order_on_main_model(): @pytest.mark.asyncio async def test_sort_order_on_related_model(): - async with database: + async with base_ormar_config.database: aphrodite = await Owner.objects.create(name="Aphrodite") hermes = await Owner.objects.create(name="Hermes") zeus = await Owner.objects.create(name="Zeus") @@ -260,7 +221,7 @@ async def test_sort_order_on_related_model(): @pytest.mark.asyncio async def test_sort_order_on_many_to_many(): - async with database: + async with base_ormar_config.database: factory1 = await Factory.objects.create(name="Factory 1") factory2 = await Factory.objects.create(name="Factory 2") @@ -334,7 +295,7 @@ async def test_sort_order_on_many_to_many(): @pytest.mark.asyncio async def test_sort_order_with_aliases(): - async with database: + async with base_ormar_config.database: al1 = await AliasTest.objects.create(name="Test4") al2 = await AliasTest.objects.create(name="Test2") al3 = await AliasTest.objects.create(name="Test1") diff --git a/tests/test_queries/test_pagination.py b/tests/test_queries/test_pagination.py index b36360c1f..5215b9f89 100644 --- a/tests/test_queries/test_pagination.py +++ b/tests/test_queries/test_pagination.py @@ -1,19 +1,12 @@ -import databases import ormar import pytest -import sqlalchemy from ormar.exceptions import QueryDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Car(ormar.Model): @@ -35,19 +28,13 @@ class User(ormar.Model): cars = ormar.ManyToMany(Car, through=UsersCar) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_limit_zero(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): for i in range(5): await Car(name=f"{i}").save() @@ -58,8 +45,8 @@ async def test_limit_zero(): @pytest.mark.asyncio async def test_pagination_errors(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): with pytest.raises(QueryDefinitionError): await Car.objects.paginate(0).all() @@ -69,8 +56,8 @@ async def test_pagination_errors(): @pytest.mark.asyncio async def test_pagination_on_single_model(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): for i in range(20): await Car(name=f"{i}").save() @@ -93,8 +80,8 @@ async def test_pagination_on_single_model(): @pytest.mark.asyncio async def test_proxy_pagination(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): user = await User(name="Jon").save() for i in range(20): diff --git a/tests/test_queries/test_queryproxy_on_m2m_models.py b/tests/test_queries/test_queryproxy_on_m2m_models.py index 4108a52a7..93a09b874 100644 --- a/tests/test_queries/test_queryproxy_on_m2m_models.py +++ b/tests/test_queries/test_queryproxy_on_m2m_models.py @@ -1,34 +1,25 @@ from typing import List, Optional, Union -import databases import ormar import pytest -import sqlalchemy from ormar.exceptions import QueryDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Subject(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="subjects", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="subjects") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=80) class Author(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="authors", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="authors") id: int = ormar.Integer(primary_key=True) first_name: str = ormar.String(max_length=80) @@ -36,11 +27,7 @@ class Author(ormar.Model): class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=40) @@ -49,19 +36,11 @@ class Category(ormar.Model): class PostCategory(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="posts_categories", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="posts_categories") class Post(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="posts", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="posts") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) @@ -71,18 +50,13 @@ class Post(ormar.Model): author: Optional[Author] = ormar.ForeignKey(Author) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_queryset_methods(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): guido = await Author.objects.create( first_name="Guido", last_name="Van Rossum" ) @@ -187,8 +161,8 @@ async def test_queryset_methods(): @pytest.mark.asyncio async def test_queryset_update(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): guido = await Author.objects.create( first_name="Guido", last_name="Van Rossum" ) diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index b4509c3e5..93805c28f 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -1,11 +1,9 @@ from enum import Enum from typing import Optional -import databases import ormar import pydantic import pytest -import sqlalchemy from ormar import QuerySet from ormar.exceptions import ( ModelListEmptyError, @@ -14,10 +12,11 @@ ) from pydantic import Json -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config(force_rollback=True) class MySize(Enum): @@ -26,11 +25,7 @@ class MySize(Enum): class Book(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="books", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="books") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) @@ -43,11 +38,7 @@ class Book(ormar.Model): class ToDo(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="todos", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="todos") id: int = ormar.Integer(primary_key=True) text: str = ormar.String(max_length=500) @@ -57,22 +48,14 @@ class ToDo(ormar.Model): class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=500) class Note(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="notes", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="notes") id: int = ormar.Integer(primary_key=True) text: str = ormar.String(max_length=500) @@ -80,11 +63,7 @@ class Note(ormar.Model): class ItemConfig(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - tablename="item_config", - ) + ormar_config = base_ormar_config.copy(tablename="item_config") id: Optional[int] = ormar.Integer(primary_key=True) item_id: str = ormar.String(max_length=32, index=True) @@ -102,9 +81,7 @@ async def first_or_404(self, *args, **kwargs): class Customer(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, + ormar_config = base_ormar_config.copy( tablename="customer", queryset_class=QuerySetCls, ) @@ -114,29 +91,19 @@ class Customer(ormar.Model): class JsonTestModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - tablename="test_model", - ) + ormar_config = base_ormar_config.copy(tablename="test_model") id: int = ormar.Integer(primary_key=True) json_field: Json = ormar.JSON() -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_delete_and_update(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await Book.objects.create( title="Tom Sawyer", author="Twain, Mark", genre="Adventure" ) @@ -191,7 +158,7 @@ async def test_delete_and_update(): @pytest.mark.asyncio async def test_get_or_create(): - async with database: + async with base_ormar_config.database: tom, created = await Book.objects.get_or_create( title="Volume I", author="Anonymous", genre="Fiction" ) @@ -217,7 +184,7 @@ async def test_get_or_create(): @pytest.mark.asyncio async def test_get_or_create_with_defaults(): - async with database: + async with base_ormar_config.database: book, created = await Book.objects.get_or_create( title="Nice book", _defaults={"author": "Mojix", "genre": "Historic"} ) @@ -253,7 +220,7 @@ async def test_get_or_create_with_defaults(): @pytest.mark.asyncio async def test_update_or_create(): - async with database: + async with base_ormar_config.database: tom = await Book.objects.update_or_create( title="Volume I", author="Anonymous", genre="Fiction" ) @@ -276,7 +243,7 @@ async def test_update_or_create(): @pytest.mark.asyncio async def test_bulk_create(): - async with database: + async with base_ormar_config.database: await ToDo.objects.bulk_create( [ ToDo(text="Buy the groceries."), @@ -299,7 +266,7 @@ async def test_bulk_create(): @pytest.mark.asyncio async def test_bulk_create_json_field(): - async with database: + async with base_ormar_config.database: json_value = {"a": 1} test_model_1 = JsonTestModel(id=1, json_field=json_value) test_model_2 = JsonTestModel(id=2, json_field=json_value) @@ -319,7 +286,7 @@ async def test_bulk_create_json_field(): query = table.select().where(table.c.json_field["a"].as_integer() == 1) res = [ JsonTestModel.from_row(record, source_model=JsonTestModel) - for record in await database.fetch_all(query) + for record in await base_ormar_config.database.fetch_all(query) ] assert test_model_1 in res @@ -329,7 +296,7 @@ async def test_bulk_create_json_field(): @pytest.mark.asyncio async def test_bulk_create_with_relation(): - async with database: + async with base_ormar_config.database: category = await Category.objects.create(name="Sample Category") await Note.objects.bulk_create( @@ -347,7 +314,7 @@ async def test_bulk_create_with_relation(): @pytest.mark.asyncio async def test_bulk_update(): - async with database: + async with base_ormar_config.database: await ToDo.objects.bulk_create( [ ToDo(text="Buy the groceries."), @@ -378,7 +345,7 @@ async def test_bulk_update(): @pytest.mark.asyncio async def test_bulk_update_with_only_selected_columns(): - async with database: + async with base_ormar_config.database: await ToDo.objects.bulk_create( [ ToDo(text="Reset the world simulation.", completed=False), @@ -407,7 +374,7 @@ async def test_bulk_update_with_only_selected_columns(): @pytest.mark.asyncio async def test_bulk_update_with_relation(): - async with database: + async with base_ormar_config.database: category = await Category.objects.create(name="Sample Category") category2 = await Category.objects.create(name="Sample II Category") @@ -436,7 +403,7 @@ async def test_bulk_update_with_relation(): @pytest.mark.asyncio async def test_bulk_update_not_saved_objts(): - async with database: + async with base_ormar_config.database: category = await Category.objects.create(name="Sample Category") with pytest.raises(ModelPersistenceError): await Note.objects.bulk_update( @@ -452,7 +419,7 @@ async def test_bulk_update_not_saved_objts(): @pytest.mark.asyncio async def test_bulk_operations_with_json(): - async with database: + async with base_ormar_config.database: items = [ ItemConfig(item_id="test1"), ItemConfig(item_id="test2"), @@ -480,14 +447,14 @@ async def test_bulk_operations_with_json(): query = table.select().where(table.c.pairs["b"].as_integer() == 2) res = [ ItemConfig.from_row(record, source_model=ItemConfig) - for record in await database.fetch_all(query) + for record in await base_ormar_config.database.fetch_all(query) ] assert len(res) == 2 @pytest.mark.asyncio async def test_custom_queryset_cls(): - async with database: + async with base_ormar_config.database: with pytest.raises(ValueError): await Customer.objects.first_or_404(id=1) @@ -498,7 +465,7 @@ async def test_custom_queryset_cls(): @pytest.mark.asyncio async def test_filter_enum(): - async with database: + async with base_ormar_config.database: it = ItemConfig(item_id="test_1") await it.save() diff --git a/tests/test_queries/test_quoting_table_names_in_on_join_clause.py b/tests/test_queries/test_quoting_table_names_in_on_join_clause.py index 921cbbb90..299246209 100644 --- a/tests/test_queries/test_quoting_table_names_in_on_join_clause.py +++ b/tests/test_queries/test_quoting_table_names_in_on_join_clause.py @@ -2,25 +2,18 @@ import uuid from typing import Dict, Optional, Union -import databases import ormar import pytest -import sqlalchemy -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() -engine = create_engine(DATABASE_URL) + +base_ormar_config = create_config() class Team(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="team", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="team") id: uuid.UUID = ormar.UUID(default=uuid.uuid4, primary_key=True, index=True) name = ormar.Text(nullable=True) @@ -30,11 +23,7 @@ class Team(ormar.Model): class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="user", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="user") id: uuid.UUID = ormar.UUID(default=uuid.uuid4, primary_key=True, index=True) client_user_id = ormar.Text() @@ -43,24 +32,16 @@ class User(ormar.Model): class Order(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="order", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="order") id: uuid.UUID = ormar.UUID(default=uuid.uuid4, primary_key=True, index=True) user: Optional[Union[User, Dict]] = ormar.ForeignKey(User) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_quoting_on_clause_without_prefix(): - async with database: + async with base_ormar_config.database: await User.objects.select_related("orders").all() diff --git a/tests/test_queries/test_reserved_sql_keywords_escaped.py b/tests/test_queries/test_reserved_sql_keywords_escaped.py index 544d714d3..ad0d3a45c 100644 --- a/tests/test_queries/test_reserved_sql_keywords_escaped.py +++ b/tests/test_queries/test_reserved_sql_keywords_escaped.py @@ -1,18 +1,11 @@ -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config(force_rollback=True) class User(ormar.Model): @@ -39,18 +32,12 @@ class Task(ormar.Model): user = ormar.ForeignKey(User) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_single_model_quotes(): - async with database: + async with base_ormar_config.database: await User.objects.create( user="test", first="first", @@ -66,7 +53,7 @@ async def test_single_model_quotes(): @pytest.mark.asyncio async def test_two_model_quotes(): - async with database: + async with base_ormar_config.database: user = await User.objects.create( user="test", first="first", diff --git a/tests/test_queries/test_reverse_fk_queryset.py b/tests/test_queries/test_reverse_fk_queryset.py index 425653ed6..1a629b21e 100644 --- a/tests/test_queries/test_reverse_fk_queryset.py +++ b/tests/test_queries/test_reverse_fk_queryset.py @@ -1,23 +1,18 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy from ormar import NoMatch -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True, name="album_id") name: str = ormar.String(max_length=100) @@ -25,22 +20,14 @@ class Album(ormar.Model): class Writer(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="writers", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="writers") id: int = ormar.Integer(primary_key=True, name="writer_id") name: str = ormar.String(max_length=100) class Track(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tracks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tracks") id: int = ormar.Integer(primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album, name="album_id") @@ -70,19 +57,13 @@ async def get_sample_data(): return album, [track1, track2, tracks3] -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_quering_by_reverse_fk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): sample_data = await get_sample_data() track1 = sample_data[1][0] album = await Album.objects.first() @@ -136,8 +117,8 @@ async def test_quering_by_reverse_fk(): @pytest.mark.asyncio async def test_getting(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): sample_data = await get_sample_data() album = sample_data[0] track1 = await album.tracks.fields(["album", "title", "position"]).get( @@ -206,8 +187,8 @@ async def test_getting(): @pytest.mark.asyncio async def test_cleaning_related(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): sample_data = await get_sample_data() album = sample_data[0] await album.tracks.clear(keep_reversed=False) @@ -221,8 +202,8 @@ async def test_cleaning_related(): @pytest.mark.asyncio async def test_loading_related(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): sample_data = await get_sample_data() album = sample_data[0] tracks = await album.tracks.select_related("written_by").all() @@ -240,8 +221,8 @@ async def test_loading_related(): @pytest.mark.asyncio async def test_adding_removing(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): sample_data = await get_sample_data() album = sample_data[0] track_new = await Track(title="Rainbow", position=5, play_count=300).save() diff --git a/tests/test_queries/test_selecting_subset_of_columns.py b/tests/test_queries/test_selecting_subset_of_columns.py index e839103aa..90dc80550 100644 --- a/tests/test_queries/test_selecting_subset_of_columns.py +++ b/tests/test_queries/test_selecting_subset_of_columns.py @@ -2,25 +2,20 @@ import itertools from typing import List, Optional -import databases import ormar import pydantic import pytest import pytest_asyncio -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class NickNames(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -28,19 +23,11 @@ class NickNames(ormar.Model): class NicksHq(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="nicks_x_hq", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="nicks_x_hq") class HQ(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="hqs", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="hqs") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") @@ -48,11 +35,7 @@ class HQ(ormar.Model): class Company(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="companies", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="companies") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="company_name") @@ -61,11 +44,7 @@ class Company(ormar.Model): class Car(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="cars", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="cars") id: int = ormar.Integer(primary_key=True) manufacturer: Optional[Company] = ormar.ForeignKey(Company) @@ -76,13 +55,7 @@ class Car(ormar.Model): aircon_type: str = ormar.String(max_length=20, nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.fixture(scope="module") @@ -94,7 +67,7 @@ def event_loop(): @pytest_asyncio.fixture(autouse=True, scope="module") async def sample_data(event_loop, create_test_database): - async with database: + async with base_ormar_config.database: nick1 = await NickNames.objects.create(name="Nippon", is_lame=False) nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True) hq = await HQ.objects.create(name="Japan") @@ -131,8 +104,8 @@ async def sample_data(event_loop, create_test_database): @pytest.mark.asyncio async def test_selecting_subset(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): all_cars = ( await Car.objects.select_related(["manufacturer__hq__nicks"]) .fields( @@ -240,7 +213,7 @@ async def test_selecting_subset(): @pytest.mark.asyncio async def test_selecting_subset_of_through_model(): - async with database: + async with base_ormar_config.database: car = ( await Car.objects.select_related(["manufacturer__hq__nicks"]) .fields( diff --git a/tests/test_queries/test_values_and_values_list.py b/tests/test_queries/test_values_and_values_list.py index 323ef55fa..bb67b38fe 100644 --- a/tests/test_queries/test_values_and_values_list.py +++ b/tests/test_queries/test_values_and_values_list.py @@ -1,23 +1,16 @@ import asyncio from typing import List, Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy from ormar.exceptions import QueryDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class User(ormar.Model): @@ -52,12 +45,7 @@ class Post(ormar.Model): category: Optional[Category] = ormar.ForeignKey(Category) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.fixture(scope="module") @@ -69,7 +57,7 @@ def event_loop(): @pytest_asyncio.fixture(autouse=True, scope="module") async def sample_data(event_loop, create_test_database): - async with database: + async with base_ormar_config.database: creator = await User(name="Anonymous").save() admin = await Role(name="admin").save() editor = await Role(name="editor").save() @@ -83,7 +71,7 @@ async def sample_data(event_loop, create_test_database): @pytest.mark.asyncio async def test_simple_queryset_values(): - async with database: + async with base_ormar_config.database: posts = await Post.objects.values() assert posts == [ {"id": 1, "name": "Ormar strikes again!", "category": 1}, @@ -94,7 +82,7 @@ async def test_simple_queryset_values(): @pytest.mark.asyncio async def test_queryset_values_nested_relation(): - async with database: + async with base_ormar_config.database: posts = await Post.objects.select_related("category__created_by").values() assert posts == [ { @@ -135,7 +123,7 @@ async def test_queryset_values_nested_relation(): @pytest.mark.asyncio async def test_queryset_values_nested_relation_subset_of_fields(): - async with database: + async with base_ormar_config.database: posts = await Post.objects.select_related("category__created_by").values( ["name", "category__name", "category__created_by__name"] ) @@ -160,7 +148,7 @@ async def test_queryset_values_nested_relation_subset_of_fields(): @pytest.mark.asyncio async def test_queryset_simple_values_list(): - async with database: + async with base_ormar_config.database: posts = await Post.objects.values_list() assert posts == [ (1, "Ormar strikes again!", 1), @@ -171,7 +159,7 @@ async def test_queryset_simple_values_list(): @pytest.mark.asyncio async def test_queryset_nested_relation_values_list(): - async with database: + async with base_ormar_config.database: posts = await Post.objects.select_related("category__created_by").values_list() assert posts == [ (1, "Ormar strikes again!", 1, 1, "News", 0, 1, 1, "Anonymous"), @@ -192,7 +180,7 @@ async def test_queryset_nested_relation_values_list(): @pytest.mark.asyncio async def test_queryset_nested_relation_subset_of_fields_values_list(): - async with database: + async with base_ormar_config.database: posts = await Post.objects.select_related("category__created_by").values_list( ["name", "category__name", "category__created_by__name"] ) @@ -205,7 +193,7 @@ async def test_queryset_nested_relation_subset_of_fields_values_list(): @pytest.mark.asyncio async def test_m2m_values(): - async with database: + async with base_ormar_config.database: user = await User.objects.select_related("roles").values() assert user == [ { @@ -231,7 +219,7 @@ async def test_m2m_values(): @pytest.mark.asyncio async def test_nested_m2m_values(): - async with database: + async with base_ormar_config.database: user = ( await Role.objects.select_related("users__categories") .filter(name="admin") @@ -256,7 +244,7 @@ async def test_nested_m2m_values(): @pytest.mark.asyncio async def test_nested_m2m_values_without_through_explicit(): - async with database: + async with base_ormar_config.database: user = ( await Role.objects.select_related("users__categories") .filter(name="admin") @@ -275,7 +263,7 @@ async def test_nested_m2m_values_without_through_explicit(): @pytest.mark.asyncio async def test_nested_m2m_values_without_through_param(): - async with database: + async with base_ormar_config.database: user = ( await Role.objects.select_related("users__categories") .filter(name="admin") @@ -293,7 +281,7 @@ async def test_nested_m2m_values_without_through_param(): @pytest.mark.asyncio async def test_nested_m2m_values_no_through_and_m2m_models_but_keep_end_model(): - async with database: + async with base_ormar_config.database: user = ( await Role.objects.select_related("users__categories") .filter(name="admin") @@ -306,7 +294,7 @@ async def test_nested_m2m_values_no_through_and_m2m_models_but_keep_end_model(): @pytest.mark.asyncio async def test_nested_flatten_and_exception(): - async with database: + async with base_ormar_config.database: with pytest.raises(QueryDefinitionError): (await Role.objects.fields({"name", "id"}).values_list(flatten=True)) @@ -316,7 +304,7 @@ async def test_nested_flatten_and_exception(): @pytest.mark.asyncio async def test_empty_result(): - async with database: + async with base_ormar_config.database: roles = await Role.objects.filter(Role.name == "test").values_list() roles2 = await Role.objects.filter(Role.name == "test").values() assert roles == roles2 == [] @@ -324,7 +312,7 @@ async def test_empty_result(): @pytest.mark.asyncio async def test_queryset_values_multiple_select_related(): - async with database: + async with base_ormar_config.database: posts = ( await Category.objects.select_related(["created_by__roles", "posts"]) .filter(Category.created_by.roles.name == "editor") @@ -357,7 +345,7 @@ async def test_queryset_values_multiple_select_related(): @pytest.mark.asyncio async def test_querysetproxy_values(): - async with database: + async with base_ormar_config.database: role = ( await Role.objects.select_related("users__categories") .filter(name="admin") @@ -403,7 +391,7 @@ async def test_querysetproxy_values(): @pytest.mark.asyncio async def test_querysetproxy_values_list(): - async with database: + async with base_ormar_config.database: role = ( await Role.objects.select_related("users__categories") .filter(name="admin") diff --git a/tests/test_relations/test_cascades.py b/tests/test_relations/test_cascades.py index 931f7625c..a85e2a70b 100644 --- a/tests/test_relations/test_cascades.py +++ b/tests/test_relations/test_cascades.py @@ -1,44 +1,31 @@ from typing import Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Band(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="bands", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="bands") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class ArtistsBands(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="artists_x_bands", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="artists_x_bands") id: int = ormar.Integer(primary_key=True) class Artist(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="artists", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="artists") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -46,11 +33,7 @@ class Artist(ormar.Model): class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -58,37 +41,27 @@ class Album(ormar.Model): class Track(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tracks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tracks") id: int = ormar.Integer(primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album, ondelete="CASCADE") title: str = ormar.String(max_length=100) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: await Band.objects.delete(each=True) await Artist.objects.delete(each=True) @pytest.mark.asyncio async def test_simple_cascade(cleanup): - async with database: + async with base_ormar_config.database: artist = await Artist(name="Dr Alban").save() await Album(name="Jamaica", artist=artist).save() await Artist.objects.delete(id=artist.id) @@ -101,7 +74,7 @@ async def test_simple_cascade(cleanup): @pytest.mark.asyncio async def test_nested_cascade(cleanup): - async with database: + async with base_ormar_config.database: artist = await Artist(name="Dr Alban").save() album = await Album(name="Jamaica", artist=artist).save() await Track(title="Yuhu", album=album).save() @@ -120,7 +93,7 @@ async def test_nested_cascade(cleanup): @pytest.mark.asyncio async def test_many_to_many_cascade(cleanup): - async with database: + async with base_ormar_config.database: artist = await Artist(name="Dr Alban").save() band = await Band(name="Scorpions").save() await artist.bands.add(band) @@ -142,7 +115,7 @@ async def test_many_to_many_cascade(cleanup): @pytest.mark.asyncio async def test_reverse_many_to_many_cascade(cleanup): - async with database: + async with base_ormar_config.database: artist = await Artist(name="Dr Alban").save() band = await Band(name="Scorpions").save() await artist.bands.add(band) diff --git a/tests/test_relations/test_customizing_through_model_relation_names.py b/tests/test_relations/test_customizing_through_model_relation_names.py index 76c0088c0..8a8f22204 100644 --- a/tests/test_relations/test_customizing_through_model_relation_names.py +++ b/tests/test_relations/test_customizing_through_model_relation_names.py @@ -1,29 +1,22 @@ -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) + +base_ormar_config = create_config() class Course(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) course_name: str = ormar.String(max_length=100) class Student(ormar.Model): - ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -34,13 +27,7 @@ class Student(ormar.Model): ) -# create db and tables -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_tables_columns(): @@ -53,8 +40,8 @@ def test_tables_columns(): @pytest.mark.asyncio async def test_working_with_changed_through_names(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): to_save = { "course_name": "basic1", "students": [{"name": "Jack"}, {"name": "Abi"}], diff --git a/tests/test_relations/test_database_fk_creation.py b/tests/test_relations/test_database_fk_creation.py index 58efffbbc..9566dcde5 100644 --- a/tests/test_relations/test_database_fk_creation.py +++ b/tests/test_relations/test_database_fk_creation.py @@ -1,35 +1,26 @@ from typing import Optional -import databases import ormar import pytest import sqlalchemy from ormar.fields.foreign_key import validate_referential_action -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() -engine = sqlalchemy.create_engine(DATABASE_URL) + +base_ormar_config = create_config() class Artist(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="artists", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="artists") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -37,20 +28,14 @@ class Album(ormar.Model): class A(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=64, nullalbe=False) class B(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=64, nullalbe=False) @@ -58,26 +43,18 @@ class B(ormar.Model): class C(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy() id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=64, nullalbe=False) b: B = ormar.ForeignKey(to=B, ondelete=ormar.ReferentialAction.CASCADE) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_simple_cascade(): - inspector = sqlalchemy.inspect(engine) + inspector = sqlalchemy.inspect(base_ormar_config.engine) columns = inspector.get_columns("albums") assert len(columns) == 3 col_names = [col.get("name") for col in columns] @@ -103,8 +80,8 @@ def test_validations_referential_action(): @pytest.mark.asyncio async def test_cascade_clear(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): a = await A.objects.create(name="a") b = await B.objects.create(name="b", a=a) await C.objects.create(name="c", b=b) diff --git a/tests/test_relations/test_foreign_keys.py b/tests/test_relations/test_foreign_keys.py index 1899aa00f..f3561c123 100644 --- a/tests/test_relations/test_foreign_keys.py +++ b/tests/test_relations/test_foreign_keys.py @@ -1,23 +1,18 @@ from typing import Optional -import databases import ormar import pytest -import sqlalchemy from ormar.exceptions import MultipleMatches, NoMatch, RelationshipInstanceError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -25,11 +20,7 @@ class Album(ormar.Model): class Track(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tracks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tracks") id: int = ormar.Integer(primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album) @@ -40,11 +31,7 @@ class Track(ormar.Model): class Cover(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="covers", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="covers") id: int = ormar.Integer(primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures") @@ -52,22 +39,14 @@ class Cover(ormar.Model): class Organisation(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="org", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="org") id: int = ormar.Integer(primary_key=True) ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) class Team(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="teams", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="teams") id: int = ormar.Integer(primary_key=True) org: Optional[Organisation] = ormar.ForeignKey(Organisation) @@ -75,44 +54,34 @@ class Team(ormar.Model): class Member(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="members", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="members") id: int = ormar.Integer(primary_key=True) team: Optional[Team] = ormar.ForeignKey(Team) email: str = ormar.String(max_length=100) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_wrong_query_foreign_key_type(): - async with database: + async with base_ormar_config.database: with pytest.raises(RelationshipInstanceError): Track(title="The Error", album="wrong_pk_type") @pytest.mark.asyncio async def test_setting_explicitly_empty_relation(): - async with database: + async with base_ormar_config.database: track = Track(album=None, title="The Bird", position=1) assert track.album is None @pytest.mark.asyncio async def test_related_name(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): album = await Album.objects.create(name="Vanilla") await Cover.objects.create(album=album, title="The cover file") assert len(album.cover_pictures) == 1 @@ -120,8 +89,8 @@ async def test_related_name(): @pytest.mark.asyncio async def test_model_crud(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): album = Album(name="Jamaica") await album.save() track1 = Track(album=album, title="The Bird", position=1) @@ -152,8 +121,8 @@ async def test_model_crud(): @pytest.mark.asyncio async def test_select_related(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): album = Album(name="Malibu") await album.save() track1 = Track(album=album, title="The Bird", position=1) @@ -181,8 +150,8 @@ async def test_select_related(): @pytest.mark.asyncio async def test_model_removal_from_relations(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): album = Album(name="Chichi") await album.save() track1 = Track(album=album, title="The Birdman", position=1) @@ -223,8 +192,8 @@ async def test_model_removal_from_relations(): @pytest.mark.asyncio async def test_fk_filter(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): malibu = Album(name="Malibu%") await malibu.save() await Track.objects.create(album=malibu, title="The Bird", position=1) @@ -283,8 +252,8 @@ async def test_fk_filter(): @pytest.mark.asyncio async def test_multiple_fk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): acme = await Organisation.objects.create(ident="ACME Ltd") red_team = await Team.objects.create(org=acme, name="Red Team") blue_team = await Team.objects.create(org=acme, name="Blue Team") @@ -309,8 +278,8 @@ async def test_multiple_fk(): @pytest.mark.asyncio async def test_pk_filter(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): fantasies = await Album.objects.create(name="Test") track = await Track.objects.create( album=fantasies, title="Test1", position=1 @@ -332,8 +301,8 @@ async def test_pk_filter(): @pytest.mark.asyncio async def test_limit_and_offset(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): fantasies = await Album.objects.create(name="Limitless") await Track.objects.create( id=None, album=fantasies, title="Sample", position=1 @@ -364,8 +333,8 @@ async def test_limit_and_offset(): @pytest.mark.asyncio async def test_get_exceptions(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): fantasies = await Album.objects.create(name="Test") with pytest.raises(NoMatch): @@ -380,8 +349,8 @@ async def test_get_exceptions(): @pytest.mark.asyncio async def test_wrong_model_passed_as_fk(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): with pytest.raises(RelationshipInstanceError): org = await Organisation.objects.create(ident="ACME Ltd") await Track.objects.create(album=org, title="Test1", position=1) @@ -389,8 +358,8 @@ async def test_wrong_model_passed_as_fk(): @pytest.mark.asyncio async def test_bulk_update_model_with_no_children(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): album = await Album.objects.create(name="Test") album.name = "Test2" await Album.objects.bulk_update([album], columns=["name"]) @@ -401,8 +370,8 @@ async def test_bulk_update_model_with_no_children(): @pytest.mark.asyncio async def test_bulk_update_model_with_children(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): best_seller = await Album.objects.create(name="to_be_best_seller") best_seller2 = await Album.objects.create(name="to_be_best_seller2") not_best_seller = await Album.objects.create(name="unpopular") diff --git a/tests/test_relations/test_m2m_through_fields.py b/tests/test_relations/test_m2m_through_fields.py index 4f1eb849c..99eacf0dd 100644 --- a/tests/test_relations/test_m2m_through_fields.py +++ b/tests/test_relations/test_m2m_through_fields.py @@ -1,20 +1,13 @@ from typing import Any, ForwardRef -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config(force_rollback=True) class Category(ormar.Model): @@ -48,13 +41,7 @@ class Post(ormar.Model): blog = ormar.ForeignKey(Blog) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) class PostCategory2(ormar.Model): @@ -74,7 +61,7 @@ class Post2(ormar.Model): @pytest.mark.asyncio async def test_forward_ref_is_updated(): - async with database: + async with base_ormar_config.database: assert Post2.ormar_config.requires_ref_update Post2.update_forward_refs() @@ -83,7 +70,7 @@ async def test_forward_ref_is_updated(): @pytest.mark.asyncio async def test_setting_fields_on_through_model(): - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() category = await Category(name="Test category").save() await post.categories.add(category) @@ -94,7 +81,7 @@ async def test_setting_fields_on_through_model(): @pytest.mark.asyncio async def test_setting_additional_fields_on_through_model_in_add(): - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() category = await Category(name="Test category").save() await post.categories.add(category, sort_order=1) @@ -104,7 +91,7 @@ async def test_setting_additional_fields_on_through_model_in_add(): @pytest.mark.asyncio async def test_setting_additional_fields_on_through_model_in_create(): - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category2", postcategory={"sort_order": 2} @@ -115,7 +102,7 @@ async def test_setting_additional_fields_on_through_model_in_create(): @pytest.mark.asyncio async def test_getting_additional_fields_from_queryset() -> Any: - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category1", postcategory={"sort_order": 1} @@ -137,7 +124,7 @@ async def test_getting_additional_fields_from_queryset() -> Any: @pytest.mark.asyncio async def test_only_one_side_has_through() -> Any: - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category1", postcategory={"sort_order": 1} @@ -162,7 +149,7 @@ async def test_only_one_side_has_through() -> Any: @pytest.mark.asyncio async def test_filtering_by_through_model() -> Any: - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category1", @@ -189,7 +176,7 @@ async def test_filtering_by_through_model() -> Any: @pytest.mark.asyncio async def test_deep_filtering_by_through_model() -> Any: - async with database: + async with base_ormar_config.database: blog = await Blog(title="My Blog").save() post = await Post(title="Test post", blog=blog).save() @@ -220,7 +207,7 @@ async def test_deep_filtering_by_through_model() -> Any: @pytest.mark.asyncio async def test_ordering_by_through_model() -> Any: - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category1", @@ -255,7 +242,7 @@ async def test_ordering_by_through_model() -> Any: @pytest.mark.asyncio async def test_update_through_models_from_queryset_on_through() -> Any: - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category1", @@ -284,7 +271,7 @@ async def test_update_through_models_from_queryset_on_through() -> Any: @pytest.mark.asyncio async def test_update_through_model_after_load() -> Any: - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category1", @@ -303,7 +290,7 @@ async def test_update_through_model_after_load() -> Any: @pytest.mark.asyncio async def test_update_through_from_related() -> Any: - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category1", @@ -332,7 +319,7 @@ async def test_update_through_from_related() -> Any: @pytest.mark.asyncio async def test_excluding_fields_on_through_model() -> Any: - async with database: + async with base_ormar_config.database: post = await Post(title="Test post").save() await post.categories.create( name="Test category1", diff --git a/tests/test_relations/test_many_to_many.py b/tests/test_relations/test_many_to_many.py index 044724a16..ad1a6fb40 100644 --- a/tests/test_relations/test_many_to_many.py +++ b/tests/test_relations/test_many_to_many.py @@ -1,25 +1,20 @@ import asyncio from typing import List, Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy from ormar.exceptions import ModelPersistenceError, NoMatch, RelationshipInstanceError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config(force_rollback=True) class Author(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="authors", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="authors") id: int = ormar.Integer(primary_key=True) first_name: str = ormar.String(max_length=80) @@ -27,22 +22,14 @@ class Author(ormar.Model): class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=40) class Post(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="posts", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="posts") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) @@ -57,18 +44,13 @@ def event_loop(): loop.close() -@pytest_asyncio.fixture(autouse=True, scope="module") -async def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: PostCategory = Post.ormar_config.model_fields["categories"].through await PostCategory.objects.delete(each=True) await Post.objects.delete(each=True) @@ -78,7 +60,7 @@ async def cleanup(): @pytest.mark.asyncio async def test_not_saved_raises_error(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author(first_name="Guido", last_name="Van Rossum").save() post = await Post.objects.create(title="Hello, M2M", author=guido) news = Category(name="News") @@ -89,7 +71,7 @@ async def test_not_saved_raises_error(cleanup): @pytest.mark.asyncio async def test_not_existing_raises_error(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author(first_name="Guido", last_name="Van Rossum").save() post = await Post.objects.create(title="Hello, M2M", author=guido) @@ -101,7 +83,7 @@ async def test_not_existing_raises_error(cleanup): @pytest.mark.asyncio async def test_assigning_related_objects(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") @@ -124,7 +106,7 @@ async def test_assigning_related_objects(cleanup): @pytest.mark.asyncio async def test_quering_of_the_m2m_models(cleanup): - async with database: + async with base_ormar_config.database: # orm can do this already. guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) @@ -159,7 +141,7 @@ async def test_quering_of_the_m2m_models(cleanup): @pytest.mark.asyncio async def test_removal_of_the_relations(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") @@ -186,7 +168,7 @@ async def test_removal_of_the_relations(cleanup): @pytest.mark.asyncio async def test_selecting_related(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") @@ -213,7 +195,7 @@ async def test_selecting_related(cleanup): @pytest.mark.asyncio async def test_selecting_related_fail_without_saving(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = Post(title="Hello, M2M", author=guido) with pytest.raises(RelationshipInstanceError): @@ -222,7 +204,7 @@ async def test_selecting_related_fail_without_saving(cleanup): @pytest.mark.asyncio async def test_adding_unsaved_related(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = Category(name="News") @@ -236,7 +218,7 @@ async def test_adding_unsaved_related(cleanup): @pytest.mark.asyncio async def test_removing_unsaved_related(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = Category(name="News") diff --git a/tests/test_relations/test_postgress_select_related_with_limit.py b/tests/test_relations/test_postgress_select_related_with_limit.py index e432255b7..93b4e0914 100644 --- a/tests/test_relations/test_postgress_select_related_with_limit.py +++ b/tests/test_relations/test_postgress_select_related_with_limit.py @@ -4,16 +4,14 @@ from enum import Enum from typing import Optional -import databases import ormar import pytest -import sqlalchemy -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config(force_rollback=True) class PrimaryKeyMixin: @@ -25,12 +23,6 @@ class Level(Enum): STAFF = "1" -base_ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, -) - - class User(PrimaryKeyMixin, ormar.Model): """User Model Class to Implement Method for Operations of User Entity""" @@ -65,17 +57,12 @@ class Task(PrimaryKeyMixin, ormar.Model): ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_selecting_related_with_limit(): - async with database: + async with base_ormar_config.database: user1 = await User(mobile="9928917653", password="pass1").save() user2 = await User(mobile="9928917654", password="pass2").save() await Task(name="one", user=user1).save() diff --git a/tests/test_relations/test_prefetch_related.py b/tests/test_relations/test_prefetch_related.py index ab054d50e..5a2dc245a 100644 --- a/tests/test_relations/test_prefetch_related.py +++ b/tests/test_relations/test_prefetch_related.py @@ -1,33 +1,24 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config(force_rollback=True) class RandomSet(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="randoms", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="randoms") id: int = ormar.Integer(name="random_id", primary_key=True) name: str = ormar.String(max_length=100) class Tonation(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tonations", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tonations") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(name="tonation_name", max_length=100) @@ -35,22 +26,14 @@ class Tonation(ormar.Model): class Division(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="divisions", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="divisions") id: int = ormar.Integer(name="division_id", primary_key=True) name: str = ormar.String(max_length=100, nullable=True) class Shop(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="shops", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="shops") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=True) @@ -58,19 +41,11 @@ class Shop(ormar.Model): class AlbumShops(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums_x_shops", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums_x_shops") class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=True) @@ -78,11 +53,7 @@ class Album(ormar.Model): class Track(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tracks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tracks") id: int = ormar.Integer(name="track_id", primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album) @@ -92,11 +63,7 @@ class Track(ormar.Model): class Cover(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="covers", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="covers") id: int = ormar.Integer(primary_key=True) album: Optional[Album] = ormar.ForeignKey( @@ -106,19 +73,13 @@ class Cover(ormar.Model): artist: str = ormar.String(max_length=200, nullable=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_prefetch_related(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): album = Album(name="Malibu") await album.save() ton1 = await Tonation.objects.create(name="B-mol") @@ -196,8 +157,8 @@ async def test_prefetch_related(): @pytest.mark.asyncio async def test_prefetch_related_with_many_to_many(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): div = await Division.objects.create(name="Div 1") shop1 = await Shop.objects.create(name="Shop 1", division=div) shop2 = await Shop.objects.create(name="Shop 2", division=div) @@ -245,8 +206,8 @@ async def test_prefetch_related_with_many_to_many(): @pytest.mark.asyncio async def test_prefetch_related_empty(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): await Track.objects.create(title="The Bird", position=1) track = await Track.objects.prefetch_related(["album__cover_pictures"]).get( title="The Bird" @@ -257,8 +218,8 @@ async def test_prefetch_related_empty(): @pytest.mark.asyncio async def test_prefetch_related_with_select_related(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): div = await Division.objects.create(name="Div 1") shop1 = await Shop.objects.create(name="Shop 1", division=div) shop2 = await Shop.objects.create(name="Shop 2", division=div) @@ -330,8 +291,8 @@ async def test_prefetch_related_with_select_related(): @pytest.mark.asyncio async def test_prefetch_related_with_select_related_and_fields(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): div = await Division.objects.create(name="Div 1") shop1 = await Shop.objects.create(name="Shop 1", division=div) shop2 = await Shop.objects.create(name="Shop 2", division=div) diff --git a/tests/test_relations/test_prefetch_related_multiple_models_relation.py b/tests/test_relations/test_prefetch_related_multiple_models_relation.py index bd077a3fb..ecaf722fc 100644 --- a/tests/test_relations/test_prefetch_related_multiple_models_relation.py +++ b/tests/test_relations/test_prefetch_related_multiple_models_relation.py @@ -1,44 +1,30 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -db = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="test_users", - ) + ormar_config = base_ormar_config.copy(tablename="test_users") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50) class Signup(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="test_signup", - ) + ormar_config = base_ormar_config.copy(tablename="test_signup") id: int = ormar.Integer(primary_key=True) class Session(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="test_sessions", - ) + ormar_config = base_ormar_config.copy(tablename="test_sessions") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=255, index=True) @@ -52,17 +38,12 @@ class Session(ormar.Model): ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_add_students(): - async with db: + async with base_ormar_config.database: for user_id in [1, 2, 3, 4, 5]: await User.objects.create(name=f"User {user_id}") diff --git a/tests/test_relations/test_python_style_relations.py b/tests/test_relations/test_python_style_relations.py index 422b9ec26..d9287ff4b 100644 --- a/tests/test_relations/test_python_style_relations.py +++ b/tests/test_relations/test_python_style_relations.py @@ -1,23 +1,18 @@ from typing import List, Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Author(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="authors", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="authors") id: int = ormar.Integer(primary_key=True) first_name: str = ormar.String(max_length=80) @@ -25,22 +20,14 @@ class Author(ormar.Model): class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=40) class Post(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="posts", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="posts") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) @@ -48,18 +35,13 @@ class Post(ormar.Model): author: Optional[Author] = ormar.ForeignKey(Author, related_name="author_posts") -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: PostCategory = Post.ormar_config.model_fields["categories"].through await PostCategory.objects.delete(each=True) await Post.objects.delete(each=True) @@ -69,7 +51,7 @@ async def cleanup(): @pytest.mark.asyncio async def test_selecting_related(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") diff --git a/tests/test_relations/test_relations_default_exception.py b/tests/test_relations/test_relations_default_exception.py index d88d899b4..2066caf33 100644 --- a/tests/test_relations/test_relations_default_exception.py +++ b/tests/test_relations/test_relations_default_exception.py @@ -1,24 +1,19 @@ # type: ignore from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy from ormar.exceptions import ModelDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Author(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="authors", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="authors") id: int = ormar.Integer(primary_key=True) first_name: str = ormar.String(max_length=80) @@ -26,25 +21,20 @@ class Author(ormar.Model): class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=40) +create_test_database = init_tests(base_ormar_config) + + def test_fk_error(): with pytest.raises(ModelDefinitionError): class Post(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="posts", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="posts") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) @@ -56,11 +46,7 @@ def test_m2m_error(): with pytest.raises(ModelDefinitionError): class Post(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="posts", - database=database, - metadata=metadata, - ) + ormar_config = base_ormar_config.copy(tablename="posts") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) diff --git a/tests/test_relations/test_replacing_models_with_copy.py b/tests/test_relations/test_replacing_models_with_copy.py index 240caa0b5..436a7006b 100644 --- a/tests/test_relations/test_replacing_models_with_copy.py +++ b/tests/test_relations/test_replacing_models_with_copy.py @@ -1,22 +1,17 @@ from typing import Any, Optional, Tuple, Union -import databases import ormar import pytest -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -26,11 +21,7 @@ class Album(ormar.Model): class Track(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="tracks", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="tracks") id: int = ormar.Integer(primary_key=True) album: Optional[Album] = ormar.ForeignKey(Album) @@ -41,13 +32,7 @@ class Track(ormar.Model): properties: Tuple[str, Any] -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio diff --git a/tests/test_relations/test_saving_related.py b/tests/test_relations/test_saving_related.py index 31e9d5d52..2c777a457 100644 --- a/tests/test_relations/test_saving_related.py +++ b/tests/test_relations/test_saving_related.py @@ -1,24 +1,18 @@ from typing import Union -import databases import ormar import pytest -import sqlalchemy as sa from ormar.exceptions import ModelPersistenceError -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -metadata = sa.MetaData() -db = databases.Database(DATABASE_URL) + +base_ormar_config = create_config() class Category(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="categories", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="categories") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50, unique=True, index=True) @@ -26,11 +20,7 @@ class Category(ormar.Model): class Workshop(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="workshops", - metadata=metadata, - database=db, - ) + ormar_config = base_ormar_config.copy(tablename="workshops") id: int = ormar.Integer(primary_key=True) topic: str = ormar.String(max_length=255, index=True) @@ -39,18 +29,13 @@ class Workshop(ormar.Model): ) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_model_relationship(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): cat = await Category(name="Foo", code=123).save() ws = await Workshop(topic="Topic 1", category=cat).save() @@ -68,8 +53,8 @@ async def test_model_relationship(): @pytest.mark.asyncio async def test_model_relationship_with_not_saved(): - async with db: - async with db.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): cat = Category(name="Foo", code=123) with pytest.raises(ModelPersistenceError): await Workshop(topic="Topic 1", category=cat).save() diff --git a/tests/test_relations/test_select_related_with_limit.py b/tests/test_relations/test_select_related_with_limit.py index 36d3785ce..1be283a26 100644 --- a/tests/test_relations/test_select_related_with_limit.py +++ b/tests/test_relations/test_select_related_with_limit.py @@ -1,44 +1,30 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -db = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class Keyword(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="keywords", - ) + ormar_config = base_ormar_config.copy(tablename="keywords") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50) class KeywordPrimaryModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="primary_models_keywords", - ) + ormar_config = base_ormar_config.copy(tablename="primary_models_keywords") id: int = ormar.Integer(primary_key=True) class PrimaryModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="primary_models", - ) + ormar_config = base_ormar_config.copy(tablename="primary_models") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=255, index=True) @@ -50,11 +36,7 @@ class PrimaryModel(ormar.Model): class SecondaryModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=db, - tablename="secondary_models", - ) + ormar_config = base_ormar_config.copy(tablename="secondary_models") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -65,7 +47,7 @@ class SecondaryModel(ormar.Model): @pytest.mark.asyncio async def test_create_primary_models(): - async with db: + async with base_ormar_config.database: for name, some_text, some_other_text in [ ("Primary 1", "Some text 1", "Some other text 1"), ("Primary 2", "Some text 2", "Some other text 2"), @@ -153,9 +135,4 @@ async def test_create_primary_models(): assert len(models5[2].keywords) == 0 -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) diff --git a/tests/test_relations/test_select_related_with_m2m_and_pk_name_set.py b/tests/test_relations/test_select_related_with_m2m_and_pk_name_set.py index a79d79485..0f9beebbc 100644 --- a/tests/test_relations/test_select_related_with_m2m_and_pk_name_set.py +++ b/tests/test_relations/test_select_related_with_m2m_and_pk_name_set.py @@ -2,23 +2,16 @@ from datetime import date from typing import List, Optional, Union -import databases import ormar import pytest import sqlalchemy from ormar import ModelDefinitionError -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Role(ormar.Model): @@ -53,12 +46,7 @@ class User(ormar.Model): lastupdate: date = ormar.DateTime(server_default=sqlalchemy.func.now()) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) def test_wrong_model(): @@ -74,7 +62,7 @@ class User(ormar.Model): @pytest.mark.asyncio async def test_create_primary_models(): - async with database: + async with base_ormar_config.database: await Role.objects.create( name="user", order=0, description="no administration right" ) diff --git a/tests/test_relations/test_selecting_proper_table_prefix.py b/tests/test_relations/test_selecting_proper_table_prefix.py index da5393241..b0de7e05a 100644 --- a/tests/test_relations/test_selecting_proper_table_prefix.py +++ b/tests/test_relations/test_selecting_proper_table_prefix.py @@ -1,44 +1,30 @@ from typing import List, Optional -import databases import ormar import pytest -import sqlalchemy -from sqlalchemy import create_engine -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class User(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - tablename="test_users", - ) + ormar_config = base_ormar_config.copy(tablename="test_users") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50) class Signup(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - tablename="test_signup", - ) + ormar_config = base_ormar_config.copy(tablename="test_signup") id: int = ormar.Integer(primary_key=True) class Session(ormar.Model): - ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, - tablename="test_sessions", - ) + ormar_config = base_ormar_config.copy(tablename="test_sessions") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=255, index=True) @@ -47,17 +33,12 @@ class Session(ormar.Model): students: Optional[List[User]] = ormar.ManyToMany(User, through=Signup) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest.mark.asyncio async def test_list_sessions_for_user(): - async with database: + async with base_ormar_config.database: for user_id in [1, 2, 3, 4, 5]: await User.objects.create(name=f"User {user_id}") diff --git a/tests/test_relations/test_skipping_reverse.py b/tests/test_relations/test_skipping_reverse.py index b4d2f38e7..3cc49df8f 100644 --- a/tests/test_relations/test_skipping_reverse.py +++ b/tests/test_relations/test_skipping_reverse.py @@ -1,21 +1,14 @@ from typing import List, Optional -import databases import ormar import pytest import pytest_asyncio -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - -base_ormar_config = ormar.OrmarConfig( - metadata=metadata, - database=database, -) +base_ormar_config = create_config() class Author(ormar.Model): @@ -42,18 +35,13 @@ class Post(ormar.Model): author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: PostCategory = Post.ormar_config.model_fields["categories"].through await PostCategory.objects.delete(each=True) await Post.objects.delete(each=True) @@ -81,7 +69,7 @@ def test_model_definition(): @pytest.mark.asyncio async def test_assigning_related_objects(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") @@ -109,7 +97,7 @@ async def test_assigning_related_objects(cleanup): @pytest.mark.asyncio async def test_quering_of_related_model_works_but_no_result(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") @@ -149,7 +137,7 @@ async def test_quering_of_related_model_works_but_no_result(cleanup): @pytest.mark.asyncio async def test_removal_of_the_relations(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") @@ -174,7 +162,7 @@ async def test_removal_of_the_relations(cleanup): @pytest.mark.asyncio async def test_selecting_related(cleanup): - async with database: + async with base_ormar_config.database: guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") guido2 = await Author.objects.create( first_name="Guido2", last_name="Van Rossum" diff --git a/tests/test_relations/test_through_relations_fail.py b/tests/test_relations/test_through_relations_fail.py index 129fc21b2..5057d907b 100644 --- a/tests/test_relations/test_through_relations_fail.py +++ b/tests/test_relations/test_through_relations_fail.py @@ -1,23 +1,17 @@ # type: ignore -import databases import ormar import pytest -import sqlalchemy from ormar import ModelDefinitionError -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() +base_ormar_config = create_config() -def test_through_with_relation_fails(): - base_ormar_config = ormar.OrmarConfig( - database=database, - metadata=metadata, - ) +def test_through_with_relation_fails(): class Category(ormar.Model): ormar_config = base_ormar_config.copy(tablename="categories") @@ -46,3 +40,6 @@ class Post(ormar.Model): id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=200) categories = ormar.ManyToMany(Category, through=PostCategory) + + +create_test_database = init_tests(base_ormar_config) diff --git a/tests/test_relations/test_weakref_checking.py b/tests/test_relations/test_weakref_checking.py index b246e0502..3a666bbae 100644 --- a/tests/test_relations/test_weakref_checking.py +++ b/tests/test_relations/test_weakref_checking.py @@ -1,30 +1,21 @@ -import databases import ormar -import sqlalchemy -from tests.settings import DATABASE_URL +from tests.settings import create_config -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() +from tests.lifespan import init_tests class Band(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="bands", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="bands") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) class Artist(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="artists", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="artists") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -32,6 +23,9 @@ class Artist(ormar.Model): band: Band = ormar.ForeignKey(Band) +create_test_database = init_tests(base_ormar_config) + + def test_weakref_init(): band = Band(name="Band") artist1 = Artist(name="Artist 1", band=band) diff --git a/tests/test_signals/test_signals.py b/tests/test_signals/test_signals.py index 251e91734..5e6405b8c 100644 --- a/tests/test_signals/test_signals.py +++ b/tests/test_signals/test_signals.py @@ -1,11 +1,9 @@ from typing import Optional -import databases import ormar import pydantic import pytest import pytest_asyncio -import sqlalchemy from ormar import ( post_bulk_update, post_delete, @@ -18,18 +16,15 @@ from ormar.exceptions import SignalDefinitionError from ormar.signals import SignalEmitter -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class AuditLog(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="audits", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="audits") id: int = ormar.Integer(primary_key=True) event_type: str = ormar.String(max_length=100) @@ -37,22 +32,14 @@ class AuditLog(ormar.Model): class Cover(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="covers", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="covers") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=100) class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -61,19 +48,13 @@ class Album(ormar.Model): cover: Optional[Cover] = ormar.ForeignKey(Cover) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: await AuditLog.objects.delete(each=True) @@ -98,8 +79,8 @@ def test_invalid_signal(): @pytest.mark.asyncio async def test_signal_functions(cleanup): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): @pre_save(Album) async def before_save(sender, instance, **kwargs): @@ -221,8 +202,8 @@ async def after_bulk_update(sender, instances, **kwargs): @pytest.mark.asyncio async def test_multiple_signals(cleanup): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): @pre_save(Album) async def before_save(sender, instance, **kwargs): @@ -252,8 +233,8 @@ async def before_save2(sender, instance, **kwargs): @pytest.mark.asyncio async def test_static_methods_as_signals(cleanup): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): class AlbumAuditor: event_type = "ALBUM_INSTANCE" @@ -277,8 +258,8 @@ async def before_save(sender, instance, **kwargs): @pytest.mark.asyncio async def test_methods_as_signals(cleanup): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): class AlbumAuditor: def __init__(self): @@ -304,8 +285,8 @@ async def before_save(self, sender, instance, **kwargs): @pytest.mark.asyncio async def test_multiple_senders_signal(cleanup): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): @pre_save([Album, Cover]) async def before_save(sender, instance, **kwargs): @@ -332,8 +313,8 @@ async def before_save(sender, instance, **kwargs): @pytest.mark.asyncio async def test_modifing_the_instance(cleanup): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): @pre_update(Album) async def before_update(sender, instance, **kwargs): @@ -358,8 +339,8 @@ async def before_update(sender, instance, **kwargs): @pytest.mark.asyncio async def test_custom_signal(cleanup): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): async def after_update(sender, instance, **kwargs): if instance.play_count > 50 and not instance.is_best_seller: diff --git a/tests/test_signals/test_signals_for_relations.py b/tests/test_signals/test_signals_for_relations.py index 08ddcc130..5db964a06 100644 --- a/tests/test_signals/test_signals_for_relations.py +++ b/tests/test_signals/test_signals_for_relations.py @@ -1,11 +1,9 @@ from typing import Optional -import databases import ormar import pydantic import pytest import pytest_asyncio -import sqlalchemy from ormar import ( post_relation_add, post_relation_remove, @@ -13,18 +11,15 @@ pre_relation_remove, ) -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() + +base_ormar_config = create_config() class AuditLog(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="audits", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="audits") id: int = ormar.Integer(primary_key=True) event_type: str = ormar.String(max_length=100) @@ -32,33 +27,21 @@ class AuditLog(ormar.Model): class Cover(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="covers", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="covers") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=100) class Artist(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="artists", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="artists") id: int = ormar.Integer(name="artist_id", primary_key=True) name: str = ormar.String(name="fname", max_length=100) class Album(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="albums", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="albums") id: int = ormar.Integer(primary_key=True) title: str = ormar.String(max_length=100) @@ -66,26 +49,20 @@ class Album(ormar.Model): artists = ormar.ManyToMany(Artist) -@pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.drop_all(engine) - metadata.create_all(engine) - yield - metadata.drop_all(engine) +create_test_database = init_tests(base_ormar_config) @pytest_asyncio.fixture(autouse=True, scope="function") async def cleanup(): yield - async with database: + async with base_ormar_config.database: await AuditLog.objects.delete(each=True) @pytest.mark.asyncio async def test_relation_signal_functions(): - async with database: - async with database.transaction(force_rollback=True): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): @pre_relation_add([Album, Cover, Artist]) async def before_relation_add( diff --git a/tests/test_utils/test_queryset_utils.py b/tests/test_utils/test_queryset_utils.py index c57129e30..c5c86bd36 100644 --- a/tests/test_utils/test_queryset_utils.py +++ b/tests/test_utils/test_queryset_utils.py @@ -1,6 +1,4 @@ -import databases import ormar -import sqlalchemy from ormar.queryset.queries.prefetch_query import sort_models from ormar.queryset.utils import ( subtract_dict, @@ -9,7 +7,11 @@ update_dict_from_list, ) -from tests.settings import DATABASE_URL +from tests.settings import create_config +from tests.lifespan import init_tests + + +base_ormar_config = create_config() def test_list_to_dict_translation(): @@ -173,16 +175,8 @@ def test_subtracting_with_set_and_dict(): assert test == {"translation": {"translations": {"language": Ellipsis}}} -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() - - class SortModel(ormar.Model): - ormar_config = ormar.OrmarConfig( - tablename="sorts", - metadata=metadata, - database=database, - ) + ormar_config = base_ormar_config.copy(tablename="sorts") id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -212,3 +206,5 @@ def test_sorting_models(): orders_by = {"sort_order": "asc", "none": ..., "id": "asc", "uu": 2, "aa": None} models = sort_models(models, orders_by) assert [model.id for model in models] == [1, 4, 2, 3, 5, 6] + +create_test_database = init_tests(base_ormar_config)