Skip to content

Commit

Permalink
Skip zero entries when creating a sparse matrix from numpy (#160) (th…
Browse files Browse the repository at this point in the history
…anks @alebugariu)
  • Loading branch information
volkm authored Mar 19, 2024
2 parents 8ab4121 + c4ced65 commit 6c983bc
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 35 deletions.
45 changes: 19 additions & 26 deletions lib/stormpy/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,53 +6,45 @@
def build_sparse_matrix(array, row_group_indices=[]):
"""
Build a sparse matrix from numpy array.
Zero entries are skipped.
:param numpy array: The array.
:param List[double] row_group_indices: List containing the starting row of each row group in ascending order.
:return: Sparse matrix.
"""

num_row = array.shape[0]
num_col = array.shape[1]

len_group_indices = len(row_group_indices)
if len_group_indices > 0:
builder = storage.SparseMatrixBuilder(rows=num_row, columns=num_col, has_custom_row_grouping=True,
row_groups=len_group_indices)
else:
builder = storage.SparseMatrixBuilder(rows=num_row, columns=num_col)

row_group_index = 0
for r in range(num_row):
# check whether to start a custom row group
if row_group_index < len_group_indices and r == row_group_indices[row_group_index]:
builder.new_row_group(r)
row_group_index += 1
# insert values of the current row
for c in range(num_col):
builder.add_next_value(r, c, array[r, c])

return builder.build()
return _build_sparse_matrix(storage.SparseMatrixBuilder, array, row_group_indices=row_group_indices)


def build_parametric_sparse_matrix(array, row_group_indices=[]):
"""
Build a sparse matrix from numpy array.
Zero entries are skipped.
:param numpy array: The array.
:param List[double] row_group_indices: List containing the starting row of each row group in ascending order.
:return: Parametric sparse matrix.
"""
return _build_sparse_matrix(storage.ParametricSparseMatrixBuilder, array, row_group_indices=row_group_indices)


def _build_sparse_matrix(builder_class, array, row_group_indices=[]):
"""
General method to build a sparse matrix from numpy array.
Zero entries are skipped.
:param class builder_class: The class type used to create the matrix builder.
:param numpy array: The array.
:param List[double] row_group_indices: List containing the starting row of each row group in ascending order.
:return: Sparse matrix.
"""
num_row = array.shape[0]
num_col = array.shape[1]

len_group_indices = len(row_group_indices)
if len_group_indices > 0:
builder = storage.ParametricSparseMatrixBuilder(rows=num_row, columns=num_col, has_custom_row_grouping=True,
row_groups=len_group_indices)
builder = builder_class(rows=num_row, columns=num_col, has_custom_row_grouping=True, row_groups=len_group_indices)
else:
builder = storage.ParametricSparseMatrixBuilder(rows=num_row, columns=num_col)
builder = builder_class(rows=num_row, columns=num_col)

row_group_index = 0
for r in range(num_row):
Expand All @@ -62,7 +54,8 @@ def build_parametric_sparse_matrix(array, row_group_indices=[]):
row_group_index += 1
# insert values of the current row
for c in range(num_col):
builder.add_next_value(r, c, array[r, c])
if array[r, c] != 0:
builder.add_next_value(r, c, array[r, c])

return builder.build()

Expand Down
40 changes: 31 additions & 9 deletions tests/storage/test_matrix_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,39 @@ def test_matrix_from_numpy(self):
# Check matrix dimension
assert matrix.nr_rows == array.shape[0]
assert matrix.nr_columns == array.shape[1]
assert matrix.nr_entries == 8
assert matrix.nr_entries == 7

# Check matrix values
for r in range(array.shape[1]):
row = matrix.get_row(r)
for e in row:
assert (e.value() == array[r, e.column])
assert e.value() == array[r, e.column]

@numpy_avail
def test_matrix_from_numpy_zeros(self):
import numpy as np
array = np.array([[0, 0, 1, 0],
[0.1, 0, 0, 0.9],
[0, 0, 0, 0],
[1, 0, 0, 0]], dtype='float64')

matrix = stormpy.build_sparse_matrix(array)

# Check matrix dimension
assert matrix.nr_rows == array.shape[0]
assert matrix.nr_columns == array.shape[1]
assert matrix.nr_entries == 4

# Check matrix values
for r in range(array.shape[1]):
row = matrix.get_row(r)
for e in row:
assert e.value() == array[r, e.column]

@numpy_avail
def test_parametric_matrix_from_numpy(self):
import numpy as np
zero = stormpy.RationalRF(0)
one_pol = stormpy.RationalRF(1)
one_pol = stormpy.FactorizedPolynomial(one_pol)
first_val = stormpy.FactorizedRationalFunction(one_pol, one_pol)
Expand All @@ -231,22 +253,22 @@ def test_parametric_matrix_from_numpy(self):
third_val = stormpy.FactorizedRationalFunction(one_pol, two_pol)

array = np.array([[sec_val, first_val],
[first_val, sec_val],
[sec_val, sec_val],
[first_val, zero],
[0, sec_val],
[third_val, third_val]])

matrix = stormpy.build_parametric_sparse_matrix(array)

# Check matrix dimension
assert matrix.nr_rows == array.shape[0]
assert matrix.nr_columns == array.shape[1]
assert matrix.nr_entries == 8
assert matrix.nr_entries == 6

# Check matrix values
for r in range(array.shape[1]):
row = matrix.get_row(r)
for e in row:
assert (e.value() == array[r, e.column])
assert e.value() == array[r, e.column]

@numpy_avail
def test_matrix_from_numpy_row_grouping(self):
Expand All @@ -261,13 +283,13 @@ def test_matrix_from_numpy_row_grouping(self):
# Check matrix dimension
assert matrix.nr_rows == array.shape[0]
assert matrix.nr_columns == array.shape[1]
assert matrix.nr_entries == 8
assert matrix.nr_entries == 7

# Check matrix values
for r in range(array.shape[1]):
row = matrix.get_row(r)
for e in row:
assert (e.value() == array[r, e.column])
assert e.value() == array[r, e.column]

# Check row groups
assert matrix.get_row_group_start(0) == 1
Expand Down Expand Up @@ -303,7 +325,7 @@ def test_parametric_matrix_from_numpy_row_grouping(self):
for r in range(array.shape[1]):
row = matrix.get_row(r)
for e in row:
assert (e.value() == array[r, e.column])
assert e.value() == array[r, e.column]

# Check row groups
assert matrix.get_row_group_start(0) == 1
Expand Down

0 comments on commit 6c983bc

Please sign in to comment.