diff --git a/colabdesign/shared/utils.py b/colabdesign/shared/utils.py index fd5aa237..67b0dc4f 100644 --- a/colabdesign/shared/utils.py +++ b/colabdesign/shared/utils.py @@ -5,7 +5,6 @@ import sys, gc def clear_mem(): - # clear vram (GPU) backend = jax.lib.xla_bridge.get_backend() if hasattr(backend,'live_buffers'): @@ -16,15 +15,16 @@ def clear_mem(): # https://github.com/google/jax/issues/10828 for module_name, module in sys.modules.items(): if module_name.startswith("jax"): - for obj_name in dir(module): - obj = getattr(module, obj_name) - if hasattr(obj, "cache_clear"): - try: - obj.cache_clear() - except: - pass + if module_name not in ["jax.interpreters.partial_eval"]: + for obj_name in dir(module): + obj = getattr(module, obj_name) + if hasattr(obj, "cache_clear"): + try: + obj.cache_clear() + except: + pass gc.collect() - + def update_dict(D, *args, **kwargs): '''robust function for updating dictionary''' def set_dict(d, x, override=False):