diff --git a/nannyml/base.py b/nannyml/base.py index f81241ef..9cc61326 100644 --- a/nannyml/base.py +++ b/nannyml/base.py @@ -8,7 +8,7 @@ import copy import logging from abc import ABC, abstractmethod -from typing import Generic, Iterable, List, Optional, Tuple, TypeVar, Union, overload +from typing import Generic, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union, overload import numpy as np import pandas as pd @@ -616,8 +616,9 @@ def _raise_exception_for_negative_values(column: pd.Series): ) -def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]: - """Remove rows of dataframe containing NaN values on selected columns. +def _common_nan_removal_dataframe(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]: + """ + Remove rows of dataframe containing NaN values on selected columns. Parameters ---------- @@ -634,13 +635,88 @@ def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple empty: Boolean whether the resulting data are contain any rows (false) or not (true) """ - # If we want target and it's not available we get None if not set(selected_columns) <= set(data.columns): raise InvalidArgumentsException( f"Selected columns: {selected_columns} not all present in provided data columns {list(data.columns)}" ) df = data.dropna(axis=0, how='any', inplace=False, subset=selected_columns).reset_index(drop=True).infer_objects() - empty: bool = False - if df.shape[0] == 0: - empty = True - return (df, empty) + empty: bool = df.shape[0] == 0 + return df, empty + + +def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]: + """ + Remove rows of numpy arrays containing NaN values on selected columns. + + Parameters + ---------- + data: Sequence[np.array] + Sequence containing numpy arrays. + selected_columns: List[int] + List containing the indices of column numbers + + Returns + ------- + df: + Dataframe with rows containing NaN's on selected_columns removed. The columns of the DataFrame are the + numpy ndarrays in the same order as the input data. + empty: + Boolean whether the resulting data are contain any rows (false) or not (true) + """ + # Check if all selected_columns indices are valid for the first ndarray + if not all(col < len(data) for col in selected_columns): + raise InvalidArgumentsException( + f"Selected columns: {selected_columns} not all present in provided data columns with shape {data[0].shape}" + ) + + # Convert the numpy ndarrays to a pandas dataframe + df = pd.DataFrame({f'col_{i}': col for i, col in enumerate(data)}) + + # Use the dataframe function to remove NaNs + selected_columns_names = [df.columns[col] for col in selected_columns] + result, empty = _common_nan_removal_dataframe(df, selected_columns_names) + + return result, empty + + +@overload +def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]: + ... + + +@overload +def common_nan_removal(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]: + ... + + +def common_nan_removal( + data: Union[pd.DataFrame, Sequence[np.array]], selected_columns: Union[List[str], List[int]] +) -> Tuple[pd.DataFrame, bool]: + """ + Wrapper function to handle both pandas DataFrame and sequences of numpy ndarrays. + + Parameters + ---------- + data: Union[pd.DataFrame, Sequence[np.array]] + Pandas dataframe or sequence of numpy ndarrays containing data. + selected_columns: Union[List[str], List[int]] + List containing the column names or indices + + Returns + ------- + result: + Dataframe with rows containing NaN's on selected columns removed. All columns of original + dataframe or ndarrays are being returned. + empty: + Boolean whether the resulting data contains any rows (false) or not (true) + """ + if isinstance(data, pd.DataFrame): + if not all(isinstance(col, str) for col in selected_columns): + raise TypeError("When data is a pandas DataFrame, selected_columns should be a list of strings.") + return _common_nan_removal_dataframe(data, selected_columns) # type: ignore + elif isinstance(data, Sequence) and all(isinstance(arr, np.ndarray) for arr in data): + if not all(isinstance(col, int) for col in selected_columns): + raise TypeError("When data is a sequence of numpy ndarrays, selected_columns should be a list of integers.") + return _common_nan_removal_ndarrays(data, selected_columns) # type: ignore + else: + raise TypeError("Data should be either a pandas DataFrame or a sequence of numpy ndarrays.") diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..d553f725 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,82 @@ +import numpy as np +import pandas as pd +import pytest + +from nannyml.base import common_nan_removal +from nannyml.exceptions import InvalidArgumentsException + + +def test_common_nan_removal_dataframe(): + data = pd.DataFrame({'A': [1, 2, np.nan, 4], 'B': [5, np.nan, 7, 8], 'C': [9, 10, 11, np.nan]}) + selected_columns = ['A', 'B'] + df_cleaned, is_empty = common_nan_removal(data, selected_columns) + + expected_df = pd.DataFrame({'A': [1, 4], 'B': [5, 8], 'C': [9, np.nan]}).reset_index(drop=True) + + pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False) # ignore types because of infer_objects + assert not is_empty + + +def test_common_nan_removal_dataframe_all_nan(): + data = pd.DataFrame({'A': [np.nan, np.nan], 'B': [np.nan, np.nan], 'C': [np.nan, np.nan]}) + selected_columns = ['A', 'B'] + df_cleaned, is_empty = common_nan_removal(data, selected_columns) + + expected_df = pd.DataFrame(columns=['A', 'B', 'C']) + + pd.testing.assert_frame_equal(df_cleaned, expected_df, check_index_type=False, check_dtype=False) + assert is_empty + + +def test_common_nan_removal_arrays(): + data = [np.array([1, 5, 9]), np.array([2, np.nan, 10]), np.array([np.nan, 7, 11]), np.array([4, 8, np.nan])] + selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B' + + df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices) + + expected_df = pd.DataFrame( + { + 'col_0': [1, 9], + 'col_1': [2, 10], + 'col_2': [np.nan, 11], + 'col_3': [4, np.nan], + } + ).reset_index(drop=True) + + pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False) + assert not is_empty + + +def test_common_nan_removal_arrays_all_nan(): + data = [ + np.array([np.nan, np.nan]), + np.array([np.nan, np.nan]), + np.array([np.nan, np.nan]), + ] + selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B' + + df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices) + + expected_df = pd.DataFrame(columns=['col_0', 'col_1', 'col_2']) + + pd.testing.assert_frame_equal(df_cleaned, expected_df, check_index_type=False, check_dtype=False) + assert is_empty + + +def test_invalid_dataframe_columns(): + data = pd.DataFrame({'A': [1, 2, np.nan, 4], 'B': [5, np.nan, 7, 8], 'C': [9, 10, 11, np.nan]}) + selected_columns = ['A', 'D'] # 'D' does not exist + with pytest.raises(InvalidArgumentsException): + common_nan_removal(data, selected_columns) + + +def test_invalid_array_columns(): + data = [ + np.array([np.nan, np.nan]), + np.array([np.nan, np.nan]), + np.array([np.nan, np.nan]), + ] + selected_columns_indices = [0, 3] # Index 3 does not exist in ndarray + + with pytest.raises(InvalidArgumentsException): + common_nan_removal(data, selected_columns_indices)