diff --git a/lib/kernels/CMakeLists.txt b/lib/kernels/CMakeLists.txt index 8ccd7c1011..5a6a0d1357 100644 --- a/lib/kernels/CMakeLists.txt +++ b/lib/kernels/CMakeLists.txt @@ -8,6 +8,8 @@ file(GLOB_RECURSE SRC LIST_DIRECTORIES False src/*.cc src/cuda/cuda_helper.cu + src/cuda/loss_function_kernels.cu + src/cuda/optimizer_kernels.cu src/cuda/ops/*.cu ) diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 96a3b3b281..fd66697793 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -14,9 +14,10 @@ namespace FlexFlow { struct ArrayShape { public: ArrayShape() = delete; - ArrayShape(size_t *dims, size_t num_dims); - ArrayShape(TensorShape const &shape); - ArrayShape(std::vector const &); + explicit ArrayShape(size_t *dims, size_t num_dims); + explicit ArrayShape(TensorShape const &shape); + explicit ArrayShape(std::vector const &); + explicit ArrayShape(LegionTensorDims const &); /** * @brief Alias of ArrayShape::num_elements for compatibility with @@ -42,9 +43,16 @@ struct ArrayShape { std::optional at_maybe(legion_dim_t) const; std::optional at_maybe(ff_dim_t) const; - ArrayShape - sub_shape(std::optional> start, - std::optional> end) const; + ArrayShape sub_shape(legion_dim_t start, ff_dim_t end) const; + + ArrayShape sub_shape(std::optional start, + std::optional end) const; + + ArrayShape sub_shape(std::optional start, + std::optional end) const; + + bool operator==(ArrayShape const &) const; + bool operator!=(ArrayShape const &) const; public: LegionTensorDims dims; diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index e4dd9723b8..29c5e29a93 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -10,6 +10,9 @@ legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value); legion_dim_t legion_dim_from_ff_dim(ff_dim_t, int num_dimensions); +std::optional legion_dim_from_ff_dim(std::optional, + int num_dimensions); + template using LegionOrdered = DimOrdered; diff --git a/lib/kernels/include/kernels/optimizer_kernels.h b/lib/kernels/include/kernels/optimizer_kernels.h index 9ca6bf8e2b..ed7c2778dd 100644 --- a/lib/kernels/include/kernels/optimizer_kernels.h +++ b/lib/kernels/include/kernels/optimizer_kernels.h @@ -1,7 +1,8 @@ #ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_OPTIMIZER_KERNELS_H #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_OPTIMIZER_KERNELS_H -#include "device.h" +#include "kernels/device.h" +#include "kernels/ff_handle.h" namespace FlexFlow { @@ -20,7 +21,8 @@ void sgd_nccl_update_task_gpu(ffStream_t, float lr, float momentum, bool nesterov, - float weight_decay PerDeviceFFHandle const &, + float weight_decay, + PerDeviceFFHandle const &, float const *weight_grad_ptr, size_t size, float *weight_ptr, @@ -32,6 +34,8 @@ void adam_ps_update_task_gpu(ffStream_t, float beta2, float weight_decay, float epsilon, + size_t size, + int num_replicas, float const *weight_grad_ptr, float *adam_m_ptr, float *adam_v_ptr, @@ -43,6 +47,7 @@ void adam_nccl_update_task_gpu(ffStream_t, float beta2, float weight_decay, float epsilon, + size_t size, PerDeviceFFHandle const &, float const *weight_grad_ptr, float *adam_m_ptr, diff --git a/lib/kernels/src/allocation.cc b/lib/kernels/src/allocation.cc index ccd88580db..b57fbee257 100644 --- a/lib/kernels/src/allocation.cc +++ b/lib/kernels/src/allocation.cc @@ -14,7 +14,7 @@ void Allocator::deallocate(void *ptr) { GenericTensorAccessorW Allocator::allocate_tensor(TensorShape const &tensor_shape) { void *ptr = this->allocate(get_size_in_bytes(tensor_shape)); - return {tensor_shape.data_type, tensor_shape, ptr}; + return {tensor_shape.data_type, ArrayShape{tensor_shape}, ptr}; } } // namespace FlexFlow diff --git a/lib/kernels/src/array_shape.cc b/lib/kernels/src/array_shape.cc index d5e2f1167d..31ee7b6001 100644 --- a/lib/kernels/src/array_shape.cc +++ b/lib/kernels/src/array_shape.cc @@ -1,4 +1,5 @@ #include "kernels/array_shape.h" +#include "op-attrs/dim_ordered/slice.h" #include "utils/containers/product.h" namespace FlexFlow { @@ -19,6 +20,9 @@ ArrayShape::ArrayShape(TensorShape const &shape) ArrayShape::ArrayShape(std::vector const &input_dims) : dims(input_dims) {} +ArrayShape::ArrayShape(LegionTensorDims const &legion_tensor_dims) + : dims(legion_tensor_dims) {} + std::size_t ArrayShape::get_volume() const { return this->num_elements(); } @@ -50,10 +54,23 @@ std::size_t ArrayShape::at(ff_dim_t idx) const { return dims.at(legion_dim_from_ff_dim(idx, this->num_dims())); } -ArrayShape ArrayShape::sub_shape( - std::optional> start, - std::optional> end) const { - NOT_IMPLEMENTED(); +ArrayShape ArrayShape::sub_shape(legion_dim_t start, ff_dim_t end) const { + legion_dim_t legion_end = legion_dim_from_ff_dim(end, num_dims()); + return this->sub_shape(start, legion_end); +} + +ArrayShape ArrayShape::sub_shape(std::optional start, + std::optional end) const { + std::optional legion_start = + legion_dim_from_ff_dim(start, num_dims()); + std::optional legion_end = + legion_dim_from_ff_dim(end, num_dims()); + return this->sub_shape(legion_start, legion_end); +} + +ArrayShape ArrayShape::sub_shape(std::optional start, + std::optional end) const { + return ArrayShape{slice(this->dims, start, end)}; } std::optional ArrayShape::at_maybe(legion_dim_t index) const { @@ -77,6 +94,14 @@ TensorShape get_tensor_shape(ArrayShape const &shape, DataType dtype) { dtype}; } +bool ArrayShape::operator==(ArrayShape const &other) const { + return this->dims == other.dims; +} + +bool ArrayShape::operator!=(ArrayShape const &other) const { + return this->dims != other.dims; +} + std::string format_as(ArrayShape const &x) { std::ostringstream oss; oss << " <<>>( - (float *)w_grad_ptr, src, size, 1.0f); + (float *)weight_grad_ptr, src, size, 1.0f); } // checkCUDA(cudaDeviceSynchronize()); // Step 2: SGD update - sgd_update<<>>( - size, - op->lr, - op->weight_decay, - op->momentum, - op->nesterov, - w_grad_ptr, - v_ptr, - w_ptr); + sgd_update<<>>(size, + lr, + weight_decay, + momentum, + nesterov, + weight_grad_ptr, + sgd_v_ptr, + weight_ptr); // checkCUDA(cudaDeviceSynchronize()); } #ifdef FF_USE_NCCL -__host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, - PerDeviceOpState const *meta, - float const *w_grad_ptr, - size_t size, - float *w_ptr, - float *v_ptr) { +void sgd_nccl_update_task_gpu(cudaStream_t stream, + float lr, + float momentum, + bool nesterov, + float weight_decay, + PerDeviceFFHandle const &handle, + float const *weight_grad_ptr, + size_t size, + float *weight_ptr, + float *sgd_v_ptr) { // Use NCCL to sync gradients // fprintf(stderr, "weight(%p) Before ncclAllReduce...\n", w_grad_ptr); - cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - checkNCCL(ncclAllReduce(w_grad_ptr, - (float *)w_grad_ptr, + checkNCCL(ncclAllReduce(weight_grad_ptr, + (float *)weight_grad_ptr, size, - ncclFloat, - ncclSum, - meta->handle.ncclComm, + ncclDataType_t::ncclFloat, + ncclRedOp_t::ncclSum, + handle.ncclComm, stream)); // fprintf(stderr, "weight(%p) After ncclAllReduce...\n", w_grad_ptr); // print_tensor((float*)w_grad_ptr, 16, "[After ncclAllReduce]"); // Step 2: SGD update - sgd_update<<>>( - size, - op->lr, - op->weight_decay, - op->momentum, - op->nesterov, - w_grad_ptr, - v_ptr, - w_ptr); + sgd_update<<>>(size, + lr, + weight_decay, + momentum, + nesterov, + weight_grad_ptr, + sgd_v_ptr, + weight_ptr); // checkCUDA(cudaDeviceSynchronize()); } #endif @@ -144,20 +150,24 @@ __global__ void adam_update(int count, } } -__host__ void AdamOptimizer::ps_update_task_gpu(AdamOptimizer const *op, - float const *w_grad_ptr, - size_t size, - int num_replicas, - float *w_ptr, - float *v_ptr, - float *m_ptr) { - cudaStream_t stream; +void adam_ps_update_task_gpu(cudaStream_t stream, + float alpha_t, + float beta1, + float beta2, + float weight_decay, + float epsilon, + size_t size, + int num_replicas, + float const *weight_grad_ptr, + float *adam_m_ptr, + float *adam_v_ptr, + float *weight_ptr) { checkCUDA(get_legion_stream(&stream)); // Step 1: Gather gradients in the first replica for (int i = 1; i < num_replicas; i++) { - float const *src = w_grad_ptr + i * size; + float const *src = weight_grad_ptr + i * size; add_kernel<<>>( - size, 1.0f, src, (float *)w_grad_ptr); + size, 1.0f, src, (float *)weight_grad_ptr); } // checkCUDA(cudaDeviceSynchronize()); // fprintf(stderr, "alpha = %.8lf alpha_t = %.8lf decay = %.8lf\n", @@ -165,50 +175,54 @@ __host__ void AdamOptimizer::ps_update_task_gpu(AdamOptimizer const *op, // Step 2: Adam update adam_update<<>>( size, - op->alpha_t, - op->beta1, - op->beta2, - op->weight_decay, - op->epsilon, - w_grad_ptr, - m_ptr, - v_ptr, - w_ptr); + alpha_t, + beta1, + beta2, + weight_decay, + epsilon, + weight_grad_ptr, + adam_m_ptr, + adam_v_ptr, + weight_ptr); // checkCUDA(cudaDeviceSynchronize()); } #ifdef FF_USE_NCCL -__host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, - PerDeviceOpState const *meta, - float const *w_grad_ptr, - size_t size, - float *w_ptr, - float *v_ptr, - float *m_ptr) { +void adam_nccl_update_task_gpu(cudaStream_t stream, + float alpha_t, + float beta1, + float beta2, + float weight_decay, + float epsilon, + size_t size, + PerDeviceFFHandle const &handle, + float const *weight_grad_ptr, + float *adam_m_ptr, + float *adam_v_ptr, + float *weight_ptr) { // Use NCCL to sync gradients - cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - checkNCCL(ncclAllReduce(w_grad_ptr, - (float *)w_grad_ptr, + checkNCCL(ncclAllReduce(weight_grad_ptr, + (float *)weight_grad_ptr, size, - ncclFloat, - ncclSum, - meta->handle.ncclComm, + ncclDataType_t::ncclFloat, + ncclRedOp_t::ncclSum, + handle.ncclComm, stream)); // fprintf(stderr, "alpha = %.8lf alpha_t = %.8lf decay = %.8lf\n", // op->alpha, op->alpha_t, op->weight_decay); // Step 2: Adam update adam_update<<>>( size, - op->alpha_t, - op->beta1, - op->beta2, - op->weight_decay, - op->epsilon, - w_grad_ptr, - m_ptr, - v_ptr, - w_ptr); + alpha_t, + beta1, + beta2, + weight_decay, + epsilon, + weight_grad_ptr, + adam_m_ptr, + adam_v_ptr, + weight_ptr); // checkCUDA(cudaDeviceSynchronize()); } #endif diff --git a/lib/kernels/src/legion_dim.cc b/lib/kernels/src/legion_dim.cc index 9ef47d40ae..c190a02220 100644 --- a/lib/kernels/src/legion_dim.cc +++ b/lib/kernels/src/legion_dim.cc @@ -10,4 +10,13 @@ legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, int num_dimensions) { return legion_dim_t(num_dimensions - ff_dim.value - 1); } +std::optional + legion_dim_from_ff_dim(std::optional ff_dim, int num_dimensions) { + if (ff_dim.has_value()) { + return legion_dim_from_ff_dim(ff_dim.value(), num_dimensions); + } else { + return std::nullopt; + } +} + } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/arg_ref.h b/lib/local-execution/include/local-execution/arg_ref.h index 30326b0e84..30da405c13 100644 --- a/lib/local-execution/include/local-execution/arg_ref.h +++ b/lib/local-execution/include/local-execution/arg_ref.h @@ -60,6 +60,20 @@ struct ArgRefSpec { friend struct std::hash>; }; +template +std::string format_as(ArgRefSpec const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +template +std::ostream &operator<<(std::ostream &s, ArgRefSpec const &x) { + return (s << fmt::to_string(x)); +} + } // namespace FlexFlow namespace std { @@ -68,8 +82,7 @@ template struct hash<::FlexFlow::ArgRefSpec> { size_t operator()(::FlexFlow::ArgRefSpec const &s) const { size_t result = 0; - hash_combine(s.type_idx); - hash_combine(s.ref_type); + ::FlexFlow::hash_combine(result, s.type_idx); return result; } }; diff --git a/lib/local-execution/include/local-execution/concrete_arg.h b/lib/local-execution/include/local-execution/concrete_arg.h index 3bc2714a71..ac5d97f3c4 100644 --- a/lib/local-execution/include/local-execution/concrete_arg.h +++ b/lib/local-execution/include/local-execution/concrete_arg.h @@ -3,6 +3,7 @@ #include "fmt/format.h" #include "local-execution/serialization.h" +#include "utils/hash-utils.h" #include "utils/type_index.h" #include @@ -53,4 +54,17 @@ std::ostream &operator<<(std::ostream &, ConcreteArgSpec const &); } // namespace FlexFlow +namespace std { + +template <> +struct hash<::FlexFlow::ConcreteArgSpec> { + size_t operator()(::FlexFlow::ConcreteArgSpec const &s) const { + size_t result = 0; + ::FlexFlow::hash_combine(result, s.get_type_index()); + return result; + } +}; + +} // namespace std + #endif diff --git a/lib/local-execution/include/local-execution/fwd_bwd_op_task_impl_function.h b/lib/local-execution/include/local-execution/fwd_bwd_op_task_impl_function.h new file mode 100644 index 0000000000..cc82291f6a --- /dev/null +++ b/lib/local-execution/include/local-execution/fwd_bwd_op_task_impl_function.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_FWD_BWD_TASK_IMPL_FUNCTION_H +#define _FLEXFLOW_LOCAL_EXECUTION_FWD_BWD_TASK_IMPL_FUNCTION_H + +#include "local-execution/task_argument_accessor.h" + +namespace FlexFlow { + +struct FwdBwdOpTaskImplFunction { + + std::optional (*function_ptr)(TaskArgumentAccessor const &); + + bool operator==(FwdBwdOpTaskImplFunction const &) const; + bool operator!=(FwdBwdOpTaskImplFunction const &) const; + bool operator<(FwdBwdOpTaskImplFunction const &) const; + bool operator>(FwdBwdOpTaskImplFunction const &) const; + bool operator<=(FwdBwdOpTaskImplFunction const &) const; + bool operator>=(FwdBwdOpTaskImplFunction const &) const; +}; + +std::string format_as(FwdBwdOpTaskImplFunction const &x); +std::ostream &operator<<(std::ostream &s, FwdBwdOpTaskImplFunction const &x); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::FwdBwdOpTaskImplFunction> { + size_t operator()(::FlexFlow::FwdBwdOpTaskImplFunction const &) const; +}; +} // namespace std + +#endif diff --git a/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h b/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h deleted file mode 100644 index 7f80af77f3..0000000000 --- a/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_FWD_BWD_TASK_IMPL_FUNCTION_H -#define _FLEXFLOW_LOCAL_EXECUTION_FWD_BWD_TASK_IMPL_FUNCTION_H - -#include "local-execution/task_argument_accessor.h" - -namespace FlexFlow { - -struct FwdBwdTaskImplFunction { - - std::optional (*function_ptr)(TaskArgumentAccessor const &); - - bool operator==(FwdBwdTaskImplFunction const &) const; - bool operator!=(FwdBwdTaskImplFunction const &) const; - bool operator<(FwdBwdTaskImplFunction const &) const; - bool operator>(FwdBwdTaskImplFunction const &) const; - bool operator<=(FwdBwdTaskImplFunction const &) const; - bool operator>=(FwdBwdTaskImplFunction const &) const; -}; - -std::string format_as(FwdBwdTaskImplFunction const &x); -std::ostream &operator<<(std::ostream &s, FwdBwdTaskImplFunction const &x); - -} // namespace FlexFlow - -namespace std { -template <> -struct hash<::FlexFlow::FwdBwdTaskImplFunction> { - size_t operator()(::FlexFlow::FwdBwdTaskImplFunction const &) const; -}; -} // namespace std - -#endif diff --git a/lib/local-execution/include/local-execution/generic_task_impl_function.h b/lib/local-execution/include/local-execution/generic_task_impl_function.h new file mode 100644 index 0000000000..425740f61d --- /dev/null +++ b/lib/local-execution/include/local-execution/generic_task_impl_function.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_GENERIC_TASK_IMPL_FUNCTION_H +#define _FLEXFLOW_LOCAL_EXECUTION_GENERIC_TASK_IMPL_FUNCTION_H + +#include "local-execution/device_specific_device_states.dtg.h" +#include "local-execution/task_argument_accessor.h" + +namespace FlexFlow { + +struct GenericTaskImplFunction { + + void (*function_ptr)(TaskArgumentAccessor const &); + + bool operator==(GenericTaskImplFunction const &) const; + bool operator!=(GenericTaskImplFunction const &) const; + bool operator<(GenericTaskImplFunction const &) const; + bool operator>(GenericTaskImplFunction const &) const; + bool operator<=(GenericTaskImplFunction const &) const; + bool operator>=(GenericTaskImplFunction const &) const; +}; + +std::string format_as(GenericTaskImplFunction const &x); +std::ostream &operator<<(std::ostream &s, GenericTaskImplFunction const &x); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::GenericTaskImplFunction> { + size_t operator()(::FlexFlow::GenericTaskImplFunction const &) const; +}; +} // namespace std + +#endif diff --git a/lib/local-execution/include/local-execution/init_op_task_impl_function.h b/lib/local-execution/include/local-execution/init_op_task_impl_function.h new file mode 100644 index 0000000000..7b23a2bc64 --- /dev/null +++ b/lib/local-execution/include/local-execution/init_op_task_impl_function.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_INIT_TASK_IMPL_FUNCTION_H +#define _FLEXFLOW_LOCAL_EXECUTION_INIT_TASK_IMPL_FUNCTION_H + +#include "local-execution/device_specific_device_states.dtg.h" +#include "local-execution/task_argument_accessor.h" + +namespace FlexFlow { + +struct InitOpTaskImplFunction { + + DeviceSpecificDeviceStates (*function_ptr)(TaskArgumentAccessor const &); + + bool operator==(InitOpTaskImplFunction const &) const; + bool operator!=(InitOpTaskImplFunction const &) const; + bool operator<(InitOpTaskImplFunction const &) const; + bool operator>(InitOpTaskImplFunction const &) const; + bool operator<=(InitOpTaskImplFunction const &) const; + bool operator>=(InitOpTaskImplFunction const &) const; +}; + +std::string format_as(InitOpTaskImplFunction const &x); +std::ostream &operator<<(std::ostream &s, InitOpTaskImplFunction const &x); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::InitOpTaskImplFunction> { + size_t operator()(::FlexFlow::InitOpTaskImplFunction const &) const; +}; +} // namespace std + +#endif diff --git a/lib/local-execution/include/local-execution/init_task_impl_function.h b/lib/local-execution/include/local-execution/init_task_impl_function.h deleted file mode 100644 index b85944e13a..0000000000 --- a/lib/local-execution/include/local-execution/init_task_impl_function.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_INIT_TASK_IMPL_FUNCTION_H -#define _FLEXFLOW_LOCAL_EXECUTION_INIT_TASK_IMPL_FUNCTION_H - -#include "local-execution/device_specific_device_states.dtg.h" -#include "local-execution/task_argument_accessor.h" - -namespace FlexFlow { - -struct InitTaskImplFunction { - - DeviceSpecificDeviceStates (*function_ptr)(TaskArgumentAccessor const &); - - bool operator==(InitTaskImplFunction const &) const; - bool operator!=(InitTaskImplFunction const &) const; - bool operator<(InitTaskImplFunction const &) const; - bool operator>(InitTaskImplFunction const &) const; - bool operator<=(InitTaskImplFunction const &) const; - bool operator>=(InitTaskImplFunction const &) const; -}; - -std::string format_as(InitTaskImplFunction const &x); -std::ostream &operator<<(std::ostream &s, InitTaskImplFunction const &x); - -} // namespace FlexFlow - -namespace std { -template <> -struct hash<::FlexFlow::InitTaskImplFunction> { - size_t operator()(::FlexFlow::InitTaskImplFunction const &) const; -}; -} // namespace std - -#endif diff --git a/lib/local-execution/include/local-execution/itask_argument_accessor.h b/lib/local-execution/include/local-execution/itask_argument_accessor.h index b4d188e4a3..9eff9460c2 100644 --- a/lib/local-execution/include/local-execution/itask_argument_accessor.h +++ b/lib/local-execution/include/local-execution/itask_argument_accessor.h @@ -5,6 +5,7 @@ #include "local-execution/concrete_arg.h" #include "local-execution/op_task_signature.h" #include "local-execution/privilege_tensor_accessor.h" +#include "local-execution/tensor_type.dtg.h" namespace FlexFlow { @@ -15,10 +16,11 @@ struct ITaskArgumentAccessor { virtual ConcreteArgSpec const &get_concrete_arg(slot_id_t) const = 0; - virtual GenericTensorAccessor - get_tensor(slot_id_t slot, Permissions priv, IsGrad is_grad) const = 0; + virtual GenericTensorAccessor get_tensor(slot_id_t slot, + Permissions priv, + TensorType tensor_type) const = 0; virtual VariadicGenericTensorAccessor get_variadic_tensor( - slot_id_t slot, Permissions priv, IsGrad is_grad) const = 0; + slot_id_t slot, Permissions priv, TensorType tensor_type) const = 0; virtual Allocator get_allocator() const = 0; virtual size_t get_device_idx() const = 0; diff --git a/lib/local-execution/include/local-execution/layer_tensor_key.struct.toml b/lib/local-execution/include/local-execution/layer_tensor_key.struct.toml new file mode 100644 index 0000000000..3ec6d7b0f1 --- /dev/null +++ b/lib/local-execution/include/local-execution/layer_tensor_key.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LayerTensorKey" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "local-execution/reduced_tensor_t.dtg.h" +] + +[[fields]] +name = "layer_guid" +type = "::FlexFlow::layer_guid_t" + +[[fields]] +name = "reduced_tensor" +type = "::FlexFlow::reduced_tensor_t" diff --git a/lib/local-execution/include/local-execution/local_slots_backing.h b/lib/local-execution/include/local-execution/local_slots_backing.h index 5b826c7022..a632f432cf 100644 --- a/lib/local-execution/include/local-execution/local_slots_backing.h +++ b/lib/local-execution/include/local-execution/local_slots_backing.h @@ -1,53 +1,82 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_SLOT_REGISTRY_H -#define _FLEXFLOW_LOCAL_EXECUTION_SLOT_REGISTRY_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_LOCAL_SLOTS_BACKING_H +#define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_SLOTS_BACKING_H #include "kernels/accessor.h" +#include "local-execution/layer_tensor_key.dtg.h" #include "local-execution/local_task_argument_accessor.h" +#include "local-execution/non_graph_tensor_guid_t.dtg.h" #include "local-execution/op_task_invocation.h" #include "local-execution/per_device_op_state.h" #include "local-execution/runtime_arg_config.h" +#include "local-execution/task_invocation.dtg.h" +#include "local-execution/tensor_role.dtg.h" #include "pcg/computation_graph.dtg.h" -#include "pcg/layer_guid_t.dtg.h" #include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { +using LayerTensorBackingMap = + std::unordered_map; + using TensorBackingMap = - std::unordered_map; + std::unordered_map; struct LocalSlotsBacking { - LocalSlotsBacking(TensorBackingMap const &, RuntimeArgConfig const &); + LocalSlotsBacking(LayerTensorBackingMap const &allocated_forward_tensors, + TensorBackingMap const &allocated_non_graph_tensors, + RuntimeArgConfig const &); public: void add_per_device_op_state(layer_guid_t const &, DeviceSpecificDeviceStates const &); - void allocate_outgoing_tensors(layer_guid_t const &, - ComputationGraph const &, - Allocator &); + void allocate_layer_tensors(layer_guid_t const &, + ComputationGraph const &, + Allocator &); + void allocate_tensors_by_role(TensorRole const &, + layer_guid_t const &, + ComputationGraph const &, + Allocator &); + void allocate_optimizer_tensors(layer_guid_t const &weight_layer, + tensor_guid_t const &, + ComputationGraph const &, + Allocator &, + TaskSignature const &); TensorSlotsBacking construct_tensor_slots_backing(OpTaskBinding const &, layer_guid_t const &) const; + TensorSlotsBacking + construct_tensor_slots_backing(TaskBinding const &, + std::optional const &) const; ArgSlotsBacking construct_arg_slots_backing(OpTaskBinding const &, layer_guid_t const &) const; + ArgSlotsBacking construct_arg_slots_backing(TaskBinding const &) const; ConcreteArgSpec resolve_runtime_arg_ref_spec(RuntimeArgRefSpec const &) const; ConcreteArgSpec resolve_op_arg_ref_spec(OpArgRefSpec const &, layer_guid_t const &) const; -private: - bool is_tensor_allocated(tensor_guid_t const &) const; - bool is_gradient_tensor_allocated(tensor_guid_t const &) const; - GenericTensorAccessorW const &get_tensor_backing(tensor_guid_t const &, - IsGrad) const; + GenericTensorAccessorW const & + get_tensor_backing(TensorType const &, + reduced_tensor_t const &, + std::optional const &) const; + + bool is_forward_tensor_allocated(LayerTensorKey const &) const; + bool is_non_graph_tensor_allocated(reduced_tensor_t const &) const; public: // tensors - TensorBackingMap tensor_mapping; - TensorBackingMap gradient_tensor_mapping; - std::unordered_map> + LayerTensorBackingMap tensor_mapping; + LayerTensorBackingMap gradient_tensor_mapping; + LayerTensorBackingMap optimizer_tensor_mapping; + TensorBackingMap non_graph_tensor_mapping; + std::unordered_map> input_tensor_slots; - std::unordered_map> + std::unordered_map> + weight_tensor_slots; + std::unordered_map> output_tensor_slots; + std::unordered_map> + weight_optimizer_tensor_guids; // arguments std::unordered_map diff --git a/lib/local-execution/include/local-execution/local_task_argument_accessor.h b/lib/local-execution/include/local-execution/local_task_argument_accessor.h index 1e1516a0de..db0e98c2b1 100644 --- a/lib/local-execution/include/local-execution/local_task_argument_accessor.h +++ b/lib/local-execution/include/local-execution/local_task_argument_accessor.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_LOCAL_TASK_ARGUMENT_ACCESSOR_H #define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_TASK_ARGUMENT_ACCESSOR_H -#include "local-execution/slot_grad_id.dtg.h" +#include "local-execution/slot_tensor_type_id.dtg.h" #include "local-execution/task_argument_accessor.h" #include #include @@ -9,7 +9,7 @@ namespace FlexFlow { using TensorSlotsBacking = std::unordered_map< - SlotGradId, + SlotTensorTypeId, std::variant>>; using ArgSlotsBacking = std::unordered_map; @@ -25,9 +25,9 @@ struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor { GenericTensorAccessor get_tensor(slot_id_t slot, Permissions priv, - IsGrad is_grad) const override; + TensorType tensor_type) const override; VariadicGenericTensorAccessor get_variadic_tensor( - slot_id_t slot, Permissions priv, IsGrad is_grad) const override; + slot_id_t slot, Permissions priv, TensorType tensor_type) const override; Allocator get_allocator() const override; @@ -40,7 +40,7 @@ struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor { }; using TensorSlotsBackingWithoutAddresses = std::unordered_map< - SlotGradId, + SlotTensorTypeId, std::variant, std::vector>>>; diff --git a/lib/local-execution/include/local-execution/local_training_backing.h b/lib/local-execution/include/local-execution/local_training_backing.h index 6789624076..cbab4bf031 100644 --- a/lib/local-execution/include/local-execution/local_training_backing.h +++ b/lib/local-execution/include/local-execution/local_training_backing.h @@ -3,7 +3,9 @@ #include "local-execution/local_slots_backing.h" #include "local-execution/task_registry.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/computation_graph.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" namespace FlexFlow { @@ -13,16 +15,26 @@ using PerLayerElapsedTime = struct LocalTrainingBacking { LocalTrainingBacking(Allocator const &, ComputationGraph const &, - TensorBackingMap const &, + LayerTensorBackingMap const &allocated_forward_tensors, + TensorBackingMap const &allocated_non_graph_tensors, RuntimeArgConfig const &); - - void execute_init(); - PerLayerElapsedTime execute_forward(); - PerLayerElapsedTime execute_backward(); - void execute_update(); - - TaskArgumentAccessor get_task_arg_accessor(OpTaskInvocation const &, - layer_guid_t const &) const; + void register_and_allocate_layer(layer_guid_t const &); + void allocate_layer_optimizer_tensors(layer_guid_t const &, + OptimizerAttrs const &); + + void execute_init(layer_guid_t const &); + std::optional execute_forward(layer_guid_t const &); + void compute_loss(LossAttrs const &loss_attrs, + reduced_tensor_t const &logit_tensor, + reduced_tensor_t const &label_tensor); + std::optional execute_backward(layer_guid_t const &); + void execute_update(layer_guid_t const &, OptimizerAttrs const &); + + TaskArgumentAccessor + get_task_arg_accessor(TaskInvocation const &, + std::optional const &) const; + TaskArgumentAccessor get_op_task_arg_accessor(OpTaskInvocation const &, + layer_guid_t const &) const; private: DeviceSpecificDeviceStates call_init_task_impl(task_id_t, diff --git a/lib/runtime/src/loss_functions.h b/lib/local-execution/include/local-execution/loss_functions.h similarity index 60% rename from lib/runtime/src/loss_functions.h rename to lib/local-execution/include/local-execution/loss_functions.h index 620ebc6936..4ce74da766 100644 --- a/lib/runtime/src/loss_functions.h +++ b/lib/local-execution/include/local-execution/loss_functions.h @@ -13,24 +13,20 @@ * limitations under the License. */ -#ifndef _FF_LOSS_FUNCTIONS_H_ -#define _FF_LOSS_FUNCTIONS_H_ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOSS_FUNCTIONS_H_ +#define _FLEXFLOW_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOSS_FUNCTIONS_H_ +#include "local-execution/task_impl_function.dtg.h" +#include "local-execution/task_invocation.dtg.h" +#include "local-execution/task_signature.h" #include "op-attrs/ops/loss_functions.h" -#include "pcg/operator.h" -#include "pcg/parallel_tensor.h" -#include "pcg/parallel_tensor_guid_t.h" -#include "task_spec/task_invocation.h" -#include "tasks.h" namespace FlexFlow { -template <> -void register_task(); - -TaskInvocation backward(LossAttrs const &, - parallel_tensor_guid_t logit, - parallel_tensor_guid_t label); +TaskImplFunction get_loss_bwd_task_impl(); +TaskSignature get_loss_bwd_signature(); +TaskInvocation + backward(LossAttrs const &, reduced_tensor_t logit, reduced_tensor_t label); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/model_training_instance.h b/lib/local-execution/include/local-execution/model_training_instance.h new file mode 100644 index 0000000000..5cc13f0b40 --- /dev/null +++ b/lib/local-execution/include/local-execution/model_training_instance.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_MODEL_TRAINING_INSTANCE_H +#define _FLEXFLOW_LOCAL_EXECUTION_MODEL_TRAINING_INSTANCE_H + +#include "local-execution/local_training_backing.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" + +namespace FlexFlow { + +using PerLayerElapsedTime = + std::unordered_map>; + +struct ModelTrainingInstance { + ModelTrainingInstance(Allocator const &, + ComputationGraph const &, + LayerTensorBackingMap const &allocated_forward_tensors, + TensorBackingMap const &allocated_non_graph_tensors, + RuntimeArgConfig const &, + LossAttrs const &, + reduced_tensor_t const &logit_tensor, + reduced_tensor_t const &label_tensor, + OptimizerAttrs const &); + + void execute_init(); + PerLayerElapsedTime execute_forward(); + PerLayerElapsedTime execute_backward(); + void execute_update(); + + ComputationGraph computation_graph; + LocalTrainingBacking training_backing; + LossAttrs loss_attrs; + reduced_tensor_t logit_tensor; + reduced_tensor_t label_tensor; + OptimizerAttrs optimizer_attrs; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/non_graph_tensor_guid_t.struct.toml b/lib/local-execution/include/local-execution/non_graph_tensor_guid_t.struct.toml new file mode 100644 index 0000000000..4832ecaafa --- /dev/null +++ b/lib/local-execution/include/local-execution/non_graph_tensor_guid_t.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "non_graph_tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +[[fields]] +name = "raw_uid" +type = "int" diff --git a/lib/local-execution/include/local-execution/op_task_invocation.h b/lib/local-execution/include/local-execution/op_task_invocation.h index 0f351c3a0e..6484981ebf 100644 --- a/lib/local-execution/include/local-execution/op_task_invocation.h +++ b/lib/local-execution/include/local-execution/op_task_invocation.h @@ -10,7 +10,7 @@ #include "local-execution/op_tensor_spec.h" #include "local-execution/profiling.h" #include "local-execution/runtime_arg_ref.h" -#include "local-execution/slot_grad_id.dtg.h" +#include "local-execution/slot_tensor_type_id.dtg.h" #include "local-execution/task_id_t.dtg.h" #include "local-execution/variadic_tensor_ref.h" #include @@ -84,14 +84,14 @@ struct OpTaskBinding { bool operator==(OpTaskBinding const &other) const; bool operator!=(OpTaskBinding const &other) const; - std::unordered_map const & + std::unordered_map const & get_tensor_bindings() const; std::unordered_map const &get_arg_bindings() const; void bind_from_forward(OpTaskBinding const &fwd); private: - std::unordered_map tensor_bindings; + std::unordered_map tensor_bindings; std::unordered_map arg_bindings; private: diff --git a/lib/local-execution/include/local-execution/op_tensor_slot_spec.struct.toml b/lib/local-execution/include/local-execution/op_tensor_slot_spec.struct.toml index 590dbe6362..54638a7eb6 100644 --- a/lib/local-execution/include/local-execution/op_tensor_slot_spec.struct.toml +++ b/lib/local-execution/include/local-execution/op_tensor_slot_spec.struct.toml @@ -11,7 +11,7 @@ includes = [ "local-execution/slot_id_t.dtg.h", "local-execution/slot_type.dtg.h", "local-execution/tensor_role.dtg.h", - "local-execution/is_grad.dtg.h", + "local-execution/tensor_type.dtg.h", "local-execution/op_slot_options.dtg.h", ] @@ -28,8 +28,8 @@ name = "tensor_role" type = "::FlexFlow::TensorRole" [[fields]] -name = "is_grad" -type = "::FlexFlow::IsGrad" +name = "tensor_type" +type = "::FlexFlow::TensorType" [[fields]] name = "slot_option" diff --git a/lib/local-execution/src/ops/batch_matmul.h b/lib/local-execution/include/local-execution/ops/batch_matmul.h similarity index 100% rename from lib/local-execution/src/ops/batch_matmul.h rename to lib/local-execution/include/local-execution/ops/batch_matmul.h diff --git a/lib/local-execution/src/ops/batch_norm.h b/lib/local-execution/include/local-execution/ops/batch_norm.h similarity index 100% rename from lib/local-execution/src/ops/batch_norm.h rename to lib/local-execution/include/local-execution/ops/batch_norm.h diff --git a/lib/local-execution/src/ops/cast.h b/lib/local-execution/include/local-execution/ops/cast.h similarity index 100% rename from lib/local-execution/src/ops/cast.h rename to lib/local-execution/include/local-execution/ops/cast.h diff --git a/lib/local-execution/src/ops/combine.h b/lib/local-execution/include/local-execution/ops/combine.h similarity index 100% rename from lib/local-execution/src/ops/combine.h rename to lib/local-execution/include/local-execution/ops/combine.h diff --git a/lib/local-execution/src/ops/concat.h b/lib/local-execution/include/local-execution/ops/concat.h similarity index 100% rename from lib/local-execution/src/ops/concat.h rename to lib/local-execution/include/local-execution/ops/concat.h diff --git a/lib/local-execution/src/ops/conv_2d.h b/lib/local-execution/include/local-execution/ops/conv_2d.h similarity index 100% rename from lib/local-execution/src/ops/conv_2d.h rename to lib/local-execution/include/local-execution/ops/conv_2d.h diff --git a/lib/local-execution/src/ops/dropout.h b/lib/local-execution/include/local-execution/ops/dropout.h similarity index 100% rename from lib/local-execution/src/ops/dropout.h rename to lib/local-execution/include/local-execution/ops/dropout.h diff --git a/lib/local-execution/src/ops/element_binary.h b/lib/local-execution/include/local-execution/ops/element_binary.h similarity index 100% rename from lib/local-execution/src/ops/element_binary.h rename to lib/local-execution/include/local-execution/ops/element_binary.h diff --git a/lib/local-execution/src/ops/element_unary.h b/lib/local-execution/include/local-execution/ops/element_unary.h similarity index 100% rename from lib/local-execution/src/ops/element_unary.h rename to lib/local-execution/include/local-execution/ops/element_unary.h diff --git a/lib/local-execution/src/ops/embedding.h b/lib/local-execution/include/local-execution/ops/embedding.h similarity index 100% rename from lib/local-execution/src/ops/embedding.h rename to lib/local-execution/include/local-execution/ops/embedding.h diff --git a/lib/local-execution/src/ops/flat.h b/lib/local-execution/include/local-execution/ops/flat.h similarity index 100% rename from lib/local-execution/src/ops/flat.h rename to lib/local-execution/include/local-execution/ops/flat.h diff --git a/lib/local-execution/src/ops/gather.h b/lib/local-execution/include/local-execution/ops/gather.h similarity index 100% rename from lib/local-execution/src/ops/gather.h rename to lib/local-execution/include/local-execution/ops/gather.h diff --git a/lib/local-execution/src/ops/input.h b/lib/local-execution/include/local-execution/ops/input.h similarity index 100% rename from lib/local-execution/src/ops/input.h rename to lib/local-execution/include/local-execution/ops/input.h diff --git a/lib/local-execution/src/ops/layer_norm.h b/lib/local-execution/include/local-execution/ops/layer_norm.h similarity index 100% rename from lib/local-execution/src/ops/layer_norm.h rename to lib/local-execution/include/local-execution/ops/layer_norm.h diff --git a/lib/local-execution/src/ops/linear.h b/lib/local-execution/include/local-execution/ops/linear.h similarity index 100% rename from lib/local-execution/src/ops/linear.h rename to lib/local-execution/include/local-execution/ops/linear.h diff --git a/lib/local-execution/src/ops/noop.h b/lib/local-execution/include/local-execution/ops/noop.h similarity index 100% rename from lib/local-execution/src/ops/noop.h rename to lib/local-execution/include/local-execution/ops/noop.h diff --git a/lib/local-execution/src/ops/parallel_op.h b/lib/local-execution/include/local-execution/ops/parallel_op.h similarity index 100% rename from lib/local-execution/src/ops/parallel_op.h rename to lib/local-execution/include/local-execution/ops/parallel_op.h diff --git a/lib/local-execution/src/ops/pool_2d.h b/lib/local-execution/include/local-execution/ops/pool_2d.h similarity index 100% rename from lib/local-execution/src/ops/pool_2d.h rename to lib/local-execution/include/local-execution/ops/pool_2d.h diff --git a/lib/local-execution/src/ops/reduce.h b/lib/local-execution/include/local-execution/ops/reduce.h similarity index 100% rename from lib/local-execution/src/ops/reduce.h rename to lib/local-execution/include/local-execution/ops/reduce.h diff --git a/lib/local-execution/src/ops/reduction.h b/lib/local-execution/include/local-execution/ops/reduction.h similarity index 100% rename from lib/local-execution/src/ops/reduction.h rename to lib/local-execution/include/local-execution/ops/reduction.h diff --git a/lib/local-execution/src/ops/repartition.h b/lib/local-execution/include/local-execution/ops/repartition.h similarity index 100% rename from lib/local-execution/src/ops/repartition.h rename to lib/local-execution/include/local-execution/ops/repartition.h diff --git a/lib/local-execution/src/ops/replicate.h b/lib/local-execution/include/local-execution/ops/replicate.h similarity index 100% rename from lib/local-execution/src/ops/replicate.h rename to lib/local-execution/include/local-execution/ops/replicate.h diff --git a/lib/local-execution/src/ops/reshape.h b/lib/local-execution/include/local-execution/ops/reshape.h similarity index 100% rename from lib/local-execution/src/ops/reshape.h rename to lib/local-execution/include/local-execution/ops/reshape.h diff --git a/lib/local-execution/src/ops/reverse.h b/lib/local-execution/include/local-execution/ops/reverse.h similarity index 100% rename from lib/local-execution/src/ops/reverse.h rename to lib/local-execution/include/local-execution/ops/reverse.h diff --git a/lib/local-execution/src/ops/softmax.h b/lib/local-execution/include/local-execution/ops/softmax.h similarity index 100% rename from lib/local-execution/src/ops/softmax.h rename to lib/local-execution/include/local-execution/ops/softmax.h diff --git a/lib/local-execution/src/ops/split.h b/lib/local-execution/include/local-execution/ops/split.h similarity index 100% rename from lib/local-execution/src/ops/split.h rename to lib/local-execution/include/local-execution/ops/split.h diff --git a/lib/local-execution/src/ops/topk.h b/lib/local-execution/include/local-execution/ops/topk.h similarity index 100% rename from lib/local-execution/src/ops/topk.h rename to lib/local-execution/include/local-execution/ops/topk.h diff --git a/lib/local-execution/src/ops/transpose.h b/lib/local-execution/include/local-execution/ops/transpose.h similarity index 100% rename from lib/local-execution/src/ops/transpose.h rename to lib/local-execution/include/local-execution/ops/transpose.h diff --git a/lib/local-execution/src/ops/weight.h b/lib/local-execution/include/local-execution/ops/weight.h similarity index 100% rename from lib/local-execution/src/ops/weight.h rename to lib/local-execution/include/local-execution/ops/weight.h diff --git a/lib/local-execution/include/local-execution/optimizer.h b/lib/local-execution/include/local-execution/optimizer.h new file mode 100644 index 0000000000..2eb480a0c1 --- /dev/null +++ b/lib/local-execution/include/local-execution/optimizer.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_OPTIMIZER_H_ +#define _FLEXFLOW_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_OPTIMIZER_H_ + +#include "local-execution/non_graph_tensor_guid_t.dtg.h" +#include "local-execution/task_impl_function.dtg.h" +#include "local-execution/task_invocation.dtg.h" +#include "local-execution/task_signature.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "pcg/optimizers/adam_optimizer_attrs.dtg.h" +#include "pcg/optimizers/sgd_optimizer_attrs.dtg.h" + +namespace FlexFlow { + +TaskSignature get_update_signature(OptimizerAttrs const &); +TaskInvocation get_update_invocation( + OptimizerAttrs const &, + reduced_tensor_t const &weight, + std::vector const &grad_buffer_tensors); +TaskImplFunction get_update_task_impl(OptimizerAttrs const &); + +TaskSignature get_sgd_update_signature(); +TaskInvocation sgd_update(SGDOptimizerAttrs const &, + reduced_tensor_t const &weight, + reduced_tensor_t const &sgd_v); +TaskImplFunction get_sgd_update_task_impl(); + +TaskSignature get_adam_update_signature(); +TaskInvocation adam_update(AdamOptimizerAttrs const &, + reduced_tensor_t const &weight, + reduced_tensor_t const &adam_v, + reduced_tensor_t const &adam_m); +TaskImplFunction get_adam_update_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/reduced_tensor_t.struct.toml b/lib/local-execution/include/local-execution/reduced_tensor_t.struct.toml new file mode 100644 index 0000000000..726249c970 --- /dev/null +++ b/lib/local-execution/include/local-execution/reduced_tensor_t.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "reduced_tensor_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + + +[[fields]] +name = "raw_index" +type = "int" diff --git a/lib/local-execution/include/local-execution/runtime_arg_ref.h b/lib/local-execution/include/local-execution/runtime_arg_ref.h index 279d854a27..fd79e23126 100644 --- a/lib/local-execution/include/local-execution/runtime_arg_ref.h +++ b/lib/local-execution/include/local-execution/runtime_arg_ref.h @@ -5,6 +5,8 @@ #include "local-execution/config.h" #include "local-execution/device_specific.h" #include "local-execution/profiling.h" +#include "utils/fmt.h" +#include "utils/type_index.h" namespace FlexFlow { @@ -14,6 +16,8 @@ enum class RuntimeArgRefType { FF_ITERATION_CONFIG }; +std::string to_string(RuntimeArgRefType const &); + template using RuntimeArgRef = ArgRef; @@ -23,6 +27,18 @@ RuntimeArgRef profiling_settings(); RuntimeArgRef> ff_handle(); RuntimeArgRef iteration_config(); +// std::string format_as(RuntimeArgRefSpec const & x) { +// std::ostringstream oss; +// oss << ""; +// return oss.str(); +// } + +// std::ostream &operator<<(std::ostream & s, RuntimeArgRefSpec const & x) { +// return (s << fmt::to_string(x)); +// } + } // namespace FlexFlow #endif diff --git a/lib/local-execution/include/local-execution/slot_grad_id.struct.toml b/lib/local-execution/include/local-execution/slot_tensor_type_id.struct.toml similarity index 62% rename from lib/local-execution/include/local-execution/slot_grad_id.struct.toml rename to lib/local-execution/include/local-execution/slot_tensor_type_id.struct.toml index 256091d272..b3b3a320c7 100644 --- a/lib/local-execution/include/local-execution/slot_grad_id.struct.toml +++ b/lib/local-execution/include/local-execution/slot_tensor_type_id.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "SlotGradId" +name = "SlotTensorTypeId" features = [ "eq", "ord", @@ -8,7 +8,7 @@ features = [ ] includes = [ - "local-execution/is_grad.dtg.h", + "local-execution/tensor_type.dtg.h", "local-execution/slot_id_t.dtg.h", ] @@ -17,5 +17,5 @@ name = "slot_id" type = "::FlexFlow::slot_id_t" [[fields]] -name = "is_grad" -type = "::FlexFlow::IsGrad" +name = "tensor_type" +type = "::FlexFlow::TensorType" diff --git a/lib/local-execution/include/local-execution/task_arg_spec.variant.toml b/lib/local-execution/include/local-execution/task_arg_spec.variant.toml new file mode 100644 index 0000000000..271e3b73d6 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_arg_spec.variant.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TaskArgSpec" +features = [ + "eq", + "fmt", + "hash" +] + +includes = [ + "local-execution/concrete_arg.h", + "local-execution/runtime_arg_ref.h" +] + +[[values]] +type = "::FlexFlow::ConcreteArgSpec" +key = "concrete_arg_spec" + +[[values]] +type = "::FlexFlow::RuntimeArgRefSpec" +key = "runtime_arg_ref" diff --git a/lib/local-execution/include/local-execution/task_argument_accessor.h b/lib/local-execution/include/local-execution/task_argument_accessor.h index 54c8dfc5f1..29d5fb8fbe 100644 --- a/lib/local-execution/include/local-execution/task_argument_accessor.h +++ b/lib/local-execution/include/local-execution/task_argument_accessor.h @@ -8,6 +8,7 @@ namespace FlexFlow { struct TaskArgumentAccessor { + // arguments template T const &get_argument(slot_id_t slot) const { if constexpr (PerDeviceOpState::IsPartOfPerDeviceOpState_v) { @@ -24,6 +25,7 @@ struct TaskArgumentAccessor { return this->get_argument(slot_id_t{slot}); } + // tensors template privilege_mode_to_accessor get_tensor(int slot) const { return this->get_tensor(slot_id_t{slot}); @@ -32,7 +34,7 @@ struct TaskArgumentAccessor { template privilege_mode_to_accessor get_tensor(slot_id_t slot) const { return std::get>( - this->ptr->get_tensor(slot, PRIV, IsGrad::NO)); + this->ptr->get_tensor(slot, PRIV, TensorType::FORWARD)); } template @@ -43,9 +45,32 @@ struct TaskArgumentAccessor { template privilege_mode_to_accessor get_tensor_grad(slot_id_t slot) const { return std::get>( - this->ptr->get_tensor(slot, PRIV, IsGrad::YES)); + this->ptr->get_tensor(slot, PRIV, TensorType::GRADIENT)); } + template + privilege_mode_to_accessor get_optimizer_tensor(int slot) const { + return this->get_tensor_grad(slot_id_t{slot}); + } + + template + privilege_mode_to_accessor get_optimizer_tensor(slot_id_t slot) const { + return std::get>( + this->ptr->get_tensor(slot, PRIV, TensorType::OPTIMIZER)); + } + + template + privilege_mode_to_accessor get_non_graph_tensor(int slot) const { + return this->get_tensor_grad(slot_id_t{slot}); + } + + template + privilege_mode_to_accessor get_non_graph_tensor(slot_id_t slot) const { + return std::get>( + this->ptr->get_tensor(slot, PRIV, TensorType::NON_GRAPH)); + } + + // variadic tensors template std::vector> get_variadic_tensor(int slot) const { @@ -56,7 +81,7 @@ struct TaskArgumentAccessor { std::vector> get_variadic_tensor(slot_id_t slot) const { return std::get>>( - this->ptr->get_variadic_tensor(slot, PRIV, IsGrad::NO)); + this->ptr->get_variadic_tensor(slot, PRIV, TensorType::FORWARD)); } template @@ -69,7 +94,33 @@ struct TaskArgumentAccessor { std::vector> get_variadic_tensor_grad(slot_id_t slot) const { return std::get>>( - this->ptr->get_variadic_tensor(slot, PRIV, IsGrad::YES)); + this->ptr->get_variadic_tensor(slot, PRIV, TensorType::GRADIENT)); + } + + template + std::vector> + get_variadic_optimizer_tensor(int slot) const { + return this->get_variadic_tensor_grad(slot_id_t{slot}); + } + + template + std::vector> + get_variadic_optimizer_tensor(slot_id_t slot) const { + return std::get>>( + this->ptr->get_variadic_tensor(slot, PRIV, TensorType::OPTIMIZER)); + } + + template + std::vector> + get_variadic_non_graph_tensor(int slot) const { + return this->get_variadic_tensor_grad(slot_id_t{slot}); + } + + template + std::vector> + get_variadic_non_graph_tensor(slot_id_t slot) const { + return std::get>>( + this->ptr->get_variadic_tensor(slot, PRIV, TensorType::NON_GRAPH)); } Allocator get_allocator() const { diff --git a/lib/local-execution/include/local-execution/task_binding.h b/lib/local-execution/include/local-execution/task_binding.h new file mode 100644 index 0000000000..93461e2e55 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_binding.h @@ -0,0 +1,76 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_BINDING_H +#define _FLEXFLOW_LOCAL_EXECUTION_TASK_BINDING_H + +#include "local-execution/reduced_tensor_t.dtg.h" +#include "local-execution/slot_id_t.dtg.h" +#include "local-execution/slot_tensor_type_id.dtg.h" +#include "local-execution/task_arg_spec.dtg.h" +#include "local-execution/task_id_t.dtg.h" +#include "local-execution/task_signature.dtg.h" +#include "utils/hash/unordered_map.h" + +namespace FlexFlow { + +struct TaskBinding { + TaskBinding() = default; + + void bind(int, TensorType const &, reduced_tensor_t const &); + void bind(slot_id_t, TensorType const &, reduced_tensor_t const &); + + template + void bind_arg(int name, T const &t) { + this->bind_arg(slot_id_t{name}, t); + } + + template + void bind_arg(slot_id_t name, T const &t) { + this->insert_arg_spec(name, TaskArgSpec{ConcreteArgSpec::create(t)}); + } + + template + void bind_arg(int name, RuntimeArgRef const &t) { + this->bind_arg(slot_id_t{name}, t); + } + + template + void bind_arg(slot_id_t name, RuntimeArgRef const &ref) { + this->insert_arg_spec(name, TaskArgSpec{RuntimeArgRefSpec::create(ref)}); + } + + bool operator==(TaskBinding const &other) const; + bool operator!=(TaskBinding const &other) const; + + std::unordered_map const & + get_tensor_bindings() const; + std::unordered_map const &get_arg_bindings() const; + +private: + std::unordered_map tensor_bindings; + std::unordered_map arg_bindings; + +private: + void insert_arg_spec(slot_id_t name, TaskArgSpec const &arg_spec); + std::tuple + tie() const; +}; + +std::string format_as(TaskBinding const &x); +std::ostream &operator<<(std::ostream &s, TaskBinding const &x); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::TaskBinding> { + size_t operator()(::FlexFlow::TaskBinding const &s) const { + size_t result = 0; + hash_combine(result, s.get_tensor_bindings()); + hash_combine(result, s.get_arg_bindings()); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/local-execution/include/local-execution/task_impl_function.variant.toml b/lib/local-execution/include/local-execution/task_impl_function.variant.toml index a12be37da2..48cab9eb01 100644 --- a/lib/local-execution/include/local-execution/task_impl_function.variant.toml +++ b/lib/local-execution/include/local-execution/task_impl_function.variant.toml @@ -8,14 +8,19 @@ features = [ ] includes = [ - "local-execution/init_task_impl_function.h", - "local-execution/fwd_bwd_task_impl_function.h", + "local-execution/init_op_task_impl_function.h", + "local-execution/fwd_bwd_op_task_impl_function.h", + "local-execution/generic_task_impl_function.h", ] [[values]] -type = "::FlexFlow::InitTaskImplFunction" -key = "init_task_impl_function" +type = "::FlexFlow::InitOpTaskImplFunction" +key = "init_op_task_impl_function" [[values]] -type = "::FlexFlow::FwdBwdTaskImplFunction" -key = "fwd_bwd_task_impl_function" +type = "::FlexFlow::FwdBwdOpTaskImplFunction" +key = "fwd_bwd_op_task_impl_function" + +[[values]] +type = "::FlexFlow::GenericTaskImplFunction" +key = "generic_task_impl_function" diff --git a/lib/local-execution/include/local-execution/task_invocation.h b/lib/local-execution/include/local-execution/task_invocation.h new file mode 100644 index 0000000000..d03d6ac8e1 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_invocation.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_INVOCATION_H +#define _FLEXFLOW_LOCAL_EXECUTION_TASK_INVOCATION_H + +#include "local-execution/task_invocation.dtg.h" + +namespace FlexFlow { + +bool is_invocation_valid(TaskSignature const &sig, TaskInvocation const &inv); + +} + +#endif diff --git a/lib/local-execution/include/local-execution/task_invocation.struct.toml b/lib/local-execution/include/local-execution/task_invocation.struct.toml new file mode 100644 index 0000000000..c9e1e22ba1 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_invocation.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TaskInvocation" +features = [ + "eq", + "fmt", + "hash" +] + +includes = [ + "local-execution/task_binding.h", + "local-execution/task_id_t.dtg.h" +] + + +[[fields]] +name = "task_id" +type = "::FlexFlow::task_id_t" + +[[fields]] +name = "binding" +type = "::FlexFlow::TaskBinding" diff --git a/lib/local-execution/include/local-execution/task_registry.h b/lib/local-execution/include/local-execution/task_registry.h index e00cc183da..24790a28e3 100644 --- a/lib/local-execution/include/local-execution/task_registry.h +++ b/lib/local-execution/include/local-execution/task_registry.h @@ -2,6 +2,7 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_REGISTRY_H #define _FLEXFLOW_LOCAL_EXECUTION_TASK_REGISTRY_H +#include "local-execution/op_task_type.dtg.h" #include "local-execution/task_registry.dtg.h" #include "op-attrs/computation_graph_op_attrs.h" @@ -13,6 +14,10 @@ void register_tasks_for_layer(TaskRegistry &, layer_guid_t const &, ComputationGraphOpAttrs const &attrs); +bool registry_contains_op_task(TaskRegistry const &, + layer_guid_t const &, + OpTaskType const &); + } // namespace FlexFlow #endif diff --git a/lib/local-execution/include/local-execution/task_signature.h b/lib/local-execution/include/local-execution/task_signature.h new file mode 100644 index 0000000000..b10edce6d4 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_signature.h @@ -0,0 +1,60 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_SIGNATURE_H +#define _FLEXFLOW_LOCAL_EXECUTION_TASK_SIGNATURE_H + +#include "local-execution/task_signature.dtg.h" +#include "utils/type_index.h" + +namespace FlexFlow { + +TaskSignature make_empty_task_signature(); + +void add_slot(TaskSignature &, + int name, + TensorType, + SlotType slot_type = SlotType::TENSOR); +void add_slot(TaskSignature &, + slot_id_t name, + TensorType, + SlotType slot_type = SlotType::TENSOR); + +template +void add_arg_slot(TaskSignature &task_signature, int name) { + add_arg_slot(task_signature, slot_id_t{name}); +} + +template +void add_arg_slot(TaskSignature &task_signature, slot_id_t name) { + // static_assert(is_serializable::value, "Type must be serializable"); + task_signature.task_arg_types.insert({name, get_type_index_for_type()}); +} + +template +void add_return_value(TaskSignature &task_signature) { + task_signature.return_value = get_type_index_for_type(); +} + +/** + * @brief Adds an argument slot without checking if it is serializable. + * + * This function is used for arguments that are device-specific. + */ + +template +void add_unchecked_arg_slot(TaskSignature &task_signature, int name) { + add_unchecked_arg_slot(task_signature, slot_id_t{name}); +} + +/** + * @brief Adds an argument slot without checking if it is serializable. + * + * This function is used for arguments that are device-specific. + */ + +template +void add_unchecked_arg_slot(TaskSignature &task_signature, slot_id_t name) { + task_signature.task_arg_types.insert({name, get_type_index_for_type()}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/task_signature.struct.toml b/lib/local-execution/include/local-execution/task_signature.struct.toml new file mode 100644 index 0000000000..7efb0c658a --- /dev/null +++ b/lib/local-execution/include/local-execution/task_signature.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "TaskSignature" +features = [ + "eq", + "fmt", + "hash" +] + +includes = [ + "local-execution/tensor_type_slot_spec.dtg.h", + "local-execution/slot_id_t.dtg.h", + "", + "" +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", + "utils/fmt/optional.h", + "utils/type_index.h" +] + +[[fields]] +name = "return_value" +type = "std::optional" + +[[fields]] +name = "task_arg_types" +type = "std::unordered_map<::FlexFlow::slot_id_t, std::type_index>" + +[[fields]] +name = "tensor_guid_slots" +type = "std::unordered_map<::FlexFlow::slot_id_t, ::FlexFlow::TensorTypeSlotSpec>" diff --git a/lib/local-execution/include/local-execution/tensor_reduction.h b/lib/local-execution/include/local-execution/tensor_reduction.h new file mode 100644 index 0000000000..eb55b07ee4 --- /dev/null +++ b/lib/local-execution/include/local-execution/tensor_reduction.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TENSOR_REDUCTION_H +#define _FLEXFLOW_LOCAL_EXECUTION_TENSOR_REDUCTION_H + +#include "local-execution/reduced_tensor_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" + +namespace FlexFlow { + +reduced_tensor_t lower(tensor_guid_t const &); + +std::vector lower(std::vector const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/tensor_type.enum.toml b/lib/local-execution/include/local-execution/tensor_type.enum.toml new file mode 100644 index 0000000000..31ce5ba83a --- /dev/null +++ b/lib/local-execution/include/local-execution/tensor_type.enum.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TensorType" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "NON_GRAPH" + +[[values]] +name = "FORWARD" + +[[values]] +name = "GRADIENT" + +[[values]] +name = "OPTIMIZER" diff --git a/lib/local-execution/include/local-execution/tensor_type_slot_spec.struct.toml b/lib/local-execution/include/local-execution/tensor_type_slot_spec.struct.toml new file mode 100644 index 0000000000..ceba809474 --- /dev/null +++ b/lib/local-execution/include/local-execution/tensor_type_slot_spec.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "TensorTypeSlotSpec" +features = [ + "eq", + "fmt", + "hash", + "ord", +] + +includes = [ + "local-execution/slot_type.dtg.h", + "local-execution/slot_id_t.dtg.h", + "local-execution/tensor_type.dtg.h", +] + +[[fields]] +name = "slot_id" +type = "::FlexFlow::slot_id_t" + +[[fields]] +name = "tensor_type" +type = "::FlexFlow::TensorType" + +[[fields]] +name = "slot_type" +type = "::FlexFlow::SlotType" diff --git a/lib/local-execution/src/fwd_bwd_op_task_impl_function.cc b/lib/local-execution/src/fwd_bwd_op_task_impl_function.cc new file mode 100644 index 0000000000..308dbfd3ae --- /dev/null +++ b/lib/local-execution/src/fwd_bwd_op_task_impl_function.cc @@ -0,0 +1,54 @@ +#include "local-execution/fwd_bwd_op_task_impl_function.h" + +namespace FlexFlow { + +bool FwdBwdOpTaskImplFunction::operator==( + FwdBwdOpTaskImplFunction const &other) const { + return this->function_ptr == other.function_ptr; +} + +bool FwdBwdOpTaskImplFunction::operator!=( + FwdBwdOpTaskImplFunction const &other) const { + return this->function_ptr != other.function_ptr; +} + +bool FwdBwdOpTaskImplFunction::operator<( + FwdBwdOpTaskImplFunction const &other) const { + return this->function_ptr < other.function_ptr; +} + +bool FwdBwdOpTaskImplFunction::operator>( + FwdBwdOpTaskImplFunction const &other) const { + return this->function_ptr > other.function_ptr; +} + +bool FwdBwdOpTaskImplFunction::operator<=( + FwdBwdOpTaskImplFunction const &other) const { + return this->function_ptr <= other.function_ptr; +} + +bool FwdBwdOpTaskImplFunction::operator>=( + FwdBwdOpTaskImplFunction const &other) const { + return this->function_ptr >= other.function_ptr; +} + +std::string format_as(FwdBwdOpTaskImplFunction const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, FwdBwdOpTaskImplFunction const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::FwdBwdOpTaskImplFunction const &x) const { + return std::hash{}(x.function_ptr); +} +} // namespace std diff --git a/lib/local-execution/src/fwd_bwd_task_impl_function.cc b/lib/local-execution/src/fwd_bwd_task_impl_function.cc deleted file mode 100644 index f85d7cec61..0000000000 --- a/lib/local-execution/src/fwd_bwd_task_impl_function.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "local-execution/fwd_bwd_task_impl_function.h" - -namespace FlexFlow { - -bool FwdBwdTaskImplFunction::operator==( - FwdBwdTaskImplFunction const &other) const { - return this->function_ptr == other.function_ptr; -} - -bool FwdBwdTaskImplFunction::operator!=( - FwdBwdTaskImplFunction const &other) const { - return this->function_ptr != other.function_ptr; -} - -bool FwdBwdTaskImplFunction::operator<( - FwdBwdTaskImplFunction const &other) const { - return this->function_ptr < other.function_ptr; -} - -bool FwdBwdTaskImplFunction::operator>( - FwdBwdTaskImplFunction const &other) const { - return this->function_ptr > other.function_ptr; -} - -bool FwdBwdTaskImplFunction::operator<=( - FwdBwdTaskImplFunction const &other) const { - return this->function_ptr <= other.function_ptr; -} - -bool FwdBwdTaskImplFunction::operator>=( - FwdBwdTaskImplFunction const &other) const { - return this->function_ptr >= other.function_ptr; -} - -std::string format_as(FwdBwdTaskImplFunction const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} - -std::ostream &operator<<(std::ostream &s, FwdBwdTaskImplFunction const &x) { - return s << fmt::to_string(x); -} - -} // namespace FlexFlow - -namespace std { -size_t hash::operator()( - ::FlexFlow::FwdBwdTaskImplFunction const &x) const { - return std::hash{}(x.function_ptr); -} -} // namespace std diff --git a/lib/local-execution/src/generic_task_impl_function.cc b/lib/local-execution/src/generic_task_impl_function.cc new file mode 100644 index 0000000000..87d4db53e6 --- /dev/null +++ b/lib/local-execution/src/generic_task_impl_function.cc @@ -0,0 +1,53 @@ +#include "local-execution/generic_task_impl_function.h" + +namespace FlexFlow { + +bool GenericTaskImplFunction::operator==( + GenericTaskImplFunction const &other) const { + return this->function_ptr == other.function_ptr; +} + +bool GenericTaskImplFunction::operator!=( + GenericTaskImplFunction const &other) const { + return this->function_ptr != other.function_ptr; +} + +bool GenericTaskImplFunction::operator<( + GenericTaskImplFunction const &other) const { + return this->function_ptr < other.function_ptr; +} + +bool GenericTaskImplFunction::operator>( + GenericTaskImplFunction const &other) const { + return this->function_ptr > other.function_ptr; +} + +bool GenericTaskImplFunction::operator<=( + GenericTaskImplFunction const &other) const { + return this->function_ptr <= other.function_ptr; +} + +bool GenericTaskImplFunction::operator>=( + GenericTaskImplFunction const &other) const { + return this->function_ptr >= other.function_ptr; +} + +std::string format_as(GenericTaskImplFunction const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, GenericTaskImplFunction const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::GenericTaskImplFunction const &x) const { + return std::hash{}(x.function_ptr); +} +} // namespace std diff --git a/lib/local-execution/src/init_op_task_impl_function.cc b/lib/local-execution/src/init_op_task_impl_function.cc new file mode 100644 index 0000000000..abe84b828e --- /dev/null +++ b/lib/local-execution/src/init_op_task_impl_function.cc @@ -0,0 +1,53 @@ +#include "local-execution/init_op_task_impl_function.h" + +namespace FlexFlow { + +bool InitOpTaskImplFunction::operator==( + InitOpTaskImplFunction const &other) const { + return this->function_ptr == other.function_ptr; +} + +bool InitOpTaskImplFunction::operator!=( + InitOpTaskImplFunction const &other) const { + return this->function_ptr != other.function_ptr; +} + +bool InitOpTaskImplFunction::operator<( + InitOpTaskImplFunction const &other) const { + return this->function_ptr < other.function_ptr; +} + +bool InitOpTaskImplFunction::operator>( + InitOpTaskImplFunction const &other) const { + return this->function_ptr > other.function_ptr; +} + +bool InitOpTaskImplFunction::operator<=( + InitOpTaskImplFunction const &other) const { + return this->function_ptr <= other.function_ptr; +} + +bool InitOpTaskImplFunction::operator>=( + InitOpTaskImplFunction const &other) const { + return this->function_ptr >= other.function_ptr; +} + +std::string format_as(InitOpTaskImplFunction const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, InitOpTaskImplFunction const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::InitOpTaskImplFunction const &x) const { + return std::hash{}(x.function_ptr); +} +} // namespace std diff --git a/lib/local-execution/src/init_task_impl_function.cc b/lib/local-execution/src/init_task_impl_function.cc deleted file mode 100644 index 9501f72dd6..0000000000 --- a/lib/local-execution/src/init_task_impl_function.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "local-execution/init_task_impl_function.h" - -namespace FlexFlow { - -bool InitTaskImplFunction::operator==(InitTaskImplFunction const &other) const { - return this->function_ptr == other.function_ptr; -} - -bool InitTaskImplFunction::operator!=(InitTaskImplFunction const &other) const { - return this->function_ptr != other.function_ptr; -} - -bool InitTaskImplFunction::operator<(InitTaskImplFunction const &other) const { - return this->function_ptr < other.function_ptr; -} - -bool InitTaskImplFunction::operator>(InitTaskImplFunction const &other) const { - return this->function_ptr > other.function_ptr; -} - -bool InitTaskImplFunction::operator<=(InitTaskImplFunction const &other) const { - return this->function_ptr <= other.function_ptr; -} - -bool InitTaskImplFunction::operator>=(InitTaskImplFunction const &other) const { - return this->function_ptr >= other.function_ptr; -} - -std::string format_as(InitTaskImplFunction const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, InitTaskImplFunction const &x) { - return s << fmt::to_string(x); -} - -} // namespace FlexFlow - -namespace std { -size_t hash::operator()( - ::FlexFlow::InitTaskImplFunction const &x) const { - return std::hash{}(x.function_ptr); -} -} // namespace std diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index b42aec10bb..c99a2b154f 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -4,6 +4,7 @@ #include "local-execution/tracked_allocator.h" #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" +#include "pcg/computation_graph/layer_added_result.dtg.h" #include "pcg/computation_graph_builder.h" #include "pcg/parallel_tensor_attrs.h" #include "utils/containers/transform.h" @@ -66,25 +67,27 @@ CostDetails LocalCostEstimator::estimate_cost( }; // add operator to graph - std::vector output_tensor_ids = - cg_builder.add_layer(layer_attrs, - input_tensor_ids, - transform(get_vector_piece_attrs(weights), - [&](TensorAttrs const &a) { - return cg_builder.create_weight(a); - }), - get_vector_piece_attrs(outputs)); + LayerAddedResult layer_added_result = + cg_builder.add_layer_and_get_layer_added_result( + layer_attrs, + input_tensor_ids, + transform(get_vector_piece_attrs(weights), + [&](TensorAttrs const &a) { + return cg_builder.create_weight(a); + }), + get_vector_piece_attrs(outputs)); LocalTrainingBacking local_backing(allocator, cg_builder.computation_graph, tensor_backing_map, this->runtime_arg_config); + local_backing.register_and_allocate_layer(layer_added_result.layer); + local_backing.execute_init(layer_added_result.layer); + float fwd = local_backing.execute_forward(layer_added_result.layer).value(); + float bwd = local_backing.execute_backward(layer_added_result.layer).value(); + float total_execution_time = fwd + bwd; - local_backing.execute_init(); - PerLayerElapsedTime fwd = local_backing.execute_forward(); - PerLayerElapsedTime bwd = local_backing.execute_backward(); - - return CostDetails{get_total_elapsed_time(fwd, bwd), + return CostDetails{total_execution_time, tracked_allocator_ptr->get_current_mem_usage()}; } diff --git a/lib/local-execution/src/local_slots_backing.cc b/lib/local-execution/src/local_slots_backing.cc index ac35d63c0b..f1bb5a9a5b 100644 --- a/lib/local-execution/src/local_slots_backing.cc +++ b/lib/local-execution/src/local_slots_backing.cc @@ -1,14 +1,19 @@ #include "local-execution/local_slots_backing.h" +#include "local-execution/tensor_reduction.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/computation_graph.h" #include "utils/containers/contains_key.h" +#include "utils/containers/map_values.h" #include "utils/overload.h" namespace FlexFlow { -LocalSlotsBacking::LocalSlotsBacking(TensorBackingMap const &allocated_tensors, - RuntimeArgConfig const &runtime_arg_config) - : tensor_mapping(allocated_tensors), +LocalSlotsBacking::LocalSlotsBacking( + LayerTensorBackingMap const &allocated_forward_tensors, + TensorBackingMap const &allocated_non_graph_tensors, + RuntimeArgConfig const &runtime_arg_config) + : tensor_mapping(allocated_forward_tensors), + non_graph_tensor_mapping(allocated_non_graph_tensors), runtime_arg_config(runtime_arg_config){}; void LocalSlotsBacking::add_per_device_op_state( @@ -17,74 +22,133 @@ void LocalSlotsBacking::add_per_device_op_state( this->per_device_op_states.insert({op_guid, device_state}); } -void LocalSlotsBacking::allocate_outgoing_tensors( +void LocalSlotsBacking::allocate_layer_tensors( layer_guid_t const &layer_guid, ComputationGraph const &computation_graph, Allocator &allocator) { - std::vector incoming_tensors = - get_incoming_tensors(computation_graph, layer_guid); - std::vector outgoing_tensors = - get_outgoing_tensors(computation_graph, layer_guid); - for (tensor_guid_t const &output_tensor : outgoing_tensors) { - TensorAttrs tensor_attrs = - get_tensor_attrs(computation_graph, output_tensor); + this->allocate_tensors_by_role( + TensorRole::INPUT, layer_guid, computation_graph, allocator); + this->allocate_tensors_by_role( + TensorRole::WEIGHT, layer_guid, computation_graph, allocator); + this->allocate_tensors_by_role( + TensorRole::OUTPUT, layer_guid, computation_graph, allocator); +} + +void LocalSlotsBacking::allocate_tensors_by_role( + TensorRole const &role, + layer_guid_t const &layer_guid, + ComputationGraph const &computation_graph, + Allocator &allocator) { + std::vector tensors; + switch (role) { + case TensorRole::INPUT: + tensors = get_incoming_inputs(computation_graph, layer_guid); + this->input_tensor_slots.insert({layer_guid, lower(tensors)}); + break; + case TensorRole::WEIGHT: + tensors = get_incoming_weights(computation_graph, layer_guid); + this->weight_tensor_slots.insert({layer_guid, lower(tensors)}); + break; + case TensorRole::OUTPUT: + tensors = get_outgoing_tensors(computation_graph, layer_guid); + this->output_tensor_slots.insert({layer_guid, lower(tensors)}); + break; + default: + throw mk_runtime_error("Invalid tensor role, got {}", role); + } + + for (tensor_guid_t const &tensor : tensors) { + TensorAttrs tensor_attrs = get_tensor_attrs(computation_graph, tensor); + reduced_tensor_t reduced_tensor = lower(tensor); + LayerTensorKey layer_tensor_key = + LayerTensorKey{layer_guid, reduced_tensor}; // tensor allocation - if (!is_tensor_allocated(output_tensor)) { + if (!is_forward_tensor_allocated(layer_tensor_key)) { GenericTensorAccessorW tensor_backing = allocator.allocate_tensor(tensor_attrs.shape); - this->tensor_mapping.insert({output_tensor, tensor_backing}); + this->tensor_mapping.insert({layer_tensor_key, tensor_backing}); } // gradient tensor allocation - if (tensor_attrs.create_gradients == CreateGrad::YES && - !is_gradient_tensor_allocated(output_tensor)) { + if (tensor_attrs.create_gradients == CreateGrad::YES) { GenericTensorAccessorW gradient_tensor_backing = allocator.allocate_tensor(tensor_attrs.shape); this->gradient_tensor_mapping.insert( - {output_tensor, gradient_tensor_backing}); + {layer_tensor_key, gradient_tensor_backing}); } } +} - this->input_tensor_slots.insert({layer_guid, incoming_tensors}); - this->output_tensor_slots.insert({layer_guid, outgoing_tensors}); +void LocalSlotsBacking::allocate_optimizer_tensors( + layer_guid_t const &weight_layer, + tensor_guid_t const &weight, + ComputationGraph const &cg, + Allocator &allocator, + TaskSignature const &sig) { + GenericTensorAccessorW weight_backing = this->get_tensor_backing( + TensorType::FORWARD, lower(weight), weight_layer); + int num_grad_buffer_tensors = + sig.tensor_guid_slots.size() - 2; // ignore 2 (weight and weight_grad) + std::vector optimizer_buffer_tensors; + for (int i = 0; i < num_grad_buffer_tensors; ++i) { + reduced_tensor_t buffer_tensor = reduced_tensor_t{i}; + GenericTensorAccessorW buffer_backing = allocator.allocate_tensor( + get_tensor_shape(weight_backing.shape, weight_backing.data_type)); + this->optimizer_tensor_mapping.insert( + {LayerTensorKey{weight_layer, buffer_tensor}, buffer_backing}); + optimizer_buffer_tensors.push_back(buffer_tensor); + } + this->weight_optimizer_tensor_guids.insert( + {weight_layer, optimizer_buffer_tensors}); } -bool LocalSlotsBacking::is_tensor_allocated( - tensor_guid_t const &tensor_id) const { - return contains_key(this->tensor_mapping, tensor_id); +bool LocalSlotsBacking::is_forward_tensor_allocated( + LayerTensorKey const &layer_tensor_id) const { + return contains_key(this->tensor_mapping, layer_tensor_id); } -bool LocalSlotsBacking::is_gradient_tensor_allocated( - tensor_guid_t const &tensor_id) const { - return contains_key(this->gradient_tensor_mapping, tensor_id); +bool LocalSlotsBacking::is_non_graph_tensor_allocated( + reduced_tensor_t const &tensor_id) const { + return contains_key(this->non_graph_tensor_mapping, tensor_id); } -GenericTensorAccessorW const & - LocalSlotsBacking::get_tensor_backing(tensor_guid_t const &tensor_id, - IsGrad is_grad) const { - switch (is_grad) { - case IsGrad::NO: - assert(contains_key(this->tensor_mapping, tensor_id)); - return this->tensor_mapping.at(tensor_id); - case IsGrad::YES: - assert(contains_key(this->gradient_tensor_mapping, tensor_id)); - return this->gradient_tensor_mapping.at(tensor_id); +GenericTensorAccessorW const &LocalSlotsBacking::get_tensor_backing( + TensorType const &tensor_type, + reduced_tensor_t const &tensor_id, + std::optional const &layer_guid) const { + switch (tensor_type) { + case TensorType::FORWARD: + return this->tensor_mapping.at( + LayerTensorKey{layer_guid.value(), tensor_id}); + case TensorType::NON_GRAPH: + return this->non_graph_tensor_mapping.at(tensor_id); + case TensorType::GRADIENT: + return this->gradient_tensor_mapping.at( + LayerTensorKey{layer_guid.value(), tensor_id}); + case TensorType::OPTIMIZER: + return this->optimizer_tensor_mapping.at( + LayerTensorKey{layer_guid.value(), tensor_id}); default: - throw mk_runtime_error(fmt::format( - "IsGrad should only have YES or NO, received {}", is_grad)); + throw mk_runtime_error( + fmt::format("Invalid tensor type {}", tensor_type)); } } TensorSlotsBacking LocalSlotsBacking::construct_tensor_slots_backing( OpTaskBinding const &binding, layer_guid_t const &op_guid) const { TensorSlotsBacking mapping; + for (auto const &tensor_binding : binding.get_tensor_bindings()) { - SlotGradId slot_grad_id = tensor_binding.first; + SlotTensorTypeId slot_grad_id = tensor_binding.first; OpTensorSpec tensor_spec = tensor_binding.second; - std::vector tensor_guids; + std::vector tensor_guids; + int weight_adjusted_idx = 0; switch (tensor_spec.role) { - case TensorRole::INPUT: case TensorRole::WEIGHT: + assert(contains_key(this->weight_tensor_slots, op_guid)); + tensor_guids = this->weight_tensor_slots.at(op_guid); + break; + case TensorRole::INPUT: assert(contains_key(this->input_tensor_slots, op_guid)); tensor_guids = this->input_tensor_slots.at(op_guid); break; @@ -94,39 +158,59 @@ TensorSlotsBacking LocalSlotsBacking::construct_tensor_slots_backing( break; default: throw mk_runtime_error( - fmt::format("Invalid TensorRole")); // inserting role yields - // "type_is_unformattable" error + fmt::format("Invalid TensorRole {}", tensor_spec.role)); } - assert(tensor_guids.size() > tensor_spec.idx); - IsGrad is_grad = slot_grad_id.is_grad; - GenericTensorAccessorW tensor_backing = - this->get_tensor_backing(tensor_guids.at(tensor_spec.idx), is_grad); + mapping.insert({slot_grad_id, + this->get_tensor_backing(slot_grad_id.tensor_type, + tensor_guids.at(tensor_spec.idx), + op_guid)}); + } + return mapping; +} - mapping.insert({slot_grad_id, tensor_backing}); +TensorSlotsBacking LocalSlotsBacking::construct_tensor_slots_backing( + TaskBinding const &binding, + std::optional const &layer_guid) const { + TensorSlotsBacking mapping; + + for (auto const &tensor_binding : binding.get_tensor_bindings()) { + reduced_tensor_t tensor_id = tensor_binding.second; + SlotTensorTypeId slot_tensor_type_id = tensor_binding.first; + GenericTensorAccessorW accessor = this->get_tensor_backing( + slot_tensor_type_id.tensor_type, tensor_id, layer_guid); + mapping.insert({slot_tensor_type_id, accessor}); } + return mapping; } ArgSlotsBacking LocalSlotsBacking::construct_arg_slots_backing( OpTaskBinding const &binding, layer_guid_t const &op_guid) const { - ArgSlotsBacking mapping; - for (auto const &arg_binding : binding.get_arg_bindings()) { - slot_id_t arg_slot = arg_binding.first; - OpArgSpec op_arg_spec = arg_binding.second; - - mapping.insert({arg_slot, - op_arg_spec.visit(overload{ - [&](OpArgRefSpec const &s) { - return this->resolve_op_arg_ref_spec(s, op_guid); - }, - [&](RuntimeArgRefSpec const &s) { - return this->resolve_runtime_arg_ref_spec(s); - }, - [](ConcreteArgSpec const &s) { return s; }, - })}); - } - return mapping; + return map_values( + binding.get_arg_bindings(), [&](OpArgSpec const &arg_binding) { + return arg_binding.template visit( + overload{[&](OpArgRefSpec const &s) { + return this->resolve_op_arg_ref_spec(s, op_guid); + }, + [&](RuntimeArgRefSpec const &s) { + return this->resolve_runtime_arg_ref_spec(s); + }, + [](ConcreteArgSpec const &s) { return s; }}); + }); +} + +ArgSlotsBacking LocalSlotsBacking::construct_arg_slots_backing( + TaskBinding const &binding) const { + return map_values( + binding.get_arg_bindings(), [&](TaskArgSpec const &arg_binding) { + return arg_binding.template visit( + overload{[&](RuntimeArgRefSpec const &s) { + return this->resolve_runtime_arg_ref_spec(s); + }, + [](ConcreteArgSpec const &s) { return s; }}); + }); + ; } ConcreteArgSpec LocalSlotsBacking::resolve_op_arg_ref_spec( @@ -143,12 +227,14 @@ ConcreteArgSpec LocalSlotsBacking::resolve_op_arg_ref_spec( op_arg_ref_spec.get_ref_type().get(); assert(contains_key(this->input_tensor_slots, op_guid)); - std::vector input_tensor_guids = + std::vector input_tensor_guids = this->input_tensor_slots.at(op_guid); assert(input_tensor_guids.size() > index_op_arg_ref.idx); - GenericTensorAccessorW tensor_backing = this->get_tensor_backing( - input_tensor_guids.at(index_op_arg_ref.idx), IsGrad::NO); + GenericTensorAccessorW tensor_backing = + this->get_tensor_backing(TensorType::FORWARD, + input_tensor_guids.at(index_op_arg_ref.idx), + op_guid); ParallelTensorShape shape = lift_to_parallel( get_tensor_shape(tensor_backing.shape, tensor_backing.data_type)); return ConcreteArgSpec::create(shape); diff --git a/lib/local-execution/src/local_task_argument_accessor.cc b/lib/local-execution/src/local_task_argument_accessor.cc index 5d0156201e..75479a1f88 100644 --- a/lib/local-execution/src/local_task_argument_accessor.cc +++ b/lib/local-execution/src/local_task_argument_accessor.cc @@ -19,10 +19,10 @@ ConcreteArgSpec const & } GenericTensorAccessor LocalTaskArgumentAccessor::get_tensor( - slot_id_t slot, Permissions priv, IsGrad is_grad) const { - SlotGradId slot_grad_pair = SlotGradId{slot, is_grad}; + slot_id_t slot, Permissions priv, TensorType tensor_type) const { + SlotTensorTypeId slot_tensor_type = SlotTensorTypeId{slot, tensor_type}; auto tensor_backing = std::get( - this->tensor_slots_backing.at(slot_grad_pair)); + this->tensor_slots_backing.at(slot_tensor_type)); if (priv == Permissions::RO) { GenericTensorAccessorR readonly_tensor_backing = { tensor_backing.data_type, tensor_backing.shape, tensor_backing.ptr}; @@ -34,10 +34,10 @@ GenericTensorAccessor LocalTaskArgumentAccessor::get_tensor( } } VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( - slot_id_t slot, Permissions priv, IsGrad is_grad) const { - SlotGradId slot_grad_pair = SlotGradId{slot, is_grad}; + slot_id_t slot, Permissions priv, TensorType tensor_type) const { + SlotTensorTypeId slot_tensor_type = SlotTensorTypeId{slot, tensor_type}; auto variadic_tensor_backing = std::get>( - this->tensor_slots_backing.at(slot_grad_pair)); + this->tensor_slots_backing.at(slot_tensor_type)); if (priv == Permissions::RO) { std::vector readonly_variadic_tensor_backing = {}; for (GenericTensorAccessorW const &tensor_backing : diff --git a/lib/local-execution/src/local_training_backing.cc b/lib/local-execution/src/local_training_backing.cc index 0fdf1761e3..e432b1afe9 100644 --- a/lib/local-execution/src/local_training_backing.cc +++ b/lib/local-execution/src/local_training_backing.cc @@ -1,7 +1,14 @@ #include "local-execution/local_training_backing.h" +#include "local-execution/loss_functions.h" +#include "local-execution/optimizer.h" +#include "local-execution/task_invocation.h" #include "local-execution/task_signature_impl.h" +#include "local-execution/tensor_reduction.h" #include "pcg/computation_graph.h" -#include "utils/containers/reversed.h" +#include "pcg/optimizer_attrs.h" +#include "utils/containers/contains.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/get_only.h" #include "utils/exception.h" namespace FlexFlow { @@ -9,22 +16,34 @@ namespace FlexFlow { LocalTrainingBacking::LocalTrainingBacking( Allocator const &allocator, ComputationGraph const &computation_graph, - TensorBackingMap const &tensor_backing_mapping, + LayerTensorBackingMap const &allocated_forward_tensors, + TensorBackingMap const &allocated_non_graph_tensors, RuntimeArgConfig const &runtime_arg_config) : allocator(allocator), computation_graph(computation_graph), - local_slots_backing(tensor_backing_mapping, runtime_arg_config), - task_registry(empty_task_registry()) { - - for (layer_guid_t const &node : topological_ordering(computation_graph)) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(computation_graph, node).attrs; - - // allocate outgoing tensors - this->local_slots_backing.allocate_outgoing_tensors( - node, computation_graph, this->allocator); + local_slots_backing(allocated_forward_tensors, + allocated_non_graph_tensors, + runtime_arg_config), + task_registry(empty_task_registry()) {} + +void LocalTrainingBacking::register_and_allocate_layer( + layer_guid_t const &node) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, node).attrs; + this->local_slots_backing.allocate_layer_tensors( + node, this->computation_graph, this->allocator); + register_tasks_for_layer(this->task_registry, node, attrs); +} - // register tasks - register_tasks_for_layer(this->task_registry, node, attrs); +void LocalTrainingBacking::allocate_layer_optimizer_tensors( + layer_guid_t const &node, OptimizerAttrs const &optimizer_attrs) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, node).attrs; + if (attrs.has()) { + TaskSignature sig = get_update_signature(optimizer_attrs); + tensor_guid_t weight_tensor = + get_only(get_outgoing_tensors(this->computation_graph, node)); + this->local_slots_backing.allocate_optimizer_tensors( + node, weight_tensor, this->computation_graph, this->allocator, sig); } } @@ -34,7 +53,7 @@ DeviceSpecificDeviceStates TaskSignatureAndImpl task_sig_impl = this->task_registry.task_mapping.at(task_id); auto fn = - task_sig_impl.impl_function.get().function_ptr; + task_sig_impl.impl_function.get().function_ptr; return fn(acc); } @@ -44,71 +63,109 @@ std::optional TaskSignatureAndImpl task_sig_impl = this->task_registry.task_mapping.at(task_id); auto fn = - task_sig_impl.impl_function.get().function_ptr; + task_sig_impl.impl_function.get().function_ptr; return fn(acc); } -void LocalTrainingBacking::execute_init() { - for (layer_guid_t const &operator_node : - topological_ordering(this->computation_graph)) { - if (this->task_registry.init_task_ids.at(operator_node).has_value()) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - - OpTaskInvocation invocation = init(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - DeviceSpecificDeviceStates device_state = - this->call_init_task_impl(invocation.task_id, accessor); - this->local_slots_backing.add_per_device_op_state(operator_node, - device_state); - } +void LocalTrainingBacking::execute_init(layer_guid_t const &operator_node) { + if (registry_contains_op_task( + this->task_registry, operator_node, OpTaskType::INIT)) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = init(attrs); + TaskArgumentAccessor accessor = + this->get_op_task_arg_accessor(invocation, operator_node); + DeviceSpecificDeviceStates device_state = + this->call_init_task_impl(invocation.task_id, accessor); + this->local_slots_backing.add_per_device_op_state(operator_node, + device_state); } } -PerLayerElapsedTime LocalTrainingBacking::execute_forward() { - PerLayerElapsedTime per_op_elapsed_time; - for (layer_guid_t const &operator_node : - topological_ordering(this->computation_graph)) { - if (this->task_registry.forward_task_ids.at(operator_node).has_value()) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - - OpTaskInvocation invocation = forward(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - std::optional elapsed_time = - this->call_task_impl(invocation.task_id, accessor); - per_op_elapsed_time.insert({operator_node, elapsed_time}); - } +std::optional + LocalTrainingBacking::execute_forward(layer_guid_t const &operator_node) { + if (registry_contains_op_task( + this->task_registry, operator_node, OpTaskType::FWD)) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = forward(attrs); + TaskArgumentAccessor accessor = + this->get_op_task_arg_accessor(invocation, operator_node); + return this->call_task_impl(invocation.task_id, accessor); + } else { + return std::nullopt; } - return per_op_elapsed_time; } -PerLayerElapsedTime LocalTrainingBacking::execute_backward() { - PerLayerElapsedTime per_op_elapsed_time; - for (layer_guid_t const &operator_node : - reversed(topological_ordering(this->computation_graph))) { - if (this->task_registry.backward_task_ids.at(operator_node).has_value()) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - - OpTaskInvocation invocation = backward(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - std::optional elapsed_time = - this->call_task_impl(invocation.task_id, accessor); - per_op_elapsed_time.insert({operator_node, elapsed_time}); - } +void LocalTrainingBacking::compute_loss(LossAttrs const &loss_attrs, + reduced_tensor_t const &logit_tensor, + reduced_tensor_t const &label_tensor) { + assert( + this->local_slots_backing.is_non_graph_tensor_allocated(logit_tensor) && + this->local_slots_backing.is_non_graph_tensor_allocated(label_tensor)); + TaskInvocation loss_invocation = + backward(loss_attrs, logit_tensor, label_tensor); + // assert(is_invocation_valid(get_loss_bwd_signature(), loss_invocation)); + TaskArgumentAccessor loss_accessor = + this->get_task_arg_accessor(loss_invocation, std::nullopt); + TaskImplFunction loss_impl_fn = get_loss_bwd_task_impl(); + loss_impl_fn.get().function_ptr(loss_accessor); +} + +std::optional + LocalTrainingBacking::execute_backward(layer_guid_t const &operator_node) { + if (registry_contains_op_task( + this->task_registry, operator_node, OpTaskType::BWD)) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = backward(attrs); + TaskArgumentAccessor accessor = + this->get_op_task_arg_accessor(invocation, operator_node); + return this->call_task_impl(invocation.task_id, accessor); + } else { + return std::nullopt; } - return per_op_elapsed_time; } -void LocalTrainingBacking::execute_update() { - NOT_IMPLEMENTED(); +void LocalTrainingBacking::execute_update( + layer_guid_t const &node, OptimizerAttrs const &optimizer_attrs) { + LayerAttrs layer_attrs = get_layer_attrs(this->computation_graph, node); + if (layer_attrs.attrs.has()) { + // get tensors + reduced_tensor_t weight_tensor = + lower(get_only(get_outgoing_tensors(this->computation_graph, node))); + std::vector optimizer_buffer_tensors = + this->local_slots_backing.weight_optimizer_tensor_guids.at(node); + + // get invocation + TaskInvocation invocation = get_update_invocation( + optimizer_attrs, weight_tensor, optimizer_buffer_tensors); + // assert(is_invocation_valid(get_update_signature(attrs), invocation)); + + // execute update + TaskArgumentAccessor accessor = + this->get_task_arg_accessor(invocation, node); + TaskImplFunction update_impl_fn = get_update_task_impl(optimizer_attrs); + update_impl_fn.get().function_ptr(accessor); + } } TaskArgumentAccessor LocalTrainingBacking::get_task_arg_accessor( + TaskInvocation const &invocation, + std::optional const &layer_guid) const { + TensorSlotsBacking tensor_slots_backing = + this->local_slots_backing.construct_tensor_slots_backing( + invocation.binding, layer_guid); + ArgSlotsBacking arg_slots_backing = + this->local_slots_backing.construct_arg_slots_backing(invocation.binding); + return TaskArgumentAccessor::create( + this->allocator, tensor_slots_backing, arg_slots_backing); +} + +TaskArgumentAccessor LocalTrainingBacking::get_op_task_arg_accessor( OpTaskInvocation const &invocation, layer_guid_t const &op_guid) const { TensorSlotsBacking tensor_slots_backing = this->local_slots_backing.construct_tensor_slots_backing( diff --git a/lib/runtime/src/loss_functions.cc b/lib/local-execution/src/loss_functions.cc similarity index 53% rename from lib/runtime/src/loss_functions.cc rename to lib/local-execution/src/loss_functions.cc index b0d5ac2029..e54841acb5 100644 --- a/lib/runtime/src/loss_functions.cc +++ b/lib/local-execution/src/loss_functions.cc @@ -13,56 +13,48 @@ * limitations under the License. */ -#include "loss_functions.h" +#include "op-attrs/ops/loss_functions.h" #include "kernels/loss_function_kernels.h" -#include "legion.h" -#include "runtime/profiling.h" -#include "task_spec/task_argument_accessor.h" +#include "local-execution/loss_functions.h" +#include "local-execution/profiling.h" namespace FlexFlow { -enum LossSlots { - LOGIT_GRAD, - LOGIT, - LABEL, - LOSS_ATTRS, - BATCH_SIZE, - PROFILING_SETTINGS -}; +enum Slots { LOGIT, LABEL, ATTRS, PROFILING }; -TaskInvocation backward_invocation(LossAttrs const &attrs, - EnableProfiling enable_profiling, - parallel_tensor_guid_t logit, - parallel_tensor_guid_t label) { - auto binding = IndexTaskBinding{LOGIT}; - StandardTypedTaskArg arg = attrs; - binding.bind_arg(LOSS_ATTRS, attrs); - binding.bind(LOGIT, logit); - binding.bind(LABEL, label); - binding.bind(LOGIT_GRAD, grad(logit)); - binding.bind_arg(PROFILING_SETTINGS, profiling_settings()); +TaskSignature get_loss_bwd_signature() { + TaskSignature sig = make_empty_task_signature(); + add_slot(sig, LOGIT, TensorType::NON_GRAPH); + add_slot(sig, LABEL, TensorType::NON_GRAPH); + add_slot(sig, LOGIT, TensorType::GRADIENT); - /* if ((logit_domain != part_domain) || (label_domain != part_domain)) { */ // TODO @lockshaw make sure this is still checked - /* fprintf(stderr, */ - /* "Encounter inconsistency in parallelizing loss computation"); */ - /* assert(false); */ - /* } */ - return {LOSS_BWD_TASK_ID, binding}; + add_arg_slot(sig, ATTRS); + add_arg_slot(sig, PROFILING); + return sig; } -static void - loss_backward_task(Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto attrs = acc.get_argument(LOSS_ATTRS); - auto profiling_settings = - acc.get_argument(PROFILING_SETTINGS); - auto batch_size = acc.get_argument(BATCH_SIZE); - auto logit_grad = acc.get_tensor(LOGIT_GRAD); +TaskInvocation backward(LossAttrs const &attrs, + reduced_tensor_t logit, + reduced_tensor_t label) { + TaskBinding b; + b.bind(LOGIT, TensorType::NON_GRAPH, logit); + b.bind(LABEL, TensorType::NON_GRAPH, label); + b.bind(LOGIT, TensorType::GRADIENT, logit); + + b.bind_arg(ATTRS, attrs); + b.bind_arg(PROFILING, profiling_settings()); + + return TaskInvocation{task_id_t::LOSS_BWD_TASK_ID, b}; +} + +static void backward_task_impl(TaskArgumentAccessor const &acc) { + auto attrs = acc.get_argument(ATTRS); + auto profiling = acc.get_argument(PROFILING); + auto logit_grad = acc.get_tensor_grad(LOGIT); auto logit = acc.get_tensor(LOGIT); auto label = acc.get_tensor(LABEL); + int batch_size = logit.shape.at(legion_dim_t{1}); + // assuming logit shape is [batch dim, num classes] LossFunction loss_type = get_loss_function(attrs); float scale_factor = 1.0f / batch_size; @@ -72,50 +64,45 @@ static void } if (loss_type == LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY) { - // assertion the outter-most dim is replica dim and replica degree is 1 - auto scce_attrs = get(attrs); + // label shape is [batch dim, 1] + auto scce_attrs = attrs.get(); size_t ndim = logit.shape.num_dims(); - assert(logit.shape.at(legion_dim_t(ndim - 1)) == 1); - int num_samples = logit.shape.at(legion_dim_t(ndim - 2)); - int num_classes = logit.shape.get_volume() / num_samples; + int num_classes = logit.shape.at(legion_dim_t{0}); assert(logit_grad.shape == logit.shape); int k = 1; if (scce_attrs.replace_labels) { k = logit.shape.at(legion_dim_t(ndim - 1)) / label.shape.at(legion_dim_t( ndim - 1)); // TODO FIXME something seems wrong here, isn't the - // numerator guaranteed to be 1? + // numerator guaranteed to be 1? <--- this is not the + // case because of the potential parallel dim } - assert(label.shape.sub_shape(legion_dim_t(1), nullopt) == - logit.shape.sub_shape(legion_dim_t(1), nullopt)); + assert(label.shape.sub_shape(legion_dim_t(1), std::nullopt) == + logit.shape.sub_shape(legion_dim_t(1), std::nullopt)); assert(k * label.shape.at(legion_dim_t(ndim - 1)) == logit.shape.at(legion_dim_t(ndim - 1))); assert(label.shape.at(legion_dim_t(0)) == 1); profile(sparse_categorical_crossentropy_loss_backward_kernel, - profiling_settings, + profiling, "[SparseCategoricalCrossEntropyLoss] backward_time = %.2lfms\n", get_float_ptr(logit_grad), get_float_ptr(logit), - get_int32_ptr(label), - logit.shape.get_volume(), + reinterpret_cast(get_float_ptr(label)), + get_volume(logit.shape), get_volume(logit_grad.shape), - num_samples, + batch_size, num_classes, k, scale_factor); } else { assert(logit.shape == label.shape); assert(logit_grad.shape == logit.shape); - // assertion the outter-most dim is replica dim and replica degree is 1 - size_t ndim = logit.shape.num_dims(); - assert(logit.shape.at(legion_dim_t(ndim - 1)) == 1); - int num_samples = label.shape.at(legion_dim_t(ndim - 1)); - int num_channels = logit.shape.get_volume() / num_samples; + int num_channels = logit.shape.at(legion_dim_t{0}); switch (loss_type) { case LossFunction::CATEGORICAL_CROSSENTROPY: { profile(categorical_crossentropy_loss_backward_kernel, - profiling_settings, + profiling, "[CategoricalCrossEntropyLoss] backward_time = %.2lfms\n", get_float_ptr(logit_grad), get_float_ptr(logit), @@ -127,7 +114,7 @@ static void } case LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE: { profile(mean_squared_error_avg_loss_backward_kernel, - profiling_settings, + profiling, "[MeanSquaredErrorAvgLoss] backward_time = %.2lfms\n", get_float_ptr(logit_grad), get_float_ptr(logit), @@ -139,7 +126,7 @@ static void } case LossFunction::IDENTITY: { profile(identity_loss_backward_kernel, - profiling_settings, + profiling, "[IdentityLoss] backward_time = %.2lfms\n", get_float_ptr(logit_grad), get_float_ptr(logit), @@ -156,16 +143,8 @@ static void } } -template <> -void register_task() { - TaskSignature sig; - sig.add_arg_slot(LOSS_ATTRS); - sig.add_arg_slot(PROFILING_SETTINGS); - sig.add_slot(LOGIT, {SlotType::TENSOR, Permissions::RO}); - sig.add_slot(LABEL, {SlotType::TENSOR, Permissions::RO}); - sig.add_slot(LOGIT_GRAD, {SlotType::TENSOR, Permissions::RW}); - - register_task(LOSS_BWD_TASK_ID, "Loss Backward", sig, loss_backward_task); +TaskImplFunction get_loss_bwd_task_impl() { + return TaskImplFunction{GenericTaskImplFunction{backward_task_impl}}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/model_training_instance.cc b/lib/local-execution/src/model_training_instance.cc new file mode 100644 index 0000000000..5a58e4c524 --- /dev/null +++ b/lib/local-execution/src/model_training_instance.cc @@ -0,0 +1,77 @@ +#include "local-execution/model_training_instance.h" +#include "pcg/computation_graph.h" +#include "pcg/optimizer_attrs.h" +#include "utils/containers/reversed.h" + +namespace FlexFlow { + +ModelTrainingInstance::ModelTrainingInstance( + Allocator const &allocator, + ComputationGraph const &computation_graph, + LayerTensorBackingMap const &allocated_forward_tensors, + TensorBackingMap const &allocated_non_graph_tensors, + RuntimeArgConfig const &runtime_arg_config, + LossAttrs const &loss_attrs, + reduced_tensor_t const &logit_tensor, + reduced_tensor_t const &label_tensor, + OptimizerAttrs const &optimizer_attrs) + : computation_graph(computation_graph), + training_backing(allocator, + computation_graph, + allocated_forward_tensors, + allocated_non_graph_tensors, + runtime_arg_config), + loss_attrs(loss_attrs), logit_tensor(logit_tensor), + label_tensor(label_tensor), optimizer_attrs(optimizer_attrs) { + + // allocate each layer's tensors + for (layer_guid_t const &node : + topological_ordering(this->computation_graph)) { + this->training_backing.register_and_allocate_layer(node); + this->training_backing.allocate_layer_optimizer_tensors( + node, this->optimizer_attrs); + } +} + +void ModelTrainingInstance::execute_init() { + for (layer_guid_t const &node : + topological_ordering(this->computation_graph)) { + this->training_backing.execute_init(node); + } +} + +PerLayerElapsedTime ModelTrainingInstance::execute_forward() { + PerLayerElapsedTime per_layer_elapsed_time; + for (layer_guid_t const &node : + topological_ordering(this->computation_graph)) { + std::optional elapsed_time = + this->training_backing.execute_forward(node); + per_layer_elapsed_time.insert({node, elapsed_time}); + } + return per_layer_elapsed_time; +} + +PerLayerElapsedTime ModelTrainingInstance::execute_backward() { + this->training_backing.compute_loss( + this->loss_attrs, this->logit_tensor, this->label_tensor); + + PerLayerElapsedTime per_layer_elapsed_time; + for (layer_guid_t const &node : + reversed(topological_ordering(this->computation_graph))) { + std::optional elapsed_time = + this->training_backing.execute_backward(node); + per_layer_elapsed_time.insert({node, elapsed_time}); + } + return per_layer_elapsed_time; +} + +void ModelTrainingInstance::execute_update() { + for (layer_guid_t const &node : + topological_ordering(this->computation_graph)) { + this->training_backing.execute_update(node, this->optimizer_attrs); + } + this->optimizer_attrs = + get_next_iteration_optimizer_attrs(this->optimizer_attrs); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/op_task_invocation.cc b/lib/local-execution/src/op_task_invocation.cc index 19c8894b05..81bf185911 100644 --- a/lib/local-execution/src/op_task_invocation.cc +++ b/lib/local-execution/src/op_task_invocation.cc @@ -20,7 +20,8 @@ void OpTaskBinding::bind(int slot, OpTensorSpec const &tensor_spec) { } void OpTaskBinding::bind(slot_id_t slot, OpTensorSpec const &tensor_spec) { - this->tensor_bindings.insert({SlotGradId{slot, IsGrad::NO}, tensor_spec}); + this->tensor_bindings.insert( + {SlotTensorTypeId{slot, TensorType::FORWARD}, tensor_spec}); } void OpTaskBinding::bind_grad(int slot, OpTensorSpec const &tensor_spec) { @@ -28,7 +29,8 @@ void OpTaskBinding::bind_grad(int slot, OpTensorSpec const &tensor_spec) { } void OpTaskBinding::bind_grad(slot_id_t slot, OpTensorSpec const &tensor_spec) { - this->tensor_bindings.insert({SlotGradId{slot, IsGrad::YES}, tensor_spec}); + this->tensor_bindings.insert( + {SlotTensorTypeId{slot, TensorType::GRADIENT}, tensor_spec}); } void OpTaskBinding::insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec) { @@ -44,13 +46,13 @@ bool OpTaskBinding::operator!=(OpTaskBinding const &other) const { return this->tie() != other.tie(); } -std::tuple const &, +std::tuple const &, std::unordered_map const &> OpTaskBinding::tie() const { return std::tie(this->tensor_bindings, this->arg_bindings); } -std::unordered_map const & +std::unordered_map const & OpTaskBinding::get_tensor_bindings() const { return this->tensor_bindings; } @@ -89,8 +91,8 @@ bool is_tensor_invocation_valid(OpTaskSignature const &sig, OpTaskInvocation const &inv) { auto tensor_bindings = inv.binding.get_tensor_bindings(); for (OpTensorSlotSpec const &op_tensor_slot_spec : sig.get_tensor_slots()) { - SlotGradId tensor_key = - SlotGradId{op_tensor_slot_spec.name, op_tensor_slot_spec.is_grad}; + SlotTensorTypeId tensor_key = SlotTensorTypeId{ + op_tensor_slot_spec.name, op_tensor_slot_spec.tensor_type}; OpTensorSpec op_tensor_spec = tensor_bindings.at(tensor_key); if (is_op_tensor_spec_invalid(op_tensor_slot_spec, op_tensor_spec)) { return false; diff --git a/lib/local-execution/src/op_task_signature.cc b/lib/local-execution/src/op_task_signature.cc index 932b330453..5c8b19265a 100644 --- a/lib/local-execution/src/op_task_signature.cc +++ b/lib/local-execution/src/op_task_signature.cc @@ -12,8 +12,12 @@ void OpTaskSignature::add_input_slot(int name, SlotType slot_type) { } void OpTaskSignature::add_input_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::INPUT, IsGrad::NO, OpSlotOptions::NECESSARY}; + OpTensorSlotSpec op_tensor_slot_spec = + OpTensorSlotSpec{name, + slot_type, + TensorRole::INPUT, + TensorType::FORWARD, + OpSlotOptions::NECESSARY}; this->op_tensor_slots.insert(op_tensor_slot_spec); } @@ -23,8 +27,12 @@ void OpTaskSignature::add_optional_input_slot(int name, SlotType slot_type) { void OpTaskSignature::add_optional_input_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::INPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; + OpTensorSlotSpec op_tensor_slot_spec = + OpTensorSlotSpec{name, + slot_type, + TensorRole::INPUT, + TensorType::FORWARD, + OpSlotOptions::OPTIONAL}; this->op_tensor_slots.insert(op_tensor_slot_spec); } @@ -38,7 +46,7 @@ void OpTaskSignature::add_untrainable_input_slot(slot_id_t name, OpTensorSlotSpec{name, slot_type, TensorRole::INPUT, - IsGrad::NO, + TensorType::FORWARD, OpSlotOptions::UNTRAINABLE}; this->op_tensor_slots.insert(op_tensor_slot_spec); } @@ -54,7 +62,7 @@ void OpTaskSignature::add_optional_untrainable_input_slot(slot_id_t name, OpTensorSlotSpec{name, slot_type, TensorRole::INPUT, - IsGrad::NO, + TensorType::FORWARD, OpSlotOptions::OPTIONAL_UNTRAINABLE}; this->op_tensor_slots.insert(op_tensor_slot_spec); } @@ -68,7 +76,7 @@ void OpTaskSignature::add_output_slot(slot_id_t name, SlotType slot_type) { OpTensorSlotSpec{name, slot_type, TensorRole::OUTPUT, - IsGrad::NO, + TensorType::FORWARD, OpSlotOptions::NECESSARY}; this->op_tensor_slots.insert(op_tensor_slot_spec); } @@ -80,8 +88,12 @@ void OpTaskSignature::add_bwd_optional_output_slot(int name, void OpTaskSignature::add_bwd_optional_output_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::OUTPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; + OpTensorSlotSpec op_tensor_slot_spec = + OpTensorSlotSpec{name, + slot_type, + TensorRole::OUTPUT, + TensorType::FORWARD, + OpSlotOptions::OPTIONAL}; this->op_tensor_slots.insert(op_tensor_slot_spec); } @@ -94,7 +106,7 @@ void OpTaskSignature::add_weight_slot(slot_id_t name, SlotType slot_type) { OpTensorSlotSpec{name, slot_type, TensorRole::WEIGHT, - IsGrad::NO, + TensorType::FORWARD, OpSlotOptions::NECESSARY}; this->op_tensor_slots.insert(op_tensor_slot_spec); } @@ -105,8 +117,12 @@ void OpTaskSignature::add_optional_weight_slot(int name, SlotType slot_type) { void OpTaskSignature::add_optional_weight_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::WEIGHT, IsGrad::NO, OpSlotOptions::OPTIONAL}; + OpTensorSlotSpec op_tensor_slot_spec = + OpTensorSlotSpec{name, + slot_type, + TensorRole::WEIGHT, + TensorType::FORWARD, + OpSlotOptions::OPTIONAL}; this->op_tensor_slots.insert(op_tensor_slot_spec); } @@ -130,7 +146,7 @@ OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd) { OpTensorSlotSpec{op_tensor_slot_spec.name, op_tensor_slot_spec.slot_type, op_tensor_slot_spec.tensor_role, - IsGrad::YES, + TensorType::GRADIENT, op_tensor_slot_spec.slot_option}; bwd.op_tensor_slots.insert(grad_spec); } diff --git a/lib/local-execution/src/ops/attention.cc b/lib/local-execution/src/ops/attention.cc index eebef9039d..b4c5d1ff8a 100644 --- a/lib/local-execution/src/ops/attention.cc +++ b/lib/local-execution/src/ops/attention.cc @@ -202,13 +202,13 @@ static std::optional } TaskImplFunction get_attention_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_attention_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_attention_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_attention_init_signature() { diff --git a/lib/local-execution/src/ops/batch_matmul.cc b/lib/local-execution/src/ops/batch_matmul.cc index 1eae409ae2..e358e0a645 100644 --- a/lib/local-execution/src/ops/batch_matmul.cc +++ b/lib/local-execution/src/ops/batch_matmul.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "batch_matmul.h" +#include "local-execution/ops/batch_matmul.h" #include "kernels/batch_matmul_kernels.h" #include "local-execution/op_task_signature.h" #include "op-attrs/get_output_shapes.h" @@ -153,10 +153,10 @@ static std::optional } TaskImplFunction get_batch_matmul_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_batch_matmul_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_batch_matmul_fwd_signature() { diff --git a/lib/local-execution/src/ops/batch_norm.cc b/lib/local-execution/src/ops/batch_norm.cc index 851566fc02..62155aa161 100644 --- a/lib/local-execution/src/ops/batch_norm.cc +++ b/lib/local-execution/src/ops/batch_norm.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "batch_norm.h" +#include "local-execution/ops/batch_norm.h" #include "kernels/batch_norm_kernels.h" namespace FlexFlow { @@ -144,13 +144,13 @@ static std::optional } TaskImplFunction get_batch_norm_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_batch_norm_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_batch_norm_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_batch_norm_init_signature() { diff --git a/lib/local-execution/src/ops/cast.cc b/lib/local-execution/src/ops/cast.cc index 3e7baf49a9..846faa9262 100644 --- a/lib/local-execution/src/ops/cast.cc +++ b/lib/local-execution/src/ops/cast.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "cast.h" +#include "local-execution/ops/cast.h" #include "kernels/cast_kernels.h" #include "local-execution/op_task_signature.h" @@ -79,10 +79,10 @@ static std::optional } TaskImplFunction get_cast_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_cast_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_cast_fwd_signature() { diff --git a/lib/local-execution/src/ops/combine.cc b/lib/local-execution/src/ops/combine.cc index ccc82cce17..b7e84878f4 100644 --- a/lib/local-execution/src/ops/combine.cc +++ b/lib/local-execution/src/ops/combine.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "combine.h" +#include "local-execution/ops/combine.h" #include "kernels/combine_kernels.h" #include "local-execution/op_task_invocation.h" #include "utils/hash-utils.h" @@ -85,10 +85,10 @@ OpTaskSignature get_combine_bwd_signature() { } TaskImplFunction get_combine_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_combine_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc index 4c3462e694..dee1dd08e5 100644 --- a/lib/local-execution/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "concat.h" +#include "local-execution/ops/concat.h" #include "kernels/concat_kernels.h" #include "local-execution/op_task_signature.h" @@ -79,10 +79,10 @@ static std::optional } TaskImplFunction get_concat_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_concat_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_concat_fwd_signature() { diff --git a/lib/local-execution/src/ops/conv_2d.cc b/lib/local-execution/src/ops/conv_2d.cc index d5c6e7f851..7ae92d70c7 100644 --- a/lib/local-execution/src/ops/conv_2d.cc +++ b/lib/local-execution/src/ops/conv_2d.cc @@ -1,4 +1,4 @@ -#include "conv_2d.h" +#include "local-execution/ops/conv_2d.h" #include "kernels/conv_2d_kernels.h" #include "op-attrs/get_output_shapes.h" @@ -132,13 +132,13 @@ static std::optional } TaskImplFunction get_conv_2d_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_conv_2d_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_conv_2d_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_conv_2d_init_signature() { diff --git a/lib/local-execution/src/ops/dropout.cc b/lib/local-execution/src/ops/dropout.cc index cac08866cc..017d023ec4 100644 --- a/lib/local-execution/src/ops/dropout.cc +++ b/lib/local-execution/src/ops/dropout.cc @@ -1,4 +1,4 @@ -#include "dropout.h" +#include "local-execution/ops/dropout.h" #include "kernels/dropout_kernels.h" #include "local-execution/op_task_invocation.h" #include "local-execution/op_task_signature.h" @@ -87,13 +87,13 @@ static std::optional } TaskImplFunction get_dropout_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_dropout_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_dropout_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_dropout_init_signature() { diff --git a/lib/local-execution/src/ops/element_binary.cc b/lib/local-execution/src/ops/element_binary.cc index 48c6c699a2..d4c12c7285 100644 --- a/lib/local-execution/src/ops/element_binary.cc +++ b/lib/local-execution/src/ops/element_binary.cc @@ -1,4 +1,4 @@ -#include "element_binary.h" +#include "local-execution/ops/element_binary.h" #include "kernels/element_binary_kernels.h" #include "local-execution/task_signature_impl.h" #include "op-attrs/get_output_shapes.h" @@ -126,15 +126,15 @@ static std::optional } TaskImplFunction get_element_binary_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_element_binary_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_element_binary_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_element_binary_init_signature() { diff --git a/lib/local-execution/src/ops/element_unary.cc b/lib/local-execution/src/ops/element_unary.cc index 4ee609bd6c..85ecf3db23 100644 --- a/lib/local-execution/src/ops/element_unary.cc +++ b/lib/local-execution/src/ops/element_unary.cc @@ -1,4 +1,4 @@ -#include "element_unary.h" +#include "local-execution/ops/element_unary.h" #include "kernels/element_unary_kernels.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/parallel_tensor_shape.h" @@ -35,7 +35,9 @@ OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { b.bind(INPUT, input_tensor(0)); b.bind(OUTPUT, output_tensor(0)); + b.bind_arg(ATTRS, attrs); + b.bind_arg(HANDLE, ff_handle()); b.bind_arg(PROFILING, profiling_settings()); b.bind_arg(PER_DEVICE_STATE, per_device_op_state()); @@ -52,15 +54,17 @@ OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); - ProfilingSettings profiling = acc.get_argument(PROFILING); + auto attrs = acc.get_argument(ATTRS); + ParallelTensorShape input_shape = acc.get_argument(INPUT_SHAPE); ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - ElementUnaryPerDeviceState per_device_state = init_kernel( - get_piece_shape(input_shape), get_piece_shape(output_shape), attrs); + ElementUnaryPerDeviceState per_device_state = + init_kernel(ArrayShape{get_piece_shape(input_shape)}, + ArrayShape{get_piece_shape(output_shape)}, + attrs); return DeviceSpecificDeviceStates{ DeviceSpecific::create(per_device_state)}; @@ -69,7 +73,7 @@ static DeviceSpecificDeviceStates static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); - auto const &attrs = acc.get_argument(ATTRS); + auto attrs = acc.get_argument(ATTRS); auto handle = acc.get_argument(HANDLE); @@ -114,13 +118,13 @@ static std::optional } TaskImplFunction get_element_unary_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_element_unary_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_element_unary_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_element_unary_init_signature() { diff --git a/lib/local-execution/src/ops/flat.cc b/lib/local-execution/src/ops/flat.cc index 3fe5029fa1..ef4dc7ab68 100644 --- a/lib/local-execution/src/ops/flat.cc +++ b/lib/local-execution/src/ops/flat.cc @@ -1,4 +1,4 @@ -#include "flat.h" +#include "local-execution/ops/flat.h" #include "kernels/flat_kernels.h" #include "op-attrs/get_output_shapes.h" @@ -53,10 +53,10 @@ static std::optional } TaskImplFunction get_flat_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_flat_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_flat_fwd_signature() { diff --git a/lib/local-execution/src/ops/gather.cc b/lib/local-execution/src/ops/gather.cc index a015c64f4d..180026e9ba 100644 --- a/lib/local-execution/src/ops/gather.cc +++ b/lib/local-execution/src/ops/gather.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "gather.h" +#include "local-execution/ops/gather.h" #include "kernels/gather_kernels.h" #include "local-execution/legion_tensor_shape.h" #include "op-attrs/get_output_shapes.h" @@ -122,13 +122,13 @@ static std::optional } TaskImplFunction get_gather_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_gather_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_gather_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_gather_init_signature() { diff --git a/lib/local-execution/src/ops/input.cc b/lib/local-execution/src/ops/input.cc index 56d19fa1ba..d7a3888220 100644 --- a/lib/local-execution/src/ops/input.cc +++ b/lib/local-execution/src/ops/input.cc @@ -1,4 +1,4 @@ -#include "input.h" +#include "local-execution/ops/input.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/layer_norm.cc b/lib/local-execution/src/ops/layer_norm.cc index e99d27319c..c9e2a8d55e 100644 --- a/lib/local-execution/src/ops/layer_norm.cc +++ b/lib/local-execution/src/ops/layer_norm.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "layer_norm.h" +#include "local-execution/ops/layer_norm.h" #include "kernels/layer_norm_kernels.h" #include "local-execution/legion_tensor_shape.h" #include "op-attrs/get_output_shapes.h" @@ -146,13 +146,13 @@ static DeviceSpecificDeviceStates } TaskImplFunction get_layer_norm_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_layer_norm_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_layer_norm_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_layer_norm_fwd_signature() { diff --git a/lib/local-execution/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc index 9934e2a45c..075aa1d9e4 100644 --- a/lib/local-execution/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -1,4 +1,4 @@ -#include "linear.h" +#include "local-execution/ops/linear.h" #include "kernels/linear_kernels.h" #include "local-execution/task_argument_accessor.h" #include "op-attrs/ff_dim.h" @@ -161,13 +161,13 @@ static std::optional } TaskImplFunction get_linear_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_linear_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_linear_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_linear_init_signature() { diff --git a/lib/local-execution/src/ops/noop.cc b/lib/local-execution/src/ops/noop.cc index e35fdec275..7357806880 100644 --- a/lib/local-execution/src/ops/noop.cc +++ b/lib/local-execution/src/ops/noop.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "noop.h" +#include "local-execution/ops/noop.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc index 33d62b713c..66f27fa69f 100644 --- a/lib/local-execution/src/ops/pool_2d.cc +++ b/lib/local-execution/src/ops/pool_2d.cc @@ -1,4 +1,4 @@ -#include "pool_2d.h" +#include "local-execution/ops/pool_2d.h" #include "kernels/pool_2d_kernels.h" #include "op-attrs/get_output_shapes.h" @@ -141,13 +141,13 @@ static std::optional } TaskImplFunction get_pool_2d_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_pool_2d_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_pool_2d_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_pool_2d_init_signature() { diff --git a/lib/local-execution/src/ops/reduce.cc b/lib/local-execution/src/ops/reduce.cc index a043d9f847..c157a98b36 100644 --- a/lib/local-execution/src/ops/reduce.cc +++ b/lib/local-execution/src/ops/reduce.cc @@ -1,4 +1,4 @@ -#include "reduce.h" +#include "local-execution/ops/reduce.h" #include "kernels/reduce_kernels.h" #include "op-attrs/get_output_shapes.h" @@ -102,13 +102,13 @@ static std::optional } TaskImplFunction get_reduce_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_reduce_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reduce_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reduce_init_signature() { diff --git a/lib/local-execution/src/ops/reduction.cc b/lib/local-execution/src/ops/reduction.cc index a58d79a4f8..95962661e2 100644 --- a/lib/local-execution/src/ops/reduction.cc +++ b/lib/local-execution/src/ops/reduction.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "reduction.h" +#include "local-execution/ops/reduction.h" #include "kernels/reduction_kernels.h" #include "op-attrs/get_output_shapes.h" #include "utils/exception.h" @@ -74,10 +74,10 @@ static std::optional } TaskImplFunction get_reduction_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reduction_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reduction_fwd_signature() { diff --git a/lib/local-execution/src/ops/repartition.cc b/lib/local-execution/src/ops/repartition.cc index 73692f4a13..9bba8109f3 100644 --- a/lib/local-execution/src/ops/repartition.cc +++ b/lib/local-execution/src/ops/repartition.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "repartition.h" +#include "local-execution/ops/repartition.h" #include "kernels/partition_kernels.h" #include "op-attrs/get_output_shapes.h" #include "utils/exception.h" @@ -98,13 +98,13 @@ static std::optional } TaskImplFunction get_repartition_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_repartition_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_repartition_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_repartition_init_signature() { diff --git a/lib/local-execution/src/ops/replicate.cc b/lib/local-execution/src/ops/replicate.cc index 135475a711..5ae93c4439 100644 --- a/lib/local-execution/src/ops/replicate.cc +++ b/lib/local-execution/src/ops/replicate.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "replicate.h" +#include "local-execution/ops/replicate.h" #include "kernels/replicate_kernels.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/parallel_tensor_shape.h" @@ -73,10 +73,10 @@ static std::optional } TaskImplFunction get_replicate_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_replicate_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_replicate_fwd_signature() { diff --git a/lib/local-execution/src/ops/reshape.cc b/lib/local-execution/src/ops/reshape.cc index 7584d405eb..838542a8eb 100644 --- a/lib/local-execution/src/ops/reshape.cc +++ b/lib/local-execution/src/ops/reshape.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "reshape.h" +#include "local-execution/ops/reshape.h" #include "kernels/reshape_kernels.h" #include "op-attrs/get_output_shapes.h" @@ -92,13 +92,13 @@ static std::optional } TaskImplFunction get_reshape_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_reshape_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reshape_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reshape_init_signature() { diff --git a/lib/local-execution/src/ops/reverse.cc b/lib/local-execution/src/ops/reverse.cc index 366a579bea..63032585b8 100644 --- a/lib/local-execution/src/ops/reverse.cc +++ b/lib/local-execution/src/ops/reverse.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "reverse.h" +#include "local-execution/ops/reverse.h" #include "kernels/accessor.h" #include "kernels/reverse_kernels.h" #include "op-attrs/get_output_shapes.h" @@ -103,10 +103,10 @@ static std::optional } TaskImplFunction get_reverse_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reverse_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reverse_fwd_signature() { diff --git a/lib/local-execution/src/ops/softmax.cc b/lib/local-execution/src/ops/softmax.cc index 4c7979ae9b..5e78781ddc 100644 --- a/lib/local-execution/src/ops/softmax.cc +++ b/lib/local-execution/src/ops/softmax.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "softmax.h" +#include "local-execution/ops/softmax.h" #include "kernels/softmax_kernels.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/parallel_tensor_shape.h" @@ -108,13 +108,13 @@ static std::optional } TaskImplFunction get_softmax_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_softmax_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_softmax_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_softmax_init_signature() { diff --git a/lib/local-execution/src/ops/split.cc b/lib/local-execution/src/ops/split.cc index 9f039d84f8..556d30109b 100644 --- a/lib/local-execution/src/ops/split.cc +++ b/lib/local-execution/src/ops/split.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "split.h" +#include "local-execution/ops/split.h" #include "kernels/array_shape.h" #include "kernels/split_kernels.h" #include "op-attrs/get_output_shapes.h" @@ -114,10 +114,10 @@ static std::optional } TaskImplFunction get_split_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_split_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_split_fwd_signature() { diff --git a/lib/local-execution/src/ops/topk.cc b/lib/local-execution/src/ops/topk.cc index 7f3519529a..41a28340db 100644 --- a/lib/local-execution/src/ops/topk.cc +++ b/lib/local-execution/src/ops/topk.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "topk.h" +#include "local-execution/ops/topk.h" #include "kernels/topk_kernels.h" #include "op-attrs/get_output_shapes.h" #include "utils/exception.h" @@ -120,13 +120,13 @@ static std::optional } TaskImplFunction get_topk_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_topk_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_topk_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_topk_init_signature() { diff --git a/lib/local-execution/src/ops/transpose.cc b/lib/local-execution/src/ops/transpose.cc index 3e4ac15db3..78e9fbde6f 100644 --- a/lib/local-execution/src/ops/transpose.cc +++ b/lib/local-execution/src/ops/transpose.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "transpose.h" +#include "local-execution/ops/transpose.h" #include "kernels/transpose_kernels.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/transpose.h" @@ -99,13 +99,13 @@ OpTaskInvocation backward(TransposeAttrs const &attrs) { } TaskImplFunction get_transpose_init_task_impl() { - return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; } TaskImplFunction get_transpose_fwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_transpose_bwd_task_impl() { - return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_transpose_init_signature() { diff --git a/lib/local-execution/src/ops/weight.cc b/lib/local-execution/src/ops/weight.cc index 5537163e85..f96c104f33 100644 --- a/lib/local-execution/src/ops/weight.cc +++ b/lib/local-execution/src/ops/weight.cc @@ -1,4 +1,4 @@ -#include "weight.h" +#include "local-execution/ops/weight.h" namespace FlexFlow { diff --git a/lib/local-execution/src/optimizer.cc b/lib/local-execution/src/optimizer.cc new file mode 100644 index 0000000000..5c0d6c54f2 --- /dev/null +++ b/lib/local-execution/src/optimizer.cc @@ -0,0 +1,215 @@ +#include "local-execution/optimizer.h" +#include "kernels/optimizer_kernels.h" +#include "local-execution/profiling.h" +#include "utils/overload.h" + +namespace FlexFlow { + +enum Slots { ATTRS, WEIGHT, SGD_V, PROFILING, ADAM_M, ADAM_V, HANDLE }; + +TaskSignature get_sgd_update_signature() { + TaskSignature sig = make_empty_task_signature(); + add_slot(sig, WEIGHT, TensorType::FORWARD); + add_slot(sig, WEIGHT, TensorType::GRADIENT); + add_slot(sig, SGD_V, TensorType::OPTIMIZER); + + add_arg_slot(sig, ATTRS); + add_arg_slot(sig, PROFILING); + if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + add_unchecked_arg_slot(sig, HANDLE); + } + return sig; +} + +TaskInvocation sgd_update(SGDOptimizerAttrs const &attrs, + reduced_tensor_t const &weight, + reduced_tensor_t const &sgd_v) { + TaskBinding b; + b.bind(WEIGHT, TensorType::FORWARD, weight); + b.bind(WEIGHT, TensorType::GRADIENT, weight); + + if (attrs.momentum > 0.0f) { + b.bind(SGD_V, TensorType::OPTIMIZER, sgd_v); + } + b.bind_arg(ATTRS, attrs); + b.bind_arg(PROFILING, profiling_settings()); + + if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + b.bind_arg(HANDLE, ff_handle()); + return TaskInvocation{task_id_t::SGD_UPD_NCCL_TASK_ID, b}; + } else { + return TaskInvocation{task_id_t::SGD_UPD_PS_TASK_ID, b}; + } +} + +static void sgd_update_task_impl(TaskArgumentAccessor const &acc) { + auto attrs = acc.get_argument(ATTRS); + auto weight_grad = acc.get_tensor_grad(WEIGHT); + auto weight = acc.get_tensor(WEIGHT); + auto profiling = acc.get_argument(PROFILING); + + assert(weight.shape == weight_grad.shape); + size_t size = weight_grad.shape.get_volume(); + + assert(weight_grad.shape.get_volume() & weight.shape.get_volume() == 0); + size_t num_replicas = + weight_grad.shape.get_volume() / weight.shape.get_volume(); + + float *sgd_v_ptr; + if (attrs.momentum > 0.0f) { + auto sgd_v = acc.get_tensor(SGD_V); + assert(sgd_v.shape == weight.shape); + sgd_v_ptr = sgd_v.get_float_ptr(); + } + + if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + auto handle = acc.get_argument(HANDLE); + profile(sgd_nccl_update_task_gpu, + profiling, + "[SGD NCCL] update_time = %.2lfms\n", + attrs.lr, + attrs.momentum, + attrs.nesterov, + attrs.weight_decay, + handle, + weight_grad.get_float_ptr(), + size, + weight.get_float_ptr(), + sgd_v_ptr); + + } else { + profile(sgd_ps_update_task_gpu, + profiling, + "[SGD PS] update_time = %.2lfms\n", + attrs.lr, + attrs.momentum, + attrs.nesterov, + attrs.weight_decay, + weight_grad.get_float_ptr(), + size, + num_replicas, + weight.get_float_ptr(), + sgd_v_ptr); + } +} + +TaskImplFunction get_sgd_update_task_impl() { + return TaskImplFunction{GenericTaskImplFunction{sgd_update_task_impl}}; +} + +TaskSignature get_adam_update_signature() { + TaskSignature sig = make_empty_task_signature(); + add_slot(sig, WEIGHT, TensorType::FORWARD); + add_slot(sig, WEIGHT, TensorType::GRADIENT); + add_slot(sig, ADAM_V, TensorType::OPTIMIZER); + add_slot(sig, ADAM_M, TensorType::OPTIMIZER); + + add_arg_slot(sig, ATTRS); + add_arg_slot(sig, PROFILING); + if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + add_unchecked_arg_slot(sig, HANDLE); + } + return sig; +} + +TaskInvocation adam_update(AdamOptimizerAttrs const &attrs, + reduced_tensor_t const &weight, + reduced_tensor_t const &adam_v, + reduced_tensor_t const &adam_m) { + TaskBinding b; + b.bind(WEIGHT, TensorType::FORWARD, weight); + b.bind(WEIGHT, TensorType::GRADIENT, weight); + b.bind(ADAM_M, TensorType::OPTIMIZER, adam_m); + b.bind(ADAM_V, TensorType::OPTIMIZER, adam_v); + b.bind_arg(ATTRS, attrs); + b.bind_arg(PROFILING, profiling_settings()); + + if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + b.bind_arg(HANDLE, ff_handle()); + return TaskInvocation{task_id_t::ADAM_UPD_NCCL_TASK_ID, b}; + } else { + return TaskInvocation{task_id_t::ADAM_UPD_PS_TASK_ID, b}; + } +} + +static void adam_update_task_impl(TaskArgumentAccessor const &acc) { + auto attrs = acc.get_argument(ATTRS); + auto weight_grad = acc.get_tensor_grad(WEIGHT); + auto weight = acc.get_tensor(WEIGHT); + auto v_tensor = acc.get_tensor(ADAM_V); + auto m_tensor = acc.get_tensor(ADAM_M); + + auto profiling = acc.get_argument(PROFILING); + + assert(weight.shape == weight_grad.shape); + size_t size = weight_grad.shape.get_volume(); + + assert(weight_grad.shape.get_volume() % weight.shape.get_volume() == 0); + size_t num_replicas = + weight_grad.shape.get_volume() / weight.shape.get_volume(); + + if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + auto handle = acc.get_argument(HANDLE); + profile(adam_nccl_update_task_gpu, + profiling, + "[Adam NCCL] update_time = %.2lfms\n", + attrs.alpha_t, + attrs.beta1, + attrs.beta2, + attrs.weight_decay, + attrs.epsilon, + size, + handle, + weight_grad.get_float_ptr(), + m_tensor.get_float_ptr(), + v_tensor.get_float_ptr(), + weight.get_float_ptr()); + } else { + profile(adam_ps_update_task_gpu, + profiling, + "[Adam NCCL] update_time = %.2lfms\n", + attrs.alpha_t, + attrs.beta1, + attrs.beta2, + attrs.weight_decay, + attrs.epsilon, + size, + num_replicas, + weight_grad.get_float_ptr(), + m_tensor.get_float_ptr(), + v_tensor.get_float_ptr(), + weight.get_float_ptr()); + } +} + +TaskImplFunction get_adam_update_task_impl() { + return TaskImplFunction{GenericTaskImplFunction{adam_update_task_impl}}; +} + +TaskSignature get_update_signature(OptimizerAttrs const &attrs) { + return attrs.visit(overload{ + [&](SGDOptimizerAttrs const &) { return get_sgd_update_signature(); }, + [&](AdamOptimizerAttrs const &) { return get_adam_update_signature(); }}); +} + +TaskInvocation get_update_invocation( + OptimizerAttrs const &attrs, + reduced_tensor_t const &weight, + std::vector const &grad_buffer_tensors) { + return attrs.visit(overload{ + [&](SGDOptimizerAttrs const &s) { + return sgd_update(s, weight, grad_buffer_tensors.at(0)); + }, + [&](AdamOptimizerAttrs const &s) { + return adam_update( + s, weight, grad_buffer_tensors.at(0), grad_buffer_tensors.at(1)); + }}); +} + +TaskImplFunction get_update_task_impl(OptimizerAttrs const &attrs) { + return attrs.visit(overload{ + [&](SGDOptimizerAttrs const &) { return get_sgd_update_task_impl(); }, + [&](AdamOptimizerAttrs const &) { return get_adam_update_task_impl(); }}); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/runtime_arg_ref.cc b/lib/local-execution/src/runtime_arg_ref.cc index 56201a5c55..1f591b4d82 100644 --- a/lib/local-execution/src/runtime_arg_ref.cc +++ b/lib/local-execution/src/runtime_arg_ref.cc @@ -3,6 +3,19 @@ namespace FlexFlow { +std::string to_string(RuntimeArgRefType const &runtime_arg_ref_type) { + switch (runtime_arg_ref_type) { + case RuntimeArgRefType::FF_HANDLE: + return "FF_HANDLE"; + case RuntimeArgRefType::PROFILING_SETTINGS: + return "PROFILING_SETTINGS"; + case RuntimeArgRefType::FF_ITERATION_CONFIG: + return "FF_ITERATION_CONFIG"; + default: + return "Unknown"; + } +} + RuntimeArgRef profiling_settings() { return {RuntimeArgRefType::PROFILING_SETTINGS}; } diff --git a/lib/local-execution/src/task_binding.cc b/lib/local-execution/src/task_binding.cc new file mode 100644 index 0000000000..5261eec217 --- /dev/null +++ b/lib/local-execution/src/task_binding.cc @@ -0,0 +1,60 @@ +#include "local-execution/task_binding.h" +#include "utils/containers/contains_key.h" +#include "utils/fmt/unordered_map.h" + +namespace FlexFlow { + +void TaskBinding::bind(int name, + TensorType const &tensor_type, + reduced_tensor_t const &binding) { + this->bind(slot_id_t{name}, tensor_type, binding); +} + +void TaskBinding::bind(slot_id_t name, + TensorType const &tensor_type, + reduced_tensor_t const &binding) { + this->tensor_bindings.insert({SlotTensorTypeId{name, tensor_type}, binding}); +} + +void TaskBinding::insert_arg_spec(slot_id_t name, TaskArgSpec const &arg_spec) { + assert(!contains_key(this->arg_bindings, name)); + this->arg_bindings.insert({name, arg_spec}); +} + +bool TaskBinding::operator==(TaskBinding const &other) const { + return this->tie() == other.tie(); +} + +bool TaskBinding::operator!=(TaskBinding const &other) const { + return this->tie() != other.tie(); +} + +std::tuple const &, + std::unordered_map const &> + TaskBinding::tie() const { + return std::tie(this->tensor_bindings, this->arg_bindings); +} + +std::unordered_map const & + TaskBinding::get_tensor_bindings() const { + return this->tensor_bindings; +} + +std::unordered_map const & + TaskBinding::get_arg_bindings() const { + return this->arg_bindings; +} + +std::string format_as(TaskBinding const &x) { + std::ostringstream oss; + oss << "> task_ids; + switch (op_task_type) { + case OpTaskType::INIT: + task_ids = task_registry.init_task_ids; + break; + case OpTaskType::FWD: + task_ids = task_registry.forward_task_ids; + break; + case OpTaskType::BWD: + task_ids = task_registry.backward_task_ids; + break; + default: + throw mk_runtime_error("Invalid OpTaskType, got {}", op_task_type); + } + + return task_ids.at(op).has_value(); +} + } // namespace FlexFlow diff --git a/lib/local-execution/src/task_signature.cc b/lib/local-execution/src/task_signature.cc new file mode 100644 index 0000000000..a608ab8ab8 --- /dev/null +++ b/lib/local-execution/src/task_signature.cc @@ -0,0 +1,25 @@ +#include "local-execution/task_signature.h" + +namespace FlexFlow { + +TaskSignature make_empty_task_signature() { + return TaskSignature(std::nullopt, {}, {}); +} + +void add_slot(TaskSignature &task_signature, + int name, + TensorType tensor_type, + SlotType slot_type) { + add_slot(task_signature, slot_id_t{name}, tensor_type, slot_type); +} + +void add_slot(TaskSignature &task_signature, + slot_id_t name, + TensorType tensor_type, + SlotType slot_type) { + TensorTypeSlotSpec tensor_guid_slot_spec = + TensorTypeSlotSpec{slot_type, tensor_type}; + task_signature.tensor_guid_slots.insert({name, tensor_guid_slot_spec}); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/task_signature_impl.cc b/lib/local-execution/src/task_signature_impl.cc index ca428aad25..199e232a6b 100644 --- a/lib/local-execution/src/task_signature_impl.cc +++ b/lib/local-execution/src/task_signature_impl.cc @@ -1,33 +1,33 @@ #include "local-execution/task_signature_impl.h" #include "local-execution/ops/attention.h" -#include "ops/batch_matmul.h" -#include "ops/batch_norm.h" -#include "ops/cast.h" -#include "ops/combine.h" -#include "ops/concat.h" -#include "ops/conv_2d.h" -#include "ops/dropout.h" -#include "ops/element_binary.h" -#include "ops/element_unary.h" -#include "ops/embedding.h" -#include "ops/flat.h" -#include "ops/gather.h" -#include "ops/input.h" -#include "ops/layer_norm.h" -#include "ops/linear.h" -#include "ops/noop.h" -#include "ops/pool_2d.h" -#include "ops/reduce.h" -#include "ops/reduction.h" -#include "ops/repartition.h" -#include "ops/replicate.h" -#include "ops/reshape.h" -#include "ops/reverse.h" -#include "ops/softmax.h" -#include "ops/split.h" -#include "ops/topk.h" -#include "ops/transpose.h" -#include "ops/weight.h" +#include "local-execution/ops/batch_matmul.h" +#include "local-execution/ops/batch_norm.h" +#include "local-execution/ops/cast.h" +#include "local-execution/ops/combine.h" +#include "local-execution/ops/concat.h" +#include "local-execution/ops/conv_2d.h" +#include "local-execution/ops/dropout.h" +#include "local-execution/ops/element_binary.h" +#include "local-execution/ops/element_unary.h" +#include "local-execution/ops/embedding.h" +#include "local-execution/ops/flat.h" +#include "local-execution/ops/gather.h" +#include "local-execution/ops/input.h" +#include "local-execution/ops/layer_norm.h" +#include "local-execution/ops/linear.h" +#include "local-execution/ops/noop.h" +#include "local-execution/ops/pool_2d.h" +#include "local-execution/ops/reduce.h" +#include "local-execution/ops/reduction.h" +#include "local-execution/ops/repartition.h" +#include "local-execution/ops/replicate.h" +#include "local-execution/ops/reshape.h" +#include "local-execution/ops/reverse.h" +#include "local-execution/ops/softmax.h" +#include "local-execution/ops/split.h" +#include "local-execution/ops/topk.h" +#include "local-execution/ops/transpose.h" +#include "local-execution/ops/weight.h" #include "utils/overload.h" namespace FlexFlow { @@ -50,8 +50,8 @@ TaskSignatureAndImpl get_task_sig_impl(task_id_t const &task_id) { return TaskSignatureAndImpl{get_element_unary_fwd_task_impl(), get_element_unary_fwd_signature()}; case task_id_t::ELEMENTUNARY_BWD_TASK_ID: - return TaskSignatureAndImpl{get_element_binary_bwd_task_impl(), - get_element_binary_bwd_signature()}; + return TaskSignatureAndImpl{get_element_unary_bwd_task_impl(), + get_element_unary_bwd_signature()}; case task_id_t::CONV2D_INIT_TASK_ID: return TaskSignatureAndImpl{get_conv_2d_init_task_impl(), get_conv_2d_init_signature()}; diff --git a/lib/local-execution/src/tensor_reduction.cc b/lib/local-execution/src/tensor_reduction.cc new file mode 100644 index 0000000000..19324509bb --- /dev/null +++ b/lib/local-execution/src/tensor_reduction.cc @@ -0,0 +1,17 @@ +#include "local-execution/tensor_reduction.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +reduced_tensor_t lower(tensor_guid_t const &tensor_guid) { + return reduced_tensor_t{tensor_guid.raw_graph_output.idx}; +} + +std::vector + lower(std::vector const &tensor_guids) { + return transform(tensor_guids, [&](tensor_guid_t const &tensor_guid) { + return lower(tensor_guid); + }); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc index 4c01df53e9..6ff551828f 100644 --- a/lib/local-execution/test/src/test_local_cost_estimator.cc +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -5,6 +5,7 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test_utils.h" using namespace ::FlexFlow; @@ -46,13 +47,18 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { DataType::FLOAT, }); + std::vector weights; ParallelTensorShape weights_shape = throw_if_unexpected( get_weights_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); - ParallelTensorAttrs weight_attrs = - ParallelTensorAttrs{weights_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - CreateGrad::YES}; + weights.push_back(make_weight_attrs(weights_shape, std::nullopt)); + ParallelTensorShape input_bias_shape = + throw_if_unexpected(get_input_bias_shape( + attrs, inputs_shape, inputs_shape, inputs_shape)); + weights.push_back(make_weight_attrs(input_bias_shape, std::nullopt)); + ParallelTensorShape output_bias_shape = + throw_if_unexpected(get_output_bias_shape( + attrs, inputs_shape, inputs_shape, inputs_shape)); + weights.push_back(make_weight_attrs(output_bias_shape, std::nullopt)); ParallelTensorShape output_shape = throw_if_unexpected( get_output_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); @@ -66,7 +72,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { PCGOperatorAttrs{attrs}, std::vector{ inputs_shape, inputs_shape, inputs_shape}, - std::vector{weight_attrs}, + weights, std::vector{output_attrs}, make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1})); diff --git a/lib/local-execution/test/src/test_local_slots_backing.cc b/lib/local-execution/test/src/test_local_slots_backing.cc index 1ec441fbca..5d58e7e757 100644 --- a/lib/local-execution/test/src/test_local_slots_backing.cc +++ b/lib/local-execution/test/src/test_local_slots_backing.cc @@ -81,7 +81,7 @@ TEST_SUITE(FF_TEST_SUITE) { LocalSlotsBacking local_slots_backing = {tensor_backing_map, runtime_arg_config}; - SUBCASE("LocalSlotsBacking::allocate_outgoing_tensors") { + SUBCASE("LocalSlotsBacking::allocate_tensors_by_role") { auto get_result_shape_and_dtype_for_tensor_guid_and_map = [&](tensor_guid_t t, TensorBackingMap m) -> std::pair { @@ -92,14 +92,11 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Input (QKV) and gradient tensors allocation") { // allocate all tensors from input nodes - for (layer_guid_t const &node : - topological_ordering(cg_builder.computation_graph)) { - if (node == layer_guid) { - break; - } - local_slots_backing.allocate_outgoing_tensors( - node, cg_builder.computation_graph, allocator); - } + local_slots_backing.allocate_tensors_by_role( + TensorRole::INPUT, + layer_guid, + cg_builder.computation_graph, + allocator); SUBCASE("Query grad") { std::pair result = @@ -127,8 +124,11 @@ TEST_SUITE(FF_TEST_SUITE) { } } SUBCASE("Output and gradient tensors allocation") { - local_slots_backing.allocate_outgoing_tensors( - layer_guid, cg_builder.computation_graph, allocator); + local_slots_backing.allocate_tensors_by_role( + TensorRole::OUTPUT, + layer_guid, + cg_builder.computation_graph, + allocator); SUBCASE("Output") { std::pair result = get_result_shape_and_dtype_for_tensor_guid_and_map( @@ -154,14 +154,20 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Tensor slots") { - local_slots_backing.allocate_outgoing_tensors( + local_slots_backing.allocate_layer_tensors( layer_guid, cg_builder.computation_graph, allocator); SUBCASE("Input tensor slots") { - std::vector correct_incoming_tensors = - get_incoming_tensors(cg_builder.computation_graph, layer_guid); - CHECK(correct_incoming_tensors == + std::vector correct_incoming_input_tensors = + get_incoming_inputs(cg_builder.computation_graph, layer_guid); + CHECK(correct_incoming_input_tensors == local_slots_backing.input_tensor_slots.at(layer_guid)); } + SUBCASE("Weight tensor slots") { + std::vector correct_incoming_weight_tensors = + get_incoming_weights(cg_builder.computation_graph, layer_guid); + CHECK(correct_incoming_weight_tensors == + local_slots_backing.weight_tensor_slots.at(layer_guid)); + } SUBCASE("Output tensor slots") { std::vector correct_outgoing_tensors = get_outgoing_tensors(cg_builder.computation_graph, layer_guid); @@ -192,7 +198,7 @@ TEST_SUITE(FF_TEST_SUITE) { b.bind(QUERY, input_tensor(0)); b.bind(KEY, input_tensor(1)); b.bind(VALUE, input_tensor(2)); - b.bind(WEIGHTS, weight_tensor(3)); + b.bind(WEIGHTS, weight_tensor(0)); b.bind(OUTPUT, output_tensor(0)); b.bind_grad(QUERY, input_tensor(0)); @@ -205,12 +211,8 @@ TEST_SUITE(FF_TEST_SUITE) { return b; }(); - // allocate all incoming and outgoing tensors for graph - for (layer_guid_t const &node : - topological_ordering(cg_builder.computation_graph)) { - local_slots_backing.allocate_outgoing_tensors( - node, cg_builder.computation_graph, allocator); - } + local_slots_backing.allocate_layer_tensors( + layer_guid, cg_builder.computation_graph, allocator); SUBCASE("LocalSlotsBacking::construct_tensor_slots_backing") { TensorSlotsBackingWithoutAddresses result = diff --git a/lib/local-execution/test/src/test_loss_e2e.cc b/lib/local-execution/test/src/test_loss_e2e.cc new file mode 100644 index 0000000000..c4662d624c --- /dev/null +++ b/lib/local-execution/test/src/test_loss_e2e.cc @@ -0,0 +1,98 @@ +#include "doctest/doctest.h" +#include "kernels/local_cuda_allocator.h" +#include "kernels/managed_ff_stream.h" +#include "kernels/managed_per_device_ff_handle.h" +#include "local-execution/local_training_backing.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "test_utils.h" + +namespace FlexFlow { + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("Local Execution E2E") { + // initialize runtime configs + ManagedPerDeviceFFHandle managed_handle{}; + + RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ + DeviceSpecific::create(managed_handle.raw_handle()), + EnableProfiling::YES, + ProfilingSettings{/*warmup_iters=*/0, /*measure_iters=*/1}}; + + // construct graph + ComputationGraphBuilder cg_builder; + + size_t batch_size = 10; + size_t data_dim = 100; + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + tensor_guid_t input_tensor = + cg_builder.create_input(input_shape, CreateGrad::YES); + + float scalar = 4.0; + std::string layer_name = "scalar multiply"; + tensor_guid_t logit_tensor = + cg_builder.scalar_multiply(input_tensor, scalar, layer_name); + + // allocate memory + Allocator allocator = create_local_cuda_memory_allocator(); + TensorBackingMap tensor_backing_map; + GenericTensorAccessorW input_backing = + allocator.allocate_tensor(input_shape); + tensor_backing_map.insert({input_tensor, input_backing}); + + LocalTrainingBacking local_backing(allocator, + cg_builder.computation_graph, + tensor_backing_map, + runtime_arg_config); + // for (layer_guid_t const & node: + // topological_ordering(cg_builder.computation_graph)) { + // local_backing.register_and_allocate_layer(node); + // } + local_backing.register_and_allocate_layer( + get_layer_by_name(cg_builder.computation_graph, layer_name)); + + SUBCASE("SparseCategoricalCrossEntropyLossAttrs") { + TensorShape label_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, 1}}, DataType::FLOAT}; + tensor_guid_t label_tensor = + cg_builder.create_input(label_shape, CreateGrad::NO); + GenericTensorAccessorW label_backing = + allocator.allocate_tensor(label_shape); + local_backing.insert_tensor(label_tensor, label_backing); + LossAttrs loss_attrs = LossAttrs{ + SparseCategoricalCrossEntropyLossAttrs{/*replace_labels=*/false}}; + local_backing.compute_loss(loss_attrs, logit_tensor, label_tensor); + } + + SUBCASE("NonconfigurableLossAttrs") { + tensor_guid_t label_tensor = + cg_builder.create_input(input_shape, CreateGrad::NO); + GenericTensorAccessorW label_backing = + allocator.allocate_tensor(input_shape); + local_backing.insert_tensor(label_tensor, label_backing); + + SUBCASE("LossFunction::CATEGORICAL_CROSSENTROPY") { + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + local_backing.compute_loss(loss_attrs, logit_tensor, label_tensor); + } + + SUBCASE("LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE") { + LossAttrs loss_attrs = LossAttrs{NonconfigurableLossAttrs{ + LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE}}; + local_backing.compute_loss(loss_attrs, logit_tensor, label_tensor); + } + + SUBCASE("LossFunction::IDENTITY") { + LossAttrs loss_attrs = + LossAttrs{NonconfigurableLossAttrs{LossFunction::IDENTITY}}; + local_backing.compute_loss(loss_attrs, logit_tensor, label_tensor); + } + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_task_registry.cc b/lib/local-execution/test/src/test_task_registry.cc index e18b7ea2de..b71809208e 100644 --- a/lib/local-execution/test/src/test_task_registry.cc +++ b/lib/local-execution/test/src/test_task_registry.cc @@ -1,7 +1,6 @@ #include "doctest/doctest.h" #include "kernels/local_cuda_allocator.h" #include "local-execution/local_cost_estimator.h" -#include "local-execution/ops/attention.h" #include "local-execution/task_signature_impl.h" #include "pcg/computation_graph_builder.h" #include "utils/fmt/optional.h" diff --git a/lib/local-execution/test/src/test_update_e2e.cc b/lib/local-execution/test/src/test_update_e2e.cc new file mode 100644 index 0000000000..b48214d89d --- /dev/null +++ b/lib/local-execution/test/src/test_update_e2e.cc @@ -0,0 +1,96 @@ +#include "doctest/doctest.h" +#include "kernels/local_cuda_allocator.h" +#include "kernels/managed_ff_stream.h" +#include "kernels/managed_per_device_ff_handle.h" +#include "local-execution/local_training_backing.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "test_utils.h" + +namespace FlexFlow { + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("Local Execution Update E2E") { + // initialize runtime configs + ManagedPerDeviceFFHandle managed_handle{}; + + RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ + DeviceSpecific::create(managed_handle.raw_handle()), + EnableProfiling::YES, + ProfilingSettings{/*warmup_iters=*/0, /*measure_iters=*/1}}; + + // construct graph + ComputationGraphBuilder cg_builder; + + size_t batch_size = 10; + size_t data_dim = 100; + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + tensor_guid_t input_tensor = + cg_builder.create_input(input_shape, CreateGrad::YES); + + float scalar = 4.0; + std::string layer_name = "scalar_multiply"; + tensor_guid_t logit_tensor = + cg_builder.scalar_multiply(input_tensor, scalar, layer_name); + + // allocate memory + Allocator allocator = create_local_cuda_memory_allocator(); + TensorBackingMap tensor_backing_map; + GenericTensorAccessorW input_backing = + allocator.allocate_tensor(input_shape); + tensor_backing_map.insert({input_tensor, input_backing}); + + LocalTrainingBacking local_backing(allocator, + cg_builder.computation_graph, + tensor_backing_map, + runtime_arg_config); + // for (layer_guid_t const & node: + // topological_ordering(cg_builder.computation_graph)) { + // local_backing.register_and_allocate_layer(node); + // } + layer_guid_t layer_guid = + get_layer_by_name(cg_builder.computation_graph, layer_name); + local_backing.register_and_allocate_layer(layer_guid); + + SUBCASE("SGDOptimizerAttrs") { + SUBCASE("momentum=0") { + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.0f, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + local_backing.allocate_layer_optimizer_tensors(layer_guid, + optimizer_attrs); + local_backing.execute_update(layer_guid, optimizer_attrs); + } + SUBCASE("momentum=0.9") { + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + local_backing.allocate_layer_optimizer_tensors(layer_guid, + optimizer_attrs); + local_backing.execute_update(layer_guid, optimizer_attrs); + } + } + SUBCASE("AdamOptimizerAttrs") { + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{AdamOptimizerAttrs{/*alpha=*/0.001, + /*beta1=*/0.9, + /*beta2=*/0.999, + /*weight_decay=*/0.001, + /*alpha_t=*/0.001, + /*beta_t=*/0.9, + /*beta2_t=*/0.999, + /*epsilon=*/1e-8}}; + local_backing.allocate_layer_optimizer_tensors(layer_guid, + optimizer_attrs); + local_backing.execute_update(layer_guid, optimizer_attrs); + } + } +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 268554b5be..73473d6ac5 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -1,37 +1,37 @@ #ifndef _OPERATOR_PARAMS_H #define _OPERATOR_PARAMS_H +#include "local-execution/ops/attention.h" +#include "local-execution/ops/batch_matmul.h" +#include "local-execution/ops/batch_norm.h" +#include "local-execution/ops/broadcast.h" +#include "local-execution/ops/cast.h" +#include "local-execution/ops/combine.h" +#include "local-execution/ops/concat.h" +#include "local-execution/ops/conv_2d.h" +#include "local-execution/ops/dropout.h" +#include "local-execution/ops/element_binary.h" +#include "local-execution/ops/element_unary.h" +#include "local-execution/ops/embedding.h" +#include "local-execution/ops/flat.h" +#include "local-execution/ops/gather.h" +#include "local-execution/ops/input.h" +#include "local-execution/ops/layer_norm.h" +#include "local-execution/ops/linear.h" +#include "local-execution/ops/noop.h" +#include "local-execution/ops/pool_2d.h" +#include "local-execution/ops/reduce.h" +#include "local-execution/ops/reduction.h" +#include "local-execution/ops/repartition.h" +#include "local-execution/ops/replicate.h" +#include "local-execution/ops/reshape.h" +#include "local-execution/ops/reverse.h" +#include "local-execution/ops/softmax.h" +#include "local-execution/ops/split.h" +#include "local-execution/ops/topk.h" +#include "local-execution/ops/transpose.h" #include "op-attrs/ops/core.h" #include "op-attrs/pcg_operator_attrs.dtg.h" -#include "ops/attention.h" -#include "ops/batch_matmul.h" -#include "ops/batch_norm.h" -#include "ops/broadcast.h" -#include "ops/cast.h" -#include "ops/combine.h" -#include "ops/concat.h" -#include "ops/conv_2d.h" -#include "ops/dropout.h" -#include "ops/element_binary.h" -#include "ops/element_unary.h" -#include "ops/embedding.h" -#include "ops/flat.h" -#include "ops/gather.h" -#include "ops/input.h" -#include "ops/layer_norm.h" -#include "ops/linear.h" -#include "ops/noop.h" -#include "ops/pool_2d.h" -#include "ops/reduce.h" -#include "ops/reduction.h" -#include "ops/repartition.h" -#include "ops/replicate.h" -#include "ops/reshape.h" -#include "ops/reverse.h" -#include "ops/softmax.h" -#include "ops/split.h" -#include "ops/topk.h" -#include "ops/transpose.h" #include "utils/record_formatter.h" #include "utils/variant.h" #include diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions.h similarity index 62% rename from lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h rename to lib/op-attrs/include/op-attrs/ops/loss_functions.h index ca8f3e6602..657f8d91dc 100644 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions.h @@ -4,16 +4,13 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "op-attrs/ops/loss_functions/loss_function.dtg.h" +#include "op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.h" +#include "op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(LossAttrs); - -LossFunction parse_loss_function_name(std::string const &); - -LossFunction get_loss_function(OtherLossAttrs const &); -LossFunction get_loss_function(SparseCategoricalCrossEntropyLossAttrs const &); LossFunction get_loss_function(LossAttrs const &); +LossFunction parse_loss_name(std::string const &raw_name); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml index 17293095e4..943760d949 100644 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml @@ -11,7 +11,7 @@ features = [ includes = [ "op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.h", - "op-attrs/ops/loss_functions/other_loss_attrs.dtg.h", + "op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.h", ] [[values]] @@ -19,5 +19,5 @@ type = "::FlexFlow::SparseCategoricalCrossEntropyLossAttrs" key = "sparse_categorical_cross_entropy_loss" [[values]] -type = "::FlexFlow::OtherLossAttrs" -key = "other_loss" +type = "::FlexFlow::NonconfigurableLossAttrs" +key = "nonconfigurable_loss_attrs" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.struct.toml similarity index 67% rename from lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml rename to lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.struct.toml index 284a4b1d7d..3fe7ac86c5 100644 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.struct.toml @@ -1,16 +1,16 @@ namespace = "FlexFlow" -name = "OtherLossAttrs" +name = "NonconfigurableLossAttrs" features = [ "eq", "ord", "hash", - "fmt", - "rapidcheck", "json", + "rapidcheck", + "fmt", ] includes = [ - "op-attrs/ops/loss_functions/loss_function.dtg.h", + "op-attrs/ops/loss_functions/loss_function.dtg.h" ] [[fields]] diff --git a/lib/op-attrs/src/op-attrs/ops/loss_functions.cc b/lib/op-attrs/src/op-attrs/ops/loss_functions.cc index e756d08547..2b9a7533f0 100644 --- a/lib/op-attrs/src/op-attrs/ops/loss_functions.cc +++ b/lib/op-attrs/src/op-attrs/ops/loss_functions.cc @@ -1,22 +1,18 @@ -#include "op-attrs/ops/loss_functions/loss_functions.h" +#include "op-attrs/ops/loss_functions.h" #include "utils/containers/transform.h" +#include "utils/exception.h" +#include "utils/overload.h" #include #include namespace FlexFlow { -LossFunction get_loss_type(OtherLossAttrs const &attrs) { - return attrs.loss_type; -} - -LossFunction - get_loss_type(SparseCategoricalCrossEntropyLossAttrs const &attrs) { - return LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY; -} - -LossFunction get_loss_type(LossAttrs const &attrs) { +LossFunction get_loss_function(LossAttrs const &attrs) { return attrs.visit( - [](auto const &t) { return get_loss_type(t); }); + overload{[&](SparseCategoricalCrossEntropyLossAttrs const &s) { + return LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY; + }, + [&](NonconfigurableLossAttrs const &s) { return s.loss_type; }}); } LossFunction parse_loss_name(std::string const &raw_name) { diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 45cde0de57..585399ea1d 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_BUILDER_H #include "pcg/computation_graph.dtg.h" +#include "pcg/computation_graph/layer_added_result.dtg.h" #include "pcg/initializer_attrs.dtg.h" #include "pcg/tensor_guid_t.dtg.h" @@ -256,6 +257,12 @@ struct ComputationGraphBuilder { std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; + LayerAddedResult add_layer_and_get_layer_added_result( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + std::vector add_layer(LayerAttrs const &layer, std::vector const &inputs, diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h index 4bac74b999..d4abd1b52f 100644 --- a/lib/pcg/include/pcg/optimizer_attrs.h +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -1,13 +1,12 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H -#include "pcg/optimizers/adam_optimizer_attrs.h" -#include "pcg/optimizers/sgd_optimizer_attrs.h" -#include "utils/variant.h" +#ifndef _FLEXFLOW_PCG_OPTIMIZER_ATTRS_H +#define _FLEXFLOW_PCG_OPTIMIZER_ATTRS_H + +#include "pcg/optimizer_attrs.dtg.h" namespace FlexFlow { -using OptimizerAttrs = std::variant; +OptimizerAttrs get_next_iteration_optimizer_attrs(OptimizerAttrs const &old); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/optimizer_attrs.variant.toml b/lib/pcg/include/pcg/optimizer_attrs.variant.toml new file mode 100644 index 0000000000..585c150700 --- /dev/null +++ b/lib/pcg/include/pcg/optimizer_attrs.variant.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OptimizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "pcg/optimizers/sgd_optimizer_attrs.dtg.h", + "pcg/optimizers/adam_optimizer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::SGDOptimizerAttrs" +key = "sgd_optimizer" + +[[values]] +type = "::FlexFlow::AdamOptimizerAttrs" +key = "adam_optimizer" diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml index fd3e83cc4a..c25baa6c89 100644 --- a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml @@ -36,3 +36,7 @@ type = "double" [[fields]] name = "beta2_t" type = "double" + +[[fields]] +name = "epsilon" +type = "double" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 019b120936..35113553f2 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -179,6 +179,10 @@ struct ParallelComputationGraphBuilder { ParallelComputationGraph pcg; }; +ParallelTensorAttrs + make_weight_attrs(ParallelTensorShape const &shape, + std::optional const &initializer_attrs); + } // namespace FlexFlow #endif diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 4a565476bd..4c619288cb 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -106,7 +106,7 @@ static void check_incoming_tensor_roles(LayerAttrs const &layer, } } -std::vector ComputationGraphBuilder::add_layer( +LayerAddedResult ComputationGraphBuilder::add_layer_and_get_layer_added_result( LayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, @@ -115,7 +115,17 @@ std::vector ComputationGraphBuilder::add_layer( LayerAddedResult added = ::FlexFlow::add_layer( this->computation_graph, layer, concat_vectors(inputs, weights), outputs); - return added.outputs; + return added; +} + +std::vector ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + return this + ->add_layer_and_get_layer_added_result(layer, inputs, weights, outputs) + .outputs; } tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, diff --git a/lib/pcg/src/pcg/optimizer_attrs.cc b/lib/pcg/src/pcg/optimizer_attrs.cc new file mode 100644 index 0000000000..8d66f7af7e --- /dev/null +++ b/lib/pcg/src/pcg/optimizer_attrs.cc @@ -0,0 +1,25 @@ +#include "pcg/optimizer_attrs.h" + +namespace FlexFlow { + +OptimizerAttrs + get_next_iteration_optimizer_attrs(OptimizerAttrs const &old_attrs) { + if (old_attrs.has()) { + AdamOptimizerAttrs old = old_attrs.get(); + double new_beta1_t = old.beta_t * old.beta1; + double new_beta2_t = old.beta2_t * old.beta2; + double new_alpha_t = old.alpha * sqrt(1 - new_beta2_t) / (1 - new_beta1_t); + return OptimizerAttrs{AdamOptimizerAttrs{old.alpha, + old.beta1, + old.beta2, + old.weight_decay, + new_alpha_t, + new_beta1_t, + new_beta2_t, + old.epsilon}}; + } else { + return old_attrs; + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index ce00ea62f4..b56156fe8a 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -20,7 +20,7 @@ static std::string get_default_name(PCGOperatorAttrs const &attrs) { return get_default_name(get_op_type(attrs)); } -static ParallelTensorAttrs make_weight_attrs( +ParallelTensorAttrs make_weight_attrs( ParallelTensorShape const &shape, std::optional const &initializer_attrs) { return ParallelTensorAttrs{