Skip to content

Commit

Permalink
Spark write csv/parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
Pee Tankulrat authored and Pee Tankulrat committed Nov 28, 2022
1 parent bb92182 commit 100a5a5
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 64 deletions.
4 changes: 2 additions & 2 deletions pii_anonymizer/spark/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pii_anonymizer.spark.acquire.csv_parser import CsvParser
from pii_anonymizer.spark.analyze.detectors.pii_detector import PIIDetector
from pii_anonymizer.spark.constants import ACQUIRE, REPORT
from pii_anonymizer.spark.write.csv_writer import CsvWriter
from pii_anonymizer.spark.write.output_writer import OutputWriter
from pii_anonymizer.common.get_args import get_args


Expand All @@ -37,7 +37,7 @@ def run(self):
print("NO PII VALUES WERE FOUND!")
else:
report_generator.generate(results_df=pii_analysis_report)
CsvWriter(spark, config=self.config).write_csv(df=redacted_data_frame)
OutputWriter(spark, config=self.config).write(df=redacted_data_frame)


if __name__ == "__main__":
Expand Down
32 changes: 0 additions & 32 deletions pii_anonymizer/spark/write/csv_writer.py

This file was deleted.

55 changes: 55 additions & 0 deletions pii_anonymizer/spark/write/output_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pyspark.sql import SparkSession, DataFrame
from pii_anonymizer.common.constants import (
ANONYMIZE,
ACQUIRE,
OUTPUT_FILE_PATH,
FILE_PATH,
OUTPUT_FILE_FORMAT,
)

output_file_format = ["csv", "parquet"]
output_file_format_err_msg = (
f"Output file format must be {' or '.join(output_file_format)}"
)


class OutputWriter:
def __init__(self, spark: SparkSession, config):
self.__validate_config(config)
self.__validate_output_format(config)
self.output_path = config[ANONYMIZE][OUTPUT_FILE_PATH]
self.input_file_name = config[ACQUIRE][FILE_PATH]
self.spark = spark

def __validate_config(self, config):
if (
ANONYMIZE not in config
or not config[ANONYMIZE]
or OUTPUT_FILE_PATH not in config[ANONYMIZE]
or not config[ANONYMIZE][OUTPUT_FILE_PATH]
):
raise ValueError(
"Config 'output_file_path' needs to be provided for parsing"
)

def __validate_output_format(self, config):
self.output_format = config[ANONYMIZE].get(OUTPUT_FILE_FORMAT, "csv")
if self.output_format not in output_file_format:
raise ValueError(output_file_format_err_msg)

def get_output_file_path(self):
file_name = self.input_file_name.split("/")[-1]
file_name_no_extension = file_name.split(".")[0]
result = f"{self.output_path}/{file_name_no_extension}_anonymized"
return result

def write(self, df: DataFrame):
match self.output_format:
case "csv":
df.write.mode("overwrite").option("header", "true").csv(
self.get_output_file_path()
)
case "parquet":
df.write.mode("overwrite").option("header", "true").parquet(
self.get_output_file_path()
)
30 changes: 0 additions & 30 deletions pii_anonymizer/spark/write/tests/test_csv_writer.py

This file was deleted.

78 changes: 78 additions & 0 deletions pii_anonymizer/spark/write/tests/test_output_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from unittest import TestCase
from unittest.mock import MagicMock, call
from pyspark.sql import SparkSession
from pii_anonymizer.spark.write.output_writer import (
OutputWriter,
output_file_format_err_msg,
)


class TestOutputWriter(TestCase):
def setUp(self) -> None:
self.SPARK = (
SparkSession.builder.master("local").appName("Test CsvWriter").getOrCreate()
)

def test_invalid_config_gets_caught_during_initialization(self):
context = {}
with self.assertRaises(ValueError) as ve:
OutputWriter(self.SPARK, config=context)
self.assertEqual(
str(ve.exception),
"Config 'output_file_path' needs to be provided for parsing",
)

def test_invalid_output_format_gets_caught_during_initialization(self):
context = context = {
"acquire": {"file_path": "/anonymizer/test_data.csv", "delimiter": ","},
"anonymize": {
"output_file_path": "/anonymizer/output",
"output_file_format": "invalid_format",
},
}
with self.assertRaises(ValueError) as ve:
OutputWriter(self.SPARK, config=context)
self.assertEqual(str(ve.exception), output_file_format_err_msg)

def test_correct_output_path_is_generated(self):
context = {
"acquire": {"file_path": "/anonymizer/test_data.csv", "delimiter": ","},
"anonymize": {"output_file_path": "/anonymizer/output"},
}
input_file_name = "test_data"
output_directory = "/anonymizer/output"
expected = f"{output_directory}/{input_file_name}_anonymized"
writer = OutputWriter(spark=self.SPARK, config=context)
self.assertEqual(writer.get_output_file_path(), expected)

def test_writer_call_correct_methods_on_write(self):
output_format_list = ["csv", "parquet"]
for output_format in output_format_list:
with self.subTest():
context = {
"acquire": {
"file_path": "/anonymizer/test_data.csv",
"delimiter": ",",
},
"anonymize": {
"output_file_path": "./output",
"output_file_format": output_format,
},
}
df = MagicMock()
OutputWriter(self.SPARK, config=context).write(df)
match output_format:
case "csv":
kall = (
call.write.mode("overwrite")
.option("header", "true")
.csv("./output/test_data_anonymized")
)
self.assertEqual(df.mock_calls, kall.call_list())
case "parquet":
kall = (
call.write.mode("overwrite")
.option("header", "true")
.parquet("./output/test_data_anonymized")
)
self.assertEqual(df.mock_calls, kall.call_list())

1 comment on commit 100a5a5

@pee-tw
Copy link
Collaborator

@pee-tw pee-tw commented on 100a5a5 Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.