From c8a6da03c15fe26849d449febfaaf3e5a520ee4d Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 12 Jul 2024 03:13:33 +0545 Subject: [PATCH] chore: Refactor encryption classes and add state_dict method --- src/litdata/utilities/encryption.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/litdata/utilities/encryption.py b/src/litdata/utilities/encryption.py index 286b4176..3ba8be15 100644 --- a/src/litdata/utilities/encryption.py +++ b/src/litdata/utilities/encryption.py @@ -1,7 +1,7 @@ import base64 import os from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Dict, Optional, Tuple from litdata.constants import _CRYPTOGRAPHY_AVAILABLE @@ -57,7 +57,7 @@ def _derive_key(self, password: str) -> bytes: ) return base64.urlsafe_b64encode(kdf.derive(password.encode())) - def state_dict(self) -> dict: + def state_dict(self) -> Dict[str, Any]: return {"password": self.password} @@ -90,18 +90,18 @@ def __init__( self.password = password self.extension = "rsa" - def _load_private_key(self, path: str): + def _load_private_key(self, path: str) -> rsa.RSAPrivateKey: with open(path, "rb") as key_file: return serialization.load_pem_private_key( key_file.read(), password=self.password.encode() if self.password else None, ) - def _load_public_key(self, path: str): + def _load_public_key(self, path: str) -> rsa.RSAPublicKey: with open(path, "rb") as key_file: return serialization.load_pem_public_key(key_file.read()) - def _generate_keys(self): + def _generate_keys(self) -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, @@ -133,9 +133,11 @@ def decrypt(self, data: bytes) -> bytes: ), ) - def save_keys(self, private_key_path: str, public_key_path: str, password: str = None) -> None: + def save_keys(self, private_key_path: str, public_key_path: str) -> None: encryption_algorithm = ( - serialization.BestAvailableEncryption(password.encode()) if password else serialization.NoEncryption() + serialization.BestAvailableEncryption(self.password.encode()) + if self.password + else serialization.NoEncryption() ) with open(private_key_path, "wb") as f: @@ -154,5 +156,5 @@ def save_keys(self, private_key_path: str, public_key_path: str, password: str = ) ) - def state_dict(self) -> dict: + def state_dict(self) -> Dict[str, Any]: return {"private_key": self.private_key, "public_key": self.public_key, "password": self.password}