# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x17678ceb0>]

🎛 The 3-steps workflow 🎛¶

Open in Colab

It is already very useful to be able to execute a JAX function on a dataframe in a single work step and with a single command line thanks to WAX-ML accessors.

The 1-step WAX-ML’s stream API works like that:

<data-container>.stream(...).apply(...)

But this is not optimal because, under the hood, there are mainly three costly steps:

  • (1) (synchronize | data tracing | encode): make the data “JAX ready”

  • (2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.

  • (3) (format): convert data back to pandas/xarray/numpy format.

With the wax.stream primitives, it is quite easy to explicitly split the 1-step workflow into a 3-step workflow.

This will allow the user to have full control over each step and iterate on each one.

It is actually very useful to iterate on step (2), the “calculation step” when you are doing research. You can then take full advantage of the JAX primitives, especially the jit primitive.

Let’s illustrate how to reimplement WAX-ML EWMA yourself with the WAX-ML 3-step workflow.

Imports¶

import numpy as onp
import pandas as pd
import xarray as xr

from wax.accessors import register_wax_accessors
from wax.external.eagerpy import convert_to_tensors
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import unroll

register_wax_accessors()

Performance on big dataframes¶

Generate data¶

T = 1.0e5
N = 1000
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
    onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)

pandas EWMA¶

%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 2.03 s, sys: 167 ms, total: 2.19 s
Wall time: 2.19 s

WAX-ML EWMA¶

%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 1.8 s, sys: 876 ms, total: 2.68 s
Wall time: 2.67 s

It’s a little faster, but not that much faster…

WAX-ML EWMA (without format step)¶

Let’s disable the final formatting step (the output is now in raw JAX format):

%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()
df_ewma_wax_no_format.block_until_ready()
CPU times: user 1.8 s, sys: 1 s, total: 2.8 s
Wall time: 3.55 s
DeviceArray([[-2.9661492e-01, -3.2103235e-01, -9.6144844e-03, ...,
               4.4276214e-01,  1.1568004e+00, -1.0724162e+00],
             [-9.5860586e-02, -2.0366390e-01, -3.7967765e-01, ...,
              -6.1594141e-01,  7.7942121e-01,  1.8111229e-02],
             [ 2.8211397e-01,  1.7749734e-01, -2.0034584e-01, ...,
              -7.6095390e-01,  5.3778893e-01,  3.1442198e-01],
             ...,
             [-1.1917123e-01,  1.3764068e-01,  2.4761766e-01, ...,
               2.0842913e-01,  2.5283977e-01, -1.1205430e-01],
             [-1.0947308e-01,  3.6484647e-01,  2.4164049e-01, ...,
               2.7038181e-01,  2.4539444e-01,  6.2920153e-05],
             [ 4.8219025e-02,  1.5648599e-01,  1.2161890e-01, ...,
               2.0765728e-01,  8.9837506e-02,  1.0943251e-01]],            dtype=float32)
type(df_ewma_wax_no_format)
jaxlib.xla_extension.DeviceArray

Let’s check the device on which the calculation was performed (if you have GPU available, this should be GpuDevice otherwise it will be CpuDevice):

df_ewma_wax_no_format.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>

Now we will see how to break down WAX-ML one-liners <dataset>.ewm(...).mean() or <dataset>.stream(...).apply(...) into 3 steps:

  • a preparation step where we prepare JAX-ready data and functions.

  • a processing step where we execute the JAX program

  • a post-processing step where we format the results in pandas or xarray format.

Generate data (in dataset format)¶

WAX-ML Sream object works on datasets. So let’s transform the DataFrame into a xarray Dataset:

dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")
del dataframe

Step (1) (synchronize | data tracing | encode)¶

In this step, WAX-ML do:

  • “data tracing” : prepare the indices for fast access tin the JAX function access_data

  • synchronize streams if there is multiple ones. This functionality have options : freq, ffills

  • encode and convert data from numpy to JAX: use encoders for datetimes64 and string_ dtypes. Be aware that by default JAX works in float32 (see JAX’s Common Gotchas to work in float64).

We have a function Stream.prepare that implement this Step (1). It prepares a function that wraps the input function with the actual data and indices in a pair of pure functions (TransformedWithState Haiku tuple).

