Skip to content

Commit

Permalink
pass jit options correctly to overload
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Sep 26, 2024
1 parent 90eefda commit 1a61ca9
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 73 deletions.
8 changes: 6 additions & 2 deletions pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ def image_data_products(dsl,
if ovar:
# scale the natural weights
# RHS is weight relative to unity since wgtp included in ressq
# print(np.mean(ressq[mask>0]/ovar), np.std(ressq[mask>0]/ovar))
meani = np.mean(ressq[mask>0]/ovar)
stdi = np.std(ressq[mask>0]/ovar)
print(f"Band {bandid} before: mean = {meani:.3e}, std = {stdi:.3e}")
# wgt_relative_one = (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar)
# wgt *= wgt_relative_one
wgt *= (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar)
Expand Down Expand Up @@ -471,7 +473,9 @@ def image_data_products(dsl,
ovar = ssq/mask.sum()
wgt /= ovar
ressq = (residual_vis*wgt*residual_vis.conj()).real
print(np.mean(ressq[mask>0]), np.std(ressq[mask>0]))
meanf = np.mean(ressq[mask>0])
stdf = np.std(ressq[mask>0])
print(f"Band {bandid} after: mean = {meani:.3e}, std = {stdi:.3e}")

# import ipdb; ipdb.set_trace()

Expand Down
93 changes: 28 additions & 65 deletions pfb/opt/pcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# import pyscilog
# log = pyscilog.get_logger('PCG')

@njit(**JIT_OPTIONS, parallel=True)
@njit(nogil=True, cache=False, parallel=False)
def update(x, xp, r, rp, p, Ap, alpha):
return update_impl(x, xp, r, rp, p, Ap, alpha)

Expand All @@ -26,7 +26,7 @@ def update_impl(x, xp, r, rp, p, Ap, alpha):
return NotImplementedError


@overload(update_impl, jit_options=JIT_OPTIONS, parallel=True)
@overload(update_impl, jit_options={**JIT_OPTIONS}) #, "parallel":True})
def nb_update_impl(x, xp, r, rp, p, Ap, alpha):
if x.ndim==3:
def impl(x, xp, r, rp, p, Ap, alpha):
Expand Down Expand Up @@ -60,46 +60,7 @@ def alpha_update_impl(r, y, p, Ap):
return NotImplementedError


@overload(alpha_update_impl, jit_options=JIT_OPTIONS, parallel=True)
def nb_alpha_update_impl(r, y, p, Ap):
if r.ndim==2:
def impl(r, y, p, Ap):
rnorm = 0.0
rnorm_den = 0.0
nx, ny = r.shape
for i in prange(nx):
for j in range(ny):
rnorm += r[i,j]*y[i,j]
rnorm_den += p[i,j]*Ap[i,j]

alpha = rnorm/rnorm_den
return rnorm, alpha
elif r.ndim==3:
def impl(r, y, p, Ap):
rnorm = 0.0
rnorm_den = 0.0
nband, nx, ny = r.shape
for b in range(nband):
for i in prange(nx):
for j in range(ny):
rnorm += r[b,i,j]*y[b,i,j]
rnorm_den += p[b,i,j]*Ap[b,i,j]

alpha = rnorm/rnorm_den
return rnorm, alpha
return impl


@njit(**JIT_OPTIONS, parallel=True)
def alpha_update(r, y, p, Ap):
return alpha_update_impl(r, y, p, Ap)


def alpha_update_impl(r, y, p, Ap):
return NotImplementedError


@overload(alpha_update_impl, jit_options=JIT_OPTIONS, parallel=True)
@overload(alpha_update_impl, jit_options={**JIT_OPTIONS, 'parallel':True})
def nb_alpha_update_impl(r, y, p, Ap):
if r.ndim==2:
def impl(r, y, p, Ap):
Expand Down Expand Up @@ -138,7 +99,7 @@ def beta_update_impl(r, y, p, rnorm):
return NotImplementedError


@overload(beta_update_impl, jit_options=JIT_OPTIONS, parallel=True)
@overload(beta_update_impl, jit_options={**JIT_OPTIONS, 'parallel':True})
def nb_beta_update_impl(r, y, p, rnorm):
if r.ndim==2:
def impl(r, y, p, rnorm):
Expand Down Expand Up @@ -231,21 +192,21 @@ def M(x): return x
ti = time()
# x = xp + alpha * p
# r = rp + alpha * Ap
ne.evaluate('xp + alpha*p',
out=x,
local_dict={
'xp': xp,
'alpha': alpha,
'p': p},
casting='unsafe')
ne.evaluate('rp + alpha*Ap',
out=r,
local_dict={
'rp': rp,
'alpha': alpha,
'Ap': Ap},
casting='unsafe')
# x, r = update(x, xp, r, rp, p, Ap, alpha)
# ne.evaluate('xp + alpha*p',
# out=x,
# local_dict={
# 'xp': xp,
# 'alpha': alpha,
# 'p': p},
# casting='unsafe')
# ne.evaluate('rp + alpha*Ap',
# out=r,
# local_dict={
# 'rp': rp,
# 'alpha': alpha,
# 'Ap': Ap},
# casting='unsafe')
x, r = update(x, xp, r, rp, p, Ap, alpha)
tupdate += (time() - ti)
y = M(r)

Expand All @@ -259,15 +220,17 @@ def M(x): return x
ti = time()
rnorm_next = np.vdot(r, y)
beta = rnorm_next / rnorm
ne.evaluate('beta*p-y',
out=p,
local_dict={
'beta': beta,
'p': p,
'y': y},
casting='unsafe')
# ne.evaluate('beta*p-y',
# out=p,
# local_dict={
# 'beta': beta,
# 'p': p,
# 'y': y},
# casting='unsafe')

# p = beta * p - y
p *= beta
p -= y
# rnorm, p = beta_update(r, y, p, rnorm)
tp += (time() - ti)
rnorm = rnorm_next
Expand Down
5 changes: 3 additions & 2 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def eval_coeffs_to_slice(time, freq, coeffs, Ix, Iy,
return image_in


@njit(**JIT_OPTIONS, parallel=True)
@njit(nogil=True, cache=True)
def norm_diff(x, xp):
return norm_diff_impl(x, xp)

Expand All @@ -1030,7 +1030,8 @@ def norm_diff_impl(x, xp):
return NotImplementedError


@overload(norm_diff_impl, jit_options=JIT_OPTIONS, parallel=True)
# @overload(norm_diff_impl, jit_options={**JIT_OPTIONS, "parallel":True})
@overload(norm_diff_impl, jit_options={**JIT_OPTIONS}) # parallel reduction slower?
def nb_norm_diff_impl(x, xp):
if x.ndim==3:
def impl(x, xp):
Expand Down
2 changes: 0 additions & 2 deletions pfb/utils/stokes2vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def single_stokes(
msid=None,
wid=None):

print(' Numba cache = ', numba.config.CACHE_DIR)

fieldid = ds.FIELD_ID
ddid = ds.DATA_DESC_ID
scanid = ds.SCAN_NUMBER
Expand Down
4 changes: 2 additions & 2 deletions pfb/utils/weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def filter_extreme_counts(counts, level=10.0):
return counts


@njit(**JIT_OPTIONS, parallel=True)
@njit(nogil=True, cache=False, parallel=False)
def weight_data(data, weight, flag, jones, tbin_idx, tbin_counts,
ant1, ant2, pol, product, nc):

Expand All @@ -217,7 +217,7 @@ def _weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts,
raise NotImplementedError


@overload(_weight_data_impl, **JIT_OPTIONS, parallel=True, prefer_literal=True)
@overload(_weight_data_impl, prefer_literal=True, jit_options={**JIT_OPTIONS, "parallel":True})
def nb_weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts,
ant1, ant2, pol, product, nc):

Expand Down

0 comments on commit 1a61ca9

Please sign in to comment.