Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apple GPU support & float32 dtype #14

Open
TomaSusi opened this issue Oct 7, 2024 · 2 comments
Open

Apple GPU support & float32 dtype #14

TomaSusi opened this issue Oct 7, 2024 · 2 comments

Comments

@TomaSusi
Copy link

TomaSusi commented Oct 7, 2024

Hi,

Just getting started with MACE but am really digging it! I was excited to see that you support Apple GPUs, but is that only for training? When I try to use a mace_off() or mace_mp() ASE calculator and specify both the dtype and the device, I get an error:

Using MACE-OFF23 MODEL for MACECalculator with /Users/tomasusi/.cache/mace/MACE-OFF23_medium.model
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[95], line 13
---> 13 calc = mace_off(device='mps', default_dtype='float32')

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/foundations_models.py:206](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/foundations_models.py#line=205), in mace_off(model, device, default_dtype, return_raw_model, **kwargs)
    202 if default_dtype == "float32":
    203     print(
    204         "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
    205     )
--> 206 mace_calc = MACECalculator(
    207     model_paths=model, device=device, default_dtype=default_dtype, **kwargs
    208 )
    209 return mace_calc

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:127](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=126), in MACECalculator.__init__(self, model_paths, device, energy_units_to_eV, length_units_to_A, default_dtype, charges_key, model_type, compile_mode, fullgraph, **kwargs)
    125     self.use_compile = True
    126 else:
--> 127     self.models = [
    128         torch.load(f=model_path, map_location=device)
    129         for model_path in model_paths
    130     ]
    131     self.use_compile = False
    132 for model in self.models:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:128](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=127), in <listcomp>(.0)
    125     self.use_compile = True
    126 else:
    127     self.models = [
--> 128         torch.load(f=model_path, map_location=device)
    129         for model_path in model_paths
    130     ]
    131     self.use_compile = False
    132 for model in self.models:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1097](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1096), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1095             except RuntimeError as e:
   1096                 raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
-> 1097         return _load(
   1098             opened_zipfile,
   1099             map_location,
   1100             pickle_module,
   1101             overall_storage=overall_storage,
   1102             **pickle_load_args,
   1103         )
   1104 if mmap:
   1105     f_name = "" if not isinstance(f, str) else f"{f}, "

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1525](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1524), in _load(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)
   1522 # Needed for tensors where storage device and rebuild tensor device are
   1523 # not connected (wrapper subclasses and tensors rebuilt using numpy)
   1524 torch._utils._thread_local_state.map_location = map_location
-> 1525 result = unpickler.load()
   1526 del torch._utils._thread_local_state.map_location
   1528 torch._utils._validate_loaded_sparse_tensors()

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py:200](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py#line=199), in _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
    197 def _rebuild_tensor_v2(
    198     storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
    199 ):
--> 200     tensor = _rebuild_tensor(storage, storage_offset, size, stride)
    201     tensor.requires_grad = requires_grad
    202     if metadata:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py:178](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py#line=177), in _rebuild_tensor(storage, storage_offset, size, stride)
    176 def _rebuild_tensor(storage, storage_offset, size, stride):
    177     # first construct a tensor with the correct dtype[/device](http://localhost:8889/device)
--> 178     t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
    179     return t.set_(storage._untyped_storage, storage_offset, size, stride)

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Or maybe this is just a simple bug..? I am running PyTorch 2.4.1.

@ilyes319
Copy link
Contributor

ilyes319 commented Oct 9, 2024

Hey @TomaSusi,

The models were trained with float64 and because MPS does not support float64, it is a bit of a pain to deliver the model on MPS.
The short term hack for you is the following:

  1. Download the models here for MP: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0 or mace-off: https://github.com/ACEsuit/mace-off. They are the .model files. Select the size you want.
  2. Load the model on CPU:
model = torch.load(model_path, device="cpu")
model = model.float()
torch.save(model, new_model_path)
  1. Create the calculator on MPS using the new model path:
calc = MACECalculator(new_model_path, device="mps")

@TomaSusi
Copy link
Author

Thanks for the quick reply!

Loading the model with the given syntax doesn't work on the latest pytorch:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[334], line 7
      4 model_path_off = 'mace-off[/MACE-OFF23_medium.model](http://localhost:8889/MACE-OFF23_medium.model)'
      5 new_model_path_off = 'mace-off[/MACE-OFF23_medium_mps.model](http://localhost:8889/MACE-OFF23_medium_mps.model)'
----> 7 model_off = torch.load(model_path_off, device="cpu")
      8 model_off = model_off.float()
      9 torch.save(model_off, new_model_path_off)

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1114](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1113), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1112     except RuntimeError as e:
   1113         raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
