Skip to content

Commit

Permalink
Load csv tables with pandas if pyarrow fails (#450)
Browse files Browse the repository at this point in the history
* Add failing test

* Fix test

* Fix tests

* Improve comments

* Specify test in class
  • Loading branch information
hagenw authored Jul 12, 2024
1 parent 0f0b069 commit 60cd1ed
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 12 deletions.
63 changes: 51 additions & 12 deletions audformat/core/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from audformat.core.column import Column
from audformat.core.common import HeaderBase
from audformat.core.common import HeaderDict
from audformat.core.common import to_pandas_dtype
from audformat.core.errors import BadIdError
from audformat.core.index import filewise_index
from audformat.core.index import index_type
Expand Down Expand Up @@ -880,25 +881,63 @@ def _load_csv(self, path: str):
than the method applied here.
We first load the CSV file as a :class:`pyarrow.Table`
and convert it to a dataframe afterwards.
If this fails,
we fall back to :func:`pandas.read_csv()`.
Args:
path: path to table, including file extension
"""
levels = list(self._levels_and_dtypes.keys())
columns = list(self.columns.keys())
table = csv.read_csv(
path,
read_options=csv.ReadOptions(
column_names=levels + columns,
skip_rows=1,
),
convert_options=csv.ConvertOptions(
column_types=self._pyarrow_csv_schema(),
strings_can_be_null=True,
),
)
df = self._pyarrow_table_to_dataframe(table, from_csv=True)
try:
table = csv.read_csv(
path,
read_options=csv.ReadOptions(
column_names=levels + columns,
skip_rows=1,
),
convert_options=csv.ConvertOptions(
column_types=self._pyarrow_csv_schema(),
strings_can_be_null=True,
),
)
df = self._pyarrow_table_to_dataframe(table, from_csv=True)
except pa.lib.ArrowInvalid:
# If pyarrow fails to parse the CSV file
# https://github.com/audeering/audformat/issues/449

# Collect csv file columns and data types.
# index
columns_and_dtypes = self._levels_and_dtypes
# columns
for column_id, column in self.columns.items():
if column.scheme_id is not None:
columns_and_dtypes[column_id] = self.db.schemes[
column.scheme_id
].dtype
else:
columns_and_dtypes[column_id] = define.DataType.OBJECT

# Replace data type with converter for dates or timestamps
converters = {}
dtypes_wo_converters = {}
for column, dtype in columns_and_dtypes.items():
if dtype == define.DataType.DATE:
converters[column] = lambda x: pd.to_datetime(x)
elif dtype == define.DataType.TIME:
converters[column] = lambda x: pd.to_timedelta(x)
else:
dtypes_wo_converters[column] = to_pandas_dtype(dtype)

df = pd.read_csv(
path,
usecols=list(columns_and_dtypes.keys()),
dtype=dtypes_wo_converters,
index_col=levels,
converters=converters,
float_precision="round_trip",
)

self._df = df

Expand Down
73 changes: 73 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,79 @@ def test_load(tmpdir):
os.remove(f"{path_no_ext}.{ext}")


class TestLoadBrokenCsv:
r"""Test loading of malformed csv files.
If csv files contain a lot of special characters,
or a different number of columns,
than specified in the database header,
loading of them should not fail.
See https://github.com/audeering/audformat/issues/449
"""

def database_with_hidden_columns(self) -> audformat.Database:
r"""Database with hidden columns.
Create database with hidden columns
that are stored in csv,
but not in the header of the table.
Ensure:
* it contains an empty table
* the columns use schemes with time and date data types
* at least one column has no scheme
as those cases needed special care with csv files,
before switching to use pyarrow.csv.read_csv()
in https://github.com/audeering/audformat/pull/419.
Returns:
database
"""
db = audformat.Database("mydb")
db.schemes["date"] = audformat.Scheme("date")
db.schemes["time"] = audformat.Scheme("time")
db["table"] = audformat.Table(audformat.filewise_index("file.wav"))
db["table"]["date"] = audformat.Column(scheme_id="date")
db["table"]["date"].set([pd.to_datetime("2018-10-26")])
db["table"]["time"] = audformat.Column(scheme_id="time")
db["table"]["time"].set([pd.Timedelta(1)])
db["table"]["no-scheme"] = audformat.Column()
db["table"]["no-scheme"].set(["label"])
db["empty-table"] = audformat.Table(audformat.filewise_index())
db["empty-table"]["column"] = audformat.Column()
# Add a hidden column to the table dataframes,
# without adding it to the table header
db["table"].df["hidden"] = ["hidden"]
db["empty-table"].df["hidden"] = []
return db

def test_load_broken_csv(self, tmpdir):
r"""Test loading a database table from broken csv files.
Broken csv files
refer to csv tables,
that raise an error
when loading with ``pyarrow.csv.read_csv()``.
Args:
tmpdir: tmpdir fixture
"""
db = self.database_with_hidden_columns()
build_dir = audeer.mkdir(tmpdir, "build")
db.save(build_dir, storage_format="csv")
db_loaded = audformat.Database.load(build_dir, load_data=True)
assert "table" in db_loaded
assert "empty-table" in db_loaded
assert "hidden" not in db_loaded["table"].df
assert "hidden-column" not in db_loaded["empty-table"].df


def test_load_old_pickle(tmpdir):
# We have stored string dtype as object dtype before
# and have to fix this when loading old PKL files from cache.
Expand Down

0 comments on commit 60cd1ed

Please sign in to comment.