Skip to content

Commit

Permalink
giving up on custom serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
quadrismegistus committed Aug 21, 2024
1 parent 7ca7508 commit 2980752
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 15 deletions.
14 changes: 12 additions & 2 deletions hashstash/serializers/custom/custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import *
from .serialize import _serialize
from .deserialize import _deserialize

@log.debug
def serialize_numpy(obj):
Expand Down Expand Up @@ -89,18 +90,27 @@ def deserialize_pandas_series(obj):

@log.debug
def serialize_bytes(obj):
return _encode(obj, compress=False, b64=True, as_string=True)
return encode(obj, compress=False, b64=True, as_string=True)

@log.debug
def deserialize_bytes(obj):
return _decode(obj, compress=False, b64=True)
return decode(obj, compress=False, b64=True)

@log.debug
def serialize_set(obj):
return [_serialize(v) for v in sorted(obj, key=lambda x: str(x))] # ensure pseudo sorted for deterministic output

@log.debug
def deserialize_set(*args):
return {_deserialize(v) for v in args}

@log.debug
def serialize_tuple(obj):
return [_serialize(v) for v in obj]

@log.debug
def deserialize_tuple(*args):
return tuple([_deserialize(v) for v in args])

CUSTOM_OBJECT_SERIALIZERS = {
'pandas.core.frame.DataFrame':serialize_pandas_df,
Expand Down
6 changes: 4 additions & 2 deletions hashstash/serializers/custom/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def deserialize_python(data: dict, init_funcs=["from_dict"]):
x = obj_func(*args, **kwargs)
return x

@log.debug
def call_init():
return call_function_politely(obj, *args, **kwargs)

@log.debug
def call_init_func(func_name, args, kwargs):
if hasattr(obj, func_name):
func = getattr(obj, func_name)
Expand All @@ -53,9 +55,9 @@ def call_init_func(func_name, args, kwargs):
if obj is None and OBJ_SRC_KEY in data:
src = data[OBJ_SRC_KEY]
if src.startswith("class "):
return deserialize_class(data)
obj = deserialize_class(data)
else:
return deserialize_function(src, obj_addr)
obj = deserialize_function(src, obj_addr)

if args or kwargs:
for initfunc in init_funcs:
Expand Down
21 changes: 16 additions & 5 deletions hashstash/serializers/custom/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ def serialize_object(obj):
# any dictionary at all?
obj_d = prune_none_values(get_dict(obj, __ignore=False))
if obj_d:
return {
out = {
OBJ_ADDR_KEY: obj_addr,
OBJ_KWARGS_KEY: {k: _serialize(v) for k, v in obj_d.items()},
}
if not can_import_object(obj):
out[OBJ_SRC_KEY] = reconstruct_class_source(obj.__class__)

return out

# otherwise unknown
return serialize_unknown(obj)
Expand Down Expand Up @@ -180,7 +184,7 @@ def serialize_class(obj):
# Check if the class is defined in __main__ or can't be imported
if obj.__module__ == '__main__' or not can_import_object(obj):
try:
out[OBJ_SRC_KEY] = inspect.getsource(obj)
out[OBJ_SRC_KEY] = reformat_python_source(inspect.getsource(obj))
except OSError:
logger.warning(f"Could not get source for {obj.__name__} using inspect. Attempting to reconstruct.")
try:
Expand All @@ -199,6 +203,7 @@ def serialize_class(obj):

return out

@log.info
def reconstruct_class_source(cls):
lines = [f"class {cls.__name__}:"]

Expand All @@ -219,11 +224,17 @@ def reconstruct_class_source(cls):
try:
func_source = get_function_str(value)
# Remove any leading whitespace and add proper indentation
func_lines = [line.strip() for line in func_source.split('\n')]
func_lines = [" " + line for line in func_lines if line]
func_lines = func_source.split('\n')
func_lines = [" " + line if line.strip() else line for line in func_lines]
lines.extend(func_lines)
lines.append("") # Add an empty line after each method
except OSError:
logger.warning(f"Could not get source for method {name}")

return "\n".join(lines)
src = "\n".join(lines)
out = reformat_python_source(src)
print("reconstructed")
print(out)
print()

return out
9 changes: 7 additions & 2 deletions hashstash/utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def ff(x):

@log.debug
def f(x):
def fff(y):
log.debug('inner inner hello')
time.sleep(1)
return y*2
log.debug('outer hello')
time.sleep(1)
return ff(x)*2
# time.sleep(1)
from hashstash import serialize
return serialize(fff, 'jsonpickle_ext')
11 changes: 10 additions & 1 deletion hashstash/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,13 @@ def is_dataframe(obj):

def get_fn_ext(fn):
# without period
return fn.split('.')[-1]
return fn.split('.')[-1]


def reformat_python_source(src):
lines = src.split('\n')
if lines:
leading_spaces = len(lines[0]) - len(lines[0].lstrip())
return '\n'.join(line[leading_spaces:] for line in lines)
else:
return src
184 changes: 181 additions & 3 deletions tests/test_serialize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from hashstash import *
import numpy as np
import pandas as pd
config.set_serializer('custom')
import pytest
import base64
from hashstash.constants import SERIALIZER_TYPES
logger.setLevel(logging.DEBUG)

SERIALIZER_TYPES = ['jsonpickle_ext']
config.set_serializer('jsonpickle_ext')

def test_serialize_jsonable():
assert json.loads(serialize(42)) == 42
Expand All @@ -14,7 +19,7 @@ def test_serialize_numpy():
arr = np.array([1, 2, 3])
result = json.loads(serialize(arr))
assert result[OBJ_ADDR_KEY] == 'numpy.ndarray'
assert result[OBJ_ARGS_KEY] == [b64encode(arr.tobytes()).decode('utf-8')]
assert result[OBJ_ARGS_KEY] == [base64.b64encode(arr.tobytes()).decode('utf-8')]

def test_serialize_pandas_df():
df = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
Expand Down Expand Up @@ -97,4 +102,177 @@ def test_roundtrip():
assert deserialized['dict'] == original['dict']
assert np.array_equal(deserialized['numpy'], original['numpy'])
assert deserialized['pandas_df'].equals(original['pandas_df'])
assert deserialized['pandas_series'].equals(original['pandas_series'])
assert deserialized['pandas_series'].equals(original['pandas_series'])

@pytest.fixture(params=list(SERIALIZER_TYPES))
def serializer_type(request):
return request.param

def test_serialize_deserialize_all_types(serializer_type):
config.set_serializer(serializer_type)

original = {
'int': 42,
'str': 'hello',
'list': [1, 2, 3],
'dict': {'a': 1, 'b': 2},
'numpy': np.array([1, 2, 3]),
'pandas_df': pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}),
'pandas_series': pd.Series([1, 2, 3], name='test'),
'set': {1, 2, 3},
'tuple': (1, 2, 3),
'bytes': b'hello'
}

