Skip to content

Commit

Permalink
implement unknown shape dimension in reshape (#612)
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z authored May 6, 2023
1 parent 412b13f commit beda4c1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
49 changes: 40 additions & 9 deletions code/ndarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -2084,24 +2084,51 @@ MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
#if NDARRAY_HAS_RESHAPE
mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
ndarray_obj_t *source = MP_OBJ_TO_PTR(oin);
if(!mp_obj_is_type(_shape, &mp_type_tuple)) {
mp_raise_TypeError(translate("shape must be a tuple"));
if(!mp_obj_is_type(_shape, &mp_type_tuple) && !mp_obj_is_int(_shape)) {
mp_raise_TypeError(translate("shape must be integer or tuple of integers"));
}

mp_obj_tuple_t *shape;

if(mp_obj_is_int(_shape)) {
mp_obj_t *items = m_new(mp_obj_t, 1);
items[0] = _shape;
shape = mp_obj_new_tuple(1, items);
} else {
shape = MP_OBJ_TO_PTR(_shape);
}

mp_obj_tuple_t *shape = MP_OBJ_TO_PTR(_shape);
if(shape->len > ULAB_MAX_DIMS) {
mp_raise_ValueError(translate("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS)));
}
size_t *new_shape = m_new0(size_t, ULAB_MAX_DIMS);

size_t new_length = 1;
for(uint8_t i=0; i < shape->len; i++) {
new_shape[ULAB_MAX_DIMS - i - 1] = mp_obj_get_int(shape->items[shape->len - i - 1]);
new_length *= new_shape[ULAB_MAX_DIMS - i - 1];
size_t *new_shape = m_new0(size_t, ULAB_MAX_DIMS);
uint8_t unknown_dim = 0;
uint8_t unknown_index = 0;

for(uint8_t i = 0; i < shape->len; i++) {
int32_t ax_len = mp_obj_get_int(shape->items[shape->len - i - 1]);
if(ax_len >= 0) {
new_shape[ULAB_MAX_DIMS - i - 1] = (size_t)ax_len;
new_length *= new_shape[ULAB_MAX_DIMS - i - 1];
} else {
unknown_dim++;
unknown_index = ULAB_MAX_DIMS - i - 1;
}
}

if(unknown_dim > 1) {
mp_raise_ValueError(translate("can only specify one unknown dimension"));
} else if(unknown_dim == 1) {
new_shape[unknown_index] = source->len / new_length;
new_length = source->len;
}

if(source->len != new_length) {
mp_raise_ValueError(translate("input and output shapes are not compatible"));
mp_raise_ValueError(translate("cannot reshape array"));
}

ndarray_obj_t *ndarray;
if(ndarray_is_dense(source)) {
int32_t *new_strides = strides_from_shape(new_shape, source->dtype);
Expand All @@ -2118,7 +2145,11 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
if(inplace) {
mp_raise_ValueError(translate("cannot assign new shape"));
}
ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype);
if(mp_obj_is_type(_shape, &mp_type_tuple)) {
ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype);
} else {
ndarray = ndarray_new_linear_array(source->len, source->dtype);
}
ndarray_copy_array(source, ndarray, 0);
}
return MP_OBJ_FROM_PTR(ndarray);
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.0.10
#define ULAB_VERSION 6.0.11
#define xstr(s) str(s)
#define str(s) #s

Expand Down
6 changes: 6 additions & 0 deletions docs/ulab-change-log.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
Sat, 6 May 2023

version 6.0.11

.reshape can now interpret unknown shape dimension

Sat, 6 May 2023

version 6.0.10

fix binary division
Expand Down

0 comments on commit beda4c1

Please sign in to comment.