Skip to content

Commit

Permalink
end-to-end test rest_api_source on all destinations. Removes redundan…
Browse files Browse the repository at this point in the history
…t helpers from test/utils.py
  • Loading branch information
willi-mueller committed Sep 2, 2024
1 parent fca33ce commit 6104550
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 71 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Any
import dlt
import pytest
from dlt.sources.rest_api.typing import RESTAPIConfig
from dlt.sources.helpers.rest_client.paginators import SinglePagePaginator

from dlt.sources.rest_api import rest_api_source
from tests.utils import ALL_DESTINATIONS, assert_load_info, load_table_counts
from tests.pipeline.utils import assert_load_info, load_table_counts
from tests.load.utils import (
destinations_configs,
DestinationTestConfiguration,
)


def _make_pipeline(destination_name: str):
Expand All @@ -16,8 +21,12 @@ def _make_pipeline(destination_name: str):
)


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
def test_rest_api_source(destination_name: str) -> None:
@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True),
ids=lambda x: x.name,
)
def test_rest_api_source(destination_config: DestinationTestConfiguration, request: Any) -> None:
config: RESTAPIConfig = {
"client": {
"base_url": "https://pokeapi.co/api/v2/",
Expand All @@ -39,9 +48,8 @@ def test_rest_api_source(destination_name: str) -> None:
],
}
data = rest_api_source(config)
pipeline = _make_pipeline(destination_name)
pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True)
load_info = pipeline.run(data)
print(load_info)
assert_load_info(load_info)
table_names = [t["name"] for t in pipeline.default_schema.data_tables()]
table_counts = load_table_counts(pipeline, *table_names)
Expand All @@ -53,8 +61,12 @@ def test_rest_api_source(destination_name: str) -> None:
assert table_counts["location"] == 1036


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
def test_dependent_resource(destination_name: str) -> None:
@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True),
ids=lambda x: x.name,
)
def test_dependent_resource(destination_config: DestinationTestConfiguration, request: Any) -> None:
config: RESTAPIConfig = {
"client": {
"base_url": "https://pokeapi.co/api/v2/",
Expand Down Expand Up @@ -96,7 +108,7 @@ def test_dependent_resource(destination_name: str) -> None:
}

data = rest_api_source(config)
pipeline = _make_pipeline(destination_name)
pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True)
load_info = pipeline.run(data)
assert_load_info(load_info)
table_names = [t["name"] for t in pipeline.default_schema.data_tables()]
Expand Down
2 changes: 1 addition & 1 deletion tests/sources/rest_api/integration/test_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
rest_api_source,
)
from tests.sources.rest_api.conftest import DEFAULT_PAGE_SIZE, DEFAULT_TOTAL_PAGES
from tests.utils import assert_load_info, assert_query_data, load_table_counts
from tests.pipeline.utils import assert_load_info, assert_query_data, load_table_counts


def test_load_mock_api(mock_api_server):
Expand Down
10 changes: 0 additions & 10 deletions tests/sources/rest_api/integration/test_processing_steps.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
from typing import Any, Callable, Dict, List

import dlt
from dlt.sources.rest_api import RESTAPIConfig, rest_api_source


def _make_pipeline(destination_name: str):
return dlt.pipeline(
pipeline_name="rest_api",
destination=destination_name,
dataset_name="rest_api_data",
full_refresh=True,
)


def test_rest_api_source_filtered(mock_api_server) -> None:
config: RESTAPIConfig = {
"client": {
Expand Down
55 changes: 3 additions & 52 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import platform
import sys
from os import environ
from typing import Any, Iterable, Iterator, Literal, List, Union, get_args
from typing import Any, Iterable, Iterator, Literal, Union, get_args
from unittest.mock import patch

import pytest
Expand All @@ -18,23 +18,18 @@
from dlt.common.configuration.specs.config_providers_context import (
ConfigProvidersContext,
)
from dlt.common.pipeline import LoadInfo, PipelineContext
from dlt.common.pipeline import PipelineContext
from dlt.common.runtime.init import init_logging
from dlt.common.runtime.telemetry import start_telemetry, stop_telemetry
from dlt.common.schema import Schema
from dlt.common.storages import FileStorage
from dlt.common.storages.versioned_storage import VersionedStorage
from dlt.common.typing import DictStrAny, StrAny, TDataItem
from dlt.common.typing import StrAny, TDataItem
from dlt.common.utils import custom_environ, uniq_id
from dlt.common.pipeline import SupportsPipeline

TEST_STORAGE_ROOT = "_storage"

ALL_DESTINATIONS = dlt.config.get("ALL_DESTINATIONS", list) or [
"duckdb",
]


# destination constants
IMPLEMENTED_DESTINATIONS = {
"athena",
Expand Down Expand Up @@ -338,47 +333,3 @@ def is_running_in_github_fork() -> bool:
skipifgithubfork = pytest.mark.skipif(
is_running_in_github_fork(), reason="Skipping test because it runs on a PR coming from fork"
)


def assert_load_info(info: LoadInfo, expected_load_packages: int = 1) -> None:
"""Asserts that expected number of packages was loaded and there are no failed jobs"""
assert len(info.loads_ids) == expected_load_packages
# all packages loaded
assert all(package.state == "loaded" for package in info.load_packages) is True
# no failed jobs in any of the packages
info.raise_on_failed_jobs()


def load_table_counts(p: dlt.Pipeline, *table_names: str) -> DictStrAny:
"""Returns row counts for `table_names` as dict"""
with p.sql_client() as c:
query = "\nUNION ALL\n".join(
[
f"SELECT '{name}' as name, COUNT(1) as c FROM {c.make_qualified_table_name(name)}"
for name in table_names
]
)
with c.execute_query(query) as cur:
rows = list(cur.fetchall())
return {r[0]: r[1] for r in rows}


def assert_query_data(
p: dlt.Pipeline,
sql: str,
table_data: List[Any],
schema_name: str = None,
info: LoadInfo = None,
) -> None:
"""Asserts that query selecting single column of values matches `table_data`. If `info` is provided, second column must contain one of load_ids in `info`"""
with p.sql_client(schema_name=schema_name) as c:
with c.execute_query(sql) as cur:
rows = list(cur.fetchall())
assert len(rows) == len(table_data)
for r, d in zip(rows, table_data):
row = list(r)
# first element comes from the data
assert row[0] == d
# the second is load id
if info:
assert row[1] in info.loads_ids

0 comments on commit 6104550

Please sign in to comment.