-> 1114 return _legacy_load(
   1115     opened_file, map_location, pickle_module, **pickle_load_args
   1116 )

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1338](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1337), in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
   1332 if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
   1333     raise RuntimeError(
   1334         "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
   1335         f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
   1336         "functionality.")
-> 1338 magic_number = pickle_module.load(f, **pickle_load_args)
   1339 if magic_number != MAGIC_NUMBER:
   1340     raise RuntimeError("Invalid magic number; corrupt file?")

TypeError: 'device' is an invalid keyword argument for load()

If I update this to map_location='cpu' (or remove the keyword), the model is loaded.

However, this results in another error:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[358], line 9
      7 prim.calc = calc_off
      8 print("forces on primitive:")
----> 9 print(prim.get_forces())

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/atoms.py:812](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/atoms.py#line=811), in Atoms.get_forces(self, apply_constraint, md)
    810 if self._calc is None:
    811     raise RuntimeError('Atoms object has no calculator.')
--> 812 forces = self._calc.get_forces(self)
    814 if apply_constraint:
    815     # We need a special md flag here because for MD we want
    816     # to skip real constraints but include special "constraints"
    817     # Like Hookean.
    818     for constraint in self.constraints:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/abc.py:30](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/abc.py#line=29), in GetPropertiesMixin.get_forces(self, atoms)
     29 def get_forces(self, atoms=None):
---> 30     return self.get_property('forces', atoms)

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/calculator.py:538](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/calculator.py#line=537), in BaseCalculator.get_property(self, name, atoms, allow_calculation)
    535     if self.use_cache:
    536         self.atoms = atoms.copy()
--> 538     self.calculate(atoms, [name], system_changes)
    540 if name not in self.results:
    541     # For some reason the calculator was not able to do what we want,
    542     # and that is OK.
    543     raise PropertyNotImplementedError(
    544         '{} not present in this ' 'calculation'.format(name)
    545     )

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:244](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=243), in MACECalculator.calculate(self, atoms, properties, system_changes)
    242 for i, model in enumerate(self.models):
    243     batch = self._clone_batch(batch_base)
--> 244     out = model(
    245         batch.to_dict(),
    246         compute_stress=compute_stress,
    247         training=self.use_compile,
    248     )
    249     if self.model_type in ["MACE", "EnergyDipoleMACE"]:
    250         ret_tensors["energies"][i] = out["energy"].detach()

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py:1553](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py:1562](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/models.py:395](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/models.py#line=394), in ScaleShiftMACE.forward(self, data, training, compute_force, compute_virials, compute_stress, compute_displacement, compute_hessian)
    393 total_energy = e0 + inter_e
    394 node_energy = node_e0 + node_inter_es
--> 395 forces, virials, stress, hessian = get_outputs(
    396     energy=inter_e,
    397     positions=data["positions"],
    398     displacement=displacement,
    399     cell=data["cell"],
    400     training=training,
    401     compute_force=compute_force,
    402     compute_virials=compute_virials,
    403     compute_stress=compute_stress,
    404     compute_hessian=compute_hessian,
    405 )
    406 output = {
    407     "energy": total_energy,
    408     "node_energy": node_energy,
   (...)
    415     "node_feats": node_feats_out,
    416 }
    418 return output

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py:185](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py#line=184), in get_outputs(energy, positions, displacement, cell, training, compute_force, compute_virials, compute_stress, compute_hessian)
    167 def get_outputs(
    168     energy: torch.Tensor,
    169     positions: torch.Tensor,
   (...)
    181     Optional[torch.Tensor],
    182 ]:
    183     if (compute_virials or compute_stress) and displacement is not None:
    184         # forces come for free
--> 185         forces, virials, stress = compute_forces_virials(
    186             energy=energy,
    187             positions=positions,
    188             displacement=displacement,
    189             cell=cell,
    190             compute_stress=compute_stress,
    191             training=(training or compute_hessian),
    192         )
    193     elif compute_force:
    194         forces, virials, stress = (
    195             compute_forces(
    196                 energy=energy,
   (...)
    201             None,
    202         )

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py:62](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py#line=61), in compute_forces_virials(energy, positions, displacement, cell, training, compute_stress)
     60 if compute_stress and virials is not None:
     61     cell = cell.view(-1, 3, 3)
---> 62     volume = torch.linalg.det(cell).abs().unsqueeze(-1)
     63     stress = virials [/](http://localhost:8889/) volume.view(-1, 1, 1)
     64     stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress))

NotImplementedError: The operator 'aten::_linalg_det.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Setting this environment variable does not seem to help (I tried both %env within the notebook and !export, and editing my .zprofile file. I posted into the suggested issue (pytorch/pytorch#77764 (comment)).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants