diff --git a/src/fairseq2/gang.py b/src/fairseq2/gang.py index b4d268a6e..4a91a0562 100644 --- a/src/fairseq2/gang.py +++ b/src/fairseq2/gang.py @@ -720,7 +720,7 @@ def setup_2D_mesh_gangs( First gang is the row in the mesh, second is the column. For example, assuming 8 devices denoted by g0 to g7, calling this function with ``row_length`` = 4 amounts to defining the 2D mesh - [[g0, g1, g2, g3], [g4, g5, g6, g7]] and creating 2 sets of gangs: + [[g0, g1, g2, g3], [g4, g5, g6, g7]] and making 2 sets of gangs: 2 gangs of size 4 (mesh rows): [g0, g1, g2, g3], [g4, g5, g6, g7] @@ -732,11 +732,11 @@ def setup_2D_mesh_gangs( (for example, 2 hosts: one with g0 to g3, and the other with g4 to g7), the first gang can be used to maximize local intra-host communication. - Example use-cases include creating tensor- and data- parallel gangs, or + Example use-cases include making tensor- and data- parallel gangs, or sharding and replicating gangs in FSDP's hybrid sharding. :param root_gang: - The gang whose topology will be used to create the new gangs. + The gang whose topology will be used to make the new gangs. :param row_length: The size of the gangs corresponding to the 2D mesh rows. :param create_single_rank_process_groups: @@ -774,7 +774,9 @@ def setup_2D_mesh_gangs( output = {} - log.info("Initializing sub-gangs for a 2D device mesh of shape {}.", list(mesh_shape)) + log.info( + "Initializing sub-gangs for a 2D device mesh of shape {}.", list(mesh_shape) + ) if dim_descriptions is None: dim_descriptions = [f"dim-{dim}" for dim in range(2)] @@ -783,7 +785,9 @@ def setup_2D_mesh_gangs( gang_size = mesh_shape[1 - dim] - log.info("Initializing {} gang with a size of {}.", dim_descriptions[dim], gang_size) + log.info( + "Initializing {} gang with a size of {}.", dim_descriptions[dim], gang_size + ) # Match row length (dim 0) or column length (dim 1) match gang_size: @@ -801,7 +805,7 @@ def setup_2D_mesh_gangs( sub_gang = root_gang.make_gang(ranks.tolist()) if i == rank_coords[dim]: current_subgang = sub_gang - + if current_subgang is None: raise InternalError(f"`current_gang` ({dim_descriptions[dim]}) is `None`.") @@ -843,16 +847,14 @@ def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs: raise GangError( f"The number of processes in the root gang is expected to be a multiple of the tensor parallel size ({tp_size}), but is {root_gang.size} instead." ) - + output_from_2D_mesh = setup_2D_mesh_gangs( root_gang, row_length=tp_size, dim_descriptions=["tensor parallel", "data parallel"], ) - return Gangs( - root_gang, output_from_2D_mesh[1], output_from_2D_mesh[0] - ) + return Gangs(root_gang, output_from_2D_mesh[1], output_from_2D_mesh[0]) def broadcast_flag(gang: Gang, flag: bool, source_rank: int = 0) -> bool: