diff --git a/mmengine/model/wrappers/pipeline_distributed.py b/mmengine/model/wrappers/pipeline_distributed.py index f1b87eede3..cd12e3d71c 100644 --- a/mmengine/model/wrappers/pipeline_distributed.py +++ b/mmengine/model/wrappers/pipeline_distributed.py @@ -4,17 +4,17 @@ from functools import partial from queue import Queue from threading import Event, Thread, current_thread -from typing import Dict, List, Iterable, Union, Optional, Tuple +from typing import Any, Dict, List, Union, Optional, Tuple import psutil import torch import torch.nn as nn -from mmengine.analysis import get_model_complexity_info +from mmengine.analysis import FlopAnalyzer from mmengine.config import Config, ConfigDict -from mmengine.registry import MODEL_WRAPPERS +from mmengine.registry import MODELS, MODEL_WRAPPERS -ModelType = Union[Config, ConfigDict, dict, nn.Module, str] +ModelType = Union[Config, ConfigDict, dict, nn.Module] @MODEL_WRAPPERS.register_module() @@ -29,7 +29,7 @@ class MMPipelineParallel(nn.Module): hook_visited_times: Dict[str, int] = {} def __init__(self, - model: Optional[ModelType] = None, + model: ModelType, weights: Optional[str] = None, num_pipelines: Optional[int] = None, num_chunks: Optional[int] = None, @@ -37,7 +37,7 @@ def __init__(self, memory_map: Optional[Dict[str, int]] = None, no_split_module_classes: Optional[List[str]] = None, language_module_classes: Optional[str] = None, - device_map: Union[str, Dict[str, str]] = 'auto', + device_map: Union[str, Dict[str, dict]] = 'auto', offload_directory: Optional[str] = None, exec_entry: str = 'forward'): @@ -64,12 +64,18 @@ def __init__(self, self.no_split_module_classes = [] if no_split_module_classes is None \ else no_split_module_classes self.language_module_classes = language_module_classes + self.model_tree = self._get_model_tree() if isinstance(device_map, dict): - self.device_map = self._check_device_map(device_map) + self.device_map = device_map + self.offload_map = self._init_offload_map() + self.module_map = self._init_module_map() + self._inited = True else: - self.device_map = self._init_device_map(device_map) - self.offload_map = self._init_offload_map() - self.module_map = self._init_module_map() + # after we get the input shape, we can init the device map + self.device_map_policy = device_map + self.offload_map = None + self.module_map = None + self._inited = False # init offload directory self.offload_directory = offload_directory @@ -84,18 +90,27 @@ def __init__(self, # init stream contexts MMPipelineParallel.stream_contexts = self._init_stream_contexts() - # load weights, register hooks and dispatch model - self._load_and_dispatch(weights) - self._register_hooks() - def _init_model(self, model: ModelType) -> nn.Module: """ TODO """ - pass + if isinstance(model, nn.Module): + return model.to('meta') + elif isinstance(model, dict): + cfg = ConfigDict(model) + elif isinstance(model, (Config, ConfigDict)): + cfg = model + else: + raise TypeError( + f'Unsupported model type {type(model)}' + ) + with MMPipelineParallel._init_empty(): + model = MODELS.build(cfg) + return model @contextmanager - def _init_empty(self): + @staticmethod + def _init_empty(): """ TODO """ @@ -183,49 +198,198 @@ def _convert_memory_map(self, ) return memory_map - def _check_device_map(self, device_map: Dict[str, str]) -> Dict[str, str]: + def _get_model_tree(self) -> Dict[str, Any]: """ TODO """ - pass + def bfs(module: nn.Module, + prefix: str, + info: Dict[str, Any]): + # self + info['self'] = module + info['parameter_size'] = self._parameter_size(module) + info['flops'] = None + info['exec_order'] = None + info['checked'] = False + # buffer + if len(module._buffers) != 0: + info['buffers'] = {} + for name, buffer in module._buffers.items(): + curr_name = name if prefix == '' else f'{prefix}.{name}' + info['buffers'][curr_name] = buffer + # submodule + module_class_name = module.__class__.__name__ + if not (len(module._modules) == 0 or + module_class_name in self.no_split_module_classes): + info['submodules'] = {} + for name, submodule in module._modules.items(): + curr_name = name if prefix == '' else f'{prefix}.{name}' + info['submodules'][curr_name] = {} + bfs(submodule, curr_name, info['submodules'][curr_name]) + tree = {} + bfs(self.model, '', tree) + return tree + + def _parameter_size(self, module: nn.Module) -> int: + """ + TODO + """ + size = 0 + for _, param in module.named_parameters(): + size += param.nelement() * param.element_size() + return size - def _init_device_map(self, device_map: str) -> Dict[str, str]: + def _find_tied_weights(self) -> List[List[str]]: """ TODO """ pass - def _init_offload_map(self) -> Dict[str, str]: + def _get_meta_data(self, data_sample: Tuple(tuple, dict)): """ TODO """ - pass + args, kwargs = data_sample + args_meta = [] + for arg in args: + if isinstance(arg, torch.Tensor): + args_meta.append(arg.to('meta')) + else: + args_meta.append(arg) + kwargs_meta = {} + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs_meta[k] = v.to('meta') + else: + kwargs_meta[k] = v + data_meta = (tuple(args_meta), kwargs_meta) + return data_meta - def _init_module_map(self) -> Dict[str, str]: + def _get_flops(self, data_sample: Tuple(tuple, dict)): + """ + TODO + """ + data_meta = self._get_meta_data(data_sample) + flop_analyzer = FlopAnalyzer(self.model, inputs=data_meta) + flops = flop_analyzer.by_module() + # merge into model tree + for name, num_flops in flops.items(): + tree = self.model_tree + if name == '': + tree['flops'] = num_flops + else: + name_split = name.split('.') + for i in range(len(name_split)): + curr_name = '.'.join(name_split[:i + 1]) + if i == len(name_split) - 1: + tree[curr_name]['flops'] = num_flops + else: + # due to no_split_module_classes + if 'submodules' not in tree[curr_name]: + continue + else: + tree = tree[curr_name]['submodules'] + + def _get_exec_order(self, data_sample: Tuple(tuple, dict)): + """ + TODO + """ + exec_order = [] + + def return_module_name_hook(module: nn.Module, + args: tuple, + kwargs: dict): + exec_order.append(module.__class__.__name__) + + handle = nn.modules.module.register_module_forward_pre_hook( + return_module_name_hook + ) + + data_meta = self._get_meta_data(data_sample) + with torch.no_grad(): + self.entry(*data_meta[0], **data_meta[1]) + + handle.remove() + # merge into model tree + for name, order in enumerate(exec_order): + tree = self.model_tree + if name == '': + tree['exec_order'] = order + else: + name_split = name.split('.') + for i in range(len(name_split)): + curr_name = '.'.join(name_split[:i + 1]) + if i == len(name_split) - 1: + tree[curr_name]['exec_order'] = order + else: + # due to no_split_module_classes + if 'submodules' not in tree[curr_name]: + continue + else: + tree = tree[curr_name]['submodules'] + + def _init_device_map(self, device_map: str) -> Dict[str, dict]: """ TODO """ pass + def _init_offload_map(self) -> Dict[int, int]: + """ + TODO + """ + curr_part_id = -1 + offload_map = {} + for info in self.device_map.values(): + if info['part_id'] != curr_part_id: + curr_part_id = info['part_id'] + if info['init_device'] == 'cpu' or \ + info['init_device'] == 'disk': + offload_map[curr_part_id] = 0 + else: + offload_map[curr_part_id] = 1 + return offload_map + + def _init_module_map(self) -> Dict[str, dict]: + """ + TODO + """ + module_map = {} + for name, info in self.device_map.items(): + name_split = name.split('.') + tree = self.model_tree + for i in range(len(name_split)): + curr_name = '.'.join(name_split[:i + 1]) + if i == len(name_split) - 1: + module = tree[curr_name]['self'] + else: + tree = tree[curr_name]['submodules'] + module_map[name] = { + 'module': module, + 'curr_device': info['init_device'], + 'part_id': info['part_id'], + } + return module_map + def _init_queues(self) -> Dict[str, Queue]: """ TODO """ in_queues, out_queues = {}, {} # init move queues - for move in ['in', 'out']: + for i in range(self.num_pipelines): in_queue = Queue() out_queue = Queue() thread = Thread( target=MMPipelineParallel._worker, args=(in_queue, out_queue), - name=f'move-{move}', + name=f'part-{i}', daemon=True ) thread.start() - in_queues[f'move-{move}'] = in_queue - out_queues[f'move-{move}'] = out_queue + in_queues[f'part-{i}'] = in_queue + out_queues[f'part-{i}'] = out_queue # init execution queues for i in range(self.num_chunks): in_queue = Queue() @@ -257,12 +421,29 @@ def _init_stream_contexts(self) -> List[torch.cuda.StreamContext]: """ TODO """ - pass + curr_part_id = -1 + inited_streams = {} + stream_contexts = [] + for info in self.device_map.values(): + if info['part_id'] != curr_part_id: + curr_part_id = info['part_id'] + exec_device = info['exec_device'] + if exec_device not in inited_streams: + stream = torch.cuda.Stream(exec_device) + inited_streams[exec_device] = stream + else: + stream = inited_streams[exec_device] + stream_context = torch.cuda.stream(stream) + stream_contexts.append(stream_context) + return stream_contexts def _load_and_dispatch(self, weights: Optional[str] = None): """ TODO """ + if weights is None: + return + # load weights pass def _register_hooks(self): @@ -288,7 +469,7 @@ def _worker(in_queue: Queue, out_queue: Queue): done = (False, None) out_queue.put(done) - def _clock(self, num_chunks: int): + def _clock(self, num_chunks: int) -> List[Union[Tuple[int, int], str]]: """ TODO """ @@ -330,7 +511,7 @@ def __init__(self, def __call__(self, module: nn.Module, args: tuple, - kwargs: dict): + kwargs: dict) -> Tuple(tuple, dict): """ TODO """ @@ -392,14 +573,59 @@ def _enter_curr_part(self): MMPipelineParallel.stream_contexts[ self.part_id].__enter__() - def _chunk_data(self, data: Iterable) -> List[Iterable]: + def _chunk_data(self, args: tuple, kwargs: dict) -> List[Tuple[tuple, dict]]: """ TODO """ - pass + # args + chunked_args = [None for _ in range(len(args))] + for i in range(len(args)): + if isinstance(args[i], torch.Tensor): + chunked_args[i] = torch.chunk(args[i], self.num_chunks) + else: + chunked_args[i] = [args[i]] * self.num_chunks + # kwargs + chunked_kwargs = {} + for k in kwargs: + if isinstance(kwargs[k], torch.Tensor): + chunked_kwargs[k] = torch.chunk(kwargs[k], self.num_chunks) + else: + chunked_kwargs[k] = [kwargs[k]] * self.num_chunks + # merge + lengths = [len(arg) for arg in chunked_args] + \ + [len(v) for v in chunked_kwargs.values()] + real_num_chunks = min(lengths) + chunked_data = [] + + for i in range(real_num_chunks): + chunked_data.append( + ( + tuple([arg[i] for arg in chunked_args]), + {k: v[i] for k, v in chunked_kwargs.items()} + ) + ) + return chunked_data - def forward(self, data): + def forward(self, *args, **kwargs): """ TODO """ - pass + exec_info = None + chunked_data = self._chunk_data(args, kwargs) + # get flops, init device map, offload map, module map and exec order + if not self._inited: + self._get_flops(chunked_data[0]) + self._get_exec_order(chunked_data[0]) + self.device_map = self._init_device_map(self.device_map_policy) + self.offload_map = self._init_offload_map() + self.module_map = self._init_module_map() + self._inited = True + + num_chunks = min(len(chunked_data), self.num_chunks) + # record finished chunks + finished_chunks = set() + # clear visited times + MMPipelineParallel.hook_visited_times = {} + # main loop + for schedule in self._clock(num_chunks): + pass