Skip to content

Commit

Permalink
chore: Refactor encryption classes and add state_dict method
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Jul 11, 2024
1 parent 5b4b649 commit c8a6da0
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/litdata/utilities/encryption.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any, Dict, Optional, Tuple

from litdata.constants import _CRYPTOGRAPHY_AVAILABLE

Expand Down Expand Up @@ -57,7 +57,7 @@ def _derive_key(self, password: str) -> bytes:
)
return base64.urlsafe_b64encode(kdf.derive(password.encode()))

def state_dict(self) -> dict:
def state_dict(self) -> Dict[str, Any]:
return {"password": self.password}


Expand Down Expand Up @@ -90,18 +90,18 @@ def __init__(
self.password = password
self.extension = "rsa"

def _load_private_key(self, path: str):
def _load_private_key(self, path: str) -> rsa.RSAPrivateKey:
with open(path, "rb") as key_file:
return serialization.load_pem_private_key(
key_file.read(),
password=self.password.encode() if self.password else None,
)

def _load_public_key(self, path: str):
def _load_public_key(self, path: str) -> rsa.RSAPublicKey:
with open(path, "rb") as key_file:
return serialization.load_pem_public_key(key_file.read())

def _generate_keys(self):
def _generate_keys(self) -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]:
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
Expand Down Expand Up @@ -133,9 +133,11 @@ def decrypt(self, data: bytes) -> bytes:
),
)

def save_keys(self, private_key_path: str, public_key_path: str, password: str = None) -> None:
def save_keys(self, private_key_path: str, public_key_path: str) -> None:
encryption_algorithm = (
serialization.BestAvailableEncryption(password.encode()) if password else serialization.NoEncryption()
serialization.BestAvailableEncryption(self.password.encode())
if self.password
else serialization.NoEncryption()
)

with open(private_key_path, "wb") as f:
Expand All @@ -154,5 +156,5 @@ def save_keys(self, private_key_path: str, public_key_path: str, password: str =
)
)

def state_dict(self) -> dict:
def state_dict(self) -> Dict[str, Any]:
return {"private_key": self.private_key, "public_key": self.public_key, "password": self.password}

0 comments on commit c8a6da0

Please sign in to comment.