Skip to content

Commit

Permalink
refactor: Update encryption state_dict serialization and config state
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Jul 18, 2024
1 parent 07a0b6e commit b24e1fe
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
4 changes: 1 addition & 3 deletions src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def get_config(self) -> Dict[str, Any]:
"chunk_bytes": self._chunk_bytes,
"data_format": self._data_format,
"data_spec": treespec_dumps(self._data_spec) if self._data_spec else None,
"encryption": {"algorithm": self._encryption.algorithm, "level": self._encryption.level}
if self._encryption
else None,
"encryption": self._encryption.state_dict() if self._encryption else None,
}

def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]:
Expand Down
5 changes: 3 additions & 2 deletions src/litdata/utilities/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,14 @@ def decrypt(self, data: bytes) -> bytes:
def state_dict(self) -> Dict[str, Any]:
return {
"algorithm": self.algorithm,
"salt": base64.urlsafe_b64encode(self.salt).decode("utf-8"),
"level": self.level,
}

def save(self, file_path: str) -> None:
state = self.state_dict()
state["salt"] = base64.urlsafe_b64encode(self.salt).decode("utf-8")
with open(file_path, "wb") as file:
file.write(json.dumps(self.state_dict()).encode("utf-8"))
file.write(json.dumps(state).encode("utf-8"))

@classmethod
def load(cls, file_path: str, password: str) -> "FernetEncryption":
Expand Down

0 comments on commit b24e1fe

Please sign in to comment.