serialized = serialize(original)
deserialized = deserialize(serialized)

assert deserialized['int'] == original['int']
assert deserialized['str'] == original['str']
assert deserialized['list'] == original['list']
assert deserialized['dict'] == original['dict']
assert np.array_equal(deserialized['numpy'], original['numpy'])
assert deserialized['pandas_df'].equals(original['pandas_df'])
assert deserialized['pandas_series'].equals(original['pandas_series'])
assert deserialized['set'] == original['set']
assert deserialized['tuple'] == original['tuple']
assert deserialized['bytes'] == original['bytes']

def test_get_working_serializers():
working_serializers = get_working_serializers()
assert isinstance(working_serializers, list)
assert all(s in working_serializers for s in SERIALIZER_TYPES)

def test_serialize_class():
class TestClass:
def __init__(self, x):
self.x = x

def method(self):
return self.x * 2

serialized = serialize(TestClass)
deserialized = deserialize(serialized)

assert isinstance(deserialized, type)
instance = deserialized(5)
assert instance.x == 5
assert instance.method() == 10

def test_serialize_instance():
class TestClass:
def __init__(self, x):
self.x = x

def method(self):
return self.x * 2

original = TestClass(5)
serialized = serialize(original)
deserialized = deserialize(serialized)

# assert isinstance(deserialized, TestClass)
assert deserialized.__class__.__name__ == 'TestClass'
assert deserialized.x == 5
assert deserialized.method() == 10

def test_serialize_nested_structure():
nested = {
'list_of_dicts': [{'a': 1}, {'b': 2}],
'dict_of_lists': {'x': [1, 2], 'y': [3, 4]},
'tuple_with_list': ([1, 2], [3, 4]),
'set_of_tuples': {(1, 2), (3, 4)}
}

serialized = serialize(nested)
deserialized = deserialize(serialized)

assert deserialized == nested

def test_serialize_custom_objects():
class CustomObject:
def __init__(self, value):
self.value = value

def __eq__(self, other):
return isinstance(other, CustomObject) and self.value == other.value

original = CustomObject(42)
serialized = serialize(original)
deserialized = deserialize(serialized)

assert deserialized == original

def test_serialize_lambda():
original = lambda x: x * 2
serialized = serialize(original)
deserialized = deserialize(serialized)

assert callable(deserialized)
assert deserialized(3) == 6

def test_serialize_generator():
def gen():
yield from range(3)

original = gen()
serialized = serialize(original)
deserialized = deserialize(serialized)

assert list(deserialized) == [0, 1, 2]

def test_serialize_large_data():
large_list = list(range(1000000))
serialized = serialize(large_list)
deserialized = deserialize(serialized)

assert deserialized == large_list

def test_serialize_circular_reference():
a = []
a.append(a)

serialized = serialize(a)
deserialized = deserialize(serialized)

assert isinstance(deserialized, list)
assert deserialized[0] is deserialized

def test_serialize_set():
s = {1, 2, 3}
result = json.loads(serialize(s))
assert result[OBJ_ADDR_KEY] == 'builtins.set'
assert result[OBJ_ARGS_KEY] == [1, 2, 3]

def test_serialize_tuple():
t = (1, 2, 3)
result = json.loads(serialize(t))
assert result[OBJ_ADDR_KEY] == 'builtins.tuple'
assert result[OBJ_ARGS_KEY] == [1, 2, 3]

def test_serialize_bytes():
b = b'hello'
result = json.loads(serialize(b))
assert result[OBJ_ADDR_KEY] == 'builtins.bytes'
assert result[OBJ_ARGS_KEY] == [base64.b64encode(b).decode('utf-8')]

def test_deserialize_set():
s = {1, 2, 3}
serialized = serialize(s)
result = deserialize(serialized)
assert isinstance(result, set)
assert result == s

def test_deserialize_tuple():
t = (1, 2, 3)
serialized = serialize(t)
result = deserialize(serialized)
assert isinstance(result, tuple)
assert result == t

def test_deserialize_bytes():
b = b'hello'
serialized = serialize(b)
result = deserialize(serialized)
assert isinstance(result, bytes)
assert result == b

0 comments on commit 2980752

Please sign in to comment.