Skip to content

Commit

Permalink
Cleanup around paralleization and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Apr 7, 2023
1 parent f5ffcc8 commit 4e0fdee
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 57 deletions.
36 changes: 18 additions & 18 deletions inferelator_velocity/metrics/circcorrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@ def circular_rank_correlation(

n = radian_array.shape[1]

slices = list(
gen_even_slices(
n,
effective_n_jobs(n_jobs)
)
)

if n_jobs != 1:
slices = list(
gen_even_slices(
n,
effective_n_jobs(n_jobs)
)
)
views = Parallel(n_jobs=n_jobs)(
delayed(_circcorrcoef_array)(
radian_array,
Expand All @@ -63,11 +62,11 @@ def circular_rank_correlation(
for i, c in zip(slices, views):
corr[:, i] = c

return corr

else:
return _circcorrcoef_array(radian_array)

return corr


def _circcorrcoef_array(
X,
Expand Down Expand Up @@ -112,13 +111,6 @@ def _rank_circular_array(

n = X.shape[1]

slices = list(
gen_even_slices(
n,
effective_n_jobs(n_jobs)
)
)

def _array_apply(x_sub):
return np.apply_along_axis(
_radian_rank_vector,
Expand All @@ -127,6 +119,14 @@ def _array_apply(x_sub):
)

if n_jobs != 1:

slices = list(
gen_even_slices(
n,
effective_n_jobs(n_jobs)
)
)

views = Parallel(n_jobs=n_jobs)(
delayed(_array_apply)(
X[:, i]
Expand All @@ -142,11 +142,11 @@ def _array_apply(x_sub):
for i, r in zip(slices, views):
rad_array[:, i] = r

return rad_array

else:
return _array_apply(X)

return rad_array


def _rank_vector(x):

Expand Down
76 changes: 49 additions & 27 deletions inferelator_velocity/metrics/information.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,34 @@ def mutual_information(
m = x.shape[1]
n = x.shape[1] if y is None else y.shape[1]

slices = list(gen_even_slices(n, effective_n_jobs(n_jobs)))
if n_jobs != 1:
slices = list(gen_even_slices(n, effective_n_jobs(n_jobs)))

views = Parallel(n_jobs=n_jobs)(
delayed(_mi_slice)(
x,
bins,
y_slicer=i,
y=y,
logtype=logtype
)
for i in slices
)

mutual_info = np.empty((m, n), dtype=float)

for i, r in zip(slices, views):
mutual_info[:, i] = r

return mutual_info

views = Parallel(n_jobs=n_jobs)(
delayed(_mi_slice)(
else:
return _mi_slice(
x,
i,
bins,
y=y,
logtype=logtype
)
for i in slices
)

mutual_info = np.empty((m, n), dtype=float)

for i, r in zip(slices, views):
mutual_info[:, i] = r

return mutual_info


def _shannon_entropy(
Expand Down Expand Up @@ -159,23 +168,31 @@ def _shannon_entropy(

m, n = discrete_array.shape

slices = list(gen_even_slices(n, effective_n_jobs(n_jobs)))
if n_jobs != 1:
slices = list(gen_even_slices(n, effective_n_jobs(n_jobs)))

views = Parallel(n_jobs=n_jobs)(
delayed(_entropy_slice)(
discrete_array[:, i],
bins,
logtype=logtype
views = Parallel(n_jobs=n_jobs)(
delayed(_entropy_slice)(
discrete_array[:, i],
bins,
logtype=logtype
)
for i in slices
)
for i in slices
)

entropy = np.empty(n, dtype=float)
entropy = np.empty(n, dtype=float)

for i, r in zip(slices, views):
entropy[i] = r
for i, r in zip(slices, views):
entropy[i] = r

return entropy
return entropy

else:
return _entropy_slice(
discrete_array,
bins,
logtype=logtype
)


def _entropy_slice(
Expand All @@ -186,15 +203,20 @@ def _entropy_slice(

def _entropy(vec):
px = np.bincount(vec, minlength=bins) / vec.size
return -1 * np.nansum(px * logtype(px))
log_px = logtype(
px,
out=np.full_like(px, np.nan),
where=px > 0
)
return -1 * np.nansum(px * log_px)

return np.apply_along_axis(_entropy, 0, x)


def _mi_slice(
x,
y_slicer,
bins,
y_slicer=slice(None),
y=None,
logtype=np.log
):
Expand Down
53 changes: 45 additions & 8 deletions inferelator_velocity/tests/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@

class TestProgramMetrics(unittest.TestCase):

n_jobs = 1

def test_binning(self):

expr = _make_array_discrete(EXPRESSION, BINS)
Expand All @@ -65,14 +67,23 @@ def test_binning(self):
def test_entropy(self):

expr = _make_array_discrete(EXPRESSION, BINS)
entropy = _shannon_entropy(expr, 10, logtype=np.log2)
entropy = _shannon_entropy(
expr,
10,
logtype=np.log2,
n_jobs=self.n_jobs
)

print(entropy)
self.assertTrue(np.all(entropy >= 0))
npt.assert_almost_equal(entropy[4], np.log2(BINS))
npt.assert_almost_equal(entropy[3], 0.)

entropy = _shannon_entropy(expr, 10, logtype=np.log)
entropy = _shannon_entropy(
expr,
10,
logtype=np.log,
n_jobs=self.n_jobs
)

self.assertTrue(np.all(entropy >= 0))
npt.assert_almost_equal(entropy[3], 0.)
Expand All @@ -82,8 +93,18 @@ def test_mutual_info(self):

expr = _make_array_discrete(EXPRESSION, BINS)

entropy = _shannon_entropy(expr, 10, logtype=np.log2)
mi = mutual_information(expr, 10, logtype=np.log2)
entropy = _shannon_entropy(
expr,
10,
logtype=np.log2,
n_jobs=self.n_jobs
)
mi = mutual_information(
expr,
10,
logtype=np.log2,
n_jobs=self.n_jobs
)

self.assertTrue(np.all(mi >= 0))
npt.assert_array_equal(mi[:, 3], np.zeros_like(mi[:, 3]))
Expand All @@ -94,8 +115,18 @@ def test_info_distance(self):

expr = _make_array_discrete(EXPRESSION, BINS)

entropy = _shannon_entropy(expr, 10, logtype=np.log2)
mi = mutual_information(expr, 10, logtype=np.log2)
entropy = _shannon_entropy(
expr,
10,
logtype=np.log2,
n_jobs=self.n_jobs
)
mi = mutual_information(
expr,
10,
logtype=np.log2,
n_jobs=self.n_jobs
)

with np.errstate(divide='ignore', invalid='ignore'):
calc_dist = 1 - mi / (entropy[:, None] + entropy[None, :] - mi)
Expand All @@ -105,7 +136,8 @@ def test_info_distance(self):
expr,
BINS,
logtype=np.log2,
return_information=True
return_information=True,
n_jobs=self.n_jobs
)

self.assertTrue(np.all(i_dist >= 0))
Expand All @@ -117,6 +149,11 @@ def test_info_distance(self):
)


class TestProgramMetricsParallel(TestProgramMetrics):

n_jobs = 2


class TestProgram(unittest.TestCase):

def test_find_program(self):
Expand Down
5 changes: 1 addition & 4 deletions inferelator_velocity/tests/test_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def test_times(self):
verbose=True
)

print(EXPR)
print(LAB)

self.assertListEqual(
[0, 0.5, 1.],
[times[v] for k, v in {'a': 2, 'b': 5, 'c': 9}.items()]
Expand All @@ -50,4 +47,4 @@ def test_times(self):
class TestTimeFunctions(unittest.TestCase):

def test_wrap_time(self):
pass
pass

0 comments on commit 4e0fdee

Please sign in to comment.