From ccf6b3b7efb3bae2ddaa27b8ef6a98c9975f3550 Mon Sep 17 00:00:00 2001 From: mollon Date: Fri, 10 May 2024 15:40:25 +0800 Subject: [PATCH] [iree-turbine] support simple mlp training for cuda by dynamo & ignore torch.none when it appear in backward graph --- examples/mlp_train/ut_mlp_train.py | 60 ++++++++++++++++++++++++++++ shark_turbine/dynamo/backends/cpu.py | 59 +++++++++++++++++++++++---- shark_turbine/dynamo/executor.py | 47 ++++++++++++++++------ 3 files changed, 145 insertions(+), 21 deletions(-) create mode 100644 examples/mlp_train/ut_mlp_train.py diff --git a/examples/mlp_train/ut_mlp_train.py b/examples/mlp_train/ut_mlp_train.py new file mode 100644 index 00000000..9d593dae --- /dev/null +++ b/examples/mlp_train/ut_mlp_train.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import numpy as np +import matplotlib.pyplot as plt + +device = 'cuda' + +# [ y = W_n * x_n + W_{n-1} * x_{n-1} + ... + W_1 * x_1 + b ] +torch.cuda.manual_seed_all(0) +x = torch.linspace(-1, 1, 100).reshape(-1) +y = 3 * x + 2 + torch.randn(x.size()) * 0.2 + +# cvt to tensor +x = torch.tensor(x, dtype=torch.float32).to(device) +y = torch.tensor(y, dtype=torch.float32).to(device) +print(x) +class SimpleMLP(nn.Module): + def __init__(self): + super(SimpleMLP, self).__init__() + self.weight = nn.Parameter(torch.randn(1, requires_grad=True)) + print(self.weight) + self.bias = nn.Parameter(torch.randn(1, requires_grad=True)) + + def forward(self, x : torch.Tensor): + out = x * self.weight + self.bias + return out + + +# model = SimpleMLP().to(device) +mod = SimpleMLP().to(device) + +model = torch.compile(mod, backend='turbine_cpu') + +learning_rate = 0.1 +optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) +loss_func = nn.MSELoss() + +epochs = 2000 +for epoch in range(epochs): + y_pred = model(x) + # print(y_pred) + + loss = loss_func(y_pred.to(device), y.to(device)) + + optimizer.zero_grad() + # loss = y_pred.sum() + # loss = loss.to(device) + loss.backward() + + optimizer.step() + + if (epoch + 1) % 10 == 0: + print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}') + +predicted = model(x).detach().cpu().numpy() +plt.plot(x.cpu().numpy(), y.cpu().numpy(), 'ro', label='Original data') +plt.plot(x.cpu().numpy(), predicted, label='Fitted line') +plt.legend() +plt.savefig('fitting_result.png') +plt.close() diff --git a/shark_turbine/dynamo/backends/cpu.py b/shark_turbine/dynamo/backends/cpu.py index 521d0380..761fd77b 100644 --- a/shark_turbine/dynamo/backends/cpu.py +++ b/shark_turbine/dynamo/backends/cpu.py @@ -6,6 +6,7 @@ import functools import sys +import os from ...runtime.device import ( DeviceState, @@ -16,6 +17,7 @@ ) from iree.compiler.api import ( + _initializeGlobalCL, Invocation, Session, Source, @@ -38,11 +40,31 @@ import torch from torch._dynamo.backends.common import aot_autograd from ..passes import turbine_cpu_pass_pipeline +from typing import Any, List +from functorch.compile import min_cut_rematerialization_partition -DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",) +DEFAULT_COMPILER_FLAGS = ( + "--iree-input-type=torch", + ) + +global_cl_options = [] +if os.getenv("mlir_print_ir_after_all") is not None: + global_cl_options.append("--mlir-print-ir-after-all") + global_cl_options.append("--mlir-print-ir-after-change") + +if os.getenv("mlir_print_ir_before_all") is not None: + global_cl_options.append("--mlir-print-ir-before-all") + + +if len(global_cl_options) != 0: + _initializeGlobalCL("dynamo", *global_cl_options) +def device_from_inputs(example_inputs) -> torch.device: + for x in example_inputs: + if hasattr(x, "device"): + return x.device -def _base_backend(gm: torch.fx.GraphModule, example_inputs): +def _base_backend(gm: torch.fx.GraphModule, example_inputs, is_fw=True): # Set up the session, context and invocation. # Note that we do this on one in-memory module in a few phases: # 1. Build it from the FX graph. @@ -52,7 +74,18 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs): # 4. Output to an mmap buffer. session = Session() session.set_flags(*DEFAULT_COMPILER_FLAGS) - session.set_flags("--iree-hal-target-backends=llvm-cpu") + + device = device_from_inputs(example_inputs) + + + device_index = None + device_type = device.type + if device_type == "cpu": + session.set_flags("--iree-hal-target-backends=llvm-cpu") + elif device_type == "cuda": + device_index = device.index + session.set_flags("--iree-hal-target-backends=cuda") + context = session.context importer = FxImporter(context=context) module = importer.module @@ -65,6 +98,8 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs): gm = turbine_cpu_pass_pipeline(gm, example_inputs) # Import phase. + print("before import graph") + print(gm.print_readable(), file=sys.stderr) importer.import_graph_module(gm) print(module, file=sys.stderr) with context: @@ -80,7 +115,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs): inv.output_vm_bytecode(output) # Set up for runtime. - device_state = _get_device_state() + device_state = _get_device_state(device_type, device_index) # TODO: Switch to wrap_buffer once https://github.com/openxla/iree/issues/14926 # is fixed. # vmfb_module = VmModule.wrap_buffer( @@ -94,14 +129,22 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs): ) output.close() - return SpecializedExecutable(vmfb_module, device_state) + return SpecializedExecutable(vmfb_module, device_state, importer.anticipated_return_value) +def _base_backend_fw(gm: torch.fx.GraphModule, example_inputs): + return _base_backend(gm, example_inputs, is_fw=True) -backend = aot_autograd(fw_compiler=_base_backend) +def _base_backend_bw(gm: torch.fx.GraphModule, example_inputs): + return _base_backend(gm, example_inputs, is_fw=False) +backend = aot_autograd(fw_compiler=_base_backend_fw, bw_compiler=_base_backend_bw, partition_fn=functools.partial(min_cut_rematerialization_partition, compiler="turbine_cpu")) # IREE runtime globals. For the CPU right now, there is no device selection, # so it is easy. @functools.lru_cache(maxsize=None) -def _get_device_state() -> DeviceState: - return DeviceState(driver="local-task") +def _get_device_state(device_type, device_index) -> DeviceState: + if device_type == "cpu": + return DeviceState(driver="local-task") + elif device_type == "cuda": + return DeviceState(driver="cuda", enumerated_info={'device_id':device_index}) + \ No newline at end of file diff --git a/shark_turbine/dynamo/executor.py b/shark_turbine/dynamo/executor.py index 561de515..d037d789 100644 --- a/shark_turbine/dynamo/executor.py +++ b/shark_turbine/dynamo/executor.py @@ -8,6 +8,7 @@ import os from typing import List, Optional, Sequence, Union from dataclasses import dataclass +import torch.nn as nn from iree.runtime import ( asdevicearray, create_hal_module, @@ -31,7 +32,7 @@ ) from ..runtime.device import Device, DeviceState - +from ..dynamo.tensor import dtype_to_element_type @functools.lru_cache(maxsize=None) def get_vm_instance() -> VmInstance: @@ -64,12 +65,14 @@ class SpecializedExecutable: "entry_function", "user_module", "vm_context", + "anticipated_return_value", ] def __init__( self, user_module: VmModule, device_state: DeviceState, + anticipated_return_value: list[bool], entry_name: str = "main", ): self.user_module = user_module @@ -81,6 +84,7 @@ def __init__( ), ) self.device_state = device_state + self.anticipated_return_value = anticipated_return_value self.entry_function = self.user_module.lookup_function(entry_name) def __call__(self, *inputs): @@ -101,26 +105,43 @@ def _inputs_to_device(self, inputs: list, arg_list: VmVariantList): # TODO: We are assuming the worst case here which is that we have unknown Torch # tensors that we send to the CPU and make continguous. Ideally, we would have # fast paths for our own backends and interop. + device = self.device_state.device + device_name = self.device_state.torch_device for input in inputs: - input_cpu = input.cpu().contiguous() - # Since this is already a fallback case, just use the numpy array interop. - # It isn't great, but meh... fallback case. - device_array = asdevicearray(self.device_state.device, input_cpu) - arg_list.push_ref(device_array._buffer_view) - + # input_cpu = input.cpu().contiguous() + # # Since this is already a fallback case, just use the numpy array interop. + # # It isn't great, but meh... fallback case. + # device_array = asdevicearray(self.device_state.device, input_cpu) + # arg_list.push_ref(device_array._buffer_view) + if not input.is_contiguous(): + input = input.cpu().contiguous() + + if input.device.type.startswith("cpu"): + if device_name.startswith("cuda"): + input = input.to("cuda") + + if(isinstance(input, nn.Parameter)): + buffer_view = device.allocator.import_buffer(device, input.data, dtype_to_element_type(input.dtype)) + else: + buffer_view = device.allocator.import_buffer(device, input, dtype_to_element_type(input.dtype)) + arg_list.push_ref(buffer_view) + def _returns_to_user(self, ret_list: VmVariantList): # TODO: This is also not good that we are moving back to the CPU like this. # We should be returning a custom Tensor implementation which represents # our device data and has synchronization hooks for accessing it. device = self.device_state.device - num_returns = len(ret_list) + # num_returns = len(ret_list) + num_returns = len(self.anticipated_return_value) user_returns = [None] * num_returns - for i in range(num_returns): - device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) - device_array = DeviceArray(device, device_buffer_view) - host_array = device_array.to_host() - user_returns[i] = torch_from_numpy(host_array) # type: ignore + ret_list_idx = 0 # self.anticipated_return_value could have None type elements, so here use ret_list_idx + for i in range(num_returns): + if self.anticipated_return_value[i]: + device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(ret_list_idx)) + ret_list_idx += 1 + element_type = HalElementType(device_buffer_view.element_type) + user_returns[i] = device.allocator.export_buffer(device, device_buffer_view, element_type) return user_returns