Note
Go to the end to download the full example code.
Compare JAX and Scipy#
This example compares the JAX and Scipy implementations of the interpolation backend.
import jax.numpy as jnp
import numpy as np
from time import time
import matplotlib.pyplot as plt
from astropy import units as u
from GridPolator import GridSpectra
First let’s get the spectra#
w1 = 5 * u.um
w2 = 12 * u.um
resolving_power = 100
teffs = [2800,2900,3000,3100,3200,3300]
impl_bin = 'rust'
g_jax = GridSpectra.from_vspec(
w1=w1,
w2=w2,
resolving_power=resolving_power,
teffs=teffs,
impl_bin=impl_bin,
impl_interp='jax',
fail_on_missing=False
)
g_scipy = GridSpectra.from_vspec(
w1=w1,
w2=w2,
resolving_power=resolving_power,
teffs=teffs,
impl_bin=impl_bin,
impl_interp='scipy',
fail_on_missing=False
)
Loading Spectra: 0%| | 0/6 [00:00<?, ?it/s]PHOENIX grid for 2800 not found. Downloading...
Downloading teff=2800: 0%| | 0.00/33.9M [00:00<?, ?B/s]
Downloading teff=2800: 0%| | 4.10k/33.9M [00:00<52:18, 10.8kB/s]
Downloading teff=2800: 0%| | 41.0k/33.9M [00:00<06:21, 88.7kB/s]
Downloading teff=2800: 1%| | 246k/33.9M [00:00<01:11, 469kB/s]
Downloading teff=2800: 3%|▎ | 1.06M/33.9M [00:00<00:17, 1.83MB/s]
Downloading teff=2800: 10%|█ | 3.47M/33.9M [00:01<00:04, 6.49MB/s]
Downloading teff=2800: 19%|█▉ | 6.57M/33.9M [00:01<00:02, 10.4MB/s]
Downloading teff=2800: 31%|███▏ | 10.6M/33.9M [00:01<00:01, 17.0MB/s]
Downloading teff=2800: 41%|████ | 13.8M/33.9M [00:01<00:01, 17.4MB/s]
Downloading teff=2800: 52%|█████▏ | 17.8M/33.9M [00:01<00:00, 22.5MB/s]
Downloading teff=2800: 63%|██████▎ | 21.2M/33.9M [00:01<00:00, 21.5MB/s]
Downloading teff=2800: 74%|███████▍ | 25.2M/33.9M [00:01<00:00, 25.6MB/s]
Downloading teff=2800: 85%|████████▍ | 28.7M/33.9M [00:02<00:00, 23.7MB/s]
Downloading teff=2800: 97%|█████████▋| 32.9M/33.9M [00:02<00:00, 27.9MB/s]
Downloading teff=2800: 100%|██████████| 33.9M/33.9M [00:02<00:00, 15.3MB/s]
Loading Spectra: 17%|█▋ | 1/6 [00:05<00:25, 5.12s/it]PHOENIX grid for 2900 not found. Downloading...
Downloading teff=2900: 0%| | 0.00/33.8M [00:00<?, ?B/s]
Downloading teff=2900: 0%| | 4.10k/33.8M [00:00<54:34, 10.3kB/s]
Downloading teff=2900: 0%| | 49.2k/33.8M [00:00<05:29, 103kB/s]
Downloading teff=2900: 1%| | 201k/33.8M [00:00<01:33, 358kB/s]
Downloading teff=2900: 2%|▏ | 840k/33.8M [00:00<00:23, 1.40MB/s]
Downloading teff=2900: 7%|▋ | 2.53M/33.8M [00:01<00:06, 4.59MB/s]
Downloading teff=2900: 13%|█▎ | 4.53M/33.8M [00:01<00:03, 8.10MB/s]
Downloading teff=2900: 23%|██▎ | 7.90M/33.8M [00:01<00:02, 12.0MB/s]
Downloading teff=2900: 32%|███▏ | 10.8M/33.8M [00:01<00:01, 15.9MB/s]
Downloading teff=2900: 41%|████ | 13.8M/33.8M [00:01<00:01, 19.4MB/s]
Downloading teff=2900: 49%|████▉ | 16.7M/33.8M [00:01<00:00, 18.0MB/s]
Downloading teff=2900: 58%|█████▊ | 19.7M/33.8M [00:01<00:00, 20.9MB/s]
Downloading teff=2900: 68%|██████▊ | 23.1M/33.8M [00:01<00:00, 24.0MB/s]
Downloading teff=2900: 77%|███████▋ | 26.2M/33.8M [00:02<00:00, 21.5MB/s]
Downloading teff=2900: 86%|████████▌ | 29.2M/33.8M [00:02<00:00, 23.4MB/s]
Downloading teff=2900: 96%|█████████▌| 32.3M/33.8M [00:02<00:00, 25.5MB/s]
Downloading teff=2900: 100%|██████████| 33.8M/33.8M [00:02<00:00, 14.1MB/s]
Loading Spectra: 33%|███▎ | 2/6 [00:10<00:20, 5.25s/it]
Loading Spectra: 50%|█████ | 3/6 [00:10<00:08, 2.99s/it]
Loading Spectra: 67%|██████▋ | 4/6 [00:11<00:03, 1.93s/it]
Loading Spectra: 83%|████████▎ | 5/6 [00:11<00:01, 1.34s/it]PHOENIX grid for 3300 not found. Downloading...
Downloading teff=3300: 0%| | 0.00/33.6M [00:00<?, ?B/s]
Downloading teff=3300: 0%| | 4.10k/33.6M [00:00<1:05:15, 8.57kB/s]
Downloading teff=3300: 0%| | 45.1k/33.6M [00:00<05:59, 93.1kB/s]
Downloading teff=3300: 0%| | 119k/33.6M [00:00<02:32, 219kB/s]
Downloading teff=3300: 1%|▏ | 492k/33.6M [00:00<00:34, 948kB/s]
Downloading teff=3300: 3%|▎ | 1.06M/33.6M [00:01<00:17, 1.86MB/s]
Downloading teff=3300: 9%|▉ | 2.94M/33.6M [00:01<00:05, 5.75MB/s]
Downloading teff=3300: 18%|█▊ | 6.20M/33.6M [00:01<00:02, 12.4MB/s]
Downloading teff=3300: 26%|██▌ | 8.66M/33.6M [00:01<00:01, 14.6MB/s]
Downloading teff=3300: 37%|███▋ | 12.5M/33.6M [00:01<00:01, 18.5MB/s]
Downloading teff=3300: 49%|████▊ | 16.3M/33.6M [00:01<00:00, 23.4MB/s]
Downloading teff=3300: 59%|█████▉ | 19.8M/33.6M [00:01<00:00, 26.2MB/s]
Downloading teff=3300: 69%|██████▉ | 23.1M/33.6M [00:01<00:00, 28.1MB/s]
Downloading teff=3300: 79%|███████▉ | 26.7M/33.6M [00:01<00:00, 30.2MB/s]
Downloading teff=3300: 90%|█████████ | 30.2M/33.6M [00:02<00:00, 31.7MB/s]
Downloading teff=3300: 100%|█████████▉| 33.5M/33.6M [00:02<00:00, 30.9MB/s]
Downloading teff=3300: 100%|██████████| 33.6M/33.6M [00:02<00:00, 15.7MB/s]
Loading Spectra: 100%|██████████| 6/6 [00:16<00:00, 2.55s/it]
Loading Spectra: 100%|██████████| 6/6 [00:16<00:00, 2.71s/it]
Loading Spectra: 0%| | 0/6 [00:00<?, ?it/s]
Loading Spectra: 17%|█▋ | 1/6 [00:00<00:01, 3.33it/s]
Loading Spectra: 33%|███▎ | 2/6 [00:00<00:01, 3.33it/s]
Loading Spectra: 50%|█████ | 3/6 [00:00<00:00, 3.32it/s]
Loading Spectra: 67%|██████▋ | 4/6 [00:01<00:00, 3.32it/s]
Loading Spectra: 83%|████████▎ | 5/6 [00:01<00:00, 3.31it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00, 3.31it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00, 3.32it/s]
Evaluate a single spectrum#
wl_jnp = jnp.linspace(5.0, 11.2, 100)
wl_np = np.linspace(5.0, 11.2, 100)
params_jnp = (jnp.array([2900.]),)
params_np = (np.array([2900.]),)
start = time()
flux_jnp = g_jax.evaluate(params_jnp, wl_jnp)
end = time()
print(f'JAX took {end - start} seconds')
start = time()
flux_np = g_scipy.evaluate(params_np, wl_np)
end = time()
print(f'Scipy took {end - start} seconds')
JAX took 5.801193714141846 seconds
Scipy took 0.005570888519287109 seconds
Now do 1000 of each#
N = 1000
start = time()
for _ in range(N):
flux_jnp = g_jax.evaluate(params_jnp, wl_jnp)
end = time()
print(f'JAX took {end - start} seconds\n\tthat\'s {(end - start) / 1000} seconds per call')
start = time()
for _ in range(N):
flux_np = g_scipy.evaluate(params_np, wl_np)
end = time()
print (f'Scipy took {end - start} seconds\n\tthat\'s {(end - start) / 1000} seconds per call')
JAX took 0.38318514823913574 seconds
that's 0.0003831851482391357 seconds per call
Scipy took 5.159081697463989 seconds
that's 0.00515908169746399 seconds per call
QED#
The takeaway: The first JAX call is expensive, the rest are cheap. For Scipy everything costs the same. Of course, the costs change depending on the complexity of the grid.
fig, ax = plt.subplots(1,1,figsize=(4,3))
N=3000
g_jax = GridSpectra.from_vspec(
w1=w1,
w2=w2,
resolving_power=resolving_power,
teffs=teffs,
impl_bin=impl_bin,
impl_interp='jax',
fail_on_missing=False
)
g_scipy = GridSpectra.from_vspec(
w1=w1,
w2=w2,
resolving_power=resolving_power,
teffs=teffs,
impl_bin=impl_bin,
impl_interp='scipy',
fail_on_missing=False
)
dt_jax = np.zeros(N)
for i in range(N):
start = time()
flux_jnp = g_jax.evaluate(params_jnp, wl_jnp)
end = time()
dt_jax[i] = end - start
dt_scipy = np.zeros(N)
for i in range(N):
start = time()
flux_np = g_scipy.evaluate(params_np, wl_np)
end = time()
dt_scipy[i] = end - start
x = np.arange(N)
ax.plot(x, np.cumsum(dt_jax), label='JAX',c='#B96EBD')
ax.plot(x, np.cumsum(dt_scipy), label='Scipy',c='#0054A6')
ax.set_xlabel('Iteration')
ax.set_ylabel('Time (s)')
fig.tight_layout()
_=ax.legend()

