Skip to content

Commit

Permalink
Applying review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Nov 28, 2024
1 parent 494b841 commit a424e41
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 50 deletions.
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/statistics/bincount.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@

#pragma once

#include <dpctl4pybind11.hpp>
#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

#include "dispatch_table.hpp"
#include "dpctl4pybind11.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

Expand Down
14 changes: 6 additions & 8 deletions dpnp/backend/extensions/statistics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,16 @@ struct IsNan
static bool isnan(const T &v)
{
if constexpr (type_utils::is_complex_v<T>) {
const auto real1 = std::real(v);
const auto imag1 = std::imag(v);

using vT = typename T::value_type;

const vT real1 = std::real(v);
const vT imag1 = std::imag(v);

return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
}
else {
if constexpr (std::is_floating_point_v<T> ||
std::is_same_v<T, sycl::half>) {
return sycl::isnan(v);
}
else if constexpr (std::is_floating_point_v<T> ||
std::is_same_v<T, sycl::half>) {
return sycl::isnan(v);
}

return false;
Expand Down
2 changes: 2 additions & 0 deletions dpnp/backend/extensions/statistics/histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

#pragma once

#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

#include "dispatch_table.hpp"
#include "dpctl4pybind11.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/statistics/histogramdd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ struct HistogramddF
};

template <typename T, typename HistType = size_t>
using HistogramddF2 = HistogramddF<T, T, HistType>;
using HistogramddF_ = HistogramddF<T, T, HistType>;

using SupportedTypes =
std::tuple<std::tuple<uint64_t, float>,
Expand All @@ -268,7 +268,7 @@ using SupportedTypes =

Histogramdd::Histogramdd() : dispatch_table("sample", "histogram")
{
dispatch_table.populate_dispatch_table<SupportedTypes, HistogramddF2>();
dispatch_table.populate_dispatch_table<SupportedTypes, HistogramddF_>();
}

std::tuple<sycl::event, sycl::event> Histogramdd::call(
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/statistics/histogramdd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@

#pragma once

#include "dispatch_table.hpp"
#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

// dpctl tensor headers
#include "dpctl4pybind11.hpp"

#include "dispatch_table.hpp"

namespace statistics
{
namespace histogram
Expand Down
50 changes: 14 additions & 36 deletions dpnp/dpnp_iface_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,32 +235,6 @@ def _get_bin_edges(a, bins, range, usm_type):
return bin_edges, None


def _normalize_array(a, dtype, usm_type=None):
if usm_type is None:
usm_type = a.usm_type

try:
return dpnp.asarray(
a,
dtype=dtype,
usm_type=usm_type,
sycl_queue=a.sycl_queue,
order="C",
copy=False,
)
except ValueError:
pass

return dpnp.asarray(
a,
dtype=dtype,
usm_type=usm_type,
sycl_queue=a.sycl_queue,
order="C",
copy=True,
)


def _bincount_validate(x, weights, minlength):
if x.ndim > 1:
raise ValueError("object too deep for desired array")
Expand Down Expand Up @@ -426,16 +400,16 @@ def bincount(x, weights=None, minlength=None):
"supported types"
)

x_casted = _normalize_array(x, dtype=x_casted_dtype)
x_casted = dpnp.asarray(x, dtype=x_casted_dtype, order="C")

if weights is not None:
weights_casted = _normalize_array(weights, dtype=ntype_casted)
weights_casted = dpnp.asarray(weights, dtype=ntype_casted, order="C")

n_casted = _bincount_run_native(
x_casted, weights_casted, minlength, ntype_casted, usm_type
)

n = _normalize_array(n_casted, dtype=ntype, usm_type=usm_type)
n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type, order="C")

return n

Expand Down Expand Up @@ -643,10 +617,12 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
"supported types"
)

a_casted = _normalize_array(a, a_bin_dtype)
bin_edges_casted = _normalize_array(bin_edges, a_bin_dtype)
a_casted = dpnp.asarray(a, dtype=a_bin_dtype, order="C")
bin_edges_casted = dpnp.asarray(bin_edges, dtype=a_bin_dtype, order="C")
weights_casted = (
_normalize_array(weights, hist_dtype) if weights is not None else None
dpnp.asarray(weights, dtype=hist_dtype, order="C")
if weights is not None
else None
)

# histogram implementation uses atomics, but atomics doesn't work with
Expand Down Expand Up @@ -681,7 +657,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
)
_manager.add_event_pair(mem_ev, ht_ev)

n = _normalize_array(n_casted, dtype=ntype, usm_type=usm_type)
n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type, order="C")

if density:
db = dpnp.astype(
Expand Down Expand Up @@ -1055,9 +1031,11 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False):

_histdd_check_monotonicity(bin_edges_view_list)

sample_ = _normalize_array(sample, sample_dtype)
sample_ = dpnp.asarray(sample, dtype=sample_dtype, order="C")
weights_ = (
_normalize_array(weights, hist_dtype) if weights is not None else None
dpnp.asarray(weights, dtype=hist_dtype, order="C")
if weights is not None
else None
)
n = _histdd_run_native(
sample_,
Expand All @@ -1069,7 +1047,7 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False):
)

expexted_hist_dtype = _histdd_hist_dtype(queue, weights)
n = _normalize_array(n, expexted_hist_dtype, usm_type)
n = dpnp.asarray(n, dtype=expexted_hist_dtype, usm_type=usm_type, order="C")

if density:
# calculate the probability density function
Expand Down
7 changes: 6 additions & 1 deletion dpnp/tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,15 @@ def test_linspace_data(self, dtype):
assert_array_equal(result_hist, expected_hist)

@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_invalid_bin(self, xp):
def test_invalid_bin_float(self, xp):
a = xp.array([[1, 2]])
assert_raises(ValueError, xp.histogramdd, a, bins=0.1)

@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_invalid_bin_2d_array(self, xp):
a = xp.array([[1, 2]])
assert_raises(ValueError, xp.histogramdd, a, bins=[[[10]]])

@pytest.mark.parametrize(
"bins",
[
Expand Down

0 comments on commit a424e41

Please sign in to comment.