Skip to content

Commit

Permalink
Merge branch 'repo-refactor' into purge-moe
Browse files Browse the repository at this point in the history
  • Loading branch information
reyna-abhyankar authored Oct 9, 2023
2 parents 44982f8 + e1b1be2 commit 8e6d60f
Show file tree
Hide file tree
Showing 19 changed files with 1,237 additions and 2,626 deletions.
71 changes: 52 additions & 19 deletions lib/kernels/include/kernels/batch_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,64 @@

namespace FlexFlow {

class BatchNormPerDeviceState : public PerDeviceOpState {
public:
BatchNormPerDeviceState(FFHandler handle,
std::unique_ptr<IAllocator> allocator,
int output_n,
int output_c,
int output_h,
int output_w,
bool relu,
bool profiling);
~BatchNormPerDeviceState(void);

ffTensorDescriptor_t inputTensor, outputTensor, biasTensor;
struct BatchNormPerDeviceState {
PerDeviceFFHandle handle;
Allocator allocator;
ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffTensorDescriptor_t biasTensor;
ffActivationDescriptor_t actiDesc;
ffBatchNormMode_t mode;
float *runningMean, *runningVar, *saveMean, *saveVar;
bool relu;
bool profiling;
std::unique_ptr<IAllocator> allocator;
float *runningMean;
float *runningVar;
float *saveMean;
float *saveVar;
int output_n;
int output_c;
int output_h;
int output_w;
req<bool> relu;
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(BatchNormPerDeviceState,
handle,
allocator,
inputTensor,
outputTensor,
biasTensor,
actiDesc,
mode,
runningMean,
runningVar,
saveMean,
saveVar,
output_n,
output_c,
output_h,
output_w,
relu);

namespace Kernels {
namespace BatchNorm {

BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle,
Allocator allocator,
float *runningMean,
int output_n,
int output_c,
int output_h,
int output_w,
bool relu);

void forward_kernel(ffStream_t stream,
BatchNormPerDeviceState *m,
BatchNormPerDeviceState &m,
float const *input_ptr,
float *output_ptr,
float const *scale_ptr,
float const *bias_ptr);

void backward_kernel(ffStream_t stream,
BatchNormPerDeviceState *m,
BatchNormPerDeviceState &m,
float const *input_ptr,
float *output_grad_ptr,
float const *output_ptr,
Expand All @@ -50,6 +75,14 @@ void backward_kernel(ffStream_t stream,
float *bias_grad_ptr,
size_t numElements);

void cleanup_kernel(Allocator allocator,
ffTensorDescriptor_t inputTensor,
ffTensorDescriptor_t biasTensor,
ffTensorDescriptor_t outputTensor,
ffActivationDescriptor_t actiDesc,
bool relu,
float *runningMean);

} // namespace BatchNorm
} // namespace Kernels
} // namespace FlexFlow
Expand Down
53 changes: 39 additions & 14 deletions lib/kernels/include/kernels/dropout_kernels.h
Original file line number Diff line number Diff line change
@@ -1,38 +1,63 @@
#ifndef _FLEXFLOW_OPS_KERNELS_DROPOUT_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_DROPOUT_KERNELS_H

#include "kernels/allocation.h"
#include "kernels/array_shape.h"
#include "kernels/device.h"
#include "kernels/ff_handle.h"
#include <cstddef>

namespace FlexFlow {

class DropoutPerDeviceState : public PerDeviceOpState {
struct DropoutPerDeviceState {
public:
DropoutPerDeviceState(FFHandler handler,
float rate,
unsigned long long seed,
bool profiling,
Legion::Memory gpu_mem,
Legion::Domain const &output_domain);
~DropoutPerDeviceState(void);
Realm::RegionInstance reserveInst;
ffTensorDescriptor_t inputTensor, outputTensor;
PerDeviceFFHandle handle;
Allocator allocator;
ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffDropoutDescriptor_t dropoutDesc;
void *reserveSpace, *dropoutStates;
size_t reserveSpaceSize, dropoutStateSize;
void *reserveSpace;
void *dropoutStates;
size_t reserveSpaceSize;
req<size_t> dropoutStateSize;
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(DropoutPerDeviceState,
handle,
allocator,
inputTensor,
outputTensor,
dropoutDesc,
reserveSpace,
dropoutStates,
reserveSpaceSize,
dropoutStateSize);

namespace Kernels {
namespace Dropout {

DropoutPerDeviceState init_kernel(PerDeviceFFHandle handle,
float rate,
unsigned long long seed,
ArrayShape const &output_domain,
Allocator allocator);

void forward_kernel(ffStream_t stream,
DropoutPerDeviceState *m,
DropoutPerDeviceState &m,
float const *input_ptr,
float *output_ptr);

void backward_kernel(ffStream_t stream,
DropoutPerDeviceState *m,
DropoutPerDeviceState &m,
float const *output_grad_ptr,
float *input_grad_ptr);

void cleanup_kernel(Allocator allocator,
ffTensorDescriptor_t inputTensor,
ffTensorDescriptor_t outputTensor,
ffDropoutDescriptor_t dropoutDesc,
void *dropoutStates);

} // namespace Dropout
} // namespace Kernels
} // namespace FlexFlow
Expand Down
61 changes: 40 additions & 21 deletions lib/kernels/include/kernels/element_binary_kernels.h
Original file line number Diff line number Diff line change
@@ -1,43 +1,62 @@
#ifndef _FLEXFLOW_OPS_KERNELS_ELEMENT_BINARY_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_ELEMENT_BINARY_KERNELS_H

#include "ff_handle.h"
#include "kernels/array_shape.h"
#include "kernels/device.h"
#include "op-attrs/datatype.h"
#include "op-attrs/op.h"

namespace FlexFlow {

class ElementBinaryPerDeviceState : public PerDeviceOpState {
public:
ElementBinaryPerDeviceState(FFHandler handle);
ffTensorDescriptor_t input1Tensor, input2Tensor, outputTensor;
struct ElementBinaryPerDeviceState {
PerDeviceFFHandle handle;
ffTensorDescriptor_t inputLHSTensor;
ffTensorDescriptor_t inputRHSTensor;
ffTensorDescriptor_t outputTensor;
ffOpTensorDescriptor_t opDesc;
ffReduceTensorDescriptor_t reduceAddDesc;
OperatorType op_type;
bool inplace_a, has_same_operands;
bool broadcast_input1, broadcast_input2;
char op_name[MAX_OPNAME];
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ElementBinaryPerDeviceState,
handle,
inputLHSTensor,
inputRHSTensor,
outputTensor,
opDesc,
reduceAddDesc);

namespace Kernels {
namespace ElementBinary {

void init_kernel(ElementBinaryPerDeviceState *m,
ArrayShape const &input1_domain,
ArrayShape const &input2_domain,
ArrayShape const &output_domain);
ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle,
OperatorType op_type,
bool should_broadcast_lhs,
bool should_broadcast_rhs,
ArrayShape lhs_shape,
ArrayShape rhs_shape,
ArrayShape output_shape);

void forward_kernel(ffStream_t stream,
ElementBinaryPerDeviceState const *m,
float const *in1_ptr,
float const *in2_ptr,
float *out_ptr);
ElementBinaryPerDeviceState const &m,
float const *lhs_ptr,
float const *rhs_ptr,
float *out_ptr,
OperatorType op_type,
bool broadcast_inputLHS,
PerDeviceFFHandle handle);

void backward_kernel(ffStream_t stream,
ElementBinaryPerDeviceState const *m,
ElementBinaryPerDeviceState const &m,
float const *out_grad_ptr,
float const *in1_ptr,
float const *in2_ptr,
float *in1_grad_ptr,
float *in2_grad_ptr);
float const *lhs_ptr,
float const *rhs_ptr,
float *lhs_grad_ptr,
float *rhs_grad_ptr,
OperatorType op_type,
bool broadcast_inputLHS,
bool broadcast_inputRHS,
PerDeviceFFHandle handle);

} // namespace ElementBinary
} // namespace Kernels
Expand Down
16 changes: 5 additions & 11 deletions lib/kernels/include/kernels/flat_kernels.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
#ifndef _FLEXFLOW_OPS_KERNELS_FLAT_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_FLAT_KERNELS_H

#include "kernels/accessor.h"
#include "kernels/device.h"

namespace FlexFlow {

class FlatPerDeviceState : public PerDeviceOpState {
public:
FlatPerDeviceState(FFHandler handle) : PerDeviceOpState(handle){};
};

namespace Kernels {
namespace Flat {

void forward_kernel(ffStream_t stream,
float const *input_ptr,
float *output_ptr,
size_t num_elements);
GenericTensorAccessorR input,
float *output_ptr);
void backward_kernel(ffStream_t stream,
GenericTensorAccessorR input,
float *input_grad_ptr,
float const *output_grad_ptr,
size_t num_elements);
float const *output_grad_ptr);

} // namespace Flat
} // namespace Kernels
Expand Down
Loading

0 comments on commit 8e6d60f

Please sign in to comment.