-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Pee Tankulrat
authored and
Pee Tankulrat
committed
Nov 28, 2022
1 parent
bb92182
commit 100a5a5
Showing
5 changed files
with
135 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from pyspark.sql import SparkSession, DataFrame | ||
from pii_anonymizer.common.constants import ( | ||
ANONYMIZE, | ||
ACQUIRE, | ||
OUTPUT_FILE_PATH, | ||
FILE_PATH, | ||
OUTPUT_FILE_FORMAT, | ||
) | ||
|
||
output_file_format = ["csv", "parquet"] | ||
output_file_format_err_msg = ( | ||
f"Output file format must be {' or '.join(output_file_format)}" | ||
) | ||
|
||
|
||
class OutputWriter: | ||
def __init__(self, spark: SparkSession, config): | ||
self.__validate_config(config) | ||
self.__validate_output_format(config) | ||
self.output_path = config[ANONYMIZE][OUTPUT_FILE_PATH] | ||
self.input_file_name = config[ACQUIRE][FILE_PATH] | ||
self.spark = spark | ||
|
||
def __validate_config(self, config): | ||
if ( | ||
ANONYMIZE not in config | ||
or not config[ANONYMIZE] | ||
or OUTPUT_FILE_PATH not in config[ANONYMIZE] | ||
or not config[ANONYMIZE][OUTPUT_FILE_PATH] | ||
): | ||
raise ValueError( | ||
"Config 'output_file_path' needs to be provided for parsing" | ||
) | ||
|
||
def __validate_output_format(self, config): | ||
self.output_format = config[ANONYMIZE].get(OUTPUT_FILE_FORMAT, "csv") | ||
if self.output_format not in output_file_format: | ||
raise ValueError(output_file_format_err_msg) | ||
|
||
def get_output_file_path(self): | ||
file_name = self.input_file_name.split("/")[-1] | ||
file_name_no_extension = file_name.split(".")[0] | ||
result = f"{self.output_path}/{file_name_no_extension}_anonymized" | ||
return result | ||
|
||
def write(self, df: DataFrame): | ||
match self.output_format: | ||
case "csv": | ||
df.write.mode("overwrite").option("header", "true").csv( | ||
self.get_output_file_path() | ||
) | ||
case "parquet": | ||
df.write.mode("overwrite").option("header", "true").parquet( | ||
self.get_output_file_path() | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from unittest import TestCase | ||
from unittest.mock import MagicMock, call | ||
from pyspark.sql import SparkSession | ||
from pii_anonymizer.spark.write.output_writer import ( | ||
OutputWriter, | ||
output_file_format_err_msg, | ||
) | ||
|
||
|
||
class TestOutputWriter(TestCase): | ||
def setUp(self) -> None: | ||
self.SPARK = ( | ||
SparkSession.builder.master("local").appName("Test CsvWriter").getOrCreate() | ||
) | ||
|
||
def test_invalid_config_gets_caught_during_initialization(self): | ||
context = {} | ||
with self.assertRaises(ValueError) as ve: | ||
OutputWriter(self.SPARK, config=context) | ||
self.assertEqual( | ||
str(ve.exception), | ||
"Config 'output_file_path' needs to be provided for parsing", | ||
) | ||
|
||
def test_invalid_output_format_gets_caught_during_initialization(self): | ||
context = context = { | ||
"acquire": {"file_path": "/anonymizer/test_data.csv", "delimiter": ","}, | ||
"anonymize": { | ||
"output_file_path": "/anonymizer/output", | ||
"output_file_format": "invalid_format", | ||
}, | ||
} | ||
with self.assertRaises(ValueError) as ve: | ||
OutputWriter(self.SPARK, config=context) | ||
self.assertEqual(str(ve.exception), output_file_format_err_msg) | ||
|
||
def test_correct_output_path_is_generated(self): | ||
context = { | ||
"acquire": {"file_path": "/anonymizer/test_data.csv", "delimiter": ","}, | ||
"anonymize": {"output_file_path": "/anonymizer/output"}, | ||
} | ||
input_file_name = "test_data" | ||
output_directory = "/anonymizer/output" | ||
expected = f"{output_directory}/{input_file_name}_anonymized" | ||
writer = OutputWriter(spark=self.SPARK, config=context) | ||
self.assertEqual(writer.get_output_file_path(), expected) | ||
|
||
def test_writer_call_correct_methods_on_write(self): | ||
output_format_list = ["csv", "parquet"] | ||
for output_format in output_format_list: | ||
with self.subTest(): | ||
context = { | ||
"acquire": { | ||
"file_path": "/anonymizer/test_data.csv", | ||
"delimiter": ",", | ||
}, | ||
"anonymize": { | ||
"output_file_path": "./output", | ||
"output_file_format": output_format, | ||
}, | ||
} | ||
df = MagicMock() | ||
OutputWriter(self.SPARK, config=context).write(df) | ||
match output_format: | ||
case "csv": | ||
kall = ( | ||
call.write.mode("overwrite") | ||
.option("header", "true") | ||
.csv("./output/test_data_anonymized") | ||
) | ||
self.assertEqual(df.mock_calls, kall.call_list()) | ||
case "parquet": | ||
kall = ( | ||
call.write.mode("overwrite") | ||
.option("header", "true") | ||
.parquet("./output/test_data_anonymized") | ||
) | ||
self.assertEqual(df.mock_calls, kall.call_list()) |
100a5a5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to issue #30
thoughtworks-datakind/anonymizer#30