%%time
stream = dataset.wax.stream()
CPU times: user 244 µs, sys: 36 µs, total: 280 µs
Wall time: 287 µs

Define our custom function to be applied on a dict of arrays having the same structure than the original dataset:

def my_ewma_on_dataset(dataset):
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
transform_dataset, jxs = stream.prepare(dataset, my_ewma_on_dataset)

Let’s definite the init parameters and state of the transformation we will apply.

Step (2) (compile | code tracing | execution)¶

In this step we:

  • prepare a pure function (with Haiku’s transform mechanism) Define a “transformation” function which:

    • access to the data

    • apply another transformation, here: EWMA

  • compile it with jax.jit

  • perform code tracing and execution (the last line):

    • Unroll the transformation on “steps” xs (a np.arange vector).

outputs = unroll(transform_dataset)(jxs)
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>

Once it has been compiled and “traced” by JAX, the function is much faster to execute:

%%timeit
outputs = unroll(transform_dataset)(jxs)
_ = outputs.block_until_ready()
619 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This is 3x faster than pandas implementation!

Manually prepare the data and manage the device¶

In order to manage the device on which the computations take place, we need to have even more control over the execution flow. Instead of calling stream.prepare to build the transform_dataset function, we can do it ourselves by :

  • using the stream.trace_dataset function

  • converting the numpy data in jax ourself

  • puting the data on the device we want.

np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")

We explicitly set data on CPUs (the is not needed if you only have CPUs):

from jax.tree_util import tree_leaves, tree_map

cpus = jax.devices("cpu")
jnp_data, jnp_index, jxs = tree_map(
    lambda x: jax.device_put(x, cpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to CPU device.")
data copied to CPU device.

We have now “JAX-ready” data for later fast access.

Let’s define the transformation that wrap the actual data and indices in a pair of pure functions:

@jax.jit
@unroll
def transform_dataset(step):
    dataset = tree_access_data(jnp_data, jnp_index, step)
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])

And we can call it as before:

%%time
outputs = transform_dataset(jxs)
_ = outputs.block_until_ready()
CPU times: user 1.72 s, sys: 1.06 s, total: 2.78 s
Wall time: 3.31 s
%%time
outputs = transform_dataset(jxs)
_ = outputs.block_until_ready()
CPU times: user 546 ms, sys: 52.8 ms, total: 599 ms
Wall time: 567 ms
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>

Step(3) (format)¶

Let’s come back to pandas/xarray:

%%time
y = format_dataframe(
    dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)
CPU times: user 27.9 ms, sys: 70.4 ms, total: 98.2 ms
Wall time: 144 ms

It’s quite slow (see WEP3 enhancement proposal).

GPU execution¶

Let’s look with execution on GPU

try:
    gpus = jax.devices("gpu")
    jnp_data, jnp_index, jxs = tree_map(
        lambda x: jax.device_put(x, gpus[0]), (jnp_data, jnp_index, jxs)
    )
    print("data copied to GPU device.")
    GPU_AVAILABLE = True
except RuntimeError as err:
    print(err)
    GPU_AVAILABLE = False
Requested backend gpu, but it failed to initialize: Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host

Let’s check that our data is on the GPUs:

tree_leaves(jnp_data)[0].device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
tree_leaves(jnp_index)[0].device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
jxs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
%%time
if GPU_AVAILABLE:
    outputs = unroll(transform_dataset)(jxs)
CPU times: user 3 µs, sys: 2 µs, total: 5 µs
Wall time: 11.2 µs

Let’s redefine our function transform_dataset by explicitly specify to jax.jit the device option.

%%time
from functools import partial

if GPU_AVAILABLE:

    @partial(jax.jit, device=gpus[0])
    @unroll
    def transform_dataset(step):
        dataset = tree_access_data(jnp_data, jnp_index, step)
        return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])

    outputs = transform_dataset(jxs)
CPU times: user 8 µs, sys: 2 µs, total: 10 µs
Wall time: 15 µs
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
%%timeit
if GPU_AVAILABLE:
    outputs = unroll(transform_dataset)(jxs)
    _ = outputs.block_until_ready()
12 ns ± 0.0839 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)