Skip to content

Commit

Permalink
try at a dtype fix
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Dec 3, 2024
1 parent d9252d0 commit 0dc31e3
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions sklearnex/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,43 @@ def validate_data(
# _finite_keyword provides backward compatability for `force_all_finite`
ensure_all_finite = kwargs.pop("ensure_all_finite", True)
kwargs[_finite_keyword] = False
kwargs["validate_separately"] = True

out = _sklearn_validate_data(
_estimator,
X=X,
y=y,
**kwargs,
)

check_x = not isinstance(X, str) or X != "no_validation"
check_y = not (y is None or isinstance(y, str) and y == "no_validation")

if ensure_all_finite:
# run local finite check
allow_nan = ensure_all_finite == "allow-nan"
arg = iter(out if isinstance(out, tuple) else (out,))
if not isinstance(X, str) or X != "no_validation":
if check_x:
assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X")
if not (y is None or isinstance(y, str) and y == "no_validation"):
if check_y:
assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y")

if check_y and "dtype" in kwargs:
# validate_data does not do full dtype conversions, as it uses check_X_y
# oneDAL can make tables from [int32, int64, float32, float64], requiring
# a dtype check and conversion. This will query the array_namespace and
# convert y as necessary. This is done after assert_all_finite, because
# int y arrays do not need to finite check, and this will lead to a speedup
# in comparison to sklearn
dtype = kwargs["dtype"]
if not isinstance(dtype, (tuple, list)):
dtype = tuple(dtype)

outx, outy = out if check_x else (None, out)
if outy.dtype not in dtype:
yp, _ = get_namespace(outy)
outy = yp.astype(outy, dtype=dtype[0])
out = (outx, outy) if check_x else outy

return out


Expand Down

0 comments on commit 0dc31e3

Please sign in to comment.