From 66bfb1f4a63f35c33527b3374495279ae3c5c88d Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Tue, 5 Mar 2024 14:11:30 +0000 Subject: [PATCH] code restructure, less switching --- ptypy/engines/ML.py | 54 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/ptypy/engines/ML.py b/ptypy/engines/ML.py index 1c2496b34..d9d99b9b7 100644 --- a/ptypy/engines/ML.py +++ b/ptypy/engines/ML.py @@ -294,17 +294,6 @@ def engine_iterate(self, num=1): t2 = time.time() if self.p.all_line_coeffs: B = self.ML_model.poly_line_all_coeffs(self.ob_h, self.pr_h) - else: - B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h) - tc += time.time() - t2 - - if np.isinf(B).any() or np.isnan(B).any(): - logger.warning( - 'Warning! inf or nan found! Trying to continue...') - B[np.isinf(B)] = 0. - B[np.isnan(B)] = 0. - - if self.p.all_line_coeffs: diffB = np.arange(1,len(B))*B[1:] # coefficients of poly derivative roots = np.roots(np.flip(diffB.astype(np.double))) # roots only supports double real_roots = np.real(roots[np.isreal(roots)]) # not interested in complex roots @@ -313,8 +302,13 @@ def engine_iterate(self, num=1): else: # find real root with smallest poly objective evalp = lambda root: np.polyval(np.flip(B),root) self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective - else: # same as above but quicker when poly quadratic + else: + B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h) + # same as above but quicker when poly quadratic self.tmin = dt(-0.5 * B[1] / B[2]) + + tc += time.time() - t2 + self.ob_h *= self.tmin self.pr_h *= self.tmin self.ob += self.ob_h @@ -614,6 +608,12 @@ def poly_line_coeffs(self, ob_h, pr_h): B += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -693,6 +693,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h): B[:3] += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -844,6 +850,12 @@ def poly_line_coeffs(self, ob_h, pr_h): B += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -922,6 +934,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h): B[:3] += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -1088,6 +1106,12 @@ def poly_line_coeffs(self, ob_h, pr_h): B += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B @@ -1172,6 +1196,12 @@ def poly_line_all_coeffs(self, ob_h, pr_h): B[:3] += Brenorm * self.regularizer.poly_line_coeffs( ob_h.storages[name].data, s.data) + if np.isinf(B).any() or np.isnan(B).any(): + logger.warning( + 'Warning! inf or nan found! Trying to continue...') + B[np.isinf(B)] = 0. + B[np.isnan(B)] = 0. + self.B = B return B