diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 6c30c7e8..656c0a22 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -147,9 +147,7 @@ def get_config(self) -> Dict[str, Any]: "chunk_bytes": self._chunk_bytes, "data_format": self._data_format, "data_spec": treespec_dumps(self._data_spec) if self._data_spec else None, - "encryption": {"algorithm": self._encryption.algorithm, "level": self._encryption.level} - if self._encryption - else None, + "encryption": self._encryption.state_dict() if self._encryption else None, } def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]: diff --git a/src/litdata/utilities/encryption.py b/src/litdata/utilities/encryption.py index 4f99d922..1918e936 100644 --- a/src/litdata/utilities/encryption.py +++ b/src/litdata/utilities/encryption.py @@ -99,13 +99,14 @@ def decrypt(self, data: bytes) -> bytes: def state_dict(self) -> Dict[str, Any]: return { "algorithm": self.algorithm, - "salt": base64.urlsafe_b64encode(self.salt).decode("utf-8"), "level": self.level, } def save(self, file_path: str) -> None: + state = self.state_dict() + state["salt"] = base64.urlsafe_b64encode(self.salt).decode("utf-8") with open(file_path, "wb") as file: - file.write(json.dumps(self.state_dict()).encode("utf-8")) + file.write(json.dumps(state).encode("utf-8")) @classmethod def load(cls, file_path: str, password: str) -> "FernetEncryption":