diff --git a/ptypy/engines/ML.py b/ptypy/engines/ML.py index d9d99b9b7..9d5bb7d1a 100644 --- a/ptypy/engines/ML.py +++ b/ptypy/engines/ML.py @@ -100,10 +100,11 @@ class ML(PositionCorrectionEngine): lowlim = 0 help = Number of iterations before probe update starts - [all_line_coeffs] - default = False - type = bool - help = Whether to use all nine coefficients in the linesearch instead of three + [poly_line_coeffs] + default = quadratic + type = str + help = How many coefficients to be used in the the linesearch + doc = choose between the 'quadratic' approximation (default) or 'all' """ @@ -292,7 +293,7 @@ def engine_iterate(self, num=1): # In principle, the way things are now programmed this part # could be iterated over in a real Newton-Raphson style. t2 = time.time() - if self.p.all_line_coeffs: + if self.p.poly_line_coeffs == "all": B = self.ML_model.poly_line_all_coeffs(self.ob_h, self.pr_h) diffB = np.arange(1,len(B))*B[1:] # coefficients of poly derivative roots = np.roots(np.flip(diffB.astype(np.double))) # roots only supports double @@ -302,10 +303,12 @@ def engine_iterate(self, num=1): else: # find real root with smallest poly objective evalp = lambda root: np.polyval(np.flip(B),root) self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective - else: + elif self.p.poly_line_coeffs == "quadratic": B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h) # same as above but quicker when poly quadratic self.tmin = dt(-0.5 * B[1] / B[2]) + else: + raise NotImplementedError("poly_line_coeffs should be 'quadratic' or 'all'") tc += time.time() - t2