-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsgk.py
626 lines (522 loc) · 20 KB
/
sgk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
import numpy as np
import pynbody as pb
import scipy
import matplotlib.pyplot as plt
from shapely.geometry import LineString
from scipy import signal
from scipy import stats
from scipy.interpolate import CubicSpline, griddata
from scipy.ndimage import gaussian_filter
from fast_histogram import histogram2d
from scipy import optimize
import matplotlib.colors as colors
__author__ = "Steven Gough-Kelly"
__copyright__ = "Copyright 2022, UCLan Galaxy Dynamics"
__credits__ = ["Steven Gough-Kelly", "Victor P. Debattista", "Stuart R. Anderson"]
__license__ = "MIT"
__version__ = "1.0.1"
__maintainer__ = "Steven Gough-Kelly"
__email__ = "[email protected]"
__status__ = "Production"
def bar_align(galaxy, rbar, barfrac = 0.5, zlim=0.5, log=False):
"""
Aligns the bar of pynbody galaxy simulation with the x-axis assuming the
galaxy disc is already aligned to the XY plane using the inertial tensor.
Function does not return any values/object. Pynbody functions effect the
global variable which stores 'galaxy' so rotations within the functions
are applied to input variable 'galaxy'.
Parameters
----------
galaxy : pynbody simulation object
Galaxy object in the XY plane to be aligned.
rbar : float
Bar radius in simulation units e.g. kpc.
barfrac : float
Fraction of bar length to calculate the inertial tensor within in
simulation units e.g. kpc.
zlim : float
Vertical limit to calculate intertial tensor within in simulation units
e.g. kpc. Useful in galaxies with thick discs and weak bars.
log : Bool
Flag to output print statements.
Returns
-------
None
"""
if np.isnan(rbar):
if log:
print('* Bar undefined, using 1 kpc *')
rbar = 1.0
elif rbar*barfrac < 1.:
rbar = 1
if log:
print('* Short Bar, using 1 kpc *')
else:
rbar = rbar*barfrac
if log:
print('* Bar defined, aligning to {} kpc *'.format(rbar))
if log:
print('* Realigning bar using |z| < {} *'.format(zlim))
zfilt = pb.filt.LowPass('z',zlim)&pb.filt.HighPass('z',-zlim)
rfilt = pb.filt.LowPass('rxy',rbar)
x = np.array(galaxy[zfilt&rfilt].star['pos'].in_units('kpc'))[:,0]
y = np.array(galaxy[zfilt&rfilt].star['pos'].in_units('kpc'))[:,1]
m = np.array(galaxy.star[zfilt&rfilt]['mass'])
#Calculate the inertia tensor
I_yy, I_xx, I_xy = np.sum(m*y**2),np.sum(m*x**2),np.sum(m*x*y)
I = np.array([[I_yy, -I_xy], [-I_xy, I_xx]])
#Calculate the eigenvalues and eigenvectors
eigenvalues, eigenvectors = np.linalg.eig(I)
lowest = eigenvalues.argmin()
maj_axis = eigenvectors[:, lowest]
#Get the angle we need to rotate by
r_angle = np.degrees(np.arctan2(maj_axis[1], maj_axis[0]))
galaxy.rotate_z(-r_angle)
if log:
print('* Bar realigned by {} degrees*'.format(r_angle))
return None
def CDF(arr,bins=10_000,rng=None, norm=True):
"""
Produces cumulative distribution functions of input array.
Large number of bins produce curves which step up for each element where
len(arr) < bins. For len(arr) > bins, default number of bins is still large
enough to produce smooth lines.
Parameters
----------
arr : list or numpy arrary of int or float
Variable to be binned.
bins : int
Number of bins
rng : tuple or list of two elements
upper and lower limits over which to calculate the CDF
norm : bool
Flag to determine if CDF ranges from 0-max(cumsum) or 0-1
Returns
-------
x : the bin centres
cumulative : the cumulative sum values
"""
if type(rng)==type(None):
rng = (np.nanmin(arr),np.nanmax(arr))
values, base = np.histogram(arr, bins=bins,range=rng)
cumulative = np.cumsum(values)
if norm:
cumulative = cumulative/max(cumulative)
x = base[:-1] + np.diff(base)[0]
return x, cumulative
def distmod(app_mag=None,abs_mag=None,dist=None,extinc=None,output=None):
"""
Conversions through distance modulus equation for apparent, absolute
and distance parameters with optional reddening.
Parameters
----------
app_mag : float or ndarray of float
apparent magnitude(s) of source(s) in mag.
abs_mag : float or ndarray of float
absolute magnitude(s) of source(s) in mag.
dist : float or ndarray of float
distances(s) of source(s) in kpc.
extinc : float or ndarray of float
extinction value(s) in relevant band of source(s) in mag.
output : str
changes appropriate output. Also changes required inputs.
can take values:
* 'abs' : calculate the absolute magnitude.
Requires app_mag and dist.
* 'app' : calculate the apparent magnitude.
Requires abs_mag and dist.
* 'dist' : calculate the distance.
Requires app_mag and dist.
* 'extinc' : calculate the extinction.
Requires app_mag, abs_mag and dist.
Returns
-------
calculation: float or ndarray of float
dependent on the output parameter.
"""
def app2abs(app_mag=None,dist=None,extinc=None):
if type(extinc) is type(None):
extinc = np.zeros(len(app_mag))
return app_mag - extinc - 5*(np.log10(dist)+2)
def abs2app(abs_mag=None,dist=None,extinc=None):
if type(extinc) is type(None):
extinc = np.zeros(len(abs_mag))
return abs_mag + extinc + 5*(np.log10(dist)+2)
def mag2dist(app_mag=None,abs_mag=None,extinc=None):
if type(extinc) is type(None):
extinc = np.zeros(len(app_mag))
return 10**(((app_mag - abs_mag - extinc)/5)-2)
def magdist2extinc(app_mag=None,abs_mag=None,dist=None):
return app_mag - 5*(np.log10(dist)+2) - abs_mag
if output == 'abs':
return app2abs(app_mag=app_mag,dist=dist,extinc=extinc)
elif output=='app':
return abs2app(abs_mag=abs_mag,dist=dist,extinc=extinc)
elif output=='dist':
return mag2dist(app_mag=app_mag,abs_mag=abs_mag,extinc=extinc)
elif output=='extinc':
return magdist2extinc(app_mag=app_mag,abs_mag=abs_mag,dist=dist)
else:
return print('output type not found: '+ \
'Options "abs","app","dist" and "extinc"')
def Gauss_Hermite(w, n):
"""
Return the normalised Gauss Hermite function of order n, weights w
Gerhard MNRAS (1993) 265, 213-230
Equations 3.1 - 3.7
Parameters
----------
w : list or numpy arrary of int or float
Weights of velocities/positions.
n : int
nth order of the Gauss Hermite polynomial.
Returns
-------
numpy array of Gauss Hermite of order n for input weights.
"""
w = np.array(w)
p = scipy.special.hermite(n, monic=False) #hermite poly1d obj
norm = np.sqrt((2**(n+1))*np.pi*np.math.factorial(n)) # N_n Eqn 3.1
return (p(w)/norm) * np.exp( -0.5 * w * w )
def GaussHermiteMoment(v, n):
"""
Calculate the Gauss Hermite moment of order n for input distribution
Parameters
----------
v : list or numpy arrary of int or float
input distribution of velocities/positions.
n : int
nth order of the Gauss Hermite polynomial.
Returns
-------
float Gauss Hermite moment of order n for input distribution.
"""
v = v[np.isfinite(v)] # remove nans&inf
if len(v) <= 1: # Added SL speed+error catch when used in binned_statistic
return np.nan
v_dash = (v - np.mean(v))/np.std(v) # center on 0, norm width to 1sig
hn = np.sum(Gauss_Hermite(v_dash, n))
return np.sqrt(4*np.pi) * hn / len(v)
def linint(x,y,intercept,axis,pout=False):
"""
Finds the intercept point between arbitrary line and a horizontal or vertical slice.
Note this does not work when a slice intersects the line multiple times.
Most commonly used with cumulative distribution functions.
Parameters
----------
x : list or array
x coordinates of line.
y : list or array
y coordinates of line.
intercept : float or int
x or y axis position of the slice to calculate the intercept for.
axis : str
Takes values of 'x' or 'y' to define which axis the slice is taken from.
Pass 'x' for vertical slice or 'y' for horizontal slice.
pout : bool
Set to true if you want to print the (x,y) coordinate of the intercept point.
Returns
-------
shapely point object of the intercept. Use intercept.x or intercept.y to access the coordinates.
"""
line_1 = LineString(np.column_stack((x, y)))
if axis=='x':
line_2 = LineString(np.column_stack(([intercept,intercept],[min(y),max(y)])))
if axis=='y':
line_2 = LineString(np.column_stack(([min(x),max(x)],[intercept,intercept])))
if pout:
print((line_1.intersection(line_2).x,line_1.intersection(line_2).y))
return line_1.intersection(line_2)
def plot_linint(x,y,intercept,axis,plot_axis,pout=False,**plt_kwargs):
"""
Plot the intercept of an arbitrary line and a slice as defined using linint on a given plot axis.
Parameters
----------
x : list or array
x coordinates of line.
y : list or array
y coordinates of line.
intercept : float or int
x or y axis position of the slice to calculate the intercept for.
axis : str
Takes values of 'x' or 'y' to define which axis the slice is taken from.
Pass 'x' for vertical slice or 'y' for horizontal slice.
plot_axis : axis object
The axis object of a matplotlib plot on which you wish to plot the intercept lines.
pout : bool
Set to true if you want to print the (x,y) coordinate of the intercept point.
Returns
-------
None
"""
intercept = linint(x,y,intercept,axis,pout)
plot_axis.plot([intercept.x,intercept.x],[min(y),intercept.y],**plt_kwargs)
plot_axis.plot([min(x),intercept.x],[intercept.y,intercept.y],**plt_kwargs)
return None
def sn_bins(var, n, w=None, cent='avg', order='asc', leftover='join'):
"""
Defines bins with equal 'n' in each bin. Useful for sparse data. Order
allows for the control of direction of definition of the bins as either the
the last or first bin will have less than the target n.
Parameters
----------
var : list or numpy arrary of int or float
Variable to be binned.
n : int
Number of elements to put in a bin
w : list or numpy arrary of int or float
Weight of each element in var. Used only when cent=='avg'
cent : str
Takes values of 'avg' or 'mid' to chose definition of bin centers.
Either weighted average of each bin or midpoint between bins.
order : str
Takes values of 'asc' or 'dec' to set direction in which you define
bins. In 'asc' mode the final bin will have n_i <= n. The first bin in
'dec' mode will have n_i <= n.
If 'ValueError: The smallest edge difference is numerically 0.' is
raised in binned_statistic try changing order.
leftover: str
Takes value of 'join' to join final bin which has len(elements) < n.
Returns
-------
bins : the bin edges
bin_cents : the bin centers
"""
if order=='asc':
sorted = np.argsort(var)
bins = var[sorted[::n]]
if leftover=='join':
bins = np.append(bins[:-1],var[sorted[-1]])
else:
bins = np.append(bins,var[sorted[-1]])
elif order=='dec':
sorted = np.argsort(var)[::-1]
bins = var[sorted[::n]]
if leftover=='join':
bins = np.append(bins[:-1],var[sorted[-1]])[::-1]
else:
bins = np.append(bins,var[sorted[-1]])[::-1]
if cent=='avg':
if w is None:
w = np.ones(len(var))
bin_cents, _, _ = stats.binned_statistic(var,w*var,statistic='sum',bins=bins)
counts, _, _ = stats.binned_statistic(var,w,statistic='sum',bins=bins)
bin_cents = bin_cents/counts
if cent=='mid':
bin_cents = bins[:-1] + np.diff(bins)/2
return bins, bin_cents
def lin_sn_bins(var,nmin,range=None,mode='min'):
"""
Defines linear bins with a minimum 'nmin' elements in each bin using optimize.
Two modes allow for high resolution (max number of bins) or minimum number of
bins possible (default).
Parameters
----------
var : list or numpy arrary of int or float
Variable to be binned.
nmin : int
Minimum number of elements required for each bin
range : tuple or list of two elements.
The Min and Max values to determine bins over.
mode : str
Takes values of min or max to allow for the minimum number of bins
that fits the nmin criteria or the maximum possible to still ensure
nmin in each bin.
Returns
-------
nbins : min/max number of bins possible to ensure each bin has nmin elements.
"""
if range is None:
range = (np.min(var),np.max(var))
def sn_min(nbins,var,nmin,range):
count, edges = np.histogram(var,range=range,bins=int(nbins))
if np.nanmin(count)>nmin:
return (1/(np.diff(edges)[0]))
elif np.nanmin(count)==nmin:
return -1e-10
else:
return -(1/(np.diff(edges)[0]))
def sn_max(nbins,var,nmin,range):
count, edges = np.histogram(var,range=range,bins=int(nbins))
if np.nanmin(count)<nmin:
return (1/(np.diff(edges)[0]))
elif np.nanmin(count)==nmin:
return -1e-10
else:
return -(1/(np.diff(edges)[0]))
if mode=='min':
sol = optimize.root_scalar(sn_min, args=(var,nmin,range),
bracket=[1, len(var)], method='brentq')
if mode=='max':
sol = optimize.root_scalar(sn_max, args=(var,nmin,range),
bracket=[1, len(var)], method='brentq')
return int(sol.root)
def buttersmooth(x, y, order=2, crit=None, interp=True, cs_np=1000,log=True):
"""
A lowpass butterwooth frequency filter and optional CubicSpline
interpolation to smooth a 1D arr input.
Parameters
----------
x : list or numpy arrary of int or float
x coordinates of arr
y : list or numpy arrary of int or float
y coordinates of arr
order : int
Order of butterworth filter which controls the amplitude of suppresion
of higher frequencies than crit frequency.
crit : float or None
If None the function calculates the critical frequency estimate by
calculating a window size of len(x)/10. Can be used as a 'good first
guess'.
interp : bool
Option to implement CubicSpline interpolation.
cs_np : int
The number of data points to interpolate the arr to. Does nothing if
interp=False
log : bool
Option to print calculated critical frequency if crit=None.
Returns
-------
xout : smoothed x coordinates with len(xout)==len(x) if interp=False,
len(xout)==len(cs_np) if interp=True.
yout : smoothed y coordinates with len(yout)==len(x) if interp=False,
len(yout)==len(cs_np) if interp=True.
"""
if type(crit)==type(None):
crit = np.sum(np.diff(x)[:int(len(x)/10)])
if log:
print('Using Critical frequency of {}'.format(crit))
b, a = signal.butter(order, crit, 'lowpass')
ys = signal.filtfilt(b, a, y)
if interp:
x_cs = np.linspace(np.min(x), np.max(x), cs_np)
cs = CubicSpline(x, ys)
y_cs = cs(x_cs)
xout, yout = x_cs, y_cs
else:
xout, yout = x, ys
return xout, yout
def interpolate_missing_pixels(
image: np.ndarray,
mask: np.ndarray,
method: str = 'nearest',
fill_value: float = np.nan
):
"""
interpolate missing pixels in an image to allow for smoothing
and unsharp mask. Adapted from
https://stackoverflow.com/questions/37662180/interpolate-missing-values-2d-python
:param image: a 2D image
:param mask: a 2D boolean image, True indicates missing values
:param method: interpolation method, one of
'nearest', 'linear', 'cubic'.
:param fill_value: which value to use for filling up data outside the
convex hull of known pixel values.
Default is 0, Has no effect for 'nearest'.
:return: the image with missing values interpolated
"""
h, w = image.shape[:2]
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
known_x = xx[~mask]
known_y = yy[~mask]
known_v = image[~mask]
missing_x = xx[mask]
missing_y = yy[mask]
interp_values = griddata(
(known_x, known_y), known_v, (missing_x, missing_y),
method=method, fill_value=fill_value
)
interp_image = image.copy()
interp_image[missing_y, missing_x] = interp_values
return interp_image
def fill_holes(
image: np.ndarray,
interpolate: bool = True,
fill_value: float = 0.0
):
"""
fill nan/inf values in an image using either interpolation or fill value
:param image: a 2D image
:param mask: a 2D boolean image, True indicates missing values
:param interpolate: flag to use either interpolation or single value fill
:param fill_value: (optional) which value to use for filling missing data
has no effect when interpolate==True
:return: the image with missing values filled
"""
fullimg = np.copy(image)
if interpolate:
fullimg = interpolate_missing_pixels(fullimg, ~np.isfinite(image))
elif type(fill_value)==float:
fullimg[~np.isfinite(image)] = fill_value
else:
fullimg = image
return fullimg
def unsharp(
image: np.ndarray,
sigma: int = 0,
rtn_smooth: bool = False,
log: bool = True
):
"""
Produce astronomy unsharp image as fractional difference
from smoothed image
:param image: a 2D image
:param sigma: (optional) pixel sigma to smooth over, default is 1/10
of the (w+h)/2
:param rtn_smooth: flag to also return smoothed image
:param log: flag to control print statements
:return: the unsharp (and optional smoothed) images from input
"""
if sigma==0:
sigma = int(np.mean(image.shape)/10)
if log:
print('Using sigma of:',sigma)
smoothed = gaussian_filter(image, sigma=sigma)
unsharp = (image - smoothed)/smoothed
if rtn_smooth:
return smoothed, unsharp
else:
return unsharp
def quick_hist(x, y, nbin=1000, cmap='magma', xlim=None, ylim=None, vmin=1, vmax=None):
"""
Produces a simple 2D histogram with colourbar to quickly sutdy the distribution of two variables.
Slightly faster than an numpy implementation and significantly faster than Scipy.
Parameters
----------
x : list or array
x variable
y : list or array
y variable
nbin : int
The number of bins in BOTH x and y.
cmap : string
Matplotlib colormap to display.
xlim : array like shape (2,1)
The x limits of the plot. Default None will show full range of x.
ylim : array like shape (2,1)
The y limits of the plot. Default None will show full range of y.
vmin : float
The limit of minium intensity of the plot. Default will impose a lower limit of 1.
vmax : float
The limit of maximum intensity of the plot. Default None will be
largest bin count of the histogram.
Returns
-------
None
"""
bounds = [[x.min(), x.max()], [y.min(), y.max()]]
extent = [x.min(), x.max(), y.min(), y.max()]
h = histogram2d(x, y, range=bounds, bins=nbin)
if vmax is None:
vmax = h.max()
plt.imshow(h.T, extent=extent, origin='lower',
norm=colors.LogNorm(vmin=vmin, vmax=vmax),
cmap=cmap)
plt.xlabel('x')
plt.ylabel('y')
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
plt.colorbar(label='N')
plt.show()