Compare JAX and Scipy#

This example compares the JAX and Scipy implementations of the interpolation backend.

import jax.numpy as jnp
import numpy as np
from time import time
import matplotlib.pyplot as plt
from astropy import units as u

from GridPolator import GridSpectra

First let’s get the spectra#

w1 = 5 * u.um
w2 = 12 * u.um
resolving_power = 100
teffs = [2800,2900,3000,3100,3200,3300]
impl_bin = 'rust'

g_jax = GridSpectra.from_vspec(
    w1=w1,
    w2=w2,
    resolving_power=resolving_power,
    teffs=teffs,
    impl_bin=impl_bin,
    impl_interp='jax',
    fail_on_missing=False
)
g_scipy = GridSpectra.from_vspec(
    w1=w1,
    w2=w2,
    resolving_power=resolving_power,
    teffs=teffs,
    impl_bin=impl_bin,
    impl_interp='scipy',
    fail_on_missing=False
)
Loading Spectra:   0%|          | 0/6 [00:00<?, ?it/s]PHOENIX grid for 2800 not found. Downloading...


Downloading teff=2800:   0%|          | 0.00/33.9M [00:00<?, ?B/s]

Downloading teff=2800:   0%|          | 4.10k/33.9M [00:00<52:18, 10.8kB/s]

Downloading teff=2800:   0%|          | 41.0k/33.9M [00:00<06:21, 88.7kB/s]

Downloading teff=2800:   1%|          | 246k/33.9M [00:00<01:11, 469kB/s]

Downloading teff=2800:   3%|▎         | 1.06M/33.9M [00:00<00:17, 1.83MB/s]

Downloading teff=2800:  10%|█         | 3.47M/33.9M [00:01<00:04, 6.49MB/s]

Downloading teff=2800:  19%|█▉        | 6.57M/33.9M [00:01<00:02, 10.4MB/s]

Downloading teff=2800:  31%|███▏      | 10.6M/33.9M [00:01<00:01, 17.0MB/s]

Downloading teff=2800:  41%|████      | 13.8M/33.9M [00:01<00:01, 17.4MB/s]

Downloading teff=2800:  52%|█████▏    | 17.8M/33.9M [00:01<00:00, 22.5MB/s]

Downloading teff=2800:  63%|██████▎   | 21.2M/33.9M [00:01<00:00, 21.5MB/s]

Downloading teff=2800:  74%|███████▍  | 25.2M/33.9M [00:01<00:00, 25.6MB/s]

Downloading teff=2800:  85%|████████▍ | 28.7M/33.9M [00:02<00:00, 23.7MB/s]

Downloading teff=2800:  97%|█████████▋| 32.9M/33.9M [00:02<00:00, 27.9MB/s]
Downloading teff=2800: 100%|██████████| 33.9M/33.9M [00:02<00:00, 15.3MB/s]

Loading Spectra:  17%|█▋        | 1/6 [00:05<00:25,  5.12s/it]PHOENIX grid for 2900 not found. Downloading...


Downloading teff=2900:   0%|          | 0.00/33.8M [00:00<?, ?B/s]

Downloading teff=2900:   0%|          | 4.10k/33.8M [00:00<54:34, 10.3kB/s]

Downloading teff=2900:   0%|          | 49.2k/33.8M [00:00<05:29, 103kB/s]

Downloading teff=2900:   1%|          | 201k/33.8M [00:00<01:33, 358kB/s]

Downloading teff=2900:   2%|▏         | 840k/33.8M [00:00<00:23, 1.40MB/s]

Downloading teff=2900:   7%|▋         | 2.53M/33.8M [00:01<00:06, 4.59MB/s]

Downloading teff=2900:  13%|█▎        | 4.53M/33.8M [00:01<00:03, 8.10MB/s]

Downloading teff=2900:  23%|██▎       | 7.90M/33.8M [00:01<00:02, 12.0MB/s]

Downloading teff=2900:  32%|███▏      | 10.8M/33.8M [00:01<00:01, 15.9MB/s]

Downloading teff=2900:  41%|████      | 13.8M/33.8M [00:01<00:01, 19.4MB/s]

