Skip to content

Commit

Permalink
Fix alignment of dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Oct 20, 2023
1 parent 307f6c8 commit 335d7bd
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 96 deletions.
155 changes: 74 additions & 81 deletions audformat/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def get(
splits: limit search to selected splits
Returns:
labels
dataframe with labels
"""

Expand All @@ -472,15 +472,6 @@ def scheme_in_column(scheme_id, column, column_id):
)
)

def clean_y(y):
# Remove NaN and normalize dtypes
#
# TODO: at the moment we simply convert to string
# to avoid errors for different categorical data types.
# In real implementation we need to adjust those dtypes.
# return y.astype('string')
return y

requested_schemes = audeer.to_list(schemes)

if tables is None:
Expand All @@ -491,93 +482,96 @@ def clean_y(y):
if splits is not None:
splits = audeer.to_list(splits)

print(f'{requested_schemes=}')
# Check if requested schemes
# are stored as labels in other schemes
scheme_mappings = []
for scheme_id, scheme in self.schemes.items():
ys = []
for requested_scheme in requested_schemes:

# --- Check if requested scheme is stored as label in other schemes
scheme_mappings = []
for scheme_id, scheme in self.schemes.items():

print(f'{scheme_id=}')
if scheme.uses_table and scheme_id in self.misc_tables:
print('scheme stored in misc table')
# Labels stored as misc table
for column_id, column in self[scheme_id].columns.items():
print(f'{column_id=}')
for requested_scheme in requested_schemes:
if scheme.uses_table and scheme_id in self.misc_tables:
for column_id, column in self[scheme_id].columns.items():
if scheme_in_column(
requested_scheme, column, column_id
requested_scheme,
column,
column_id,
):
scheme_mappings.append(
(scheme_id, column_id)
)
break
elif isinstance(scheme.labels, dict):
scheme_mappings.append((scheme_id, column_id))

# Labels stored in scheme
labels = pd.DataFrame.from_dict(scheme.labels, orient='index')
for requested_scheme in requested_schemes:
print('scheme stored in labels')
print(f'{scheme.labels_as_list=}')
print(f'{labels=}')
elif isinstance(scheme.labels, dict):
labels = pd.DataFrame.from_dict(
scheme.labels,
orient='index',
)
if requested_scheme in labels:
scheme_mappings.append((scheme_id, requested_scheme))
break
print(f'{scheme_mappings=}')
print(f'{requested_schemes=}')

# Get data points for requested schemes
ys = []
for table_id, table in self.tables.items():

# Limit search to selected tables
if table_id not in tables:
continue

# Limit search to selected splits
if (
splits is not None
and table.split_id not in splits
):
continue

print(f'{table_id=}')
for column_id, column in table.columns.items():
# Get series based on scheme in column
if any(
[
scheme_in_column(
requested_scheme,
column,
column_id,
)
for requested_scheme in requested_schemes
]
):
print('Found requested scheme')
y = self[table_id][column_id].get()
ys.append(clean_y(y))
# Get series based on label of scheme
else:
print('Look for mapped schemes')
for (scheme_id, mapping) in scheme_mappings:
print(f'{scheme_id=}')
if scheme_in_column(scheme_id, column, column_id):
print('Found scheme')
y = self[table_id][column_id].get(map=mapping)
ys.append(clean_y(y))

print(f'{ys=}')
# --- Get data for requested schemes
ys_requested_scheme = []
for table_id, table in self.tables.items():

# Limit search to selected tables
if table_id not in tables:
continue

# Limit search to selected splits
if splits is not None and table.split_id not in splits:
continue

for column_id, column in table.columns.items():
# Scheme directly stored in column
if scheme_in_column(requested_scheme, column, column_id):
y = self[table_id][column_id].get()
ys_requested_scheme.append(y)
# Get series based on label of scheme
else:
for (scheme_id, mapping) in scheme_mappings:
if scheme_in_column(scheme_id, column, column_id):
y = self[table_id][column_id].get(map=mapping)
ys_requested_scheme.append(y)

# Ensure we have a common dtype for requested scheme
categorical_dtypes = [
y.dtype for y in ys_requested_scheme
if isinstance(y.dtype, pd.CategoricalDtype)
]
dtypes_of_categories = [
dtype.categories.dtype
for dtype in categorical_dtypes
]
if len(categorical_dtypes) > 0:
if len(set(dtypes_of_categories)) > 1:
raise ValueError(
'All categorical data must have the same dtype.'
)
dtype = dtypes_of_categories[0]
# Convert everything to categorical data
for n, y in enumerate(ys_requested_scheme):
if not isinstance(y.dtype, pd.CategoricalDtype):
ys_requested_scheme[n] = y.astype(
pd.CategoricalDtype(set(y.array.astype(dtype)))
)
# Find union of categorical data
data = [y.array for y in ys_requested_scheme]
data = pd.api.types.union_categoricals(data)
ys_requested_scheme = [
y.astype(data.dtype)
for y in ys_requested_scheme
]

ys += ys_requested_scheme

index = utils.union([y.index for y in ys])
obj = utils.concat(ys, rename=True).loc[index]
print(f'{obj=}')

print(f'{len(obj)}')
if len(obj) == 0:
obj = pd.DataFrame()

if isinstance(obj, pd.Series):
obj = obj.to_frame()

print(f'{obj=}')
# Start with column names matching requested schemes
matching_columns = [
column for column in requested_schemes
Expand All @@ -588,7 +582,6 @@ def clean_y(y):
if column not in requested_schemes
]
obj = obj[matching_columns + additional_columns]
print(f'{obj=}')

return obj

Expand Down
65 changes: 50 additions & 15 deletions tests/test_database_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ def mono_db(tmpdir):
db.schemes['age'] = audformat.Scheme('int', minimum=0)
db.schemes['height'] = audformat.Scheme('float')
db.schemes['rating'] = audformat.Scheme('int', labels=[0, 1, 2])
db.schemes['speaker.weight'] = audformat.Scheme(
'str',
labels=['low', 'normal', 'high'],
)
db.schemes['winner'] = audformat.Scheme(
'str',
labels={
'w1': {'year': 1995},
'w2': {'year': 1996},
'w3': {'year': 1997},
},
)
db.schemes['weather'] = audformat.Scheme(
Expand All @@ -46,6 +51,8 @@ def mono_db(tmpdir):
db['speaker']['height-with-10y'].set([1.12, 1.45, 1.01])
db['speaker']['current-height'] = audformat.Column(scheme_id='height')
db['speaker']['current-height'].set([1.76, 1.95, 1.80])
db['speaker']['weight'] = audformat.Column(scheme_id='speaker.weight')
db['speaker']['weight'].set(['normal', 'high', 'low'])

index = pd.Index(['today', 'yesterday'], name='day', dtype='string')
db['weather'] = audformat.MiscTable(index)
Expand All @@ -65,6 +72,11 @@ def mono_db(tmpdir):
db['files']['perceived-age'] = audformat.Column(scheme_id='age')
db['files']['perceived-age'].set([25, 34, 45])

index = audformat.filewise_index(['f1.wav'])
db['files.sub'] = audformat.Table(index)
db['files.sub']['speaker'] = audformat.Column(scheme_id='speaker')
db['files.sub']['speaker'].set('s1')

index = audformat.filewise_index(['f1.wav', 'f3.wav'])
db['gender'] = audformat.Table(index)
db['gender']['sex'] = audformat.Column()
Expand All @@ -89,6 +101,8 @@ def mono_db(tmpdir):
db['segments'] = audformat.Table(index)
db['segments']['rating'] = audformat.Column(scheme_id='rating')
db['segments']['rating'].set([1, 1, 2, 2])
db['segments']['winner'] = audformat.Column(scheme_id='winner')
db['segments']['winner'].set(['w1', 'w1', 'w1', 'w1'])

db.save(path)
audformat.testing.create_audio_files(db, channels=1, file_duration='1s')
Expand Down Expand Up @@ -188,9 +202,9 @@ def stereo_db(tmpdir):
pd.concat(
[
pd.Series(
['female', np.NaN, 'male'],
['female', 'male', np.NaN],
index=audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
['f1.wav', 'f3.wav', 'f2.wav']
),
dtype='object',
name='sex',
Expand Down Expand Up @@ -249,17 +263,20 @@ def stereo_db(tmpdir):
['f1.wav', 'f2.wav', 'f3.wav']
),
dtype=pd.CategoricalDtype(
categories=[23.0, 59.0],
categories=[23.0, 59.0, 25.0, 34.0, 45.0],
ordered=False,
),
name='age',
),
pd.Series(
[25, 34, 45],
[25.0, 34.0, 45.0],
index=audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
),
dtype='Int64',
dtype=pd.CategoricalDtype(
categories=[23.0, 59.0, 25.0, 34.0, 45.0],
ordered=False,
),
name='perceived-age',
),
],
Expand All @@ -277,7 +294,7 @@ def stereo_db(tmpdir):
['f1.wav', 'f2.wav', 'f3.wav']
),
dtype=pd.CategoricalDtype(
categories=[1.12, 1.45, 1.01],
categories=[1.12, 1.45, 1.01, 1.76, 1.95, 1.8],
ordered=False,
),
name='height-with-10y',
Expand All @@ -288,7 +305,7 @@ def stereo_db(tmpdir):
['f1.wav', 'f2.wav', 'f3.wav']
),
dtype=pd.CategoricalDtype(
categories=[1.76, 1.95, 1.80],
categories=[1.12, 1.45, 1.01, 1.76, 1.95, 1.8],
ordered=False,
),
name='current-height',
Expand All @@ -303,12 +320,21 @@ def stereo_db(tmpdir):
pd.concat(
[
pd.Series(
['w1', 'w1', 'w2'],
index=audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
['w1', 'w1', 'w2', 'w1', 'w1', 'w1', 'w1'],
index=audformat.utils.union(
[
audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
),
audformat.segmented_index(
['f1.wav', 'f1.wav', 'f1.wav', 'f2.wav'],
[0, 0.1, 0.3, 0],
[0.2, 0.2, 0.5, 0.7],
),
]
),
dtype=pd.CategoricalDtype(
categories=['w1', 'w2'],
['w1', 'w2', 'w3'],
ordered=False,
),
name='winner',
Expand All @@ -323,12 +349,21 @@ def stereo_db(tmpdir):
pd.concat(
[
pd.Series(
[1995, 1995, 1996],
index=audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
[1995, 1995, 1996, 1995, 1995, 1995, 1995],
index=audformat.utils.union(
[
audformat.filewise_index(
['f1.wav', 'f2.wav', 'f3.wav']
),
audformat.segmented_index(
['f1.wav', 'f1.wav', 'f1.wav', 'f2.wav'],
[0, 0.1, 0.3, 0],
[0.2, 0.2, 0.5, 0.7],
),
]
),
dtype=pd.CategoricalDtype(
categories=[1995, 1996],
[1995, 1996, 1997],
ordered=False,
),
name='year',
Expand Down

0 comments on commit 335d7bd

Please sign in to comment.