Skip to content

Commit

Permalink
Added more link/activation functions (tanh, softplus, leaky relu) for #4
Browse files Browse the repository at this point in the history
. More could be added, so it stays open.
  • Loading branch information
aromanro committed Mar 11, 2023
1 parent 1a3bcd5 commit 7c789f9
Showing 1 changed file with 145 additions and 0 deletions.
145 changes: 145 additions & 0 deletions MachineLearning/MachineLearning/LinkFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,88 @@ template<> class SigmoidFunction<double, double>
double beta;
};

template<typename InputOutputType, typename WeightsType> class TanhFunction
{
public:
TanhFunction(int size = 1)
{
}

const InputOutputType operator()(const InputOutputType& input) const
{
return input.tanh();
}

const InputOutputType derivative(const InputOutputType& input) const
{
const InputOutputType fx = operator()(input);

return 1. - fx.cwiseProduct(fx);
}
};

template<> class TanhFunction<double, double>
{
public:
TanhFunction()
{
}

const double operator()(const double& input) const
{
return tanh(input);
}

const double derivative(const double& input) const
{
const double fx = operator()(input);

return 1. - fx * fx;
}
};

template<typename InputOutputType, typename WeightsType> class SoftplusFunction
{
public:
SoftplusFunction(int size = 1)
{
}

const InputOutputType operator()(const InputOutputType& input) const
{
return input.exp() + 1.;
}

const InputOutputType derivative(const InputOutputType& input) const
{
const InputOutputType fx = operator()(-input);

return fx.cwiseInverse();
}
};

template<> class SoftplusFunction<double, double>
{
public:
SoftplusFunction()
{
}

const double operator()(const double& input) const
{
return 1. + exp(input);
}

const double derivative(const double& input) const
{
const double fx = operator()(-input);

return 1. / fx;
}
};



template<typename InputOutputType> class RELUFunction
{
public:
Expand Down Expand Up @@ -143,3 +225,66 @@ template<> class RELUFunction<double>
return (input < 0) ? 0 : 1;
}
};


template<typename InputOutputType> class LeakyRELUFunction
{
public:
LeakyRELUFunction()
{
}

void setParams(double a)
{
alpha = a;
}

const InputOutputType operator()(const InputOutputType& input) const
{
InputOutputType out = input;

for (unsigned int i = 0; i < out.size(); ++i)
out(i) *= (out(i) < 0) ? alpha : 1.;

return out;
}

const InputOutputType derivative(const InputOutputType& input) const
{
InputOutputType out = input;

for (unsigned int i = 0; i < out.size(); ++i)
out(i) = (out(i) < 0) ? alpha : 1.;

return out;
}

protected:
double alpha = 0.01;
};

template<> class LeakyRELUFunction<double>
{
public:
LeakyRELUFunction()
{
}

void setParams(double a)
{
alpha = a;
}

const double operator()(const double& input) const
{
return ((input < 0) ? alpha : 1.) * input;
}

const double derivative(const double& input) const
{
return (input < 0) ? alpha : 1.;
}

protected:
double alpha = 0.01;
};

0 comments on commit 7c789f9

Please sign in to comment.