Downloading teff=2900:  49%|████▉     | 16.7M/33.8M [00:01<00:00, 18.0MB/s]

Downloading teff=2900:  58%|█████▊    | 19.7M/33.8M [00:01<00:00, 20.9MB/s]

Downloading teff=2900:  68%|██████▊   | 23.1M/33.8M [00:01<00:00, 24.0MB/s]

Downloading teff=2900:  77%|███████▋  | 26.2M/33.8M [00:02<00:00, 21.5MB/s]

Downloading teff=2900:  86%|████████▌ | 29.2M/33.8M [00:02<00:00, 23.4MB/s]

Downloading teff=2900:  96%|█████████▌| 32.3M/33.8M [00:02<00:00, 25.5MB/s]
Downloading teff=2900: 100%|██████████| 33.8M/33.8M [00:02<00:00, 14.1MB/s]

Loading Spectra:  33%|███▎      | 2/6 [00:10<00:20,  5.25s/it]
Loading Spectra:  50%|█████     | 3/6 [00:10<00:08,  2.99s/it]
Loading Spectra:  67%|██████▋   | 4/6 [00:11<00:03,  1.93s/it]
Loading Spectra:  83%|████████▎ | 5/6 [00:11<00:01,  1.34s/it]PHOENIX grid for 3300 not found. Downloading...


Downloading teff=3300:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

Downloading teff=3300:   0%|          | 4.10k/33.6M [00:00<1:05:15, 8.57kB/s]

Downloading teff=3300:   0%|          | 45.1k/33.6M [00:00<05:59, 93.1kB/s]

Downloading teff=3300:   0%|          | 119k/33.6M [00:00<02:32, 219kB/s]

Downloading teff=3300:   1%|▏         | 492k/33.6M [00:00<00:34, 948kB/s]

Downloading teff=3300:   3%|▎         | 1.06M/33.6M [00:01<00:17, 1.86MB/s]

Downloading teff=3300:   9%|▉         | 2.94M/33.6M [00:01<00:05, 5.75MB/s]

Downloading teff=3300:  18%|█▊        | 6.20M/33.6M [00:01<00:02, 12.4MB/s]

Downloading teff=3300:  26%|██▌       | 8.66M/33.6M [00:01<00:01, 14.6MB/s]

Downloading teff=3300:  37%|███▋      | 12.5M/33.6M [00:01<00:01, 18.5MB/s]

Downloading teff=3300:  49%|████▊     | 16.3M/33.6M [00:01<00:00, 23.4MB/s]

Downloading teff=3300:  59%|█████▉    | 19.8M/33.6M [00:01<00:00, 26.2MB/s]

Downloading teff=3300:  69%|██████▉   | 23.1M/33.6M [00:01<00:00, 28.1MB/s]

Downloading teff=3300:  79%|███████▉  | 26.7M/33.6M [00:01<00:00, 30.2MB/s]

Downloading teff=3300:  90%|█████████ | 30.2M/33.6M [00:02<00:00, 31.7MB/s]

Downloading teff=3300: 100%|█████████▉| 33.5M/33.6M [00:02<00:00, 30.9MB/s]
Downloading teff=3300: 100%|██████████| 33.6M/33.6M [00:02<00:00, 15.7MB/s]

Loading Spectra: 100%|██████████| 6/6 [00:16<00:00,  2.55s/it]
Loading Spectra: 100%|██████████| 6/6 [00:16<00:00,  2.71s/it]

Loading Spectra:   0%|          | 0/6 [00:00<?, ?it/s]
Loading Spectra:  17%|█▋        | 1/6 [00:00<00:01,  3.33it/s]
Loading Spectra:  33%|███▎      | 2/6 [00:00<00:01,  3.33it/s]
Loading Spectra:  50%|█████     | 3/6 [00:00<00:00,  3.32it/s]
Loading Spectra:  67%|██████▋   | 4/6 [00:01<00:00,  3.32it/s]
Loading Spectra:  83%|████████▎ | 5/6 [00:01<00:00,  3.31it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00,  3.31it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00,  3.32it/s]

Evaluate a single spectrum#

