diff --git a/phylanx/plugins/matrixops/triu_operation.hpp b/phylanx/plugins/matrixops/triu_operation.hpp index 8a643d315..3b44a318e 100644 --- a/phylanx/plugins/matrixops/triu_operation.hpp +++ b/phylanx/plugins/matrixops/triu_operation.hpp @@ -22,20 +22,19 @@ #include namespace phylanx { namespace execution_tree { namespace primitives { - /// \brief Return an N x M matrix with ones on the k-th diagonal and - /// zeros elsewhere. - /// \param N Number of rows in the output. - /// \param M Optional. Number of columns in the output. If None, defaults - /// to N. - /// \param k Optional. Index of the diagonal: 0 (the default) refers to the - /// main diagonal, a positive value refers to an upper diagonal, - /// and a negative value to a lower diagonal. - /// \param dtype Optional. The data-type of the returned array (default: - /// 'float') + class triu_operation : public primitive_component_base , public std::enable_shared_from_this { + + public: + enum tri_mode + { + tri_mode_up, // triu + tri_mode_low // tril + }; + protected: hpx::future eval( primitive_arguments_type const& operands, @@ -43,7 +42,7 @@ namespace phylanx { namespace execution_tree { namespace primitives { eval_context ctx) const override; public: - static match_pattern_type const match_data; + static std::vector const match_data; triu_operation() = default; @@ -51,22 +50,36 @@ namespace phylanx { namespace execution_tree { namespace primitives { std::string const& name, std::string const& codename); private: - template primitive_argument_type triu2d( - ir::node_data&& arg, std::int64_t k) const; + ir::node_data&& arg, std::int64_t k) const; primitive_argument_type triu2d( - primitive_argument_type&& arg, std::int64_t k) const; + primitive_argument_type&& arg, std::int64_t k) const; template primitive_argument_type triu3d( - ir::node_data&& arg, std::int64_t k) const; + ir::node_data&& arg, std::int64_t k) const; primitive_argument_type triu3d( - primitive_argument_type&& arg, std::int64_t k) const; + primitive_argument_type&& arg, std::int64_t k) const; + - + template + primitive_argument_type tril2d( + ir::node_data&& arg, std::int64_t k) const; + + primitive_argument_type tril2d( + primitive_argument_type&& arg, std::int64_t k) const; + + template + primitive_argument_type tril3d( + ir::node_data&& arg, std::int64_t k) const; + + primitive_argument_type tril3d( + primitive_argument_type&& arg, std::int64_t k) const; + + tri_mode mode_; }; inline primitive create_triu_operation(hpx::id_type const& locality, @@ -76,6 +89,14 @@ namespace phylanx { namespace execution_tree { namespace primitives { return create_primitive_component( locality, "triu", std::move(operands), name, codename); } + + inline primitive create_tril_operation(hpx::id_type const& locality, + primitive_arguments_type&& operands, std::string const& name = "", + std::string const& codename = "") + { + return create_primitive_component( + locality, "tril", std::move(operands), name, codename); + } }}} #endif diff --git a/src/plugins/matrixops/matrixops.cpp b/src/plugins/matrixops/matrixops.cpp index fbe07e968..578369499 100644 --- a/src/plugins/matrixops/matrixops.cpp +++ b/src/plugins/matrixops/matrixops.cpp @@ -131,7 +131,9 @@ PHYLANX_REGISTER_PLUGIN_FACTORY(tile_operation_plugin, PHYLANX_REGISTER_PLUGIN_FACTORY(transpose_operation_plugin, phylanx::execution_tree::primitives::transpose_operation::match_data); PHYLANX_REGISTER_PLUGIN_FACTORY(triu_operation_plugin, - phylanx::execution_tree::primitives::triu_operation::match_data); + phylanx::execution_tree::primitives::triu_operation::match_data[0]); +PHYLANX_REGISTER_PLUGIN_FACTORY(tril_operation_plugin, + phylanx::execution_tree::primitives::triu_operation::match_data[1]); PHYLANX_REGISTER_PLUGIN_FACTORY(tuple_slicing_operation_plugin, phylanx::execution_tree::primitives::slicing_operation::match_data[3]); PHYLANX_REGISTER_PLUGIN_FACTORY(unique_operation_plugin, diff --git a/src/plugins/matrixops/triu_operation.cpp b/src/plugins/matrixops/triu_operation.cpp index 829541433..2136f756c 100644 --- a/src/plugins/matrixops/triu_operation.cpp +++ b/src/plugins/matrixops/triu_operation.cpp @@ -8,10 +8,10 @@ #include #include +#include #include #include #include -#include #include #include @@ -27,14 +27,10 @@ namespace phylanx { namespace execution_tree { namespace primitives { /////////////////////////////////////////////////////////////////////////// - match_pattern_type const triu_operation::match_data = + std::vector const triu_operation::match_data = { match_pattern_type{"triu", - std::vector{R"( - triu(_1, - __arg(_2_k, 0) - ) - )"}, + std::vector{"triu(_1,_2)", "triu(_1)"}, &create_triu_operation, &create_primitive, R"( a, k Args: @@ -47,13 +43,41 @@ namespace phylanx { namespace execution_tree { namespace primitives Returns: Return a copy of an array with the elements below the k-th diagonal zeroed.)" - } + }, + + match_pattern_type{"tril", + std::vector{"tril(_1,_2)", "tril(_1)"}, + &create_triu_operation, &create_primitive, R"( + a, k + Args: + + a (array) : a matrix or a tensor + k (optional, integer) : index of the diagonal: 0 (the default) + refers to the main diagonal, a positive value refers to an + upper diagonal, and a negative value to a lower diagonal. + + Returns: + + Return a copy of an array with the elements above the k-th diagonal zeroed.)" + }, }; /////////////////////////////////////////////////////////////////////////// + triu_operation::tri_mode extract_tri_mode(std::string const& name) + { + triu_operation::tri_mode result = triu_operation::tri_mode_up; + + if (name.find("tril") != std::string::npos) + { + result = triu_operation::tri_mode_low; + } + return result; + } + triu_operation::triu_operation(primitive_arguments_type&& operands, std::string const& name, std::string const& codename) : primitive_component_base(std::move(operands), name, codename) + , mode_(extract_tri_mode(name_)) { } /////////////////////////////////////////////////////////////////////////// @@ -66,51 +90,43 @@ namespace phylanx { namespace execution_tree { namespace primitives std::int64_t columns = m.columns(); std::int64_t rows = m.rows(); - + if (!arg.is_ref()) { - - if (k >= columns) { m = static_cast(0); - return primitive_argument_type{ir::node_data{std::move(arg)}}; + return primitive_argument_type{ + ir::node_data{std::move(arg)}}; } - else if (k <= 1 - rows) { - - return primitive_argument_type{ir::node_data{std::move(arg)}}; + return primitive_argument_type{ + ir::node_data{std::move(arg)}}; } - - for (std::int64_t i = 1-rows; i != k ; ++i) + for (std::int64_t i = 1 - rows; i != k; ++i) { - blaze::band(m, i) = static_cast(0) ; + blaze::band(m, i) = static_cast(0); } return primitive_argument_type{ir::node_data{std::move(arg)}}; - } - blaze::DynamicMatrix result(rows, columns, static_cast(0)); if (k >= columns) { return primitive_argument_type{ir::node_data{std::move(result)}}; } - else if (k <= 1 - rows) { result = m; return primitive_argument_type{ir::node_data{std::move(result)}}; } - for (std::int64_t i = k; i != columns; ++i) { blaze::band(result, i) = blaze::band(m, i); } return primitive_argument_type{ir::node_data{std::move(result)}}; - } /////////////////////////////////////////////////////////////////////////// @@ -148,9 +164,8 @@ namespace phylanx { namespace execution_tree { namespace primitives "be numeric data types")); } - /////////////////////////////////////////////////////////////////////////// - - template + /////////////////////////////////////////////////////////////////////////// + template primitive_argument_type triu_operation::triu3d( ir::node_data&& arg, std::int64_t k) const @@ -160,57 +175,54 @@ namespace phylanx { namespace execution_tree { namespace primitives std::int64_t columns = t.columns(); std::int64_t rows = t.rows(); std::size_t pages = t.pages(); - - - // if (!arg.is_ref()) - // { - - - // if (k >= columns) - // { - // m = static_cast(0); - // return primitive_argument_type{ir::node_data{std::move(arg)}}; - // } - - // else if (k <= 1 - rows) - // { - - // return primitive_argument_type{ir::node_data{std::move(arg)}}; - // } - - // for (std::int64_t i = 1-rows; i != k ; ++i) - // { - // blaze::band(m, i) = static_cast(0) ; - // } - // return primitive_argument_type{ir::node_data{std::move(arg)}}; - - // } - - + + if (!arg.is_ref()) + { + if (k >= columns) + { + t = static_cast(0); + return primitive_argument_type{ + ir::node_data{std::move(arg)}}; + } + else if (k <= 1 - rows) + { + return primitive_argument_type{ + ir::node_data{std::move(arg)}}; + } + for (std::size_t p = 0; p != pages; ++p) + { + auto slice = blaze::pageslice(t, p); + + for (std::int64_t i = 1 - rows; i != k; ++i) + { + blaze::band(slice, i) = static_cast(0); + } + } + return primitive_argument_type{ir::node_data{std::move(arg)}}; + } + blaze::DynamicTensor result(pages, rows, columns, static_cast(0)); if (k >= columns) { return primitive_argument_type{ir::node_data{std::move(result)}}; } - else if (k <= 1 - rows) { result = t; return primitive_argument_type{ir::node_data{std::move(result)}}; } - for (std:: size_t p = 0; p!=pages; ++p) + for (std::size_t p = 0; p != pages; ++p) { - auto slice = blaze::pageslice(t,p); - auto result_slice = blaze::pageslice(result,p); + auto slice = blaze::pageslice(t, p); + auto result_slice = blaze::pageslice(result, p); - for (std::int64_t i = k; i != columns; ++i) - { - blaze::band(result_slice, i) = blaze::band(slice, i); - } + for (std::int64_t i = k; i != columns; ++i) + { + blaze::band(result_slice, i) = blaze::band(slice, i); + } } return primitive_argument_type{ir::node_data{std::move(result)}}; - } /////////////////////////////////////////////////////////////////////////// @@ -248,25 +260,204 @@ namespace phylanx { namespace execution_tree { namespace primitives "be numeric data types")); } + /////////////////////////////////////////////////////////////////////////// + template + primitive_argument_type triu_operation::tril2d( + ir::node_data&& arg, std::int64_t k) const + + { + auto m = arg.matrix(); + + std::int64_t columns = m.columns(); + std::int64_t rows = m.rows(); + + if (!arg.is_ref()) + { + if (k <= -rows) + { + m = static_cast(0); + return primitive_argument_type{ + ir::node_data{std::move(arg)}}; + } + else if (k >= columns - 1) + { + return primitive_argument_type{ + ir::node_data{std::move(arg)}}; + } + for (std::int64_t i = k + 1; i != columns; ++i) + { + blaze::band(m, i) = static_cast(0); + } + return primitive_argument_type{ir::node_data{std::move(arg)}}; + } + + blaze::DynamicMatrix result(rows, columns, static_cast(0)); + + if (k <= -rows) + { + return primitive_argument_type{ir::node_data{std::move(result)}}; + } + else if (k >= columns - 1) + { + result = m; + return primitive_argument_type{ir::node_data{std::move(result)}}; + } + for (std::int64_t i = k; i != -rows; --i) + { + blaze::band(result, i) = blaze::band(m, i); + } + return primitive_argument_type{ir::node_data{std::move(result)}}; + } + + /////////////////////////////////////////////////////////////////////////// + primitive_argument_type triu_operation::tril2d( + primitive_argument_type&& arg, std::int64_t k) const + { + switch (extract_common_type(arg)) + { + case node_data_type_bool: + return tril2d( + extract_boolean_value_strict(std::move(arg), name_, codename_), + k); + + case node_data_type_int64: + return tril2d( + extract_integer_value_strict(std::move(arg), name_, codename_), + k); + + case node_data_type_double: + return tril2d( + extract_numeric_value_strict(std::move(arg), name_, codename_), + k); + + case node_data_type_unknown: + return tril2d( + extract_numeric_value(std::move(arg), name_, codename_), k); + + default: + break; + } + + HPX_THROW_EXCEPTION(hpx::bad_parameter, "triu_operation::tril2d", + generate_error_message( + "the tri primitive requires for all arguments to " + "be numeric data types")); + } + + /////////////////////////////////////////////////////////////////////////// + template + primitive_argument_type triu_operation::tril3d( + ir::node_data&& arg, std::int64_t k) const + { + auto t = arg.tensor(); + + std::int64_t columns = t.columns(); + std::int64_t rows = t.rows(); + std::size_t pages = t.pages(); + + if (!arg.is_ref()) + { + if (k <= -rows) + { + t = static_cast(0); + return primitive_argument_type{ + ir::node_data{std::move(arg)}}; + } + else if (k >= columns - 1) + { + return primitive_argument_type{ + ir::node_data{std::move(arg)}}; + } + for (std::size_t p = 0; p != pages; ++p) + { + auto slice = blaze::pageslice(t, p); + + for (std::int64_t i = k + 1; i != columns; ++i) + { + blaze::band(slice, i) = static_cast(0); + } + } + return primitive_argument_type{ir::node_data{std::move(arg)}}; + } + + blaze::DynamicTensor result(pages, rows, columns, static_cast(0)); + + if (k <= -rows) + { + return primitive_argument_type{ir::node_data{std::move(result)}}; + } + else if (k >= columns - 1) + { + result = t; + return primitive_argument_type{ir::node_data{std::move(result)}}; + } + for (std::size_t p = 0; p != pages; ++p) + { + auto slice = blaze::pageslice(t, p); + auto result_slice = blaze::pageslice(result, p); + + for (std::int64_t i = k; i != -rows; --i) + { + blaze::band(result_slice, i) = blaze::band(slice, i); + } + } + return primitive_argument_type{ir::node_data{std::move(result)}}; + } + + /////////////////////////////////////////////////////////////////////////// + primitive_argument_type triu_operation::tril3d( + primitive_argument_type&& arg, std::int64_t k) const + { + switch (extract_common_type(arg)) + { + case node_data_type_bool: + return tril3d( + extract_boolean_value_strict(std::move(arg), name_, codename_), + k); + + case node_data_type_int64: + return tril3d( + extract_integer_value_strict(std::move(arg), name_, codename_), + k); + + case node_data_type_double: + return tril3d( + extract_numeric_value_strict(std::move(arg), name_, codename_), + k); + + case node_data_type_unknown: + return tril3d( + extract_numeric_value(std::move(arg), name_, codename_), k); + + default: + break; + } + + HPX_THROW_EXCEPTION(hpx::bad_parameter, "triu_operation::tril3d", + generate_error_message( + "the tril primitive requires for all arguments to " + "be numeric data types")); + } + /////////////////////////////////////////////////////////////////////////// hpx::future triu_operation::eval( primitive_arguments_type const& operands, primitive_arguments_type const& args, eval_context ctx) const { if (operands.empty() || operands.size() > 2) - { - HPX_THROW_EXCEPTION(hpx::bad_parameter, - "triu_operation::eval", - util::generate_error_message(hpx::util::format( - "the triu_operation primitive can have one or two operands " - "got {}", operands.size()), + { + HPX_THROW_EXCEPTION(hpx::bad_parameter, "triu_operation::eval", + util::generate_error_message( + hpx::util::format("the triu_operation primitive can have " + "one or two operands " + "got {}", + operands.size()), name_, codename_)); } if (!valid(operands[0])) { - HPX_THROW_EXCEPTION(hpx::bad_parameter, - "triu_operation::eval", + HPX_THROW_EXCEPTION(hpx::bad_parameter, "triu_operation::eval", util::generate_error_message( "the triu_operation primitive requires that the " "arguments given by the operands array are " @@ -274,14 +465,12 @@ namespace phylanx { namespace execution_tree { namespace primitives name_, codename_)); } - if (operands.size()==2 && !valid(operands[1])) + if (operands.size() == 2 && !valid(operands[1])) { - HPX_THROW_EXCEPTION(hpx::bad_parameter, - "triu_operation::eval", + HPX_THROW_EXCEPTION(hpx::bad_parameter, "triu_operation::eval", util::generate_error_message( "the triu_operation primitive requires that the " - "arguments given by the operands array are " - "valid", + "arguments given by the operands array are valid", name_, codename_)); } @@ -308,22 +497,41 @@ namespace phylanx { namespace execution_tree { namespace primitives k = extract_scalar_integer_value( std::move(args[1]), this_->name_, this_->codename_); } - + + if (this_->mode_ == tri_mode_up) + { + switch (extract_numeric_value_dimension( + args[0], this_->name_, this_->codename_)) + { + case 2: + return this_->triu2d(std::move(args[0]), k); + + case 3: + return this_->triu3d(std::move(args[0]), k); + + default: + HPX_THROW_EXCEPTION(hpx::bad_parameter, + "triu_operation::eval", + this_->generate_error_message( + "This operand has unsupported " + "number of dimensions")); + } + } switch (extract_numeric_value_dimension( args[0], this_->name_, this_->codename_)) { case 2: - return this_->triu2d(std::move(args[0]), k); + return this_->tril2d(std::move(args[0]), k); case 3: - return this_->triu3d(std::move(args[0]), k); + return this_->tril3d(std::move(args[0]), k); default: HPX_THROW_EXCEPTION(hpx::bad_parameter, - "diag_operation::eval", + "triu_operation::eval", this_->generate_error_message( - "left hand side operand has unsupported " + "This operand has unsupported " "number of dimensions")); } }), diff --git a/tests/unit/plugins/matrixops/triu_operation.cpp b/tests/unit/plugins/matrixops/triu_operation.cpp old mode 100755 new mode 100644 index 2631dc3c5..fefcc0d5b --- a/tests/unit/plugins/matrixops/triu_operation.cpp +++ b/tests/unit/plugins/matrixops/triu_operation.cpp @@ -34,66 +34,137 @@ void test_triu_operation(std::string const& code, /////////////////////////////////////////////////////////////////////////////// int main(int argc, char* argv[]) { - + // triu 2d test_triu_operation("triu([[13, 42, 33], [101, 12, 65]])", - "[[13, 42, 33], [ 0, 12, 65]]"); + "[[13, 42, 33], [0, 12, 65]]"); test_triu_operation("triu([[13, 42, 33],[101, 12, 65],[50, 60, 70],[21, 22, 23]])", - "[[13, 42, 33],[ 0, 12, 65],[ 0, 0, 70],[ 0, 0, 0]]"); + "[[13, 42, 33],[0, 12, 65],[0, 0, 70],[0, 0, 0]]"); test_triu_operation("triu([[13, 42, 33],[101, 12, 65],[50, 60, 70],[21, 22, 23]],-3)", "[[13, 42, 33],[101, 12, 65],[50, 60, 70],[21, 22, 23]]"); test_triu_operation("triu([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], 2)", - "[[0, 0, 87],[0, 0, 0],[0, 0, 0],[0, 0, 0]]"); + "[[0, 0, 87], [0, 0, 0], [0, 0, 0], [0, 0, 0]]"); test_triu_operation("triu([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], 3)", - "[[0, 0, 0],[0, 0, 0],[0, 0, 0],[0, 0, 0]]"); + "[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]"); test_triu_operation("triu([[13, 42, 33],[101, 12, 65]], 1)", "[[0, 42, 33],[0, 0, 65]]"); test_triu_operation("triu([[13, 42, 33],[101, 12, 65]], -2)", "[[13, 42, 33],[101, 12, 65]]"); - - - - test_triu_operation("triu([[[69, 65, 50],[111, 102, 85]],[[30, 42, 31],[1, 26, 55]]])", - "[[[69, 65, 50],[0, 102, 85]],[[30, 42, 31],[0, 26, 55]]]"); - test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]],[[45, 8, 12],[18, 99, 154],[10,32,98]]], 3)", - "[[[0, 0, 0],[0, 0, 0],[0, 0, 0]],[[0, 0, 0],[0, 0, 0],[0, 0, 0]]]"); - test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]],[[45, 8, 12],[18, 99, 154],[10,32,98]]], 2)", - "[[[0, 0, 12],[0, 0, 0],[0, 0, 0]],[[0, 0, 12],[0, 0, 0],[0, 0, 0]]]"); - test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]],[[45, 8, 12],[18, 99, 154],[10,32,98]]], -2)", - "[[[ 87, 60, 12],[101, 72, 62],[ 21, 64, 56]],[[ 45, 8, 12], [ 18, 99, 154],[ 10, 32, 98]]]"); - test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]],[[45, 8, 12],[18, 99, 154],[10,32,98]]], -1)", - "[[[ 87, 60, 12],[101, 72, 62],[ 0, 64, 56]],[[ 45, 8, 12], [ 18, 99, 154],[ 0, 32, 98]]]"); - - + // triu 3d + test_triu_operation("triu([[[69, 65, 50],[111, 102, 85]],[[30, 42, 31],[1, 26, 55]]] +0)", + "[[[69, 65, 50], [0, 102, 85]], [[30, 42, 31], [0, 26, 55]]]"); + test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]]," + "[[45, 8, 12],[18, 99, 154],[10,32,98]]] +0, 3)", + "[[[0, 0, 0],[0, 0, 0],[0, 0, 0]],[[0, 0, 0],[0, 0, 0],[0, 0, 0]]]"); + test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]]," + "[[45, 8, 12],[18, 99, 154],[10,32,98]]] +0, 2)", + "[[[0, 0, 12],[0, 0, 0],[0, 0, 0]],[[0, 0, 12],[0, 0, 0],[0, 0, 0]]]"); + test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]]," + "[[45, 8, 12],[18, 99, 154],[10,32,98]]] +0, -2)", + "[[[ 87, 60, 12],[101, 72, 62],[ 21, 64, 56]],[[ 45, 8, 12], " + "[ 18, 99, 154],[ 10, 32, 98]]]"); + test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]]," + "[[45, 8, 12],[18, 99, 154],[10,32,98]]] +0, -1)", + "[[[ 87, 60, 12],[101, 72, 62],[ 0, 64, 56]],[[ 45, 8, 12]," + "[18, 99, 154],[ 0, 32, 98]]]"); + test_triu_operation( + "triu([[[69, 65, 50],[111, 102, 85]],[[30, 42, 31],[1, 26, 55]]])", + "[[[69, 65, 50],[0, 102, 85]],[[30, 42, 31],[0, 26, 55]]]"); + test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]], " + "[[45, 8, 12],[18, 99, 154],[10,32,98]]], 3)", + "[[[0, 0, 0],[0, 0, 0],[0, 0, 0]],[[0, 0, 0],[0, 0, 0],[0, 0, 0]]]"); + test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]]," + "[[45, 8, 12],[18, 99, 154],[10,32,98]]], 2)", + "[[[0, 0, 12],[0, 0, 0],[0, 0, 0]],[[0, 0, 12],[0, 0, 0],[0, 0, 0]]]"); + test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]], " + "[[45, 8, 12],[18, 99, 154],[10,32,98]]], -2)", + "[[[ 87, 60, 12],[101, 72, 62],[ 21, 64, 56]],[[ 45, 8, 12], " + "[ 18, 99, 154],[ 10, 32, 98]]]"); + test_triu_operation("triu([[[87, 60, 12],[101, 72, 62],[21, 64, 56]]," + "[[45, 8, 12],[18, 99, 154],[10,32,98]]], -1)", + "[[[ 87, 60, 12],[101, 72, 62],[ 0, 64, 56]],[[ 45, 8, 12]," + "[ 18, 99, 154],[ 0, 32, 98]]]"); /////////////////////////////////////////////////////////////////////////// - - // test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]])", - // "[[14, 0, 0],[10, 1, 0],[45, 79, 91],[24, 22, 31]]"); - // test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], 1)", - // "[[14, 20, 0],[10, 1, 25],[45, 79, 91],[24, 22, 31]]"); - // test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], 2)", - // "[[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]]"); - // test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], -2)", - // "[[ 0, 0, 0],[ 0, 0, 0],[45, 0, 0],[24, 22, 0]]"); - // test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], -4)", - // "[[0, 0, 0],[0, 0, 0],[0, 0, 0],[0, 0, 0]]"); + // tril 2d + test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]])", + "[[14, 0, 0],[10, 1, 0],[45, 79, 91],[24, 22, 31]]"); + test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], 1)", + "[[14, 20, 0],[10, 1, 25],[45, 79, 91],[24, 22, 31]]"); + test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], 2)", + "[[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]]"); + test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], -2)", + "[[ 0, 0, 0],[ 0, 0, 0],[45, 0, 0],[24, 22, 0]]"); + test_triu_operation("tril([[14, 20, 87],[10, 1, 25],[45, 79, 91],[24, 22, 31]], -4)", + "[[0, 0, 0],[0, 0, 0],[0, 0, 0],[0, 0, 0]]"); - - // test_triu_operation("tril([[[96, 61, 7],[42, 34, 90]],[[52, 82, 38],[21, 9, 123]],[[201, 4, 87],[25, 87,99]]])", - // "[[[ 96, 0, 0],[ 42, 34, 0]],[[ 52, 0, 0],[ 21, 9, 0]],[[201, 0, 0],[ 25, 87, 0]]]"); - // test_triu_operation("tril([[[96, 61, 7],[42, 34, 90]],[[52, 82, 38],[21, 9, 123]],[[201, 4, 87],[25, 87,99]]], 2)", - // "[[[ 96, 61, 7],[ 42, 34, 90]],[[ 52, 82, 38],[ 21, 9, 123]],[[201, 4, 87],[ 25, 87, 99]]]"); - // test_triu_operation("tril([[[96, 61, 7],[42, 34, 90]],[[52, 82, 38],[21, 9, 123]],[[201, 4, 87],[25, 87,99]]], -1)", - // "[[[ 0, 0, 0],[42, 0, 0]],[[ 0, 0, 0],[21, 0, 0]],[[ 0, 0, 0],[25, 0, 0]]]"); - // test_triu_operation("tril([[[96, 61, 7],[42, 34, 90]],[[52, 82, 38],[21, 9, 123]],[[201, 4, 87],[25, 87,99]]], -2)", - // "[[[ 0, 0, 0],[0, 0, 0]],[[ 0, 0, 0],[0, 0, 0]],[[ 0, 0, 0],[0, 0, 0]]]"); - - - - + // tril 3d + test_triu_operation("tril([[[96, 61, 7],[42, 34, 90]],[[52, 82, 38]," + "[21, 9, 123]],[[201, 4, 87],[25, 87,99]]])", + "[[[ 96, 0, 0],[ 42, 34, 0]],[[ 52, 0, 0]," + "[ 21, 9, 0]],[[201, 0, 0],[ 25, 87, 0]]]"); + test_triu_operation("tril([[[96, 61, 7],[42, 34, 90]],[[52, 82, 38]," + "[21, 9, 123]],[[201, 4, 87],[25, 87,99]]], 2)", + "[[[ 96, 61, 7],[ 42, 34, 90]],[[ 52, 82, 38]," + "[ 21, 9, 123]],[[201, 4, 87],[ 25, 87, 99]]]"); + test_triu_operation("tril([[[96, 61, 7],[42, 34, 90]],[[52, 82, 38]," + "[21, 9, 123]],[[201, 4, 87],[25, 87,99]]], -1)", + "[[[ 0, 0, 0],[42, 0, 0]],[[ 0, 0, 0]," + "[21, 0, 0]],[[ 0, 0, 0],[25, 0, 0]]]"); + test_triu_operation("tril([[[96, 61, 7],[42, 34, 90]],[[52, 82, 38]," + "[21, 9, 123]],[[201, 4, 87],[25, 87,99]]], -2)", + "[[[ 0, 0, 0],[0, 0, 0]],[[ 0, 0, 0]," + "[0, 0, 0]],[[ 0, 0, 0],[0, 0, 0]]]"); + test_triu_operation( + "tril([[[96, 61, 7, 5],[42, 34, 90, 15],[20, 54, 101, 99]]," + "[[52, 82, 38, 2],[21, 9, 123, 78],[15, 65, 8, 1]]," + "[[201, 4, 87, 65],[25, 87,99, 30],[21, 2, 211, 60]]], 3)", + "[[[ 96, 61, 7, 5],[ 42, 34, 90, 15],[ 20, 54, 101, 99]]," + "[[ 52, 82, 38, 2],[ 21, 9, 123, 78],[ 15, 65, 8, 1]]," + "[[201, 4, 87, 65],[ 25, 87, 99, 30],[ 21, 2, 211, 60]]]"); + test_triu_operation( + "tril([[[96, 61, 7, 5],[42, 34, 90, 15],[20, 54, 101, 99]]," + "[[52, 82, 38, 2],[21, 9, 123, 78],[15, 65, 8, 1]]," + "[[201, 4, 87, 65],[25, 87,99, 30],[21, 2, 211, 60]]], 4)", + "[[[ 96, 61, 7, 5],[ 42, 34, 90, 15],[ 20, 54, 101, 99]]," + "[[ 52, 82, 38, 2],[ 21, 9, 123, 78],[ 15, 65, 8, 1]]," + "[[201, 4, 87, 65],[ 25, 87, 99, 30],[ 21, 2, 211, 60]]]"); + test_triu_operation( + "tril([[[96, 61, 7, 5],[42, 34, 90, 15],[20, 54, 101, 99]]," + "[[52, 82, 38, 2],[21, 9, 123, 78],[15, 65, 8, 1]]," + "[[201, 4, 87, 65],[25, 87,99, 30],[21, 2, 211, 60]]], 2)", + "[[[ 96, 61, 7, 0],[ 42, 34, 90, 15],[ 20, 54, 101, 99]]," + "[[ 52, 82, 38, 0],[ 21, 9, 123, 78],[ 15, 65, 8, 1]]," + "[[201, 4, 87, 0],[ 25, 87, 99, 30],[ 21, 2, 211, 60]]]"); + test_triu_operation( + "tril([[[96, 61, 7, 5],[42, 34, 90, 15],[20, 54, 101, 99]]," + "[[52, 82, 38, 2],[21, 9, 123, 78],[15, 65, 8, 1]]," + "[[201, 4, 87, 65],[25, 87,99, 30],[21, 2, 211, 60]]] )", + "[[[ 96, 0, 0, 0],[ 42, 34, 0, 0],[ 20, 54, 101, 0]]," + "[[ 52, 0, 0, 0],[ 21, 9, 0, 0],[ 15, 65, 8, 0]]," + "[[201, 0, 0, 0],[ 25, 87, 0, 0],[ 21, 2, 211, 0]]]"); + test_triu_operation( + "tril([[[96, 61, 7, 5],[42, 34, 90, 15],[20, 54, 101, 99]]," + "[[52, 82, 38, 2],[21, 9, 123, 78],[15, 65, 8, 1]]," + "[[201, 4, 87, 65],[25, 87,99, 30],[21, 2, 211, 60]]], -3)", + "[[[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0]]," + "[[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0]]," + "[[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0]]]"); + test_triu_operation( + "tril([[[96, 61, 7, 5],[42, 34, 90, 15],[20, 54, 101, 99]]," + "[[52, 82, 38, 2],[21, 9, 123, 78],[15, 65, 8, 1]]," + "[[201, 4, 87, 65],[25, 87,99, 30],[21, 2, 211, 60]]], -4)", + "[[[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0]]," + "[[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0]]," + "[[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0]]]"); + test_triu_operation( + "tril([[[96, 61, 7, 5],[42, 34, 90, 15],[20, 54, 101, 99]]," + "[[52, 82, 38, 2],[21, 9, 123, 78],[15, 65, 8, 1]]," + "[[201, 4, 87, 65],[25, 87,99, 30],[21, 2, 211, 60]]], -2)", + "[[[0, 0, 0, 0],[0, 0, 0, 0],[20, 0, 0, 0]]," + "[[0, 0, 0, 0],[0, 0, 0, 0],[15, 0, 0, 0]]," + "[[0, 0, 0, 0],[0, 0, 0, 0],[21, 0, 0, 0]]]"); hpx::finalize(); - return hpx::util::report_errors(); }