Skip to content

Commit

Permalink
Add as_dataframe argument to read_csv() (#421)
Browse files Browse the repository at this point in the history
* Add always_return_dataframe argument to read_csv()

* Improve docstring

* Use as_dataframe for argument name
  • Loading branch information
hagenw authored Apr 25, 2024
1 parent 7b6b683 commit dbc7285
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
30 changes: 19 additions & 11 deletions audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,7 @@ def map_language(language: str) -> str:

def read_csv(
*args,
as_dataframe: bool = False,
**kwargs,
) -> typing.Union[pd.Index, pd.Series, pd.DataFrame]:
r"""Read object from CSV file.
Expand All @@ -1269,11 +1270,13 @@ def read_csv(
conform to :ref:`table specifications <data-tables:Tables>`.
If conversion is not possible, an error is raised.
See :meth:`pandas.read_csv` for supported arguments.
Args:
*args: arguments
**kwargs: keyword arguments
*args: arguments passed on to :func:`pandas.read_csv`
as_dataframe: if ``False``,
a dataframe is only returned for data with two or more columns,
a series for data with one column,
an index for data with zero columns
**kwargs: keyword arguments passed on to :func:`pandas.read_csv`
Returns:
object conform to :ref:`table specifications <data-tables:Tables>`
Expand All @@ -1283,19 +1286,24 @@ def read_csv(
:ref:`table specifications <data-tables:Tables>`
Examples:
>>> from io import StringIO
>>> string = StringIO(
... '''file,start,end,value
>>> string = '''file,start,end,value
... f1,00:00:00,00:00:01,0.0
... f1,00:00:01,00:00:02,1.0
... f2,00:00:02,00:00:03,2.0'''
... )
>>> read_csv(string)
>>> with open("file.csv", "w") as file:
... _ = file.write(string)
>>> read_csv("file.csv")
file start end
f1 0 days 00:00:00 0 days 00:00:01 0.0
0 days 00:00:01 0 days 00:00:02 1.0
f2 0 days 00:00:02 0 days 00:00:03 2.0
Name: value, dtype: float64
>>> read_csv("file.csv", as_dataframe=True)
value
file start end
f1 0 days 00:00:00 0 days 00:00:01 0.0
0 days 00:00:01 0 days 00:00:02 1.0
f2 0 days 00:00:02 0 days 00:00:03 2.0
"""
frame = pd.read_csv(*args, **kwargs)
Expand All @@ -1322,11 +1330,11 @@ def read_csv(
index = segmented_index(files, starts=starts, ends=ends)
frame.drop(drop, axis="columns", inplace=True)

if len(frame.columns) == 0:
if len(frame.columns) == 0 and not as_dataframe:
return index

frame = frame.set_index(index)
if len(frame.columns) == 1:
if len(frame.columns) == 1 and not as_dataframe:
return frame[frame.columns[0]]
else:
return frame
Expand Down
8 changes: 8 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,14 @@ def test_read_csv(csv, result):
pd.testing.assert_series_equal(obj, result)
else:
pd.testing.assert_frame_equal(obj, result)
# Request dataframe as return type
csv.seek(0) # rewind string file object
obj = audformat.utils.read_csv(csv, as_dataframe=True)
if isinstance(result, pd.Index):
result = pd.DataFrame([], columns=[], index=result)
elif isinstance(result, pd.Series):
result = result.to_frame()
pd.testing.assert_frame_equal(obj, result)


@pytest.mark.parametrize(
Expand Down

0 comments on commit dbc7285

Please sign in to comment.