Skip to content

Commit

Permalink
Add Database.root (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankenjoe authored May 4, 2021
1 parent ce38147 commit 3ba5c9d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
18 changes: 18 additions & 0 deletions audformat/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def __init__(
)
r"""Dictionary of tables"""

self._root = None

@property
def files(self) -> pd.Index:
r"""Files referenced in the database.
Expand All @@ -179,6 +181,18 @@ def files(self) -> pd.Index:
index = index.union(table.files.drop_duplicates())
return index.drop_duplicates()

@property
def root(self) -> typing.Optional[str]:
r"""Database root directory.
Returns ``None`` if database has not been stored yet.
Returns:
root directory
"""
return self._root

@property
def segments(self) -> pd.MultiIndex:
r"""Segments referenced in the database.
Expand Down Expand Up @@ -409,6 +423,8 @@ def job(table_id, table):
task_description='Save tables',
)

self._root = root

def update(
self,
others: typing.Union['Database', typing.Sequence['Database']],
Expand Down Expand Up @@ -656,6 +672,8 @@ def job(table_id):
task_description='Load tables',
)

db._root = root

return db

@staticmethod
Expand Down
8 changes: 8 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import filecmp
import os

import audeer
import pandas as pd
import pytest

Expand Down Expand Up @@ -176,11 +177,14 @@ def test_map_files(num_workers):
)
def test_save_and_load(tmpdir, db, storage_format, num_workers):

assert db.root is None
db.save(
tmpdir,
storage_format=storage_format,
num_workers=num_workers,
)
assert db.root == tmpdir

expected_formats = [storage_format]
for table_id in db.tables:
for ext in audformat.define.TableStorageFormat.attribute_values():
Expand All @@ -196,15 +200,18 @@ def test_save_and_load(tmpdir, db, storage_format, num_workers):
and db.tables
):
db2 = audformat.testing.create_db()
assert db2.root is None
db2.save(
tmpdir,
storage_format=audformat.define.TableStorageFormat.PICKLE,
num_workers=num_workers,
)
assert db.root == tmpdir

# Load prefers PKL files over CSV files,
# which means we are loading the second database here
db_load = audformat.Database.load(tmpdir)
assert db_load.root == db2.root
assert db_load == db2
assert db_load != db

Expand Down Expand Up @@ -266,6 +273,7 @@ def test_save_and_load(tmpdir, db, storage_format, num_workers):
load_data=False,
num_workers=num_workers,
)
assert db_load.root == tmpdir
for table_id, table in db_load.tables.items():
assert list(db_load.files) == []
assert table._id == table_id
Expand Down

0 comments on commit 3ba5c9d

Please sign in to comment.