Skip to content

Commit

Permalink
Scalar unary
Browse files Browse the repository at this point in the history
  • Loading branch information
reyna-abhyankar committed Oct 8, 2023
1 parent a909c49 commit e09080c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 62 deletions.
70 changes: 70 additions & 0 deletions lib/runtime/src/ops/element_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ enum Slots {
PER_DEVICE_STATE
};

/* ElementUnary */
OpTaskInvocation init(ElementUnaryAttrs const &attrs) {
OpTaskBinding b;

Expand Down Expand Up @@ -52,6 +53,36 @@ OpTaskInvocation backward(ElementUnaryAttrs const &attrs) {
return {ELEMENTUNARY_BWD_TASK_ID, b};
}

/* ElementScalarUnary */
OpTaskInvocation init(ElementScalarUnaryAttrs const &attrs) {
OpTaskBinding b;

b.bind_arg(HANDLE, ff_handle());
b.bind_arg(ATTRS, attrs);
b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0));

return {ELEMENTUNARY_INIT_TASK_ID, b};
}

OpTaskInvocation forward(ElementScalarUnaryAttrs const &attrs) {
OpTaskBinding b;

b.bind(INPUT, input_tensor(0));
b.bind(OUTPUT, output_tensor(0));

b.bind_arg(PROFILING, profiling_settings());
b.bind_arg(PER_DEVICE_STATE,
per_device_op_state<ElementUnaryPerDeviceState>());

return {ELEMENTUNARY_FWD_TASK_ID, b};
}

OpTaskInvocation backward(ElementScalarUnaryAttrs const &attrs) {
OpTaskBinding b = infer_bwd_binding(forward(attrs).binding);

return {ELEMENTUNARY_BWD_TASK_ID, b};
}

static DeviceSpecific<ElementUnaryPerDeviceState>
init_task_impl(TaskArgumentAccessor const &acc) {

Expand Down Expand Up @@ -171,6 +202,45 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim,
return make_metrics(forward_time, backward_time, sync_time, env);
}

CostMetrics measure_operator_cost(SimEnvFactory const &sim,
ElementScalarUnaryAttrs const &attrs,
InputParallelTensorDesc const &input_shape,
ProfilingSettings const &settings,
MachineView const &mv) {
auto env = sim.new_environment();

ParallelTensorShape output_shape = get_output_shape(attrs, input_shape);

SimTaskBinding init_binding;
init_binding.bind_arg(HANDLE, ff_handle());
init_binding.bind_arg(ATTRS, attrs);
init_binding.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0));

auto init_accessor =
env.get_init_accessor(ELEMENTUNARY_INIT_TASK_ID, init_binding);
DeviceSpecific<ElementUnaryPerDeviceState> per_device_state =
init_task_impl(init_accessor);

SimTaskBinding fwd_binding;
fwd_binding.bind(INPUT, input_shape);
fwd_binding.bind(OUTPUT, output_shape);
fwd_binding.bind_arg(PROFILING, settings);
fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state);

SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding);

auto fwd_accessor =
env.get_fwd_accessor(ELEMENTUNARY_FWD_TASK_ID, fwd_binding);
auto bwd_accessor =
env.get_bwd_accessor(ELEMENTUNARY_BWD_TASK_ID, bwd_binding);

float forward_time = forward_task_impl(fwd_accessor).value();
float backward_time = backward_task_impl(bwd_accessor).value();

float sync_time = default_estimate_sync_time(env);
return make_metrics(forward_time, backward_time, sync_time, env);
}

template <>
OpTaskSignature init_signature<ELEMENTUNARY_INIT_TASK_ID>() {
OpTaskSignature init(OpTaskType::INIT);
Expand Down
64 changes: 2 additions & 62 deletions lib/runtime/src/ops/element_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,76 +24,16 @@ OpTaskInvocation backward(ElementScalarUnaryAttrs const &);

CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory,
ElementUnaryAttrs const &attrs,
ParallelTensorShape const &input_shape,
InputParallelTensorDesc const &input_shape,
ProfilingSettings const &settings,
MachineView const &machine_view);

CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory,
ElementScalarUnaryAttrs const &attrs,
ParallelTensorShape const &input_shape,
InputParallelTensorDesc const &input_shape,
ProfilingSettings const &settings,
MachineView const &machine_view);

/* class ElementUnary : public Op { */
/* public: */
/* ElementUnary(FFModel &model, */
/* OperatorType type, */
/* const ParallelTensor x, */
/* bool inplace, */
/* char const *name, */
/* float scalar); */
/* void init(FFModel const &) override; */
/* void forward(FFModel const &) override; */
/* void backward(FFModel const &) override; */
/* void map_output_tensors(FFModel &model) override; */
/* bool can_inplace_output() override; */
/* bool has_inplace_output() override; */
/* void do_inplace_output() override; */
/* static Op * */
/* create_operator_from_layer(FFModel &model, */
/* Layer const *layer, */
/* std::vector<ParallelTensor> const &inputs);
*/

/* static PerDeviceOpState *init_task(Legion::Task const *task, */
/* std::vector<Legion::PhysicalRegion> const
* &regions, */
/* Legion::Context ctx, */
/* Legion::Runtime *runtime); */
/* static void forward_task(Legion::Task const *task, */
/* std::vector<Legion::PhysicalRegion> const
* &regions, */
/* Legion::Context ctx, */
/* Legion::Runtime *runtime); */
/* static void backward_task(Legion::Task const *task, */
/* std::vector<Legion::PhysicalRegion> const
* &regions, */
/* Legion::Context ctx, */
/* Legion::Runtime *runtime); */
/* template <typename T> */
/* static void */
/* forward_task_with_type(Legion::Task const *task, */
/* std::vector<Legion::PhysicalRegion> const
* &regions, */
/* Legion::Context ctx, */
/* Legion::Runtime *runtime); */
/* template <typename T> */
/* static void backward_task_with_type( */
/* Legion::Task const *task, */
/* std::vector<Legion::PhysicalRegion> const &regions, */
/* Legion::Context ctx, */
/* Legion::Runtime *runtime); */
/* bool measure_operator_cost(Simulator *sim, */
/* MachineView const &pc, */
/* CostMetrics &cost_metrics) const override; */

/* private: */
/* bool inplace; */

/* public: */
/* float scalar; */
/* }; */

} // namespace FlexFlow

#endif

0 comments on commit e09080c

Please sign in to comment.