You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/compile.py", line 11, in<module>
cv.to_torchscript('model.ptc')
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1429, in to_torchscript
torchscript_module = torch.jit.script(self.eval(), **kwargs)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
return torch.jit._recursive.create_script_module(
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 867, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script
fn = torch._C._jit_script_compile(
File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/annotations.py", line 366, in try_ann_to_type
assert maybe_type, msg.format(repr(ann), repr(maybe_type))
AssertionError: Unsupported annotation typing.Union[list, torch.Tensor] could not be resolved because None could not be resolved.
I tried some tinkering, and here is an example of a compilable tda_loss module, although I'm not really sure about the correctness.
code:
#!/usr/bin/env python# =============================================================================# MODULE DOCSTRING# ============================================================================="""Target Discriminant Analysis Loss Function."""__all__= ["TDALoss", "tda_loss"]
# =============================================================================# GLOBAL IMPORTS# =============================================================================fromtypingimportUnion, List, Tuplefromwarningsimportwarnimporttorch# =============================================================================# LOSS FUNCTIONS# =============================================================================classTDALoss(torch.nn.Module):
"""Compute a loss function as the distance from a simple Gaussian target distribution."""def__init__(
self,
n_states: int,
target_centers: Union[List[float], torch.Tensor],
target_sigmas: Union[List[float], torch.Tensor],
alpha: float=1.0,
beta: float=100.0,
):
"""Constructor. Parameters ---------- n_states : int Number of states. The integer labels are expected to be in between 0 and ``n_states-1``. target_centers : list or torch.Tensor Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets. target_sigmas : list or torch.Tensor Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets. alpha : float, optional Centers_loss component prefactor, by default 1. beta : float, optional Sigmas loss compontent prefactor, by default 100. """super().__init__()
self.n_states=n_statesself.target_centers=target_centersself.target_sigmas=target_sigmasself.alpha=alphaself.beta=betadefforward(
self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool=False
) ->Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Compute the value of the loss function. Parameters ---------- H : torch.Tensor Shape ``(n_batches, n_features)``. Output of the NN. labels : torch.Tensor Shape ``(n_batches,)``. Labels of the dataset. return_loss_terms : bool, optional If ``True``, the loss terms associated to the center and standard deviations of the target Gaussians are returned as well. Default is ``False``. Returns ------- loss : torch.Tensor Loss value. loss_centers : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the centers of the target Gaussians. loss_sigmas : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the standard deviations of the target Gaussians. """returntda_loss(
H,
labels,
self.n_states,
self.target_centers,
self.target_sigmas,
self.alpha,
self.beta,
return_loss_terms,
)
deftda_loss(
H: torch.Tensor,
labels: torch.Tensor,
n_states: int,
target_centers: Union[List[float], torch.Tensor],
target_sigmas: Union[List[float], torch.Tensor],
alpha: float=1,
beta: float=100,
return_loss_terms: bool=False,
) ->Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
""" Compute a loss function as the distance from a simple Gaussian target distribution. Parameters ---------- H : torch.Tensor Shape ``(n_batches, n_features)``. Output of the NN. labels : torch.Tensor Shape ``(n_batches,)``. Labels of the dataset. n_states : int The integer labels are expected to be in between 0 and ``n_states-1``. target_centers : list or torch.Tensor Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets. target_sigmas : list or torch.Tensor Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets. alpha : float, optional Centers_loss component prefactor, by default 1. beta : float, optional Sigmas loss compontent prefactor, by default 100. return_loss_terms : bool, optional If ``True``, the loss terms associated to the center and standard deviations of the target Gaussians are returned as well. Default is ``False``. Returns ------- loss : torch.Tensor Loss value. loss_centers : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the centers of the target Gaussians. loss_sigmas : torch.Tensor, optional Only returned if ``return_loss_terms is True``. The value of the loss term associated to the standard deviations of the target Gaussians. """ifnotisinstance(target_centers, torch.Tensor):
target_centers=torch.tensor(target_centers)
ifnotisinstance(target_sigmas, torch.Tensor):
target_sigmas=torch.tensor(target_sigmas)
device=H.devicetarget_centers=target_centers.to(device)
target_sigmas=target_sigmas.to(device)
loss_centers=torch.zeros_like(target_centers, device=device)
loss_sigmas=torch.zeros_like(target_sigmas, device=device)
foriinrange(n_states):
# check which elements belong to class iifnot (labels==i).any():
raiseValueError(
f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!"
)
else:
H_red=H[labels==i]
# compute mean and standard deviation over the class imu=torch.mean(H_red, 0)
iflen(torch.nonzero(labels==i)) ==1:
warn(
f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!"
)
sigma=torch.tensor(0)
else:
sigma=torch.std(H_red, 0)
# compute loss function contributes for class iloss_centers[i] =alpha* (mu-target_centers[i]).pow(2)
loss_sigmas[i] =beta* (sigma-target_sigmas[i]).pow(2)
# get total model lossloss_centers=torch.sum(loss_centers)
loss_sigmas=torch.sum(loss_sigmas)
loss=loss_centers+loss_sigmasifreturn_loss_terms:
returnloss, loss_centers, loss_sigmasreturnloss
The text was updated successfully, but these errors were encountered:
Example input:
Errors:
conda list:
I tried some tinkering, and here is an example of a compilable
tda_loss
module, although I'm not really sure about the correctness.code:
The text was updated successfully, but these errors were encountered: