diff --git a/src/litdata/utilities/encryption.py b/src/litdata/utilities/encryption.py index f285cb43..286b4176 100644 --- a/src/litdata/utilities/encryption.py +++ b/src/litdata/utilities/encryption.py @@ -23,6 +23,10 @@ def encrypt(self, data: bytes) -> bytes: def decrypt(self, data: bytes) -> bytes: pass + @abstractmethod + def state_dict(self) -> dict: + pass + class FernetEncryption(Encryption): """Encryption for the Fernet package. @@ -53,6 +57,9 @@ def _derive_key(self, password: str) -> bytes: ) return base64.urlsafe_b64encode(kdf.derive(password.encode())) + def state_dict(self) -> dict: + return {"password": self.password} + class RSAEncryption: """Encryption for the RSA package. @@ -68,7 +75,7 @@ def __init__( password: Optional[str] = None, ): if private_key_path: - self.private_key = self._load_private_key(private_key_path, password) + self.private_key = self._load_private_key(private_key_path) else: self.private_key = None @@ -79,13 +86,15 @@ def __init__( if not private_key_path and not public_key_path: self.private_key, self.public_key = self._generate_keys() + + self.password = password self.extension = "rsa" - def _load_private_key(self, path: str, password: str = None): + def _load_private_key(self, path: str): with open(path, "rb") as key_file: return serialization.load_pem_private_key( key_file.read(), - password=password.encode() if password else None, + password=self.password.encode() if self.password else None, ) def _load_public_key(self, path: str): @@ -144,3 +153,6 @@ def save_keys(self, private_key_path: str, public_key_path: str, password: str = format=serialization.PublicFormat.SubjectPublicKeyInfo, ) ) + + def state_dict(self) -> dict: + return {"private_key": self.private_key, "public_key": self.public_key, "password": self.password}