diff --git a/odetoolbox/system_of_shapes.py b/odetoolbox/system_of_shapes.py index 001a2f14..9e35e877 100644 --- a/odetoolbox/system_of_shapes.py +++ b/odetoolbox/system_of_shapes.py @@ -35,32 +35,27 @@ from .sympy_helpers import _custom_simplify_expr, _is_zero -def off_diagonal_is_zero(row: int, A) -> bool: - for col in range(A.shape[1]): - if col != row and not _is_zero(A[row, col]): - return False - return True - - def get_block_diagonal_blocks(A): assert A.shape[0] == A.shape[1], "matrix A should be square" N = A.shape[0] + + # A_mirrored = (A + A.T).applyfunc(_is_zero) # make the matrix symmetric so we only have to check one triangle + A_mirrored = (A + A.T) != 0 # make the matrix symmetric so we only have to check one triangle - A_mirrored = A + A.T # make the matrix symmetric so we only have to check one triangle + graph_components = scipy.sparse.csgraph.connected_components(A_mirrored)[1] - blocks = [] - start = 0 - blocksize = 0 + assert all(np.diff(graph_components) >= 0), "Matrix is not ordered" - while start < N: - blocksize += 1 - - if np.all(A_mirrored[start:start + blocksize, start + blocksize:N] == 0): - block = A[start:start + blocksize, start:start + blocksize] - blocks.append(block) - start += blocksize - blocksize = 0 + blocks = [] + for i in np.unique(graph_components): + idx = np.where(graph_components == i)[0] + assert all(np.diff(idx) > 0) + assert len(idx) == 1 or (len(np.unique(np.diff(idx))) == 1 and np.unique(np.diff(idx))[0] == 1) + idx_min = np.amin(idx) + idx_max = np.amax(idx) + block = A[idx_min:idx_max + 1, idx_min:idx_max + 1] + blocks.append(block) return blocks