diff --git a/tap_shopify/client.py b/tap_shopify/client.py index 53bbc59..f478d12 100644 --- a/tap_shopify/client.py +++ b/tap_shopify/client.py @@ -2,44 +2,49 @@ from __future__ import annotations -import simplejson +from datetime import datetime from functools import cached_property +from http import HTTPStatus from inspect import stack +from time import sleep from typing import Any, Dict, Iterable, Optional, cast import requests - +import simplejson +from singer_sdk import typing as th from singer_sdk.exceptions import FatalAPIError, RetriableAPIError from singer_sdk.helpers.jsonpath import extract_jsonpath - -from http import HTTPStatus - -from singer_sdk import typing as th from singer_sdk.pagination import SinglePagePaginator from singer_sdk.streams import GraphQLStream -from tap_shopify.auth import ShopifyAuthenticator -from tap_shopify.gql_queries import schema_query -from tap_shopify.paginator import ShopifyPaginator -from tap_shopify.gql_queries import query_incremental - -from datetime import datetime -from time import sleep -from singer_sdk.pagination import SinglePagePaginator - from tap_shopify.exceptions import InvalidOperation, OperationFailed -from tap_shopify.gql_queries import bulk_query, bulk_query_status +from tap_shopify.gql_queries import ( + bulk_query, + bulk_query_status, + query_incremental, +) +from tap_shopify.paginator import ShopifyPaginator def verify_recursion(func): """Verify if the stream is recursive.""" objs = [] + connections = dict(num=0, in_conn=False) def wrapper(*args, **kwargs): if not [f for f in stack() if f.function == func.__name__]: + connections["in_conn"] = False objs.clear() + field_name = args[1]["name"] field_kind = args[1]["kind"] + + if field_kind == "INTERFACE": + if connections["in_conn"] or connections["num"] >= 5: + return + connections["in_conn"] = True + connections["num"] += 1 + if field_name not in objs: if field_kind == "OBJECT": objs.append(args[1]["name"]) @@ -54,9 +59,8 @@ class ShopifyStream(GraphQLStream): query_name = None single_object_params = None - ignore_objs = [] + ignore_objs = ["image", "metafield", "metafields", "metafieldconnection", "privateMetafield", "privateMetafields"] _requests_session = None - nested_connections = [] denied_fields = [] @property @@ -90,15 +94,18 @@ def http_headers(self) -> dict: def schema_gql(self) -> dict: """Return the schema for the stream.""" return self._tap.schema_gql - + @cached_property def additional_arguments(self) -> dict: """Return the schema for the stream.""" - gql_query = next(q for q in self._tap.queries_gql if q["name"]==self.query_name) + gql_query = next( + q for q in self._tap.queries_gql if q["name"] == self.query_name + ) if "includeClosed" in [a["name"] for a in gql_query["args"]]: return ["includeClosed: true"] return [] + # @verify_connections @verify_recursion def extract_field_type(self, field) -> str: """Extract the field type from the schema.""" @@ -120,9 +127,16 @@ def extract_field_type(self, field) -> str: return th.ObjectType(*properties) elif kind == "LIST": obj_type = field["ofType"]["ofType"] + if not obj_type: + return None list_field_type = self.extract_field_type(obj_type) if list_field_type: + if obj_type["name"].endswith("Edge") and not "node" in list_field_type.type_dict["properties"].keys(): + return None return th.ArrayType(list_field_type) + elif kind == "INTERFACE" and self.config.get("bulk"): + obj_schema = self.extract_gql_schema(name) + properties = self.get_fields_schema(obj_schema["fields"]) elif kind == "ENUM": return th.StringType elif kind == "NON_NULL": @@ -133,25 +147,27 @@ def extract_field_type(self, field) -> str: def get_fields_schema(self, fields) -> dict: """Build the schema for the stream.""" + # Filtering the fields that are not needed + field_names = [f["name"] for f in fields] + if "edges" in field_names: + fields = [f for f in fields if f["name"]=="edges"] + elif "node" in field_names: + fields = [f for f in fields if f["name"]=="node"] + properties = [] for field in fields: field_name = field["name"] + type_def = field.get("type", field) + type_def = type_def["ofType"] or type_def # Ignore all the fields that need arguments if field.get("isDeprecated") and self.config.get("ignore_deprecated"): continue - if field.get("args"): - if field["args"][0]["name"] == "first": - self.nested_connections.append(field_name) - continue if field_name in self.ignore_objs: continue - if field["type"]["kind"] == "INTERFACE": - continue required = field["type"].get("kind") == "NON_NULL" - type_def = field.get("type", field) - type_def = type_def["ofType"] or type_def field_type = self.extract_field_type(type_def) + if field_type: property = th.Property(field_name, field_type, required=required) properties.append(property) @@ -180,7 +196,7 @@ def schema(self) -> dict: stream_catalog = next(stream, None) if stream_catalog: return stream_catalog["schema"] - + stream_type = self.extract_gql_schema(self.gql_type) properties = self.get_fields_schema(stream_type["fields"]) return th.PropertiesList(*properties).to_dict() @@ -223,7 +239,6 @@ def denest_schema(schema): return denest_schema(catalog) - def validate_response(self, response: requests.Response) -> None: """Validate HTTP response.""" @@ -235,11 +250,11 @@ def validate_response(self, response: requests.Response) -> None: ): msg = self.response_error_message(response) raise RetriableAPIError(msg, response) - + json_resp = response.json() - if errors:=json_resp.get("errors"): - if len(errors)==1: + if errors := json_resp.get("errors"): + if len(errors) == 1: error = errors[0] code = error.get("extensions", {}).get("code") if code in ["THROTTLED", "MAX_COST_EXCEEDED"]: @@ -260,7 +275,7 @@ def convert_id_fields(self, row: dict) -> dict: if not isinstance(row, dict): return row for key, value in row.items(): - if key=="id" and isinstance(value, str): + if key == "id" and isinstance(value, str): row["id"] = row["id"].split("/")[-1].split("?")[0] elif isinstance(value, dict): row[key] = self.convert_id_fields(value) @@ -280,7 +295,6 @@ def post_process( return row - def query_gql(self) -> str: """Set or return the GraphQL query string.""" base_query = query_incremental @@ -310,7 +324,7 @@ def get_url_params( if self.single_object_params: params = self.single_object_params return params - + def prepare_request_payload( self, context: Optional[dict], next_page_token: Optional[Any] ) -> Optional[dict]: @@ -334,7 +348,6 @@ def parse_response_gql(self, response: requests.Response) -> Iterable[dict]: yield from extract_jsonpath(json_path, json_resp) - def query_bulk(self) -> str: """Set or return the GraphQL query string.""" base_query = bulk_query @@ -411,9 +424,7 @@ def parse_response_bulk(self, response: requests.Response) -> Iterable[dict]: errors = next(extract_jsonpath(error_jsonpath, json_resp), None) if errors: raise InvalidOperation(simplejson.dumps(errors)) - operation_id = next( - extract_jsonpath(operation_id_jsonpath, json_resp) - ) + operation_id = next(extract_jsonpath(operation_id_jsonpath, json_resp)) url = self.check_status(operation_id) @@ -431,12 +442,12 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: @cached_property def query(self) -> str: """Set or return the GraphQL query string.""" - self.evaluate_query() + # TODO: figure out how to handle interfaces + # self.evaluate_query() if self.config.get("bulk"): return self.query_bulk() return self.query_gql() - def evaluate_query(self) -> dict: query = self.query_gql().lstrip() params = self.get_url_params(None, None) @@ -464,4 +475,6 @@ def evaluate_query(self) -> dict: self.denied_fields.append(message.split(" ")[3]) else: raise FatalAPIError(error.get("message", ""), response) - self.evaluate_query() \ No newline at end of file + self.evaluate_query() + + # TODO: get query cost from here diff --git a/tap_shopify/client_bulk.py b/tap_shopify/client_bulk.py deleted file mode 100644 index 416d30f..0000000 --- a/tap_shopify/client_bulk.py +++ /dev/null @@ -1,107 +0,0 @@ -"""GraphQL client handling, including shopify-betaStream base class.""" - -from datetime import datetime -from time import sleep -from typing import Any, Iterable, Optional, cast - -import requests -import simplejson -from singer_sdk.helpers.jsonpath import extract_jsonpath -from singer_sdk.pagination import SinglePagePaginator - -from tap_shopify.client import ShopifyStream -from tap_shopify.exceptions import InvalidOperation, OperationFailed -from tap_shopify.gql_queries import bulk_query, bulk_query_status, simple_query - - -class shopifyBulkStream(ShopifyStream): - """shopify stream class.""" - def query(self) -> str: - """Set or return the GraphQL query string.""" - if self.name == "shop": - base_query = simple_query - else: - base_query = bulk_query - - query = base_query.replace("__query_name__", self.query_name) - query = query.replace("__selected_fields__", self.gql_selected_fields) - filters = f"({self.filters})" if self.filters else "" - query = query.replace("__filters__", filters) - - return query - - @property - def filters(self): - """Return a dictionary of values to be used in URL parameterization.""" - filters = [] - if self.additional_arguments: - filters.extend(self.additional_arguments) - if self.replication_key: - start_date = self.get_starting_timestamp({}) - if start_date: - date = start_date.strftime("%Y-%m-%dT%H:%M:%S") - filters.append(f'query: "updated_at:>{date}"') - return ",".join(filters) - - def get_operation_status(self): - headers = self.http_headers - authenticator = self.authenticator - if authenticator: - headers.update(authenticator.auth_headers or {}) - - request = cast( - requests.PreparedRequest, - self.requests_session.prepare_request( - requests.Request( - method=self.rest_method, - url=self.get_url({}), - headers=headers, - json=dict(query=bulk_query_status, variables={}), - ), - ), - ) - - decorated_request = self.request_decorator(self._request) - response = decorated_request(request, {}) - - return response - - def check_status(self, operation_id, sleep_time=10, timeout=1800): - status_jsonpath = "$.data.currentBulkOperation" - start = datetime.now().timestamp() - - while datetime.now().timestamp() < (start + timeout): - status_response = self.get_operation_status() - status = next( - extract_jsonpath(status_jsonpath, input=status_response.json()) - ) - if status["id"] != operation_id: - raise InvalidOperation( - "The current job was not triggered by the process, " - "check if other service is using the Bulk API" - ) - if status["url"]: - return status["url"] - if status["status"] == "FAILED": - raise InvalidOperation(f"Job failed: {status['errorCode']}") - sleep(sleep_time) - raise OperationFailed("Job Timeout") - - def parse_response(self, response: requests.Response) -> Iterable[dict]: - """Parse the response and return an iterator of result rows.""" - operation_id_jsonpath = "$.data.bulkOperationRunQuery.bulkOperation.id" - error_jsonpath = "$.data.bulkOperationRunQuery.userErrors" - json_resp = response.json() - errors = next(extract_jsonpath(error_jsonpath, json_resp), None) - if errors: - raise InvalidOperation(simplejson.dumps(errors)) - operation_id = next( - extract_jsonpath(operation_id_jsonpath, json_resp) - ) - - url = self.check_status(operation_id) - - output = requests.get(url, stream=True, timeout=30) - - for line in output.iter_lines(): - yield simplejson.loads(line) diff --git a/tap_shopify/client_gql.py b/tap_shopify/client_gql.py deleted file mode 100644 index 4d2bd76..0000000 --- a/tap_shopify/client_gql.py +++ /dev/null @@ -1,68 +0,0 @@ -"""GraphQL client handling, including shopify-betaStream base class.""" - -from __future__ import annotations - -from typing import Any, Dict, Iterable, Optional - -import requests # noqa: TCH002 -from singer_sdk.helpers.jsonpath import extract_jsonpath - -from tap_shopify.client import ShopifyStream -from tap_shopify.gql_queries import query_incremental - - -class shopifyGqlStream(ShopifyStream): - """shopify stream class.""" - - def query_gql(self) -> str: - """Set or return the GraphQL query string.""" - base_query = query_incremental - - query = base_query.replace("__query_name__", self.query_name) - query = query.replace("__selected_fields__", self.gql_selected_fields) - additional_args = ", " + ", ".join(self.additional_arguments) - query = query.replace("__additional_args__", additional_args) - - return query - - def get_url_params( - self, context: Optional[dict], next_page_token: Optional[Any] - ) -> Dict[str, Any]: - """Return a dictionary of values to be used in URL parameterization.""" - params = {} - - if next_page_token: - params.update(next_page_token) - else: - params["first"] = 1 - if self.replication_key: - start_date = self.get_starting_timestamp(context) - if start_date: - date = start_date.strftime("%Y-%m-%dT%H:%M:%S") - params["filter"] = f"updated_at:>{date}" - if self.single_object_params: - params = self.single_object_params - return params - - def prepare_request_payload( - self, context: Optional[dict], next_page_token: Optional[Any] - ) -> Optional[dict]: - """Prepare the data payload for the GraphQL API request.""" - params = self.get_url_params(context, next_page_token) - query = self.query.lstrip() - request_data = { - "query": query, - "variables": params, - } - self.logger.debug(f"Attempting query:\n{query}") - return request_data - - def parse_response_gql(self, response: requests.Response) -> Iterable[dict]: - """Parse the response and return an iterator of result rows.""" - if self.replication_key: - json_path = f"$.data.{self.query_name}.edges[*].node" - else: - json_path = f"$.data.{self.query_name}" - json_resp = response.json() - - yield from extract_jsonpath(json_path, json_resp) diff --git a/tap_shopify/gql_queries.py b/tap_shopify/gql_queries.py index e72f02d..d36bcee 100644 --- a/tap_shopify/gql_queries.py +++ b/tap_shopify/gql_queries.py @@ -190,4 +190,4 @@ } } } -}""" \ No newline at end of file +}""" diff --git a/tap_shopify/tap.py b/tap_shopify/tap.py index 7273993..5a2ba51 100644 --- a/tap_shopify/tap.py +++ b/tap_shopify/tap.py @@ -2,59 +2,17 @@ from __future__ import annotations -from singer_sdk import Tap -from singer_sdk import typing as th from functools import cached_property -from tap_shopify.gql_queries import schema_query, queries_query from typing import Any, Iterable -from singer_sdk.helpers.jsonpath import extract_jsonpath as jp -import requests import inflection +import requests +from singer_sdk import Tap +from singer_sdk import typing as th +from singer_sdk.helpers.jsonpath import extract_jsonpath as jp from tap_shopify.client import ShopifyStream - -# from tap_shopify.client_bulk import shopifyBulkStream -# from tap_shopify.client_gql import shopifyGqlStream - - -# class ShopifyStream(shopifyGqlStream, shopifyBulkStream): -# """Define base based on the type GraphQL or Bulk.""" - - -# def parse_response(self, response: requests.Response) -> Iterable[dict]: -# """Parse the response and return an iterator of result rows.""" -# if self.config.get("bulk"): -# return shopifyBulkStream.parse_response(self, response) -# else: -# return shopifyGqlStream.parse_response(self, response) - -# @cached_property -# def query(self) -> str: -# """Set or return the GraphQL query string.""" -# if self.config.get("bulk"): -# return shopifyBulkStream.query(self) -# else: -# return shopifyGqlStream.query(self) - - # def evaluate_query(self) -> dict: - # query = shopifyGqlStream.query(self) - # params = self.get_url_params(None, None) - # query = self.query.lstrip() - # request_data = { - # "query": query, - # "variables": params, - # } - - # response = requests.request( - # method=self.rest_method, - # url=self.get_url({}), - # params=params, - # headers=self.http_headers, - # json=request_data, - # ) - - # return response +from tap_shopify.gql_queries import queries_query, schema_query class TapShopify(Tap): @@ -92,7 +50,7 @@ class TapShopify(Tap): th.Property( "bulk", th.BooleanType, - default=False, + default=True, description="To use the bulk API or not.", ), th.Property( @@ -172,6 +130,8 @@ def get_type_fields(self, gql_type: str) -> list[dict]: field_kind = next(jp("type.ofType.kind", field), None) if type_kind == "NON_NULL" and field_kind == "SCALAR": filtered_fields.append(field) + elif type_kind == "NON_NULL" and field_kind == "OBJECT": + filtered_fields.append(field) return {f["name"]: f["type"]["ofType"] for f in filtered_fields}