From bd9fba9ec62437b5b62fbd0b2c2c723216cc5a2c Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Wed, 30 Oct 2024 06:46:31 -0700 Subject: [PATCH] Simplify implementation of PMGDir. --- src/pymatgen/io/common.py | 26 +++++++++++++------------- tests/io/test_common.py | 5 +++-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/pymatgen/io/common.py b/src/pymatgen/io/common.py index 597d4b066c6..bdf3521923b 100644 --- a/src/pymatgen/io/common.py +++ b/src/pymatgen/io/common.py @@ -475,26 +475,26 @@ def reset(self): changed. """ # Note that py3.12 has Path.walk(). But we need to use os.walk to ensure backwards compatibility for now. - self.files = [str((Path(d) / f).relative_to(self.path)) for d, _, fnames in os.walk(self.path) for f in fnames] - - self._parsed_files: dict[str, Any] = {} + self._files: dict[str, Any] = { + str((Path(d) / f).relative_to(self.path)): None for d, _, fnames in os.walk(self.path) for f in fnames + } def __contains__(self, item): - return item in self.files + return item in self._files def __len__(self): - return len(self.files) + return len(self._files) def __iter__(self): - return iter(self.files) + return iter(self._files) def __getitem__(self, item): - if item in self._parsed_files: - return self._parsed_files[item] + if self._files.get(item): + return self._files.get(item) fpath = self.path / item if not (self.path / item).exists(): - raise ValueError(f"{item} not found in {self.path}. List of files are {self.files}.") + raise ValueError(f"{item} not found in {self.path}. List of files are {self._files.keys()}.") for k, cls_ in PMGDir.FILE_MAPPINGS.items(): if k in item: @@ -502,11 +502,11 @@ def __getitem__(self, item): module = importlib.import_module(modname) class_ = getattr(module, classname) try: - self._parsed_files[item] = class_.from_file(fpath) + self._files[item] = class_.from_file(fpath) except AttributeError: - self._parsed_files[item] = class_(fpath) + self._files[item] = class_(fpath) - return self._parsed_files[item] + return self._files[item] warnings.warn( f"No parser defined for {item}. Contents are returned as a string.", @@ -522,7 +522,7 @@ def get_files_by_name(self, name: str) -> dict[str, Any]: Returns: {filename: object from PMGDir[filename]} """ - return {f: self[f] for f in self.files if name in f} + return {f: self[f] for f in self._files if name in f} def __repr__(self): return f"PMGDir({self.path})" diff --git a/tests/io/test_common.py b/tests/io/test_common.py index a241b9e8948..9292d1f30df 100644 --- a/tests/io/test_common.py +++ b/tests/io/test_common.py @@ -33,7 +33,7 @@ def test_getitem(self): d = PMGDir(f"{TEST_FILES_DIR}/io/vasp/fixtures/scan_relaxation") assert len(d) == 2 - assert "vasprun.xml.gz" in d.files + assert "vasprun.xml.gz" in d assert "OUTCAR" in d assert d["vasprun.xml.gz"].incar["METAGGA"] == "R2scan" @@ -57,4 +57,5 @@ def test_getitem(self): assert all("OUTCAR" for k in outcars) d.reset() - assert len(d._parsed_files) == 0 + for v in d._files.values(): + assert v is None