Loading Spectra: 0%| | 0/6 [00:00<?, ?it/s]
Loading Spectra: 17%|█▋ | 1/6 [00:00<00:01, 3.35it/s]
Loading Spectra: 33%|███▎ | 2/6 [00:00<00:01, 3.34it/s]
Loading Spectra: 50%|█████ | 3/6 [00:00<00:00, 3.35it/s]
Loading Spectra: 67%|██████▋ | 4/6 [00:01<00:00, 3.36it/s]
Loading Spectra: 83%|████████▎ | 5/6 [00:01<00:00, 3.36it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00, 3.36it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00, 3.36it/s]
Loading Spectra: 0%| | 0/6 [00:00<?, ?it/s]
Loading Spectra: 17%|█▋ | 1/6 [00:00<00:01, 3.36it/s]
Loading Spectra: 33%|███▎ | 2/6 [00:00<00:01, 3.35it/s]
Loading Spectra: 50%|█████ | 3/6 [00:00<00:00, 3.35it/s]
Loading Spectra: 67%|██████▋ | 4/6 [00:01<00:00, 3.35it/s]
Loading Spectra: 83%|████████▎ | 5/6 [00:01<00:00, 3.36it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00, 3.36it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00, 3.36it/s]
Total running time of the script: (0 minutes 55.557 seconds)