From 1a61ca91336fee40cda50f5ac9f017f2b9f13d05 Mon Sep 17 00:00:00 2001 From: landmanbester Date: Thu, 26 Sep 2024 10:07:46 +0200 Subject: [PATCH] pass jit options correctly to overload --- pfb/operators/gridder.py | 8 +++- pfb/opt/pcg.py | 93 ++++++++++++---------------------------- pfb/utils/misc.py | 5 ++- pfb/utils/stokes2vis.py | 2 - pfb/utils/weighting.py | 4 +- 5 files changed, 39 insertions(+), 73 deletions(-) diff --git a/pfb/operators/gridder.py b/pfb/operators/gridder.py index d91ca038..091de45f 100644 --- a/pfb/operators/gridder.py +++ b/pfb/operators/gridder.py @@ -431,7 +431,9 @@ def image_data_products(dsl, if ovar: # scale the natural weights # RHS is weight relative to unity since wgtp included in ressq - # print(np.mean(ressq[mask>0]/ovar), np.std(ressq[mask>0]/ovar)) + meani = np.mean(ressq[mask>0]/ovar) + stdi = np.std(ressq[mask>0]/ovar) + print(f"Band {bandid} before: mean = {meani:.3e}, std = {stdi:.3e}") # wgt_relative_one = (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar) # wgt *= wgt_relative_one wgt *= (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar) @@ -471,7 +473,9 @@ def image_data_products(dsl, ovar = ssq/mask.sum() wgt /= ovar ressq = (residual_vis*wgt*residual_vis.conj()).real - print(np.mean(ressq[mask>0]), np.std(ressq[mask>0])) + meanf = np.mean(ressq[mask>0]) + stdf = np.std(ressq[mask>0]) + print(f"Band {bandid} after: mean = {meani:.3e}, std = {stdi:.3e}") # import ipdb; ipdb.set_trace() diff --git a/pfb/opt/pcg.py b/pfb/opt/pcg.py index c11ef656..5c4b77a5 100644 --- a/pfb/opt/pcg.py +++ b/pfb/opt/pcg.py @@ -17,7 +17,7 @@ # import pyscilog # log = pyscilog.get_logger('PCG') -@njit(**JIT_OPTIONS, parallel=True) +@njit(nogil=True, cache=False, parallel=False) def update(x, xp, r, rp, p, Ap, alpha): return update_impl(x, xp, r, rp, p, Ap, alpha) @@ -26,7 +26,7 @@ def update_impl(x, xp, r, rp, p, Ap, alpha): return NotImplementedError -@overload(update_impl, jit_options=JIT_OPTIONS, parallel=True) +@overload(update_impl, jit_options={**JIT_OPTIONS}) #, "parallel":True}) def nb_update_impl(x, xp, r, rp, p, Ap, alpha): if x.ndim==3: def impl(x, xp, r, rp, p, Ap, alpha): @@ -60,46 +60,7 @@ def alpha_update_impl(r, y, p, Ap): return NotImplementedError -@overload(alpha_update_impl, jit_options=JIT_OPTIONS, parallel=True) -def nb_alpha_update_impl(r, y, p, Ap): - if r.ndim==2: - def impl(r, y, p, Ap): - rnorm = 0.0 - rnorm_den = 0.0 - nx, ny = r.shape - for i in prange(nx): - for j in range(ny): - rnorm += r[i,j]*y[i,j] - rnorm_den += p[i,j]*Ap[i,j] - - alpha = rnorm/rnorm_den - return rnorm, alpha - elif r.ndim==3: - def impl(r, y, p, Ap): - rnorm = 0.0 - rnorm_den = 0.0 - nband, nx, ny = r.shape - for b in range(nband): - for i in prange(nx): - for j in range(ny): - rnorm += r[b,i,j]*y[b,i,j] - rnorm_den += p[b,i,j]*Ap[b,i,j] - - alpha = rnorm/rnorm_den - return rnorm, alpha - return impl - - -@njit(**JIT_OPTIONS, parallel=True) -def alpha_update(r, y, p, Ap): - return alpha_update_impl(r, y, p, Ap) - - -def alpha_update_impl(r, y, p, Ap): - return NotImplementedError - - -@overload(alpha_update_impl, jit_options=JIT_OPTIONS, parallel=True) +@overload(alpha_update_impl, jit_options={**JIT_OPTIONS, 'parallel':True}) def nb_alpha_update_impl(r, y, p, Ap): if r.ndim==2: def impl(r, y, p, Ap): @@ -138,7 +99,7 @@ def beta_update_impl(r, y, p, rnorm): return NotImplementedError -@overload(beta_update_impl, jit_options=JIT_OPTIONS, parallel=True) +@overload(beta_update_impl, jit_options={**JIT_OPTIONS, 'parallel':True}) def nb_beta_update_impl(r, y, p, rnorm): if r.ndim==2: def impl(r, y, p, rnorm): @@ -231,21 +192,21 @@ def M(x): return x ti = time() # x = xp + alpha * p # r = rp + alpha * Ap - ne.evaluate('xp + alpha*p', - out=x, - local_dict={ - 'xp': xp, - 'alpha': alpha, - 'p': p}, - casting='unsafe') - ne.evaluate('rp + alpha*Ap', - out=r, - local_dict={ - 'rp': rp, - 'alpha': alpha, - 'Ap': Ap}, - casting='unsafe') - # x, r = update(x, xp, r, rp, p, Ap, alpha) + # ne.evaluate('xp + alpha*p', + # out=x, + # local_dict={ + # 'xp': xp, + # 'alpha': alpha, + # 'p': p}, + # casting='unsafe') + # ne.evaluate('rp + alpha*Ap', + # out=r, + # local_dict={ + # 'rp': rp, + # 'alpha': alpha, + # 'Ap': Ap}, + # casting='unsafe') + x, r = update(x, xp, r, rp, p, Ap, alpha) tupdate += (time() - ti) y = M(r) @@ -259,15 +220,17 @@ def M(x): return x ti = time() rnorm_next = np.vdot(r, y) beta = rnorm_next / rnorm - ne.evaluate('beta*p-y', - out=p, - local_dict={ - 'beta': beta, - 'p': p, - 'y': y}, - casting='unsafe') + # ne.evaluate('beta*p-y', + # out=p, + # local_dict={ + # 'beta': beta, + # 'p': p, + # 'y': y}, + # casting='unsafe') # p = beta * p - y + p *= beta + p -= y # rnorm, p = beta_update(r, y, p, rnorm) tp += (time() - ti) rnorm = rnorm_next diff --git a/pfb/utils/misc.py b/pfb/utils/misc.py index 775c41cb..c26d5908 100644 --- a/pfb/utils/misc.py +++ b/pfb/utils/misc.py @@ -1021,7 +1021,7 @@ def eval_coeffs_to_slice(time, freq, coeffs, Ix, Iy, return image_in -@njit(**JIT_OPTIONS, parallel=True) +@njit(nogil=True, cache=True) def norm_diff(x, xp): return norm_diff_impl(x, xp) @@ -1030,7 +1030,8 @@ def norm_diff_impl(x, xp): return NotImplementedError -@overload(norm_diff_impl, jit_options=JIT_OPTIONS, parallel=True) +# @overload(norm_diff_impl, jit_options={**JIT_OPTIONS, "parallel":True}) +@overload(norm_diff_impl, jit_options={**JIT_OPTIONS}) # parallel reduction slower? def nb_norm_diff_impl(x, xp): if x.ndim==3: def impl(x, xp): diff --git a/pfb/utils/stokes2vis.py b/pfb/utils/stokes2vis.py index 4420745e..c1fe2e62 100644 --- a/pfb/utils/stokes2vis.py +++ b/pfb/utils/stokes2vis.py @@ -43,8 +43,6 @@ def single_stokes( msid=None, wid=None): - print(' Numba cache = ', numba.config.CACHE_DIR) - fieldid = ds.FIELD_ID ddid = ds.DATA_DESC_ID scanid = ds.SCAN_NUMBER diff --git a/pfb/utils/weighting.py b/pfb/utils/weighting.py index e5955bec..0fbb968d 100644 --- a/pfb/utils/weighting.py +++ b/pfb/utils/weighting.py @@ -198,7 +198,7 @@ def filter_extreme_counts(counts, level=10.0): return counts -@njit(**JIT_OPTIONS, parallel=True) +@njit(nogil=True, cache=False, parallel=False) def weight_data(data, weight, flag, jones, tbin_idx, tbin_counts, ant1, ant2, pol, product, nc): @@ -217,7 +217,7 @@ def _weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts, raise NotImplementedError -@overload(_weight_data_impl, **JIT_OPTIONS, parallel=True, prefer_literal=True) +@overload(_weight_data_impl, prefer_literal=True, jit_options={**JIT_OPTIONS, "parallel":True}) def nb_weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts, ant1, ant2, pol, product, nc):