wl_jnp = jnp.linspace(5.0, 11.2, 100)
wl_np = np.linspace(5.0, 11.2, 100)
params_jnp = (jnp.array([2900.]),)
params_np = (np.array([2900.]),)

start = time()
flux_jnp = g_jax.evaluate(params_jnp, wl_jnp)
end = time()
print(f'JAX took {end - start} seconds')

start = time()
flux_np = g_scipy.evaluate(params_np, wl_np)
end = time()
print(f'Scipy took {end - start} seconds')
JAX took 5.801193714141846 seconds
Scipy took 0.005570888519287109 seconds

Now do 1000 of each#

N = 1000
start = time()
for _ in range(N):
    flux_jnp = g_jax.evaluate(params_jnp, wl_jnp)
end = time()
print(f'JAX took {end - start} seconds\n\tthat\'s {(end - start) / 1000} seconds per call')

start = time()
for _ in range(N):
    flux_np = g_scipy.evaluate(params_np, wl_np)
end = time()
print (f'Scipy took {end - start} seconds\n\tthat\'s {(end - start) / 1000} seconds per call')
JAX took 0.38318514823913574 seconds
        that's 0.0003831851482391357 seconds per call
Scipy took 5.159081697463989 seconds
        that's 0.00515908169746399 seconds per call

QED#

The takeaway: The first JAX call is expensive, the rest are cheap. For Scipy everything costs the same. Of course, the costs change depending on the complexity of the grid.

fig, ax  = plt.subplots(1,1,figsize=(4,3))

N=3000

g_jax = GridSpectra.from_vspec(
    w1=w1,
    w2=w2,
    resolving_power=resolving_power,
    teffs=teffs,
    impl_bin=impl_bin,
    impl_interp='jax',
    fail_on_missing=False
)
g_scipy = GridSpectra.from_vspec(
    w1=w1,
    w2=w2,
    resolving_power=resolving_power,
    teffs=teffs,
    impl_bin=impl_bin,
    impl_interp='scipy',
    fail_on_missing=False
)

dt_jax = np.zeros(N)

for i in range(N):
    start = time()
    flux_jnp = g_jax.evaluate(params_jnp, wl_jnp)
    end = time()
    dt_jax[i] = end - start

dt_scipy = np.zeros(N)

for i in range(N):
    start = time()
    flux_np = g_scipy.evaluate(params_np, wl_np)
    end = time()
    dt_scipy[i] = end - start

x = np.arange(N)

ax.plot(x, np.cumsum(dt_jax), label='JAX',c='#B96EBD')
ax.plot(x, np.cumsum(dt_scipy), label='Scipy',c='#0054A6')
ax.set_xlabel('Iteration')
ax.set_ylabel('Time (s)')
fig.tight_layout()
_=ax.legend()
plot compare jax scipy
Loading Spectra:   0%|          | 0/6 [00:00<?, ?it/s]
Loading Spectra:  17%|█▋        | 1/6 [00:00<00:01,  3.35it/s]
Loading Spectra:  33%|███▎      | 2/6 [00:00<00:01,  3.34it/s]
Loading Spectra:  50%|█████     | 3/6 [00:00<00:00,  3.35it/s]
Loading Spectra:  67%|██████▋   | 4/6 [00:01<00:00,  3.36it/s]
Loading Spectra:  83%|████████▎ | 5/6 [00:01<00:00,  3.36it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00,  3.36it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00,  3.36it/s]

Loading Spectra:   0%|          | 0/6 [00:00<?, ?it/s]
Loading Spectra:  17%|█▋        | 1/6 [00:00<00:01,  3.36it/s]
Loading Spectra:  33%|███▎      | 2/6 [00:00<00:01,  3.35it/s]
Loading Spectra:  50%|█████     | 3/6 [00:00<00:00,  3.35it/s]
Loading Spectra:  67%|██████▋   | 4/6 [00:01<00:00,  3.35it/s]
Loading Spectra:  83%|████████▎ | 5/6 [00:01<00:00,  3.36it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00,  3.36it/s]
Loading Spectra: 100%|██████████| 6/6 [00:01<00:00,  3.36it/s]

Total running time of the script: (0 minutes 55.557 seconds)

Gallery generated by Sphinx-Gallery