Skip to content

Commit

Permalink
fix(ingestion/gcs): fix stateful ingestion for GCS source
Browse files Browse the repository at this point in the history
Remove pipeline name before passing context to equivalent s3 source to avoid error "Checkpointing provider DatahubIngestionCheckpointingProvider already registered."
  • Loading branch information
josges committed Jan 9, 2025
1 parent 42b2cd3 commit 86158d6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
from typing import Dict, Iterable, List, Optional
from urllib.parse import unquote
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(self, config: GCSSourceConfig, ctx: PipelineContext):
super().__init__(config, ctx)
self.config = config
self.report = GCSSourceReport()
self.platform: str = PLATFORM_GCS
self.s3_source = self.create_equivalent_s3_source(ctx)

@classmethod
Expand Down Expand Up @@ -135,7 +137,9 @@ def create_equivalent_s3_path_specs(self):

def create_equivalent_s3_source(self, ctx: PipelineContext) -> S3Source:
config = self.create_equivalent_s3_config()
return self.s3_source_overrides(S3Source(config, ctx))
s3_ctx = copy.deepcopy(ctx)
s3_ctx.pipeline_name = None
return self.s3_source_overrides(S3Source(config, s3_ctx))

def s3_source_overrides(self, source: S3Source) -> S3Source:
source.source_config.platform = PLATFORM_GCS
Expand Down
7 changes: 6 additions & 1 deletion metadata-ingestion/tests/unit/test_gcs_source.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from unittest import mock

import pytest
from pydantic import ValidationError

from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.data_lake_common.data_lake_utils import PLATFORM_GCS
from datahub.ingestion.source.gcs.gcs_source import GCSSource


def test_gcs_source_setup():
ctx = PipelineContext(run_id="test-gcs")
graph = mock.MagicMock(spec=DataHubGraph)
ctx = PipelineContext(run_id="test-gcs", graph=graph, pipeline_name="test-gcs")

# Baseline: valid config
source: dict = {
Expand All @@ -18,6 +22,7 @@ def test_gcs_source_setup():
}
],
"credential": {"hmac_access_id": "id", "hmac_access_secret": "secret"},
"stateful_ingestion": {"enabled": "true"},
}
gcs = GCSSource.create(source, ctx)
assert gcs.s3_source.source_config.platform == PLATFORM_GCS
Expand Down

0 comments on commit 86158d6

Please sign in to comment.