Skip to content

Commit

Permalink
feat: support custom forward use in torch components + generic batch_…
Browse files Browse the repository at this point in the history
…to_device
  • Loading branch information
percevalw committed Nov 13, 2024
1 parent f0c226e commit 696a410
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions edsnlp/core/torch_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def cached(key, store_key=False):
def wrapper(fn):
@wraps(fn)
def wrapped(self: "TorchComponent", *args, **kwargs):
if self._current_cache_id is None:
if self._current_cache_id is None or len(args) == 0:
return fn(self, *args, **kwargs)
cache_key = (
fn.__name__,
Expand Down Expand Up @@ -301,17 +301,17 @@ def batch_to_device(
-------
BatchInput
"""
return {
name: (
(value.to(device) if device is not None else value)
if hasattr(value, "to")
else getattr(self, name).batch_to_device(value, device=device)
if hasattr(self, name)
and hasattr(getattr(self, name), "batch_to_device")
else value
)
for name, value in batch.items()
}

def rec(x):
if hasattr(x, "to"):
return x.to(device)
if isinstance(x, dict):
return {name: rec(value) for name, value in x.items()}
if isinstance(x, (list, tuple, set)):
return type(x)(rec(value) for value in x)
return x

return rec(batch)

def forward(self, batch: BatchInput) -> BatchOutput:
"""
Expand Down

0 comments on commit 696a410

Please sign in to comment.