Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton authored Oct 10, 2022
1 parent 2cb66db commit 9e83f31
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions colabdesign/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys, gc

def clear_mem():

# clear vram (GPU)
backend = jax.lib.xla_bridge.get_backend()
if hasattr(backend,'live_buffers'):
Expand All @@ -16,15 +15,16 @@ def clear_mem():
# https://github.com/google/jax/issues/10828
for module_name, module in sys.modules.items():
if module_name.startswith("jax"):
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
try:
obj.cache_clear()
except:
pass
if module_name not in ["jax.interpreters.partial_eval"]:
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
try:
obj.cache_clear()
except:
pass
gc.collect()

def update_dict(D, *args, **kwargs):
'''robust function for updating dictionary'''
def set_dict(d, x, override=False):
Expand Down

0 comments on commit 9e83f31

Please sign in to comment.