Skip to content

Commit

Permalink
code restructure, less switching
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Mar 5, 2024
1 parent 8c4ca44 commit 66bfb1f
Showing 1 changed file with 42 additions and 12 deletions.
54 changes: 42 additions & 12 deletions ptypy/engines/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,6 @@ def engine_iterate(self, num=1):
t2 = time.time()
if self.p.all_line_coeffs:
B = self.ML_model.poly_line_all_coeffs(self.ob_h, self.pr_h)
else:
B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h)
tc += time.time() - t2

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

if self.p.all_line_coeffs:
diffB = np.arange(1,len(B))*B[1:] # coefficients of poly derivative
roots = np.roots(np.flip(diffB.astype(np.double))) # roots only supports double
real_roots = np.real(roots[np.isreal(roots)]) # not interested in complex roots
Expand All @@ -313,8 +302,13 @@ def engine_iterate(self, num=1):
else: # find real root with smallest poly objective
evalp = lambda root: np.polyval(np.flip(B),root)
self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective
else: # same as above but quicker when poly quadratic
else:
B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h)
# same as above but quicker when poly quadratic
self.tmin = dt(-0.5 * B[1] / B[2])

tc += time.time() - t2

self.ob_h *= self.tmin
self.pr_h *= self.tmin
self.ob += self.ob_h
Expand Down Expand Up @@ -614,6 +608,12 @@ def poly_line_coeffs(self, ob_h, pr_h):
B += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B
Expand Down Expand Up @@ -693,6 +693,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h):
B[:3] += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B
Expand Down Expand Up @@ -844,6 +850,12 @@ def poly_line_coeffs(self, ob_h, pr_h):
B += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B
Expand Down Expand Up @@ -922,6 +934,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h):
B[:3] += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B
Expand Down Expand Up @@ -1088,6 +1106,12 @@ def poly_line_coeffs(self, ob_h, pr_h):
B += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B
Expand Down Expand Up @@ -1172,6 +1196,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h):
B[:3] += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B
Expand Down

0 comments on commit 66bfb1f

Please sign in to comment.