Create WebSky background catalog with Dask¶
[1]:
import h5pickle as h5py
import numpy as np
import healpy as hp
import matplotlib.pyplot as plt
from tqdm import tqdm
[2]:
%load_ext jupyter_ai
%ai list gemini
[2]:
Provider |
Environment variable |
Set? |
Models |
|---|---|---|---|
| gemini | GOOGLE_API_KEY | ✅ |
gemini:gemini-1.0-progemini:gemini-1.0-pro-001gemini:gemini-1.0-pro-latestgemini:gemini-1.0-pro-vision-latestgemini:gemini-progemini:gemini-pro-vision
|
[3]:
#%%ai gemini:gemini-pro -f code
[4]:
import healpy as hp
hp.version
[4]:
<module 'healpy.version' from '/global/common/software/cmb/zonca/conda/pycmb/lib/python3.10/site-packages/healpy/version.py'>
[5]:
%alias_magic gm ai -p "gemini:gemini-pro -f code"
Created `%gm` as an alias for `%ai gemini:gemini-pro -f code`.
Created `%%gm` as an alias for `%%ai gemini:gemini-pro -f code`.
[6]:
import os
num_threads = 128
os.environ["OMP_NUM_THREADS"] = "1"
[7]:
cutoff_flux = 10000
[8]:
output_filename = "/pscratch/sd/z/zonca/websky_full_catalog.h5"
[9]:
plot = False
[10]:
cd /global/cfs/cdirs/sobs/www/users/Radio_WebSky/matched_catalogs_2
/global/cfs/cdirs/sobs/www/users/Radio_WebSky/matched_catalogs_2
[11]:
%ls
catalog_100.0.h5 catalog_232.0.h5 catalog_353.0.h5 catalog_643.0.h5
catalog_111.0.h5 catalog_24.5.h5 catalog_375.0.h5 catalog_67.8.h5
catalog_129.0.h5 catalog_256.0.h5 catalog_409.0.h5 catalog_70.0.h5
catalog_143.0.h5 catalog_27.3.h5 catalog_41.7.h5 catalog_729.0.h5
catalog_153.0.h5 catalog_275.0.h5 catalog_44.0.h5 catalog_73.7.h5
catalog_164.0.h5 catalog_294.0.h5 catalog_467.0.h5 catalog_79.6.h5
catalog_18.7.h5 catalog_30.0.h5 catalog_47.4.h5 catalog_817.0.h5
catalog_189.0.h5 catalog_306.0.h5 catalog_525.0.h5 catalog_857.0.h5
catalog_21.6.h5 catalog_314.0.h5 catalog_545.0.h5 catalog_90.2.h5
catalog_210.0.h5 catalog_340.0.h5 catalog_584.0.h5 catalog_906.0.h5
catalog_217.0.h5 catalog_35.9.h5 catalog_63.9.h5 flux_coeff.h5
[12]:
freqs = [
"18.7",
"24.5",
"44.0",
"70.0",
"100.0",
"143.0",
"217.0",
"353.0",
"545.0",
"643.0",
"729.0",
"857.0",
"906.0",
]
[13]:
freqs_array = np.array(list(map(float, freqs)))
[14]:
cat = h5py.File("catalog_100.0.h5", "r")
[18]:
cat["theta"][:4]
[18]:
array([1.64009452, 1.64009452, 1.64009452, 1.69043016], dtype='>f8')
[15]:
#%%ai gemini:gemini-pro -f code
#find the fields in a h5py File
[16]:
import dask.array as da
There are no metadata in the file, I guess fluxes are in Jy
[17]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=num_threads, threads_per_worker=1, processes=True)
client = Client(cluster)
[18]:
import pandas as pd
import xarray as xr
[19]:
field = 'flux'
[20]:
arrays = [da.from_array(h5py.File(f"catalog_{freq}.h5", "r")[field], chunks=1000000) for freq in freqs]
[21]:
flux = da.stack(arrays, axis=0)
[22]:
flux = flux.rechunk(chunks=(13, 1000000))
[23]:
flux
[23]:
|
||||||||||||||||
[24]:
# Only keep sources below cutoff
cutoff_flux_Jy = 1e-3
# flux = flux[:, flux[4, :] < cutoff_flux_Jy]
[25]:
# flux.compute_chunk_sizes()
[26]:
from numba import njit
@njit
def model(freq, a, b, c, d, e):
log_freq = np.log(freq)
return a * log_freq**4 + b * log_freq**3 + c * log_freq**2 + d * log_freq + e
[27]:
from scipy.optimize import curve_fit
[28]:
curve_fit(model, freqs_array, flux[:,0])[0]
[28]:
array([ 5.37899652e-09, -1.29664725e-07, 1.20804354e-06, -5.22231671e-06,
9.00274650e-06])
[29]:
def run_curve_fit(flux):
return curve_fit(model, freqs_array, flux)[0]
coeff = da.apply_along_axis(run_curve_fit, 0, flux)
[30]:
%%time
coeff[:,:10].compute()
CPU times: user 2min 2s, sys: 1min 28s, total: 3min 31s
Wall time: 5min 4s
[30]:
array([[ 5.37899652e-09, 6.42822511e-09, 3.31959319e-09,
1.38862575e-08, 3.78195705e-09, 1.08966713e-08,
8.68670067e-08, 3.64081373e-09, 3.66938830e-09,
3.02356547e-09],
[-1.29664725e-07, -1.25727514e-07, -6.53823561e-08,
-2.67337421e-07, -8.86265294e-08, -2.11607996e-07,
-1.96845594e-06, -8.47114797e-08, -7.11052032e-08,
-6.49361314e-08],
[ 1.20804354e-06, 8.85551144e-07, 4.83957483e-07,
1.85677682e-06, 8.11717509e-07, 1.47786561e-06,
1.66899947e-05, 7.71488341e-07, 5.08938118e-07,
5.45560777e-07],
[-5.22231671e-06, -2.61195734e-06, -1.68793603e-06,
-5.55460268e-06, -3.51684601e-06, -4.35226803e-06,
-6.28894194e-05, -3.33517575e-06, -1.66773519e-06,
-2.22563537e-06],
[ 9.00274650e-06, 2.94062995e-06, 2.71462806e-06,
6.94587053e-06, 6.25448262e-06, 5.04064302e-06,
8.91598585e-05, 5.95275131e-06, 2.49203039e-06,
3.97960150e-06]])
[31]:
import xarray as xr
[32]:
coeff.shape
[32]:
(5, 281756376)
[33]:
xr_flux = xr.DataArray(
data=coeff,
coords={"power": np.arange(4, -1, -1), "index": da.arange(coeff.shape[1])},
name="flux",
)
[34]:
xr_flux
[34]:
<xarray.DataArray 'flux' (power: 5, index: 281756376)> dask.array<run_curve_fit-along-axis, shape=(5, 281756376), dtype=float64, chunksize=(5, 1000000), chunktype=numpy.ndarray> Coordinates: * power (power) int64 4 3 2 1 0 * index (index) int64 0 1 2 3 4 ... 281756372 281756373 281756374 281756375
[35]:
xr_flux.to_netcdf(
f"/pscratch/sd/z/zonca/websky_full_catalog_{field}.h5", format="NETCDF4") # requires netcdf4 package
[36]:
xr_flux
[36]:
<xarray.DataArray 'flux' (power: 5, index: 281756376)> dask.array<run_curve_fit-along-axis, shape=(5, 281756376), dtype=float64, chunksize=(5, 1000000), chunktype=numpy.ndarray> Coordinates: * power (power) int64 4 3 2 1 0 * index (index) int64 0 1 2 3 4 ... 281756372 281756373 281756374 281756375
[ ]: