# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
jax backend gpu
[GpuDevice(id=0, process_index=0)]
🎛 The 3-steps workflow 🎛¶
It is already very useful to be able to execute a JAX function on a dataframe in a single work step and with a single command line thanks to WAX-ML accessors.
The 1-step WAX-ML’s stream API works like that:
<data-container>.stream(...).apply(...)
But this is not optimal because, under the hood, there are mainly three costly steps:
(1) (synchronize | data tracing | encode): make the data “JAX ready”
(2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.
(3) (format): convert data back to pandas/xarray/numpy format.
With the wax.stream
primitives, it is quite easy to explicitly split the 1-step workflow
into a 3-step workflow.
This will allow the user to have full control over each step and iterate on each one.
It is actually very useful to iterate on step (2), the “calculation step” when
you are doing research.
You can then take full advantage of the JAX primitives, especially the jit
primitive.
Let’s illustrate how to reimplement WAX-ML EWMA yourself with the WAX-ML 3-step workflow.
Imports¶
import haiku as hk
import numpy as onp
import pandas as pd
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.compile import jit_init_apply
from wax.external.eagerpy import convert_to_tensors
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import dynamic_unroll
register_wax_accessors()
Performance on big dataframes¶
Generate data¶
T = 1.0e5
N = 1000
%%time
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)
CPU times: user 3.87 s, sys: 143 ms, total: 4.01 s
Wall time: 3.99 s
Pandas EWMA¶
%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 1.95 s, sys: 480 ms, total: 2.42 s
Wall time: 2.42 s
WAX-ML EWMA¶
%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 2.21 s, sys: 679 ms, total: 2.89 s
Wall time: 3.11 s
It’s a little faster, but not that much faster…
WAX-ML EWMA (without format step)¶
Let’s disable the final formatting step (the output is now in raw JAX format):
%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()
CPU times: user 1.77 s, sys: 563 ms, total: 2.33 s
Wall time: 2.22 s
type(df_ewma_wax_no_format)
jaxlib.xla_extension.DeviceArray
Let’s check the device on which the calculation was performed (if you have GPU available, this should be GpuDevice
otherwise it will be CpuDevice
):
df_ewma_wax_no_format.device()
GpuDevice(id=0, process_index=0)
That’s better! In fact (see below) there is a performance problem in the final formatting step. See WEP3 for a proposal to improve the formatting step.
Generate data (in dataset format)¶
WAX-ML Sream
object works on datasets.
So let’s transform the DataFrame
into a xarray Dataset
:
dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")
Step (1) (synchronize | data tracing | encode)¶
In this step, WAX-ML do:
“data tracing” : prepare the indices for fast access tin the JAX function
access_data
synchronize streams if there is multiple ones. This functionality have options :
freq
,ffills
encode and convert data from numpy to JAX: use encoders for
datetimes64
andstring_
dtypes. Be aware that by default JAX works in float32 (see JAX’s Common Gotchas to work in float64).
We have a function Stream.prepare
that implement this Step (1).
It prepares a function that wraps the input function with the actual data and indices
in a pair of pure functions (TransformedWithState
Haiku tuple).
%%time
stream = dataframe.wax.stream()
CPU times: user 20 µs, sys: 5 µs, total: 25 µs
Wall time: 29.8 µs
Define our custom function to be applied on a dict of arrays having the same structure than the original dataset:
def my_ewma_on_dataset(dataset):
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
transform_dataset, jxs = stream.prepare(dataset, my_ewma_on_dataset)
Let’s definite the init parameters and state of the transformation we will apply.
Init params and state¶
from wax.unroll import init_params_state
rng = jax.random.PRNGKey(42)
params, state = init_params_state(transform_dataset, rng, jxs)
params
FlatMapping({'ewma': FlatMapping({'alpha': DeviceArray(0.1, dtype=float32)})})
assert state["ewma"]["count"].shape == (N,)
assert state["ewma"]["mean"].shape == (N,)
Step (2) (compile | code tracing | execution)¶
In this step we:
prepare a pure function (with Haiku’s transform mechanism) Define a “transformation” function which:
access to the data
apply another transformation, here: EWMA
compile it with
jax.jit
perform code tracing and execution (the last line):
Unroll the transformation on “steps”
xs
(anp.arange
vector).
rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
outputs.device()
GpuDevice(id=0, process_index=0)
Once it has been compiled and “traced” by JAX, the function is much faster to execute:
%%timeit
outputs, _ = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
1 loop, best of 5: 1.18 s per loop
%%time
outputs, _ = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
CPU times: user 1.29 s, sys: 6.21 ms, total: 1.29 s
Wall time: 1.24 s
This is 3x faster than pandas implementation!
(The 3x factor is obtained by measuring the execution with %timeit. We don’t know why, but when executing a code cell once at a time, then the execution time can vary a lot and we can observe some executions with a speed-up of 100x).
Manually prepare the data and manage the device¶
In order to manage the device on which the computations take place,
we need to have even more control over the execution flow.
Instead of calling stream.prepare
to build the transform_dataset
function,
we can do it ourselves by :
using the
stream.trace_dataset
functionconverting the numpy data in jax ourself
puting the data on the device we want.
np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:2983: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "asarray")
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:2983: UserWarning: Explicitly requested dtype int64 requested in asarray is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "asarray")
We explicitly set data on CPUs (the is not needed if you only have CPUs):
from jax.tree_util import tree_leaves, tree_map
cpus = jax.devices("cpu")
jnp_data, jnp_index, jxs = tree_map(
lambda x: jax.device_put(x, cpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to CPU device.")
data copied to CPU device.
We have now “JAX-ready” data for later fast access.
Let’s define the transformation that wrap the actual data and indices in a pair of pure functions:
%%time
@jit_init_apply
@hk.transform_with_state
def transform_dataset(step):
dataset = tree_access_data(jnp_data, jnp_index, step)
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
CPU times: user 277 µs, sys: 0 ns, total: 277 µs
Wall time: 289 µs
And we can call it as before:
%%time
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 373 ms, sys: 163 ms, total: 537 ms
Wall time: 533 ms
outputs.device()
CpuDevice(id=0)
Step(3) (format)¶
Let’s come back to pandas/xarray:
%%time
y = format_dataframe(
dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)
CPU times: user 2.14 s, sys: 9.66 ms, total: 2.15 s
Wall time: 2.13 s
It’s quite slow (see WEP3 enhancement proposal).
GPU execution¶
Let’s look with execution on GPU
try:
gpus = jax.devices("gpu")
jnp_data, jnp_index, jxs = tree_map(
lambda x: jax.device_put(x, gpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to GPU device.")
GPU_AVAILABLE = True
except RuntimeError as err:
print(err)
GPU_AVAILABLE = False
data copied to GPU device.
Let’s check that our data is on the GPUs:
tree_leaves(jnp_data)[0].device()
GpuDevice(id=0, process_index=0)
tree_leaves(jnp_index)[0].device()
GpuDevice(id=0, process_index=0)
jxs.device()
GpuDevice(id=0, process_index=0)
%%time
if GPU_AVAILABLE:
rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 1.7 s, sys: 555 ms, total: 2.25 s
Wall time: 2.14 s
Let’s redefine our function transform_dataset
by explicitly specify to jax.jit
the device
option.
%%time
if GPU_AVAILABLE:
@hk.transform_with_state
def transform_dataset(step):
dataset = tree_access_data(jnp_data, jnp_index, step)
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
transform_dataset = type(transform_dataset)(
transform_dataset.init, jax.jit(transform_dataset.apply, device=gpus[0])
)
rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 1.48 s, sys: 17.6 ms, total: 1.5 s
Wall time: 2.13 s
outputs.device()
GpuDevice(id=0, process_index=0)
%%timeit
if GPU_AVAILABLE:
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
1 loop, best of 5: 1.18 s per loop