Skip to content

Commit

Permalink
Check length of column and reconstructed column to ensure they match
Browse files Browse the repository at this point in the history
  • Loading branch information
botirk38 committed Aug 14, 2024
1 parent a39e018 commit 14454f4
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions huggingface_pipelines/metric_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The updated batch with the metric scores, predictions, and references.
"""

# Check if the lengths of columns and reconstructed_columns match
if len(self.config.columns) != len(self.config.reconstructed_columns):
raise ValueError(
f"Mismatch in number of columns ({len(self.config.columns)}) "
f"and reconstructed columns ({len(self.config.reconstructed_columns)})"
)

for column, reconstructed_column in zip(
self.config.columns, self.config.reconstructed_columns
):
Expand All @@ -81,8 +89,7 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(original_data[0], list):
original_data = [" ".join(item) for item in original_data]
if isinstance(reconstructed_data[0], list):
reconstructed_data = [" ".join(item)
for item in reconstructed_data]
reconstructed_data = [" ".join(item) for item in reconstructed_data]

references = [[ref.split()] for ref in original_data]
predictions = [pred.split() for pred in reconstructed_data]
Expand Down

0 comments on commit 14454f4

Please sign in to comment.