diff --git a/rdl/data_sources/AWSLambdaDataSource.py b/rdl/data_sources/AWSLambdaDataSource.py index 25d767a..4b25ebd 100644 --- a/rdl/data_sources/AWSLambdaDataSource.py +++ b/rdl/data_sources/AWSLambdaDataSource.py @@ -29,6 +29,7 @@ def __init__(self, connection_string, logger=None): .split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR) ) self.aws_lambda_client = boto3.client("lambda") + self.aws_s3_client = boto3.client("s3") @staticmethod def can_handle_connection_string(connection_string): @@ -43,8 +44,9 @@ def get_connection_string_prefix(): return AWSLambdaDataSource.CONNECTION_STRING_PREFIX def get_table_info(self, table_config, last_known_sync_version): - column_names, last_sync_version, sync_version, full_refresh_required, data_changed_since_last_sync \ - = self.__get_table_info(table_config, last_known_sync_version) + column_names, last_sync_version, sync_version, full_refresh_required, data_changed_since_last_sync = self.__get_table_info( + table_config, last_known_sync_version + ) columns_in_database = column_names change_tracking_info = ChangeTrackingInfo( last_sync_version=last_sync_version, @@ -91,11 +93,13 @@ def __get_table_info(self, table_config, last_known_sync_version): result = self.__invoke_lambda(pay_load) - return result["ColumnNames"], \ - result["LastSyncVersion"], \ - result["CurrentSyncVersion"], \ - result["FullRefreshRequired"], \ - result["DataChangedSinceLastSync"] + return ( + result["ColumnNames"], + result["LastSyncVersion"], + result["CurrentSyncVersion"], + result["FullRefreshRequired"], + result["DataChangedSinceLastSync"], + ) def __get_table_data( self, @@ -116,23 +120,31 @@ def __get_table_data( "BatchSize": batch_config["size"], "LastSyncVersion": change_tracking_info.last_sync_version, "FullRefresh": full_refresh, - "ColumnNames": list(map(lambda cfg: cfg['source_name'], columns_config)), + "ColumnNames": list( + map(lambda cfg: cfg["source_name"], columns_config) + ), "PrimaryKeyColumnNames": table_config["primary_keys"], "LastBatchPrimaryKeys": [ - {"Key": k, "Value": v} for k, v in batch_key_tracker.bookmarks.items() + {"Key": k, "Value": v} + for k, v in batch_key_tracker.bookmarks.items() ], }, } result = self.__invoke_lambda(pay_load) + command_result = self.aws_s3_client.get_object( + Bucket=result["DataBucketName"], Key=result["DataKey"] + ) - return result["ColumnNames"], result["Data"] + data = json.loads(command_result["Body"].read()) + + return result["ColumnNames"], data def __get_data_frame(self, data: [[]], column_names: []): return pandas.DataFrame(data=data, columns=column_names) def __invoke_lambda(self, pay_load): - self.logger.debug('\nRequest being sent to Lambda:') + self.logger.debug("\nRequest being sent to Lambda:") self.logger.debug(pay_load) lambda_response = self.aws_lambda_client.invoke( @@ -142,24 +154,28 @@ def __invoke_lambda(self, pay_load): Payload=json.dumps(pay_load).encode(), ) - response_status_code = int(lambda_response['StatusCode']) + response_status_code = int(lambda_response["StatusCode"]) response_function_error = lambda_response.get("FunctionError") - self.logger.debug('\nResponse received from Lambda:') + self.logger.debug("\nResponse received from Lambda:") self.logger.debug(f'Response - StatusCode = "{response_status_code}"') self.logger.debug(f'Response - FunctionError = "{response_function_error}"') - response_payload = json.loads(lambda_response['Payload'].read()) + response_payload = json.loads(lambda_response["Payload"].read()) if response_status_code != 200 or response_function_error: - self.logger.error(F'Error in response from aws lambda {self.connection_data["function"]}') - self.logger.error(f'Response - Status Code = {response_status_code}') - self.logger.error(f'Response - Error Function = {response_function_error}') - self.logger.error(f'Response - Error Details:') + self.logger.error( + f'Error in response from aws lambda {self.connection_data["function"]}' + ) + self.logger.error(f"Response - Status Code = {response_status_code}") + self.logger.error(f"Response - Error Function = {response_function_error}") + self.logger.error(f"Response - Error Details:") # the below is risky as it may contain actual data if this line is reached in case of a successful result # however, the same Payload field is used to return actual error details in case of real errors # i.e. StatusCode is 200 (since AWS could invoke the lambda) # BUT the lambda barfed with an error and therefore the FunctionError would not be None self.logger.error(response_payload) - raise Exception('Error received when invoking AWS Lambda. See logs for further details.') + raise Exception( + "Error received when invoking AWS Lambda. See logs for further details." + ) return response_payload