diff --git a/code/ndarray.c b/code/ndarray.c index ae219d28..f51d1bd0 100644 --- a/code/ndarray.c +++ b/code/ndarray.c @@ -2084,24 +2084,51 @@ MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose); #if NDARRAY_HAS_RESHAPE mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) { ndarray_obj_t *source = MP_OBJ_TO_PTR(oin); - if(!mp_obj_is_type(_shape, &mp_type_tuple)) { - mp_raise_TypeError(translate("shape must be a tuple")); + if(!mp_obj_is_type(_shape, &mp_type_tuple) && !mp_obj_is_int(_shape)) { + mp_raise_TypeError(translate("shape must be integer or tuple of integers")); + } + + mp_obj_tuple_t *shape; + + if(mp_obj_is_int(_shape)) { + mp_obj_t *items = m_new(mp_obj_t, 1); + items[0] = _shape; + shape = mp_obj_new_tuple(1, items); + } else { + shape = MP_OBJ_TO_PTR(_shape); } - mp_obj_tuple_t *shape = MP_OBJ_TO_PTR(_shape); if(shape->len > ULAB_MAX_DIMS) { mp_raise_ValueError(translate("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS))); } - size_t *new_shape = m_new0(size_t, ULAB_MAX_DIMS); size_t new_length = 1; - for(uint8_t i=0; i < shape->len; i++) { - new_shape[ULAB_MAX_DIMS - i - 1] = mp_obj_get_int(shape->items[shape->len - i - 1]); - new_length *= new_shape[ULAB_MAX_DIMS - i - 1]; + size_t *new_shape = m_new0(size_t, ULAB_MAX_DIMS); + uint8_t unknown_dim = 0; + uint8_t unknown_index = 0; + + for(uint8_t i = 0; i < shape->len; i++) { + int32_t ax_len = mp_obj_get_int(shape->items[shape->len - i - 1]); + if(ax_len >= 0) { + new_shape[ULAB_MAX_DIMS - i - 1] = (size_t)ax_len; + new_length *= new_shape[ULAB_MAX_DIMS - i - 1]; + } else { + unknown_dim++; + unknown_index = ULAB_MAX_DIMS - i - 1; + } + } + + if(unknown_dim > 1) { + mp_raise_ValueError(translate("can only specify one unknown dimension")); + } else if(unknown_dim == 1) { + new_shape[unknown_index] = source->len / new_length; + new_length = source->len; } + if(source->len != new_length) { - mp_raise_ValueError(translate("input and output shapes are not compatible")); + mp_raise_ValueError(translate("cannot reshape array")); } + ndarray_obj_t *ndarray; if(ndarray_is_dense(source)) { int32_t *new_strides = strides_from_shape(new_shape, source->dtype); @@ -2118,7 +2145,11 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) { if(inplace) { mp_raise_ValueError(translate("cannot assign new shape")); } - ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype); + if(mp_obj_is_type(_shape, &mp_type_tuple)) { + ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype); + } else { + ndarray = ndarray_new_linear_array(source->len, source->dtype); + } ndarray_copy_array(source, ndarray, 0); } return MP_OBJ_FROM_PTR(ndarray); diff --git a/code/ulab.c b/code/ulab.c index 02b7cb3a..a7842569 100644 --- a/code/ulab.c +++ b/code/ulab.c @@ -33,7 +33,7 @@ #include "user/user.h" #include "utils/utils.h" -#define ULAB_VERSION 6.0.10 +#define ULAB_VERSION 6.0.11 #define xstr(s) str(s) #define str(s) #s diff --git a/docs/ulab-change-log.md b/docs/ulab-change-log.md index efe55e7b..910055ff 100644 --- a/docs/ulab-change-log.md +++ b/docs/ulab-change-log.md @@ -1,5 +1,11 @@ Sat, 6 May 2023 +version 6.0.11 + + .reshape can now interpret unknown shape dimension + +Sat, 6 May 2023 + version 6.0.10 fix binary division