diff --git a/examples/huggingface/pippy_gpt2.py b/examples/huggingface/pippy_gpt2.py index e19ef7974..2a600cc4e 100644 --- a/examples/huggingface/pippy_gpt2.py +++ b/examples/huggingface/pippy_gpt2.py @@ -21,7 +21,7 @@ def run(args): config.n_embd = args.n_embd or config.n_embd config.n_layer = args.n_layer or config.n_layer config.n_head = args.n_head or config.n_head - print("Using device:", args.device) + print("[Rank {}] Using device: {}".format(args.rank, args.device)) # Create model model_class = GPT2ForSequenceClassification @@ -38,6 +38,8 @@ def run(args): example_inputs = generate_inputs_for_model( model_class, gpt2, model_name, args.batch_size, args.device) + assert not args.autosplit or not args.graphsplit + split_policy = None split_spec = None @@ -45,6 +47,10 @@ def run(args): # Automatic split from pippy import split_into_equal_size split_policy = split_into_equal_size(args.world_size) + elif args.graphsplit: + # Graph-based split + from pippy import split_by_graph + split_policy = split_by_graph(args.world_size) else: # Use manual split spec decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size @@ -106,6 +112,7 @@ def run(args): parser.add_argument('--n_layer', type=int, default=None) parser.add_argument('--n_head', type=int, default=None) parser.add_argument('--autosplit', action="store_true") + parser.add_argument('--graphsplit', action="store_true") args = parser.parse_args() diff --git a/pippy/ModelSplit.py b/pippy/ModelSplit.py index bb0e8be1e..db5082e0a 100644 --- a/pippy/ModelSplit.py +++ b/pippy/ModelSplit.py @@ -5,6 +5,8 @@ import torch import torch.fx as fx +from pippy.graphsplit import split_by_graph_with_num_stages + from ._IR import aten_pipe_split_alias diff --git a/pippy/_IR.py b/pippy/_IR.py index 5e2e9c23a..639e41e97 100644 --- a/pippy/_IR.py +++ b/pippy/_IR.py @@ -925,10 +925,10 @@ def set_multi_use_param_spec( if isinstance(multi_use_param_spec, MultiUseParameterConfig): multi_use_params_qualnames[param] = multi_use_param_spec elif isinstance(multi_use_param_spec, dict): - multi_use_params_qualnames[ - param - ] = multi_use_param_spec.get( - param, MultiUseParameterConfig.TRANSMIT + multi_use_params_qualnames[param] = ( + multi_use_param_spec.get( + param, MultiUseParameterConfig.TRANSMIT + ) ) else: raise ValueError( diff --git a/pippy/graphsplit.py b/pippy/graphsplit.py index 0acaa45e1..1b5aa01db 100644 --- a/pippy/graphsplit.py +++ b/pippy/graphsplit.py @@ -28,21 +28,26 @@ @dataclass class Node: name: str - weight: int + memory_weight: int + comm_weight: int stage: Optional[int] - gm_node: fx.Node + gm_nodes: List[fx.Node] + + def __hash__(self): + return hash(self.name) @dataclass class Edge: source: int target: int - weight: int + comm_weight: int -MAX_MEMORY_IMBALANCE = 2.0 -MAX_COMMUNICATION_IMBALANCE = 1.05 -SCIPY_TIME_LIMIT_SEC = 60 +MAX_MEMORY_IMBALANCE = 2.5 +MAX_COMMUNICATION_IMBALANCE = 1.1 +SCIPY_TIME_LIMIT_SEC = 30 +PRESOLVE = True """ @@ -75,6 +80,11 @@ def split_by_graph_with_num_stages( # Extract the graph data nodes, edges = _build_splitting_graph(gm, node_param_sizes) + # Pre-process the input graph by merging pairs of nodes that need to be in + # the same stage. This reduces the size of the instance for the main solver + if PRESOLVE: + nodes, edges = _split_presolve(nodes, edges) + # Run the splitting algorithm with the specified options _split_by_milp( nodes, @@ -88,7 +98,14 @@ def split_by_graph_with_num_stages( if PIPPY_VERBOSITY == "DEBUG": _print_splitting_stats(nodes, edges, num_stages) - return {n.gm_node: n.stage for n in nodes if n.stage is not None} + # Prepare the result + result = {} + for node in nodes: + if node.stage is None: + continue + for gm_node in node.gm_nodes: + result[gm_node] = node.stage + return result """ @@ -119,15 +136,15 @@ def _build_splitting_graph( weight = sum(v for _, v in node_param_sizes[node].items()) else: weight = 0 - if "example_value" in node.meta: - tensors = node.meta["example_value"] + if "val" in node.meta: + tensors = node.meta["val"] if isinstance(tensors, torch.Tensor): tensors = [tensors] activation_size[node] = sum(t.numel() for t in tensors) else: activation_size[node] = 0 node_index[node.name] = len(nodes) - nodes.append(Node(node.name, weight, None, node)) + nodes.append(Node(node.name, weight, 0, None, [node])) # Build edges edges: List[Edge] = [] @@ -144,10 +161,10 @@ def _build_splitting_graph( # Verify the collected data assert ( - sum(node.weight for node in nodes) > 0 + sum(node.memory_weight for node in nodes) > 0 ), "node weights cannot be empty" assert ( - sum(edge.weight for edge in edges) > 0 + sum(edge.comm_weight for edge in edges) > 0 ), "edge weights cannot be empty" return nodes, edges @@ -179,7 +196,7 @@ def _split_by_milp( M = len(edges) K = num_stages logger.info( - "Splitting graph with {} nodes and {} edges into {} stages".format( + "Splitting a graph with {} nodes and {} edges into {} stages".format( N, M, K ) ) @@ -237,14 +254,6 @@ def edge_var(edge_idx: int, stage: int) -> int: constraints.append(LinearConstraint(A=A, lb=1, ub=1)) # Constraint 3: - # - every stage contains at least one edge; - for j in range(K): - A = zeros(num_vars) - for i in range(M): - A[edge_var(i, j)] = 1 - constraints.append(LinearConstraint(A=A, lb=1)) - - # Constraint 4: # - edges go from a lower-index stage to an upper-index stage; multiplier = [2 ** (K - j - 1) for j in range(K)] for i in range(M): @@ -255,47 +264,56 @@ def edge_var(edge_idx: int, stage: int) -> int: A[node_var(edge.source, j)] = -multiplier[j] constraints.append(LinearConstraint(A=A, ub=0)) - # Constraint 5: + # Constraint 4: # - nodes in every stage have (approximately) the same total weight; - sum_node_weights = sum(node.weight for node in nodes) + sum_node_weights = sum(node.memory_weight for node in nodes) max_node_weight_per_stage = ( sum_node_weights * allowed_node_imbalance / float(K) ) for j in range(K): A = zeros(num_vars) for i in range(N): - A[node_var(i, j)] = nodes[i].weight + A[node_var(i, j)] = nodes[i].memory_weight constraints.append(LinearConstraint(A=A, ub=max_node_weight_per_stage)) - # Constraint 6: + # Constraint 5: # - edges in every stage have (approximately) the same total weight; - sum_edge_weights = sum(edge.weight for edge in edges) + sum_edge_weights = sum(edge.comm_weight for edge in edges) + sum( + node.comm_weight for node in nodes + ) max_edge_weight_per_stage = ( sum_edge_weights * allowed_edge_imbalance / float(K) ) for j in range(K): A = zeros(num_vars) for i in range(M): - A[edge_var(i, j)] = edges[i].weight + A[edge_var(i, j)] = edges[i].comm_weight + for i in range(N): + A[node_var(i, j)] = nodes[i].comm_weight constraints.append(LinearConstraint(A=A, ub=max_edge_weight_per_stage)) # Define the optimization objective: - # - the auxiliary variable equals to the maximum total edge weight in a stage; + # - the auxiliary variable equals to the maximum total edge-weight in a stage; edge_weight_per_stage = sum_edge_weights / float(K) for j in range(K): A = zeros(num_vars) A[edge_aux_var] = -edge_weight_per_stage for i in range(M): - A[node_var(edges[i].source, j)] += edges[i].weight - A[node_var(edges[i].target, j)] += edges[i].weight - A[edge_var(i, j)] = -edges[i].weight + edge = edges[i] + A[node_var(edge.source, j)] += edge.comm_weight + A[node_var(edge.target, j)] += edge.comm_weight + A[edge_var(i, j)] = -edge.comm_weight + for i in range(N): + A[node_var(i, j)] += nodes[i].comm_weight constraints.append(LinearConstraint(A=A, ub=0)) # - minimize the sum of inter-weight edges; c = zeros(num_vars) - for i in range(M): - for j in range(K): - c[edge_var(i, j)] = -edges[i].weight + for j in range(K): + for i in range(M): + c[edge_var(i, j)] = -edges[i].comm_weight - 1 + for i in range(N): + c[node_var(i, j)] = -nodes[i].comm_weight c[edge_aux_var] = edge_weight_per_stage # Solve the MILP problem using scipy @@ -344,6 +362,97 @@ def edge_var(edge_idx: int, stage: int) -> int: ) +""" +Pre-solve the splitting problem by merging nodes that needs to be in the same +stage. The algorithm works by +""" + + +def _split_presolve(nodes: List[Node], edges: List[Edge]): + # Count the in- and out- degree of each node + in_degree: Dict[Node, int] = defaultdict(int) + out_degree: Dict[Node, int] = defaultdict(int) + for edge in edges: + out_degree[nodes[edge.source]] += 1 + in_degree[nodes[edge.target]] += 1 + + # Initialize singleton clusters + clusters: List[List[Node]] = [] + node2cluster: Dict[Node, int] = defaultdict(int) + for node in nodes: + node2cluster[node] = len(clusters) + clusters.append([node]) + + def should_merge_edge(src, dst): + """Decide whether the edge src->dst should be merged at pre-solving""" + # already merged + if node2cluster[src] == node2cluster[dst]: + return False + # always merge sources having a unique successor + if in_degree[src] == 0 and out_degree[src] == 1: + return True + # always merge sinks having a unique predecessor + if in_degree[dst] == 1 and out_degree[dst] == 0: + return True + # merge chains of degree-1 nodes + if in_degree[src] == 1 and out_degree[src] == 1 and in_degree[dst] == 1: + return True + return False + + # Merge edges in the decreasing order of their weight + sorted_edges = sorted(edges, key=lambda e: e.comm_weight, reverse=True) + for edge in sorted_edges: + src = nodes[edge.source] + dst = nodes[edge.target] + if not should_merge_edge(src, dst): + continue + cluster_src = clusters[node2cluster[src]] + cluster_dst = clusters[node2cluster[dst]] + cluster_src.extend(cluster_dst) + for node_dst in cluster_dst: + node2cluster[node_dst] = node2cluster[src] + cluster_dst.clear() + + # Collect the resulting nodes + merged_nodes: List[Node] = [] + node_index = {} + for chain_idx, cluster in enumerate(clusters): + if len(cluster) == 0: + continue + name = cluster[0].name + gm_nodes = [] + for node in cluster: + node_index[node.name] = len(merged_nodes) + gm_nodes.extend(node.gm_nodes) + mem_weight = sum(node.memory_weight for node in cluster) + comm_weight = sum( + edge.comm_weight + for edge in edges + if nodes[edge.source] in cluster and nodes[edge.target] in cluster + ) + merged_nodes.append(Node(name, mem_weight, comm_weight, None, gm_nodes)) + + # Collect the resulting edges + merged_edges: List[Edge] = [] + for edge in edges: + src = nodes[edge.source] + dst = nodes[edge.target] + if node2cluster[src] == node2cluster[dst]: + continue + source_idx = node_index[src.name] + target_idx = node_index[dst.name] + merged_edges.append(Edge(source_idx, target_idx, edge.comm_weight)) + + logger.info( + "merged {} nodes and {} edges; max cluster has size {}".format( + len(nodes) - len(merged_nodes), + len(edges) - len(merged_edges), + max(len(c) for c in clusters), + ) + ) + return merged_nodes, merged_edges + + def _print_splitting_stats(nodes, edges, num_stages): """Compute and print various stats related to graph splitting""" print( @@ -351,7 +460,7 @@ def _print_splitting_stats(nodes, edges, num_stages): len(nodes), len(edges), num_stages ) ) - sum_node_weights = sum(node.weight for node in nodes) + sum_node_weights = sum(node.memory_weight for node in nodes) node_bound_per_stage = sum_node_weights / float(num_stages) print( "Max allowed node weight per stage: {:,} ({})".format( @@ -359,7 +468,9 @@ def _print_splitting_stats(nodes, edges, num_stages): MAX_MEMORY_IMBALANCE, ) ) - sum_edge_weights = sum(edge.weight for edge in edges) + sum_edge_weights = sum(edge.comm_weight for edge in edges) + sum( + node.comm_weight for node in nodes + ) edge_bound_per_stage = sum_edge_weights / float(num_stages) print( "Max allowed edge weight per stage: {:,} ({})".format( @@ -375,13 +486,15 @@ def _print_splitting_stats(nodes, edges, num_stages): # Extract nodes/edges/weight per stage num_nodes: Dict[int, int] = defaultdict(int) num_edges: Dict[int, int] = defaultdict(int) - node_weight: Dict[int, int] = defaultdict(int) - edge_weight: Dict[int, int] = defaultdict(int) + mem_weight: Dict[int, int] = defaultdict(int) + com_weight: Dict[int, int] = defaultdict(int) adj_weight: Dict[int, int] = defaultdict(int) for node in nodes: - num_nodes[node.stage] += 1 - node_weight[node.stage] += node.weight + num_nodes[node.stage] += len(node.gm_nodes) + mem_weight[node.stage] += node.memory_weight + com_weight[node.stage] += node.comm_weight + adj_weight[node.stage] += node.comm_weight cross_weight = 0 cross_edges = 0 @@ -393,38 +506,39 @@ def _print_splitting_stats(nodes, edges, num_stages): dst_stage = nodes[edge.target].stage if src_stage == dst_stage: num_edges[src_stage] += 1 - edge_weight[src_stage] += edge.weight - adj_weight[src_stage] += edge.weight + com_weight[src_stage] += edge.comm_weight + adj_weight[src_stage] += edge.comm_weight else: - cross_weight += edge.weight + cross_weight += edge.comm_weight cross_edges += 1 - cross_stage_weight[src_stage][dst_stage] += edge.weight - adj_weight[src_stage] += edge.weight - adj_weight[dst_stage] += edge.weight + cross_stage_weight[src_stage][dst_stage] += edge.comm_weight + adj_weight[src_stage] += edge.comm_weight + adj_weight[dst_stage] += edge.comm_weight # Print the stats + total_nodes = sum(len(node.gm_nodes) for node in nodes) for stage in range(num_stages): print(" Stage {}:".format(stage), end="") print( - " #nodes = {:3d} ({:.1f}%); node_weight = {:,} ({:5.1f}%);".format( + " #nodes = {:3d} ({:.1f}%); mem_weight = {:,} ({:5.1f}%);".format( num_nodes[stage], - 100.0 * num_nodes[stage] / len(nodes), - node_weight[stage], - 100.0 * node_weight[stage] / node_bound_per_stage, + 100.0 * num_nodes[stage] / total_nodes, + mem_weight[stage], + 100.0 * mem_weight[stage] / node_bound_per_stage, ), end="", ) print( - " #edges = {:3d} ({:.1f}%);".format( + " #edges = {:3d} ({:4.1f}%);".format( num_edges[stage], 100.0 * num_edges[stage] / len(edges), ), end="", ) print( - " edge_weight = {:,} ({:5.1f}%); adj_weight = {:,} ({:5.1f}%)".format( - edge_weight[stage], - 100.0 * edge_weight[stage] / edge_bound_per_stage, + " com_weight = {:,} ({:5.1f}%); adj_weight = {:,} ({:5.1f}%)".format( + com_weight[stage], + 100.0 * com_weight[stage] / edge_bound_per_stage, adj_weight[stage], 100.0 * adj_weight[stage] / edge_bound_per_stage, ) diff --git a/test/test_graphsplit.py b/test/test_graphsplit.py index aa746a69a..31c40dd25 100644 --- a/test/test_graphsplit.py +++ b/test/test_graphsplit.py @@ -7,9 +7,7 @@ import torch import torch.distributed as dist -from pippy import pipeline, split_by_graph -from pippy.PipelineSchedule import ScheduleGPipe -from pippy.PipelineStage import PipelineStage +from pippy import pipeline, PipelineStage, ScheduleGPipe, split_by_graph pippy.microbatch._debug_mask_minibatches = True @@ -48,11 +46,13 @@ def run_worker(args): x = torch.randn(batch_size, d_hid, device=args.device) + split_policy = split_by_graph(args.world_size) + pipe = pipeline( mod, args.chunks, example_args=(x,), - split_policy=split_by_graph(args.world_size, args.rank), + split_policy=split_policy, ) # Check returned number of stages diff --git a/test/test_stage_backward.py b/test/test_stage_backward.py index 06a85cb47..dcf8e9dac 100644 --- a/test/test_stage_backward.py +++ b/test/test_stage_backward.py @@ -10,6 +10,7 @@ d_hid = 512 batch_size = 256 + # MLP as a stage module class MLPModule(torch.nn.Module): def __init__(self, d_hid):