Skip to content

Commit

Permalink
Better wording
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinGleize committed Dec 13, 2024
1 parent 62f7265 commit 6122a65
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/fairseq2/gang.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def setup_2D_mesh_gangs(
First gang is the row in the mesh, second is the column.
For example, assuming 8 devices denoted by g0 to g7, calling this function
with ``row_length`` = 4 amounts to defining the 2D mesh
[[g0, g1, g2, g3], [g4, g5, g6, g7]] and creating 2 sets of gangs:
[[g0, g1, g2, g3], [g4, g5, g6, g7]] and making 2 sets of gangs:
2 gangs of size 4 (mesh rows):
[g0, g1, g2, g3], [g4, g5, g6, g7]
Expand All @@ -732,11 +732,11 @@ def setup_2D_mesh_gangs(
(for example, 2 hosts: one with g0 to g3, and the other with g4 to g7),
the first gang can be used to maximize local intra-host communication.
Example use-cases include creating tensor- and data- parallel gangs, or
Example use-cases include making tensor- and data- parallel gangs, or
sharding and replicating gangs in FSDP's hybrid sharding.
:param root_gang:
The gang whose topology will be used to create the new gangs.
The gang whose topology will be used to make the new gangs.
:param row_length:
The size of the gangs corresponding to the 2D mesh rows.
:param create_single_rank_process_groups:
Expand Down Expand Up @@ -774,7 +774,9 @@ def setup_2D_mesh_gangs(

output = {}

log.info("Initializing sub-gangs for a 2D device mesh of shape {}.", list(mesh_shape))
log.info(
"Initializing sub-gangs for a 2D device mesh of shape {}.", list(mesh_shape)
)
if dim_descriptions is None:
dim_descriptions = [f"dim-{dim}" for dim in range(2)]

Expand All @@ -783,7 +785,9 @@ def setup_2D_mesh_gangs(

gang_size = mesh_shape[1 - dim]

log.info("Initializing {} gang with a size of {}.", dim_descriptions[dim], gang_size)
log.info(
"Initializing {} gang with a size of {}.", dim_descriptions[dim], gang_size
)

# Match row length (dim 0) or column length (dim 1)
match gang_size:
Expand All @@ -801,7 +805,7 @@ def setup_2D_mesh_gangs(
sub_gang = root_gang.make_gang(ranks.tolist())
if i == rank_coords[dim]:
current_subgang = sub_gang

if current_subgang is None:
raise InternalError(f"`current_gang` ({dim_descriptions[dim]}) is `None`.")

Expand Down Expand Up @@ -843,16 +847,14 @@ def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs:
raise GangError(
f"The number of processes in the root gang is expected to be a multiple of the tensor parallel size ({tp_size}), but is {root_gang.size} instead."
)

output_from_2D_mesh = setup_2D_mesh_gangs(
root_gang,
row_length=tp_size,
dim_descriptions=["tensor parallel", "data parallel"],
)

return Gangs(
root_gang, output_from_2D_mesh[1], output_from_2D_mesh[0]
)
return Gangs(root_gang, output_from_2D_mesh[1], output_from_2D_mesh[0])


def broadcast_flag(gang: Gang, flag: bool, source_rank: int = 0) -> bool:
Expand Down

0 comments on commit 6122a65

Please sign in to comment.