You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
init_device_mesh is the recommended way to create a DeviceMesh, especially for nD DeviceMesh. However, init_device_mesh can only be used to initialized a world mesh -- meaning that the mesh has to contain all the ranks. Otherwise, the inference rules of DeviceMesh can be wrong and thus result in incorrect PG creation. More specifically, DeviceMesh uses get_rank() to understand which PG does this rank belong to.
DeviceMesh.from_group() can be used for manually creating the PG information. This will be correct but is impossible to let users to figure with nD DeviceMesh.
Combining init_device_mesh with DeviceMesh.from_group(), like extend_device_mesh() is likely to be wrong due to the nature of init_device_mesh requires the mesh to be world mesh.
The proposed solution is to let TorchFT provide ft_init_device_mesh and lie DeviceMesh about the dimension of replicate but it seems that this will still be incorrect because the other dimensions will still get incorrect PG information due to the usage of get_rank().
This is a tracking issue for anything related to getting FSDP working with torchtitan.
The text was updated successfully, but these errors were encountered: