Skip to content

Commit

Permalink
Merge pull request #388 from NannyML/feature/extended-common-nan-support
Browse files Browse the repository at this point in the history
Add support for `np.array` in the `common_nan_removal` function
  • Loading branch information
nnansters authored May 16, 2024
2 parents 9a1760f + 107913f commit 27b713d
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 8 deletions.
92 changes: 84 additions & 8 deletions nannyml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import copy
import logging
from abc import ABC, abstractmethod
from typing import Generic, Iterable, List, Optional, Tuple, TypeVar, Union, overload
from typing import Generic, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -616,8 +616,9 @@ def _raise_exception_for_negative_values(column: pd.Series):
)


def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]:
"""Remove rows of dataframe containing NaN values on selected columns.
def _common_nan_removal_dataframe(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]:
"""
Remove rows of dataframe containing NaN values on selected columns.
Parameters
----------
Expand All @@ -634,13 +635,88 @@ def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple
empty:
Boolean whether the resulting data are contain any rows (false) or not (true)
"""
# If we want target and it's not available we get None
if not set(selected_columns) <= set(data.columns):
raise InvalidArgumentsException(
f"Selected columns: {selected_columns} not all present in provided data columns {list(data.columns)}"
)
df = data.dropna(axis=0, how='any', inplace=False, subset=selected_columns).reset_index(drop=True).infer_objects()
empty: bool = False
if df.shape[0] == 0:
empty = True
return (df, empty)
empty: bool = df.shape[0] == 0
return df, empty


def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]:
"""
Remove rows of numpy arrays containing NaN values on selected columns.
Parameters
----------
data: Sequence[np.array]
Sequence containing numpy arrays.
selected_columns: List[int]
List containing the indices of column numbers
Returns
-------
df:
Dataframe with rows containing NaN's on selected_columns removed. The columns of the DataFrame are the
numpy ndarrays in the same order as the input data.
empty:
Boolean whether the resulting data are contain any rows (false) or not (true)
"""
# Check if all selected_columns indices are valid for the first ndarray
if not all(col < len(data) for col in selected_columns):
raise InvalidArgumentsException(
f"Selected columns: {selected_columns} not all present in provided data columns with shape {data[0].shape}"
)

# Convert the numpy ndarrays to a pandas dataframe
df = pd.DataFrame({f'col_{i}': col for i, col in enumerate(data)})

# Use the dataframe function to remove NaNs
selected_columns_names = [df.columns[col] for col in selected_columns]
result, empty = _common_nan_removal_dataframe(df, selected_columns_names)

return result, empty


@overload
def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]:
...


@overload
def common_nan_removal(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]:
...


def common_nan_removal(
data: Union[pd.DataFrame, Sequence[np.array]], selected_columns: Union[List[str], List[int]]
) -> Tuple[pd.DataFrame, bool]:
"""
Wrapper function to handle both pandas DataFrame and sequences of numpy ndarrays.
Parameters
----------
data: Union[pd.DataFrame, Sequence[np.array]]
Pandas dataframe or sequence of numpy ndarrays containing data.
selected_columns: Union[List[str], List[int]]
List containing the column names or indices
Returns
-------
result:
Dataframe with rows containing NaN's on selected columns removed. All columns of original
dataframe or ndarrays are being returned.
empty:
Boolean whether the resulting data contains any rows (false) or not (true)
"""
if isinstance(data, pd.DataFrame):
if not all(isinstance(col, str) for col in selected_columns):
raise TypeError("When data is a pandas DataFrame, selected_columns should be a list of strings.")
return _common_nan_removal_dataframe(data, selected_columns) # type: ignore
elif isinstance(data, Sequence) and all(isinstance(arr, np.ndarray) for arr in data):
if not all(isinstance(col, int) for col in selected_columns):
raise TypeError("When data is a sequence of numpy ndarrays, selected_columns should be a list of integers.")
return _common_nan_removal_ndarrays(data, selected_columns) # type: ignore
else:
raise TypeError("Data should be either a pandas DataFrame or a sequence of numpy ndarrays.")
82 changes: 82 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
import pandas as pd
import pytest

from nannyml.base import common_nan_removal
from nannyml.exceptions import InvalidArgumentsException


def test_common_nan_removal_dataframe():
data = pd.DataFrame({'A': [1, 2, np.nan, 4], 'B': [5, np.nan, 7, 8], 'C': [9, 10, 11, np.nan]})
selected_columns = ['A', 'B']
df_cleaned, is_empty = common_nan_removal(data, selected_columns)

expected_df = pd.DataFrame({'A': [1, 4], 'B': [5, 8], 'C': [9, np.nan]}).reset_index(drop=True)

pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False) # ignore types because of infer_objects
assert not is_empty


def test_common_nan_removal_dataframe_all_nan():
data = pd.DataFrame({'A': [np.nan, np.nan], 'B': [np.nan, np.nan], 'C': [np.nan, np.nan]})
selected_columns = ['A', 'B']
df_cleaned, is_empty = common_nan_removal(data, selected_columns)

expected_df = pd.DataFrame(columns=['A', 'B', 'C'])

pd.testing.assert_frame_equal(df_cleaned, expected_df, check_index_type=False, check_dtype=False)
assert is_empty


def test_common_nan_removal_arrays():
data = [np.array([1, 5, 9]), np.array([2, np.nan, 10]), np.array([np.nan, 7, 11]), np.array([4, 8, np.nan])]
selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B'

df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices)

expected_df = pd.DataFrame(
{
'col_0': [1, 9],
'col_1': [2, 10],
'col_2': [np.nan, 11],
'col_3': [4, np.nan],
}
).reset_index(drop=True)

pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False)
assert not is_empty


def test_common_nan_removal_arrays_all_nan():
data = [
np.array([np.nan, np.nan]),
np.array([np.nan, np.nan]),
np.array([np.nan, np.nan]),
]
selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B'

df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices)

expected_df = pd.DataFrame(columns=['col_0', 'col_1', 'col_2'])

pd.testing.assert_frame_equal(df_cleaned, expected_df, check_index_type=False, check_dtype=False)
assert is_empty


def test_invalid_dataframe_columns():
data = pd.DataFrame({'A': [1, 2, np.nan, 4], 'B': [5, np.nan, 7, 8], 'C': [9, 10, 11, np.nan]})
selected_columns = ['A', 'D'] # 'D' does not exist
with pytest.raises(InvalidArgumentsException):
common_nan_removal(data, selected_columns)


def test_invalid_array_columns():
data = [
np.array([np.nan, np.nan]),
np.array([np.nan, np.nan]),
np.array([np.nan, np.nan]),
]
selected_columns_indices = [0, 3] # Index 3 does not exist in ndarray

with pytest.raises(InvalidArgumentsException):
common_nan_removal(data, selected_columns_indices)

0 comments on commit 27b713d

Please sign in to comment.