Skip to content

Commit

Permalink
added tril_operation
Browse files Browse the repository at this point in the history
  • Loading branch information
Karame committed May 4, 2021
1 parent 059d402 commit 8d40f2d
Show file tree
Hide file tree
Showing 4 changed files with 445 additions and 143 deletions.
55 changes: 38 additions & 17 deletions phylanx/plugins/matrixops/triu_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,51 +22,64 @@
#include <vector>

namespace phylanx { namespace execution_tree { namespace primitives {
/// \brief Return an N x M matrix with ones on the k-th diagonal and
/// zeros elsewhere.
/// \param N Number of rows in the output.
/// \param M Optional. Number of columns in the output. If None, defaults
/// to N.
/// \param k Optional. Index of the diagonal: 0 (the default) refers to the
/// main diagonal, a positive value refers to an upper diagonal,
/// and a negative value to a lower diagonal.
/// \param dtype Optional. The data-type of the returned array (default:
/// 'float')

class triu_operation
: public primitive_component_base
, public std::enable_shared_from_this<triu_operation>
{

public:
enum tri_mode
{
tri_mode_up, // triu
tri_mode_low // tril
};

protected:
hpx::future<primitive_argument_type> eval(
primitive_arguments_type const& operands,
primitive_arguments_type const& args,
eval_context ctx) const override;

public:
static match_pattern_type const match_data;
static std::vector<match_pattern_type> const match_data;

triu_operation() = default;

triu_operation(primitive_arguments_type&& operands,
std::string const& name, std::string const& codename);

private:

template <typename T>
primitive_argument_type triu2d(
ir::node_data<T>&& arg, std::int64_t k) const;
ir::node_data<T>&& arg, std::int64_t k) const;

primitive_argument_type triu2d(
primitive_argument_type&& arg, std::int64_t k) const;
primitive_argument_type&& arg, std::int64_t k) const;

template <typename T>
primitive_argument_type triu3d(
ir::node_data<T>&& arg, std::int64_t k) const;
ir::node_data<T>&& arg, std::int64_t k) const;

primitive_argument_type triu3d(
primitive_argument_type&& arg, std::int64_t k) const;
primitive_argument_type&& arg, std::int64_t k) const;



template <typename T>
primitive_argument_type tril2d(
ir::node_data<T>&& arg, std::int64_t k) const;

primitive_argument_type tril2d(
primitive_argument_type&& arg, std::int64_t k) const;

template <typename T>
primitive_argument_type tril3d(
ir::node_data<T>&& arg, std::int64_t k) const;

primitive_argument_type tril3d(
primitive_argument_type&& arg, std::int64_t k) const;

tri_mode mode_;
};

inline primitive create_triu_operation(hpx::id_type const& locality,
Expand All @@ -76,6 +89,14 @@ namespace phylanx { namespace execution_tree { namespace primitives {
return create_primitive_component(
locality, "triu", std::move(operands), name, codename);
}

inline primitive create_tril_operation(hpx::id_type const& locality,
primitive_arguments_type&& operands, std::string const& name = "",
std::string const& codename = "")
{
return create_primitive_component(
locality, "tril", std::move(operands), name, codename);
}
}}}

#endif
4 changes: 3 additions & 1 deletion src/plugins/matrixops/matrixops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ PHYLANX_REGISTER_PLUGIN_FACTORY(tile_operation_plugin,
PHYLANX_REGISTER_PLUGIN_FACTORY(transpose_operation_plugin,
phylanx::execution_tree::primitives::transpose_operation::match_data);
PHYLANX_REGISTER_PLUGIN_FACTORY(triu_operation_plugin,
phylanx::execution_tree::primitives::triu_operation::match_data);
phylanx::execution_tree::primitives::triu_operation::match_data[0]);
PHYLANX_REGISTER_PLUGIN_FACTORY(tril_operation_plugin,
phylanx::execution_tree::primitives::triu_operation::match_data[1]);
PHYLANX_REGISTER_PLUGIN_FACTORY(tuple_slicing_operation_plugin,
phylanx::execution_tree::primitives::slicing_operation::match_data[3]);
PHYLANX_REGISTER_PLUGIN_FACTORY(unique_operation_plugin,
Expand Down
Loading

0 comments on commit 8d40f2d

Please sign in to comment.