diff --git a/install.py b/install.py index abbb11bb4..aabd00824 100755 --- a/install.py +++ b/install.py @@ -794,8 +794,8 @@ def driver(): ) parser.add_argument( "--cuda", - action=BooleanFlag, - default=os.environ.get("USE_CUDA", "0") == "1", + action= BooleanFlag, + default=True,#os.environ.get("USE_CUDA", "0") == "1", help="Build Legate with CUDA support.", ) parser.add_argument( @@ -895,7 +895,7 @@ def driver(): "--clean", dest="clean_first", action=BooleanFlag, - default=True, + default=False, help="Clean before build, and pull latest Legion.", ) parser.add_argument( diff --git a/legate/core/corelib.py b/legate/core/corelib.py index ff35071a6..77800d20d 100644 --- a/legate/core/corelib.py +++ b/legate/core/corelib.py @@ -24,7 +24,7 @@ class CoreLib(Library): def __init__(self): self._lib = None - + def get_name(self): return "legate.core" @@ -38,6 +38,7 @@ def get_c_header(self): def initialize(self, shared_lib): self._lib = shared_lib shared_lib.legate_parse_config() + #self.fused_id = self._lib.LEGATE_CORE_FUSED_TASK_ID def get_registration_callback(self): return "legate_core_perform_registration" diff --git a/legate/core/launcher.py b/legate/core/launcher.py index 6e629132e..837e53068 100644 --- a/legate/core/launcher.py +++ b/legate/core/launcher.py @@ -534,6 +534,8 @@ def __init__(self, context, task_id, mapper_id=0, tag=0): self._sharding_space = None self._point = None self._output_regions = list() + self._is_fused = False + self._fusion_metadata = None @property def library_task_id(self): @@ -577,7 +579,6 @@ def add_store(self, args, store, proj, perm, tag, flags): else: region = store.storage.region field_id = store.storage.field.field_id - req = RegionReq(region, perm, proj, tag, flags) self._req_analyzer.insert(req, field_id) @@ -643,21 +644,35 @@ def set_sharding_space(self, space): def set_point(self, point): self._point = point + def add_fusion_metadata(self, is_fused, fusion_metadata): + self._is_fused = is_fused + self._fusion_metadata = fusion_metadata + @staticmethod def pack_args(argbuf, args): argbuf.pack_32bit_uint(len(args)) for arg in args: arg.pack(argbuf) + + @staticmethod + def pack_fusion_metadata(argbuf, is_fused, fusion_metadata): + argbuf.pack_bool(is_fused) + if is_fused: + fusion_metadata.pack(argbuf) + + def build_task(self, launch_domain, argbuf): self._req_analyzer.analyze_requirements() self._out_analyzer.analyze_requirements() + #pack fusion metadata + self.pack_fusion_metadata(argbuf, self._is_fused, self._fusion_metadata) + self.pack_args(argbuf, self._inputs) self.pack_args(argbuf, self._outputs) self.pack_args(argbuf, self._reductions) self.pack_args(argbuf, self._scalars) - task = IndexTask( self.legion_task_id, launch_domain, @@ -683,6 +698,9 @@ def build_task(self, launch_domain, argbuf): def build_single_task(self, argbuf): self._req_analyzer.analyze_requirements() self._out_analyzer.analyze_requirements() + + #pack fusion metadata + self.pack_fusion_metadata(argbuf, self._is_fused, self._fusion_metadata) self.pack_args(argbuf, self._inputs) self.pack_args(argbuf, self._outputs) diff --git a/legate/core/legion.py b/legate/core/legion.py index 9fb7a88ce..a3cd653c9 100644 --- a/legate/core/legion.py +++ b/legate/core/legion.py @@ -4859,6 +4859,12 @@ def pack_32bit_int(self, arg): self.size += 4 self.add_arg(arg, legion.LEGION_TYPE_INT32) + def pack_32bit_int_arr(self, arg): + self.fmt.append(str(len(arg))+"i") + size = len(arg) + self.size += 4*size + self.args += arg + def pack_64bit_int(self, arg): self.fmt.append("q") self.size += 8 @@ -5043,7 +5049,7 @@ def pack_dtype(self, dtype): def get_string(self): if self.string is None or self.arglen != len(self.args): fmtstr = "".join(self.fmt) - assert len(fmtstr) == len(self.args) + 1 + #assert len(fmtstr) == len(self.args) + 1 self.string = struct.pack(fmtstr, *self.args) self.arglen = len(self.args) return self.string diff --git a/legate/core/operation.py b/legate/core/operation.py index b9bfc8c29..a5a5e53e4 100644 --- a/legate/core/operation.py +++ b/legate/core/operation.py @@ -20,7 +20,11 @@ from .legion import Future from .store import Store from .utils import OrderedSet - +from .legion import ( + FieldSpace, + Future +) + class Operation(object): def __init__(self, context, mapper_id=0, op_id=0): @@ -30,6 +34,7 @@ def __init__(self, context, mapper_id=0, op_id=0): self._inputs = [] self._outputs = [] self._reductions = [] + self._is_fused = False self._input_parts = [] self._output_parts = [] self._reduction_parts = [] @@ -145,11 +150,18 @@ def add_broadcast(self, store): def add_constraint(self, constraint): self._constraints.append(constraint) + def has_constraint(self, store1, store2): + part1 = self._get_unique_partition(store1) + part2 = self._get_unique_partition(store2) + cons = [str(con) for con in self._constraints] + return (str(part1 == part2) in cons) or (str(part2==part1) in cons) + def execute(self): self._context.runtime.submit(self) def get_tag(self, strategy, part): if strategy.is_key_part(part): + return 0 return 1 # LEGATE_CORE_KEY_STORE_TAG else: return 0 @@ -180,6 +192,7 @@ def __init__(self, context, task_id, mapper_id=0, op_id=0): self._task_id = task_id self._scalar_args = [] self._futures = [] + self._fusion_metadata = None def get_name(self): libname = self.context.library.get_name() @@ -195,14 +208,24 @@ def add_dtype_arg(self, dtype): def add_future(self, future): self._futures.append(future) + def add_fusion_metadata(self, fusion_metadata): + self._is_fused = True + self._fusion_metadata = fusion_metadata + def launch(self, strategy): launcher = TaskLauncher(self.context, self._task_id, self.mapper_id) - for input, input_part in zip(self._inputs, self._input_parts): + if self._is_fused: + launcher.add_fusion_metadata(self._is_fused, self._fusion_metadata) + input_parts = self._input_parts + output_parts = self._output_parts + reduction_parts = self._reduction_parts + + for input, input_part in zip(self._inputs, input_parts): proj = strategy.get_projection(input_part) tag = self.get_tag(strategy, input_part) launcher.add_input(input, proj, tag=tag) - for output, output_part in zip(self._outputs, self._output_parts): + for output, output_part in zip(self._outputs, output_parts): if output.unbound: continue proj = strategy.get_projection(output_part) @@ -212,7 +235,7 @@ def launch(self, strategy): # We update the key partition of a store only when it gets updated output.set_key_partition(partition) for ((reduction, redop), reduction_part) in zip( - self._reductions, self._reduction_parts + self._reductions, reduction_parts ): partition = strategy.get_partition(reduction_part) can_read_write = partition.is_disjoint_for(strategy, reduction) diff --git a/legate/core/partition.py b/legate/core/partition.py index 6fe166f2c..304f8b36d 100644 --- a/legate/core/partition.py +++ b/legate/core/partition.py @@ -180,10 +180,8 @@ def construct(self, region, complete=False): transform = Transform(tile_shape.ndim, tile_shape.ndim) for idx, size in enumerate(tile_shape): transform.trans[idx, idx] = size - lo = Shape((0,) * tile_shape.ndim) + self._offset hi = self._tile_shape - 1 + self._offset - extent = Rect(hi, lo, exclusive=False) color_space = self._runtime.find_or_create_index_space( diff --git a/legate/core/runtime.py b/legate/core/runtime.py index e162ab6b7..3d433637b 100644 --- a/legate/core/runtime.py +++ b/legate/core/runtime.py @@ -25,6 +25,8 @@ from legate.core import types as ty +import datetime + from .context import Context from .corelib import CoreLib from .launcher import TaskLauncher @@ -41,8 +43,12 @@ ) from .partition import Restriction from .shape import Shape -from .solver import Partitioner -from .store import RegionField, Store +from .solver import Partitioner, Strategy +from .store import RegionField, Store, FusionMetadata +import numpy as np +from .constraints import Alignment + + # A Field holds a reference to a field in a region tree @@ -522,7 +528,6 @@ def compute_launch_shape(self, store, restrictions): for dim, restriction in enumerate(restrictions): if restriction != Restriction.RESTRICTED: to_partition += (shape[dim],) - launch_shape = self._compute_launch_shape(to_partition) if launch_shape is None: return None @@ -711,14 +716,287 @@ def record_partition(self, index_space, functor, index_partition): self._index_partitions[key] = index_partition +class FusionChecker(object): + def __init__(self, ops, contexts, runtime): + """ + This is a class containing a list of constraints for fusing ops + It emits whether or not a given list of ops can be fused + """ + self.constraints = [] + self.ops = ops + self.contexts = contexts + self.runtime=runtime + self.partitioners = [] + self.strategies = [] + + def register_constraint(self, fusion_constraint_rule): + self.constraints.append(fusion_constraint_rule) + + def supress_small_fusions(self, intervals, threshold): + #find if there's a fusable sub window of length + #greater than or equal to fusion_thresh + final_set = [] + fusable=False + for interval in intervals: + if interval[1] - interval[0] >=threshold: + final_set.append(interval) + fusable = True + else: + for i in range(interval[0], interval[1]): + final_set.append((i, i+1)) + return fusable, final_set + + def can_fuse(self): + for op in reversed(self.ops): + must_be_single = any(len(gop.scalar_outputs) > 0 for gop in [op]) + partitioner = Partitioner(self.runtime, [op], must_be_single=must_be_single) + self.partitioners.append( partitioner ) + strategy = partitioner.partition_stores() + self.strategies.append(strategy) + self.strategies.reverse() + + windows = [(0, len(self.ops))] + for constraint in self.constraints: + windows = constraint.apply(self.contexts, self.runtime, self.ops, windows, self.partitioners, self.strategies) + + #for i,strategy in enumerate(self.strategies): + # print(i,"i", strategy) + old_strategies = self.strategies[:] + old_strategies.reverse() + self.strategies = [] + #for i,strategy in enumerate(old_strategies): + # print(i,strategy) + ist=0 + keyps = [] + for window in reversed(windows): + fusable,final_set = self.supress_small_fusions(windows, self.runtime._fusion_threshold) + local_partitions = [] + if window[0] == window[1]: + continue + for op in reversed(self.ops[ window[0]:window[1] ]): + strategy = old_strategies[ist] + #print("looking in", strategy) + + for output, part, in zip(op._outputs, op._output_parts): + #print("need", part) + partition = strategy.get_partition(part) + local_partitions.append(partition) + ist+=1 + midpoint = int(len(local_partitions)/2) + partition = local_partitions[midpoint] + keyps.append(partition) + #print("selected", midpoint, len(local_partitions), partition) + for op in reversed(self.ops[ window[0]:window[1] ]): + #print("selected", partition) + for output, part, in zip(op._outputs, op._output_parts): + output.reset_key_partition() + output.set_key_partition(partition) + strategy._strategy[part] = partition + key_part = partition + #check if input and output should be aligned + for input, ipart in zip(op._inputs, op._input_parts): + if input.shape== output.shape: + input.reset_key_partition() + input.set_key_partition(partition) + strategy._strategy[ipart] = partition + self.strategies.append(strategy) + #return fusable, final_set, self.strategies + self.strategies.reverse() + keyps.reverse() + + + fusable,final_set = self.supress_small_fusions(windows, self.runtime._fusion_threshold) + return fusable, final_set, self.strategies, keyps + + +class FusionConstraint(object): + def apply(self, contexts, runtime, ops, baseIntervals, partitioners, strategies): + """" + Abstract class for determining a rule that constrains + which legate operations can be fused + """ + raise NotImplementedError("Implement in derived classes") + +class cuNumericContextExists(FusionConstraint): + """ + Fusion currently exists as a cuNumeric operation + This can be removed once fusion becomes a core task + """ + def apply(self, contexts, runtime, ops, baseIntervals, partitioners, strategies): + if "cunumeric" in contexts: + return baseIntervals + else: + return [(i, i+1) for i in range(len(ops))] + + +class AllValidOps(FusionConstraint): + """ + Class for only fusing only potentially fusable ops. + This class performs the first pass of legality filtering + """ + def __init__(self): + self.validIDs = set() + self.terminals = set() + self.validIDs.add(2) #Binary op + self.validIDs.add(10) #Fill op + self.validIDs.add(21) #Unary op + + def apply(self, contexts, runtime, ops, baseIntervals, partitioners, strategies): + fusable_intervals = [] + results = [int(op._task_id) in self.validIDs for op in ops] + for baseInterval in baseIntervals: + start, end = baseInterval[0], baseInterval[0] + while end 1: + #initialize fused task + fused_task = numpy_context.create_task(fused_id) + #serialize necessary metadata on all encapsulated ops + fusion_metadata = self.serialize_multiop_metadata(numpy_context, op_subset) + fused_task.add_fusion_metadata(fusion_metadata) #sets fused_task._is_fused to true + + #add typical inputs and outputs of all subtasks to fused task + for j,op in enumerate(op_subset): + for scalar in op._scalar_args: + fused_task.add_scalar_arg(scalar[0], ty.int32) + for (reduction, redop), part in zip(op._reductions, op._reduction_parts): + fused_task.add_reduction(reduction, redop) + for input,part in zip(op._inputs, op._input_parts): + fused_task.add_input(input) + for output,part in zip(op._outputs, op._output_parts): + fused_task.add_output(output) + for future in op._futures: + fused_task.add_future(future) + for constraint in op._constraints: + if (isinstance(constraint, Alignment)): + fused_task.add_alignment(constraint._lhs.store, constraint._rhs.store) + opID+=1 + new_op_list.append(fused_task) + strats=[] + + for i,fused_task in enumerate(new_op_list): + must_be_single = any(len(gop.scalar_outputs) > 0 for gop in [fused_task]) + for output, part, in zip(fused_task._outputs, fused_task._output_parts): + output.set_key_partition(keyps[i]) + for input, part, in zip(fused_task._inputs, fused_task._input_parts): + input.set_key_partition(keyps[i]) + partitioner = Partitioner(self, [fused_task], must_be_single=must_be_single) + strategy = partitioner.partition_stores() + fused_task.strategy = strategy + strats.append(strategy) + #print(i, strategy) + return new_op_list, strats + + def _launch_outstanding(self, force_eval=True): + if len(self._outstanding_ops): + ops = self._outstanding_ops + self._outstanding_ops = [] + self._schedule(ops, force_eval) + + def _schedule(self, ops, force_eval=False): + ids = [op._task_id for op in ops] + #case 1: try fusing current window of tasks + strats = False + if len(ops)>=2 and (not force_eval): + fused_task_list,strats = self.build_fused_op(ops) + if fused_task_list: + self._clearing_pipe = True + for task in fused_task_list: + task.execute() + self._clearing_pipe = False + + # case 2: tasks processed for fusion already have + # their strategy "baked in", as we already partitioned + # them when testing fusion legality (in case 1) + elif len(ops)==1 and self._clearing_pipe: + strategy = ops[0].strategy + ops[0].launch(strategy) + + # case 3: execute the ops normally + # partition if op wasn't checked for fusability + else: + if not strats: #ops were not check for fusability, so partition them + strats = [] + for op in ops: + must_be_single = any(len(gop.scalar_outputs) > 0 for gop in [op]) + partitioner = Partitioner(self, [op], must_be_single=must_be_single) + strategy = partitioner.partition_stores() + strats.append(strategy) + for i,op in enumerate(ops): + op.launch(strats[i]) + + + def submit(self, op): + #always launch ops that've been processed for fusion + #do not re-add to the window + #as the these ops already waited in the window + if self._clearing_pipe: + self._schedule([op]) + else: + self._outstanding_ops.append(op) + if len(self._outstanding_ops) >= self._window_size: + ops = self._outstanding_ops + self._outstanding_ops = [] + self._schedule(ops) + + def _scheduleNew(self, ops): # TODO: For now we run the partitioner for each operation separately. # We will eventually want to compute a trace-wide partitioning # strategy. @@ -902,7 +1352,7 @@ def flush_scheduling_window(self): self._outstanding_ops = [] self._schedule(ops) - def submit(self, op): + def submitNew(self, op): self._outstanding_ops.append(op) if len(self._outstanding_ops) >= self._window_size: self.flush_scheduling_window() @@ -1120,6 +1570,7 @@ def reduce_future_map(self, future_map, redop): def _cleanup_legate_runtime(): global _runtime + _runtime._launch_outstanding() _runtime.destroy() del _runtime gc.collect() diff --git a/legate/core/solver.py b/legate/core/solver.py index b4d5c5e80..84aa8180c 100644 --- a/legate/core/solver.py +++ b/legate/core/solver.py @@ -35,7 +35,7 @@ def empty(self): def _add(self, var1, var2): cls = set([var1, var2]) cls_id = self._next_class_id - self._next_class_id + 1 + self._next_class_id += 1 self._classes[cls_id] = cls self._class_ids[var1] = cls_id self._class_ids[var2] = cls_id @@ -290,7 +290,10 @@ def cost(unknown): if isinstance(prev_part, NoPartition): partition = prev_part else: - partition = store.compute_key_partition(restrictions) + if store._key_partition is not None: + partition=store._key_partition + else: + partition = store.compute_key_partition(restrictions) key_parts.add(unknown) cls = constraints.find(unknown) @@ -298,7 +301,6 @@ def cost(unknown): if to_align in partitions: continue partitions[to_align] = partition - prev_part = partition for lhs, rhs in dependent.items(): diff --git a/legate/core/store.py b/legate/core/store.py index 5d002f9dd..c8063f72f 100644 --- a/legate/core/store.py +++ b/legate/core/store.py @@ -489,7 +489,8 @@ def shape(self): # If someone wants to access the shape of an unbound # store before it is set, that means the producer task is # sitting in the queue, so we should flush the queue. - self._runtime.flush_scheduling_window() + self._runtime._launch_outstanding(False) + #self._runtime.flush_scheduling_window() # At this point, we should have the shape set. assert self._shape is not None return self._shape @@ -533,7 +534,7 @@ def storage(self): # If someone is trying to retreive the storage of a store, # we need to execute outstanding operations so that we know # it has been initialized correctly. - self._runtime.flush_scheduling_window() + self._runtime._launch_outstanding(False) if self._storage is None: if self.unbound: raise RuntimeError( @@ -908,3 +909,43 @@ def find_or_create_partition(self, functor): part = converted.construct(self.storage.region, complete=complete) self._partitions[functor] = (part, proj) return part, proj + + +class FusionMetadata(object): + def __init__( + self, + input_starts, + output_starts, + offset_starts, + buffer_offsets, + reduction_starts, + scalar_starts, + future_starts, + opIDs + ): + self._input_starts = input_starts + self._output_starts = output_starts + self._offset_starts = offset_starts + self._buffer_offsets = buffer_offsets + self._reduction_starts = reduction_starts + self._scalar_starts = scalar_starts + self._future_starts = future_starts + self._opIDs = opIDs + + def packList(self, meta_list, buf): + # aggregate the ints when packing + # much faster than individually packing each int + buf.pack_32bit_int_arr(meta_list) + + def pack(self, buf): + superbuff = [len(self._opIDs)]+[len(self._buffer_offsets)] + superbuff += self._input_starts + superbuff += self._output_starts + superbuff += self._offset_starts + superbuff += self._buffer_offsets + superbuff += self._reduction_starts + superbuff += self._scalar_starts + superbuff += self._future_starts + superbuff += self._opIDs + self.packList(superbuff, buf) + diff --git a/setup.py b/setup.py index 749f42373..ac6f7dc02 100755 --- a/setup.py +++ b/setup.py @@ -69,8 +69,8 @@ def run(self): # Remove the recurse argument from the list sys.argv.remove("--recurse") setup( - name="legate.core", - version="0.1", + name="legate-core", + version="21.10.00", packages=["legate", "legate.core", "legate.timing"], cmdclass={"build_py": my_build_py}, ) diff --git a/src/core.mk b/src/core.mk index 57fbc84a5..337925308 100644 --- a/src/core.mk +++ b/src/core.mk @@ -32,10 +32,10 @@ GEN_CPU_SRC = core/legate_c.cc \ core/task/task.cc \ core/utilities/deserializer.cc \ core/utilities/machine.cc \ - core/utilities/linearize.cc + core/utilities/linearize.cc ifeq ($(strip $(USE_CUDA)),1) -GEN_CPU_SRC += core/gpu/cudalibs.cc +GEN_CPU_SRC += core/gpu/cudalibs.cc endif # Header files that we need to have installed for client legate libraries @@ -63,4 +63,4 @@ INSTALL_HEADERS = legate.h \ core/utilities/machine.h \ core/utilities/span.h \ core/utilities/type_traits.h \ - core/utilities/typedefs.h + core/utilities/typedefs.h diff --git a/src/core/data/store.cc b/src/core/data/store.cc index 2fa1cb77e..4fd19b34c 100644 --- a/src/core/data/store.cc +++ b/src/core/data/store.cc @@ -45,6 +45,7 @@ RegionField& RegionField::operator=(RegionField&& other) noexcept dim_ = other.dim_; pr_ = other.pr_; fid_ = other.fid_; + readable_ = other.readable_; writable_ = other.writable_; reducible_ = other.reducible_; diff --git a/src/core/data/store.h b/src/core/data/store.h index 15471e389..b14a5d850 100644 --- a/src/core/data/store.h +++ b/src/core/data/store.h @@ -331,6 +331,25 @@ class Store { bool reducible_{false}; }; +//containts prefix sums for a sub-op +//to index into its own data +struct FusionMetadata { + public: + bool isFused; + int32_t nOps; + int32_t nBuffers; + std::vector inputStarts; + std::vector outputStarts; + std::vector offsetStarts; + std::vector offsets; // can contain negative elements + std::vector reductionStarts; + std::vector scalarStarts; + std::vector futureStarts; + std::vector opIDs; +}; + + + } // namespace legate #include "core/data/store.inl" diff --git a/src/core/data/transform.h b/src/core/data/transform.h index 16b68618d..e39b75962 100644 --- a/src/core/data/transform.h +++ b/src/core/data/transform.h @@ -51,7 +51,7 @@ class Shift : public StoreTransform { private: int32_t dim_; - int64_t offset_; + int64_t offset_; }; class Promote : public StoreTransform { diff --git a/src/core/mapping/task.cc b/src/core/mapping/task.cc index a84bd27d2..90a9140d0 100644 --- a/src/core/mapping/task.cc +++ b/src/core/mapping/task.cc @@ -121,6 +121,7 @@ Task::Task(const LegionTask* task, : task_(task), library_(library) { MapperDeserializer dez(task, runtime, context); + fusionMetadata = dez.unpack(); inputs_ = dez.unpack>(); outputs_ = dez.unpack>(); reductions_ = dez.unpack>(); diff --git a/src/core/mapping/task.h b/src/core/mapping/task.h index 69efdc034..cf2533688 100644 --- a/src/core/mapping/task.h +++ b/src/core/mapping/task.h @@ -20,6 +20,7 @@ #include #include "core/data/scalar.h" +#include "core/data/store.h" #include "core/data/transform.h" #include "core/runtime/context.h" @@ -176,7 +177,8 @@ class Task { const LibraryContext& library_; const Legion::Task* task_; - private: + public: + FusionMetadata fusionMetadata; std::vector inputs_, outputs_, reductions_; std::vector scalars_; }; diff --git a/src/core/runtime/context.cc b/src/core/runtime/context.cc index 7a9586e44..79bfbf6a4 100644 --- a/src/core/runtime/context.cc +++ b/src/core/runtime/context.cc @@ -146,10 +146,13 @@ TaskContext::TaskContext(const Legion::Task* task, : task_(task), regions_(regions), context_(context), runtime_(runtime) { TaskDeserializer dez(task, regions); + fusionMetadata = dez.unpack(); + inputs_ = dez.unpack>(); outputs_ = dez.unpack>(); reductions_ = dez.unpack>(); scalars_ = dez.unpack>(); + } ReturnValues TaskContext::pack_return_values() const diff --git a/src/core/runtime/context.h b/src/core/runtime/context.h index 6c4efc197..d93d1662f 100644 --- a/src/core/runtime/context.h +++ b/src/core/runtime/context.h @@ -17,6 +17,8 @@ #pragma once #include "legion.h" +#include "core/data/scalar.h" +#include "core/data/store.h" #include "core/task/return.h" @@ -107,11 +109,18 @@ class LibraryContext { // of the Legion API. class TaskContext { public: + TaskContext() = default; + TaskContext(const Legion::Task* task, const std::vector& regions, Legion::Context context, Legion::Runtime* runtime); + TaskContext(const Legion::Task* task, const std::vector regions) + : task_(task), regions_(regions) + {} + + public: std::vector& inputs() { return inputs_; } std::vector& outputs() { return outputs_; } @@ -121,13 +130,14 @@ class TaskContext { public: ReturnValues pack_return_values() const; - private: + public: const Legion::Task* task_; const std::vector& regions_; Legion::Context context_; Legion::Runtime* runtime_; + FusionMetadata fusionMetadata; - private: + public: std::vector inputs_, outputs_, reductions_; std::vector scalars_; }; diff --git a/src/core/runtime/runtime.cc b/src/core/runtime/runtime.cc index aab251955..a39159c66 100644 --- a/src/core/runtime/runtime.cc +++ b/src/core/runtime/runtime.cc @@ -32,6 +32,15 @@ Logger log_legate("legate"); // This is the unique string name for our library which can be used // from both C++ and Python to generate IDs + +using LegateVariantImpl = void (*)(TaskContext&); +/*static */ std::vector > Core::opIDs = *(new std::vector >()); +/*static */ std::vector > Core::gpuOpIDs = *(new std::vector >()); +/*static */ std::vector > Core::ompOpIDs = *(new std::vector >()); +/*static */ std::unordered_map Core::cpuDescriptors = *(new std::unordered_map()); +/*static */ std::unordered_map Core::gpuDescriptors = *(new std::unordered_map()); +/*static */ std::unordered_map Core::ompDescriptors = *(new std::unordered_map()); + static const char* const core_library_name = "legate.core"; /*static*/ bool Core::show_progress = false; diff --git a/src/core/runtime/runtime.h b/src/core/runtime/runtime.h index 03b62e0c8..566ad4adb 100644 --- a/src/core/runtime/runtime.h +++ b/src/core/runtime/runtime.h @@ -19,9 +19,12 @@ #include "legion.h" #include "core/utilities/typedefs.h" - +#include "core/runtime/context.h" +#include namespace legate { +using LegateVariantImpl = void (*)(TaskContext&); + extern uint32_t extract_env(const char* env_name, const uint32_t default_value, const uint32_t test_value); @@ -30,6 +33,12 @@ class Core { public: static void parse_config(void); static void shutdown(void); + static std::unordered_map cpuDescriptors; + static std::unordered_map gpuDescriptors; + static std::unordered_map ompDescriptors; + static std::vector > opIDs; + static std::vector > gpuOpIDs; + static std::vector > ompOpIDs; public: // Configuration settings diff --git a/src/core/task/task.cc b/src/core/task/task.cc index a7eb11da9..64adf10b9 100644 --- a/src/core/task/task.cc +++ b/src/core/task/task.cc @@ -57,6 +57,23 @@ void LegateTaskRegistrar::record_variant(TaskID tid, void LegateTaskRegistrar::register_all_tasks(Runtime* runtime, LibraryContext& context) { + for (auto& taskIdx : Core::opIDs){ + auto newID = context.get_task_id(taskIdx.first); + Core::cpuDescriptors.insert(std::pair((int64_t) newID, taskIdx.second)); + } + + for (auto& taskIdx : Core::gpuOpIDs){ + auto newID = context.get_task_id(taskIdx.first); + Core::gpuDescriptors.insert(std::pair((int64_t) newID, taskIdx.second)); + } + + for (auto& taskIdx : Core::ompOpIDs){ + auto newID = context.get_task_id(taskIdx.first); + Core::ompDescriptors.insert(std::pair((int64_t) newID, taskIdx.second)); + } + + + // Do all our registrations for (auto& task : pending_task_variants_) { task.task_id = diff --git a/src/core/task/task.h b/src/core/task/task.h index 690db2684..e83019a08 100644 --- a/src/core/task/task.h +++ b/src/core/task/task.h @@ -133,6 +133,14 @@ class LegateTask { legion_task_wrapper::template legate_task_wrapper>); auto task_id = T::TASK_ID; + if (kind ==Legion::Processor::LOC_PROC){ + Core::opIDs.push_back(std::pair((int64_t)task_id, TASK_PTR)); + }else if (kind ==Legion::Processor::TOC_PROC){ + Core::gpuOpIDs.push_back(std::pair((int64_t)task_id, TASK_PTR)); + } + else if (kind ==Legion::Processor::OMP_PROC){ + Core::ompOpIDs.push_back(std::pair((int64_t)task_id, TASK_PTR)); + } T::Registrar::record_variant(task_id, T::task_name(), desc, diff --git a/src/core/utilities/deserializer.cc b/src/core/utilities/deserializer.cc index 68ab45907..9a0813013 100644 --- a/src/core/utilities/deserializer.cc +++ b/src/core/utilities/deserializer.cc @@ -40,6 +40,61 @@ TaskDeserializer::TaskDeserializer(const LegionTask* task, first_task_ = !task->is_index_space || (task->index_point == task->index_domain.lo()); } + +void TaskDeserializer::_unpack(FusionMetadata& metadata){ + metadata.isFused = unpack(); + if (!metadata.isFused){ + return; + } + //exit out if the this is not a fused op + metadata.nOps = unpack(); + metadata.nBuffers = unpack(); + int nOps = metadata.nOps; + int nBuffers = metadata.nBuffers; + + metadata.inputStarts.resize(nOps+1); + metadata.outputStarts.resize(nOps+1); + metadata.offsetStarts.resize(nOps+1); + metadata.offsets.resize(nBuffers+1); + metadata.reductionStarts.resize(nOps+1); + metadata.scalarStarts.resize(nOps+1); + metadata.futureStarts.resize(nOps+1); + metadata.opIDs.resize(nOps); + //TODO: wrap this up to reuse code` + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } +} + void TaskDeserializer::_unpack(Store& value) { auto is_future = unpack(); @@ -82,7 +137,6 @@ void TaskDeserializer::_unpack(FutureWrapper& value) future = futures_[0]; futures_ = futures_.subspan(1); } - value = FutureWrapper(read_only, field_size, domain, future, has_storage && first_task_); } @@ -91,7 +145,6 @@ void TaskDeserializer::_unpack(RegionField& value) auto dim = unpack(); auto idx = unpack(); auto fid = unpack(); - value = RegionField(dim, regions_[idx], fid); } @@ -136,6 +189,62 @@ void MapperDeserializer::_unpack(Store& value) } } +void MapperDeserializer::_unpack(FusionMetadata& metadata){ + metadata.isFused = unpack(); + if (!metadata.isFused){ + return; + } + //exit out if the this is not a fused op + metadata.nOps = unpack(); + metadata.nBuffers = unpack(); + int nOps = metadata.nOps; + int nBuffers = metadata.nBuffers; + + metadata.inputStarts.resize(nOps+1); + metadata.outputStarts.resize(nOps+1); + metadata.offsetStarts.resize(nOps+1); + metadata.offsets.resize(nBuffers+1); + metadata.reductionStarts.resize(nOps+1); + metadata.scalarStarts.resize(nOps+1); + metadata.futureStarts.resize(nOps+1); + metadata.opIDs.resize(nOps); + //TODO: wrap this up to reuse code` + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } + for (int i=0; i(); + } +} + + + void MapperDeserializer::_unpack(FutureWrapper& value) { // We still need to deserialize these fields to get to the domain diff --git a/src/core/utilities/deserializer.h b/src/core/utilities/deserializer.h index 5e8cbd227..df1fe4f47 100644 --- a/src/core/utilities/deserializer.h +++ b/src/core/utilities/deserializer.h @@ -87,6 +87,7 @@ class TaskDeserializer : public BaseDeserializer { void _unpack(FutureWrapper& value); void _unpack(RegionField& value); void _unpack(OutputRegionField& value); + void _unpack(FusionMetadata& value); private: Span futures_; @@ -109,6 +110,7 @@ class MapperDeserializer : public BaseDeserializer { void _unpack(Store& value); void _unpack(FutureWrapper& value); void _unpack(RegionField& value, bool is_output_region); + void _unpack(FusionMetadata& value); private: Legion::Mapping::MapperRuntime* runtime_; diff --git a/src/core/utilities/makeshift_serializer.cc b/src/core/utilities/makeshift_serializer.cc new file mode 100644 index 000000000..06f2d3997 --- /dev/null +++ b/src/core/utilities/makeshift_serializer.cc @@ -0,0 +1,158 @@ +#include "core/utilities/makeshift_serializer.h" + +namespace legate{ + + void MakeshiftSerializer::packScalar(const Scalar& scalar){ + pack((bool) scalar.is_tuple()); + pack((LegateTypeCode) scalar.code_); + int32_t size = scalar.size(); + packWithoutType(scalar.data_, size); + } + + void MakeshiftSerializer::packTransform(const StoreTransform* trans){ + + if (trans==nullptr){ + int32_t neg= -1; + pack((int32_t) neg); + } + else{ + int32_t code = trans->getTransformCode(); + pack((int32_t) code); + switch (code) { + case -1: { + break; + + } + case LEGATE_CORE_TRANSFORM_SHIFT: { + Shift * shifter = (Shift*) trans; + pack((int32_t) shifter->dim_); + pack((int64_t) shifter->offset_); + packTransform(trans->parent_.get()); + break; + } + case LEGATE_CORE_TRANSFORM_PROMOTE: { + Promote * promoter = (Promote*) trans; + pack((int32_t) promoter->extra_dim_); + pack((int64_t) promoter->dim_size_); + packTransform(trans->parent_.get()); + break; + } + case LEGATE_CORE_TRANSFORM_PROJECT: { + Project * projector = (Project*) trans; + pack((int32_t) projector->dim_); + pack((int64_t) projector->coord_); + packTransform(trans->parent_.get()); + break; + } + case LEGATE_CORE_TRANSFORM_TRANSPOSE: { + Transpose * projector = (Transpose*) trans; + packTransform(trans->parent_.get()); + break; + } + case LEGATE_CORE_TRANSFORM_DELINEARIZE: { + Delinearize * projector = (Delinearize*) trans; + packTransform(trans->parent_.get()); + break; + } + } + } + } +/* + case LEGATE_CORE_TRANSFORM_SHIFT: { + auto dim = unpack(); + auto offset = unpack(); + auto parent = unpack_transform(); + return std::make_unique(dim, offset, std::move(parent)); + } + case LEGATE_CORE_TRANSFORM_PROMOTE: { + auto extra_dim = unpack(); + auto dim_size = unpack(); + auto parent = unpack_transform(); + return std::make_unique(extra_dim, dim_size, std::move(parent)); + } + case LEGATE_CORE_TRANSFORM_PROJECT: { + auto dim = unpack(); + auto coord = unpack(); + auto parent = unpack_transform(); + return std::make_unique(dim, coord, std::move(parent)); + } + case LEGATE_CORE_TRANSFORM_TRANSPOSE: { + auto axes = unpack>(); + auto parent = unpack_transform(); + return std::make_unique(std::move(axes), std::move(parent)); + } + case LEGATE_CORE_TRANSFORM_DELINEARIZE: { + auto dim = unpack(); + auto sizes = unpack>(); + auto parent = unpack_transform(); + return std::make_unique(dim, std::move(sizes), std::move(parent)); + } + + def _serialize_transform(self, buf): + if self._parent is not None: + self._transform.serialize(buf) + self._parent._serialize_transform(buf) + else: + buf.pack_32bit_int(-1) +*/ + void MakeshiftSerializer::packBuffer(const Store& buffer) + { + pack((bool) buffer.is_future2()); //is_future + pack((int32_t) buffer.dim()); + //int32_t code = buffer.code(); + pack((int32_t) buffer.code()); + //pack transform: + //pack trasnform code + packTransform(buffer.transform_.get()); + + //if _isfuture + if(buffer.is_future_) + { + //std::cout<<"packing future"<=0 + else if (buffer.dim()>=0){ + pack((int32_t) buffer.redop_id_); + //pack reigon field + //pack dim + pack((int32_t) buffer.region_field_.dim()); + //pack idx (req idx) //need to map regions to idx + unsigned newID = getNewReqID(buffer.region_field_.reqIdx_); + //pack((uint32_t) buffer.region_field_.reqIdx_); + pack((uint32_t) newID); + //pack fid (field id) + pack((int32_t) buffer.region_field_.fid_); + } + else + { + //pack redop_id + pack((int32_t) buffer.redop_id_); + //pack reigon field + //pack dim; always 1 in an buffer + pack((int32_t) 1); + //pack idx (req idx) //need to map regions to idx + unsigned newID = getNewReqID(buffer.region_field_.reqIdx_); + pack((uint32_t) newID); + //pack fid (field id) + pack((int32_t) buffer.region_field_.fid_); + } + } + + + +} diff --git a/src/core/utilities/makeshift_serializer.h b/src/core/utilities/makeshift_serializer.h new file mode 100644 index 000000000..81a85b2f2 --- /dev/null +++ b/src/core/utilities/makeshift_serializer.h @@ -0,0 +1,168 @@ + +#pragma once +#include +#include +#include "core/data/store.h" +#include "core/data/scalar.h" +#include "core/data/transform.h" +#include + +namespace legate { + +class Scalar; +class Store; +class MakeshiftSerializer{ + + public: + MakeshiftSerializer(){ + size=512; + raw.resize(size); + write_offset=0; + read_offset=0; + buffer_counter=0; + } + void zero(){ + //memset ((void*)raw.data(),0,raw.size()); + write_offset=0; + buffer_counter=0; + neededReqIds.clear(); + regionReqIdMap.clear(); + } +/* + template void pack(T&& arg) + { + T copy = arg; + pack(copy); //call l-value version + } +*/ + template void pack(T arg) + { + int8_t * argAddr = (int8_t*) &arg; + if (size<=write_offset+sizeof(T)) + { + resize(sizeof(T)); + } + //for (int i=0; i((argAddr)+i); + //} + memcpy(raw.data()+write_offset, argAddr, sizeof(T)); + //std::cout<<"reint "<<*reinterpret_cast(raw.data()+write_offset)<(argByte+i); + } + write_offset+=argSize; + //std::cout<<" "< T read() + { + if (read_offset(raw.data()+read_offset); + read_offset+=sizeof(T); + return datum; + } + else{ + std::cout<<"finished reading buffer"<(id, returnAndIncrCounter())); + neededReqIds.push_back(id); + } + } + + int32_t getNewReqID(int32_t oldID) + { + return regionReqIdMap.find(oldID)->second; + } + + std::vector getReqIds (){ + //could use move semantics here + std::vector reqIdsCopy(neededReqIds); + return reqIdsCopy; + } + + private: + size_t size; + int read_offset; + int write_offset; + int buffer_counter; + std::vector raw; + + private: + std::map regionReqIdMap; //maps old reqids to new ones + std::vector neededReqIds; //list of old reqIds needed in child op + +}; +/* +int main(){ + MakeshiftSerializer ms; + int a=3; + char g='a'; + ms.pack(a); + ms.pack(g); + ms.pack(a); + ms.pack(g); + std::cout<()<()<()<()<()<()<