🔄 Online learning for time series prediction 🔄¶

In [1], the authors develop an online learning method to predict time-series generated by and ARMA (autoregressive moving average) model.

They develop an effective online learning algorithm based on an improper learning approach which consists to use an AR model for prediction with sufficiently long horizon, together with an online update of the prediction model parameters using either an “online Newton” algorithm [2] or a stochastic gradient descent algorithm.

This effective approach adresses the prediction problem, without assuming that the noise terms are Gaussian, identically distributed or even independent. Furthermore, they show that their algorithm’s performances asymptotically approaches the performance of the best ARMA model in hindsight.

We use WAX-ML to reproduce their empirical results.

We first focus in the reproduction of the “setting 1” for sanity checks and show how to setup a training environment with WAX-ML to study improper learning of ARMA time-series models.

We then study the behavior the method with different optimizers in the non-stationary environements proposed in [1] (settings 2, 3, and 4).

We use the following modules from WAX-ML:

  • ARMA : to generate a modeled time-series

  • SNARIMAX : to adaptively learn to predict the generated time-series.

  • GymFeedback: To setup a training loop.

  • VMap: to add batch dimensions to the training loop

  • optim.newton: a newton algorithm as used in [1] and developped in [2]. It extends optax optimizers.

  • OnlineOptimizer: A wrapper for a model with loss and and optimizer for online learning.

References¶

[1] Anava, O., Hazan, E., Mannor, S. and Shamir, O., 2013, June. Online learning for time series prediction. In Conference on learning theory (pp. 172-184)

[2] Hazan, E., Agarwal, A. and Kale, S., 2007. Logarithmic regret algorithms for online convex optimization. Machine Learning, 69(2-3), pp.169-192

%pylab inline
%load_ext autoreload
%autoreload 2
Populating the interactive namespace from numpy and matplotlib
from typing import Any, NamedTuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
import optax
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

from wax.modules import (
    ARMA,
    SNARIMAX,
    GymFeedback,
    Lag,
    OnlineOptimizer,
    UpdateParams,
    VMap,
)
from wax.optim import newton
from wax.unroll import unroll_transform_with_state
T = 10000
N_BATCH = 20
N_STEP_SIZE = 10
N_EPS = 5

ARMA¶

Let’s generate a sample of the “setting 1” of [1]:

alpha = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
beta = jnp.array([0.3, -0.2])

rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
sim = unroll_transform_with_state(lambda eps: ARMA(alpha, beta)(eps))
params, state = sim.init(rng, eps)
y, state = sim.apply(params, state, rng, eps)
pd.Series(y).plot()
../_images/07_Online_Time_Series_Prediction_6_0.png

SNARIMAX¶

Let’s setup an online model to try to learn the dynamic of the time-series.

First let’s run the filter with it’s initial random weights.

def predict(y, X=None):
    return SNARIMAX(10, 0, 0)(y, X)


sim = unroll_transform_with_state(predict)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (10,))
params, state = sim.init(rng, y)
(y_pred, _), state = sim.apply(params, state, rng, y)

pd.Series((y - y_pred)).plot()
../_images/07_Online_Time_Series_Prediction_10_0.png
def evaluate(y_pred, y):
    return jnp.linalg.norm(y_pred - y) ** 2, {}


def lag(shift=1):
    def __call__(y, X=None):
        yp = Lag(shift)(y)
        Xp = Lag(shift)(X) if X is not None else None
        return yp, Xp

    return __call__


def predict_and_evaluate(y, X=None):
    # predict with lagged data
    y_pred, pred_info = predict(*lag(1)(y, X))

    # evaluate loss with actual data
    loss, loss_info = evaluate(y_pred, y)

    return loss, dict(pred_info=pred_info, loss_info=loss_info)
sim = unroll_transform_with_state(predict_and_evaluate)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, y)
(loss, _), state = sim.apply(params, state, rng, y)

pd.Series(loss).expanding().mean().plot()
../_images/07_Online_Time_Series_Prediction_12_0.png

Since the model is not trained and the coefficient of the SNARIMAX filter a choosen randomly, the loss may diverge

params
FlatMapping({
  'snarimax/~/linear': FlatMapping({
                         'w': DeviceArray([[-0.47023755],
                                           [ 0.07070494],
                                           [-0.07388116],
                                           [ 0.13453043],
                                           [ 0.07728617],
                                           [ 0.0851969 ],
                                           [-0.03324771],
                                           [ 0.17115903],
                                           [ 0.1023274 ],
                                           [-0.4804019 ]], dtype=float32),
                         'b': DeviceArray([0.], dtype=float32),
                       }),
})

Learn¶

To learn the model parameters we will use the OnlineOptimizer of WAX-ML.

Setup projection¶

We can setup a projection for the parameters:

def project_params(params, opt_state=None):
    w = params["snarimax/~/linear"]["w"]
    w = jnp.clip(w, -1, 1)
    params["snarimax/~/linear"]["w"] = w
    return params
project_params(hk.data_structures.to_mutable_dict(params))
{'snarimax/~/linear': {'w': DeviceArray([[-0.47023755],
               [ 0.07070494],
               [-0.07388116],
               [ 0.13453043],
               [ 0.07728617],
               [ 0.0851969 ],
               [-0.03324771],
               [ 0.17115903],
               [ 0.1023274 ],
               [-0.4804019 ]], dtype=float32),
  'b': DeviceArray([0.], dtype=float32)}}
def learn(y, X=None):
    optim_res = OnlineOptimizer(
        predict_and_evaluate, optax.sgd(1.0e-2), project_params=project_params
    )(y, X)
    return optim_res

Let’s train:

sim = unroll_transform_with_state(learn)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, eps)
optim_res, state = sim.apply(params, state, rng, eps)

pd.Series(optim_res.loss).expanding().mean().plot()
../_images/07_Online_Time_Series_Prediction_23_0.png

Let’s look at the latest weights:

jax.tree_map(lambda x: x[-1], optim_res.updated_params)
FlatMapping({
  'snarimax/~/linear': FlatMapping({
                         'b': DeviceArray([-0.14335798], dtype=float32),
                         'w': DeviceArray([[-0.06957585],
                                           [ 0.11199051],
                                           [-0.14902674],
                                           [-0.02122167],
                                           [-0.03742805],
                                           [ 0.00436568],
                                           [-0.10548005],
                                           [-0.05702855],
                                           [-0.01185708],
                                           [-0.01168777]], dtype=float32),
                       }),
})

Learn and Forecast¶

class ForecastInfo(NamedTuple):
    optim: Any
    forecast: Any
def learn_and_forecast(y, X=None):
    optim_res = OnlineOptimizer(
        predict_and_evaluate, optax.sgd(1.0e-3), project_params=project_params
    )(*lag(1)(y, X))

    predict_params = optim_res.updated_params

    forecast, forecast_info = UpdateParams(predict)(predict_params, y, X)
    return forecast, ForecastInfo(optim_res, forecast_info)
sim = unroll_transform_with_state(learn_and_forecast)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, y)
(forecast, info), state = sim.apply(params, state, rng, y)

pd.Series(info.optim.loss).expanding().mean().plot()
<AxesSubplot:>
../_images/07_Online_Time_Series_Prediction_29_1.png

Gym simulation¶

Now let’s wrapup the training loop in a gym feedback loop.

Environment¶

Let’s build an environment corresponding to “setting 1” in [1]

def build_env():
    def env(action, obs):
        y_pred, eps = action, obs
        ar_coefs = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
        ma_coefs = jnp.array([0.3, -0.2])

        y = ARMA(ar_coefs, ma_coefs)(eps)
        # prediction used on a fresh y observation.
        rw = -((y - y_pred) ** 2)

        env_info = {"y": y, "y_pred": y_pred}
        obs = y
        return rw, obs, env_info

    return env
env = build_env()

Agent¶

Let’s build an agent:

from optax._src.base import OptState
def build_agent(time_series_model=None, opt=None):
    if time_series_model is None:
        time_series_model = lambda y, X: SNARIMAX(10)(y, X)

    if opt is None:
        opt = optax.sgd(1.0e-3)

    class AgentInfo(NamedTuple):
        optim: Any
        forecast: Any

    class ModelWithLossInfo(NamedTuple):
        pred: Any
        loss: Any

    def agent(obs):
        if isinstance(obs, tuple):
            y, X = obs
        else:
            y = obs
            X = None

        def evaluate(y_pred, y):
            return jnp.linalg.norm(y_pred - y) ** 2, {}

        def model_with_loss(y, X=None):
            # predict with lagged data
            y_pred, pred_info = time_series_model(*lag(1)(y, X))

            # evaluate loss with actual data
            loss, loss_info = evaluate(y_pred, y)

            return loss, ModelWithLossInfo(pred_info, loss_info)

        def project_params(params: Any, opt_state: OptState = None):
            del opt_state
            return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)

        def split_params(params):
            def filter_params(m, n, p):
                # print(m, n, p)
                return m.endswith("snarimax/~/linear") and n == "w"

            return hk.data_structures.partition(filter_params, params)

        def learn_and_forecast(y, X=None):
            optim_res = OnlineOptimizer(
                model_with_loss,
                opt,
                project_params=project_params,
                split_params=split_params,
            )(*lag(1)(y, X))

            predict_params = optim_res.updated_params

            y_pred, forecast_info = UpdateParams(time_series_model)(
                predict_params, y, X
            )
            return y_pred, AgentInfo(optim_res, forecast_info)

        return learn_and_forecast(y, X)

    return agent
agent = build_agent()

Gym loop¶

def gym_loop(eps):
    return GymFeedback(agent, env)(eps)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T,))
sim = unroll_transform_with_state(gym_loop)
params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
pd.Series(info.agent.optim.loss).expanding().mean().plot(label="agent loss")
pd.Series(-gym.reward).expanding().mean().plot(xlim=(0, 100), label="env loss")
plt.legend()
<matplotlib.legend.Legend at 0x167c42490>
../_images/07_Online_Time_Series_Prediction_43_1.png
pd.Series(-gym.reward).expanding().mean().plot()  # ylim=(0.09, 0.15))
<AxesSubplot:>
../_images/07_Online_Time_Series_Prediction_44_1.png

We see that the agent suffers the same loss as the environment but with a time lag.

Batch simulations¶

Average over 20 experiments¶

Slow version¶

First, let’s do it “naively” by doing a simple python “for loop”.

%%time
rng = jax.random.PRNGKey(42)
sim = unroll_transform_with_state(gym_loop)

res = {}
for i in tqdm(onp.arange(N_BATCH)):
    rng, _ = jax.random.split(rng)
    eps = jax.random.normal(rng, (T,)) * 0.3
    params, state = sim.init(rng, eps)
    (gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
    res[i] = gym_info
CPU times: user 13.7 s, sys: 182 ms, total: 13.9 s
Wall time: 13.8 s
pd.DataFrame({k: pd.Series(v.agent.optim.loss) for k, v in res.items()}).mean(
    1
).expanding().mean().plot(ylim=(0.09, 0.15))
../_images/07_Online_Time_Series_Prediction_50_0.png

Fast version with vmap¶

Instead of using a “for loop” we can use jax’s vmap transformation function!

%%time
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (N_BATCH, T)) * 0.3

rng = jax.random.PRNGKey(42)
rng = jax.random.split(rng, num=N_BATCH)
sim = unroll_transform_with_state(gym_loop)
params, state = jax.vmap(sim.init)(rng, eps)
(gym_output, gym_info), final_state = jax.vmap(sim.apply)(params, state, rng, eps)
CPU times: user 2.09 s, sys: 61 ms, total: 2.15 s
Wall time: 2.13 s

This is much faster!

pd.DataFrame(gym_info.agent.optim.loss).mean().expanding().mean().plot(
    ylim=(0.09, 0.15)
)
../_images/07_Online_Time_Series_Prediction_54_0.png
i_batch = 0
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][i_batch, :, :, 0]
w = pd.DataFrame(w)

ax = w.plot(title=f"weights on batch {i_batch}")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight value")


plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
../_images/07_Online_Time_Series_Prediction_55_0.png ../_images/07_Online_Time_Series_Prediction_55_1.png
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"].mean(axis=0)[:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")

plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
../_images/07_Online_Time_Series_Prediction_56_0.png ../_images/07_Online_Time_Series_Prediction_56_1.png

With VMap module¶

We can use the wrapper module VMap of WAX-ML. It permits to have an ever simpler syntax.

Note: we have to swap the position of time and batch dimensions in the generation of the noise variable eps.

%%time
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3


def batched_gym_loop(eps):
    return VMap(gym_loop)(eps)


sim = unroll_transform_with_state(batched_gym_loop)

rng = jax.random.PRNGKey(43)
params, state = sim.init(rng, eps)
(gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
CPU times: user 1.53 s, sys: 15.5 ms, total: 1.54 s
Wall time: 1.53 s
pd.DataFrame(gym_info.agent.optim.loss).shape
(10000, 20)
pd.DataFrame(gym_info.agent.optim.loss).mean(1).expanding().mean().plot(
    ylim=(0.09, 0.15)
)
../_images/07_Online_Time_Series_Prediction_61_0.png
i_batch = 0
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][:, i_batch, :, 0]
w = pd.DataFrame(w)

ax = w.plot(title=f"weights on batch {i_batch}")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight value")


plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
../_images/07_Online_Time_Series_Prediction_62_0.png ../_images/07_Online_Time_Series_Prediction_62_1.png
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"].mean(axis=1)[:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")

plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
../_images/07_Online_Time_Series_Prediction_63_0.png ../_images/07_Online_Time_Series_Prediction_63_1.png

Taking mean inside simulation¶

%%time


def add_batch(fun, take_mean=True):
    def fun_batch(*args, **kwargs):
        res = VMap(fun)(*args, **kwargs)
        if take_mean:
            res = jax.tree_map(lambda x: x.mean(axis=0), res)
        return res

    return fun_batch


gym_loop_batch = add_batch(gym_loop)
sim = unroll_transform_with_state(gym_loop_batch)

rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3

params, state = sim.init(rng, eps)
(gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
CPU times: user 1.53 s, sys: 18.8 ms, total: 1.54 s
Wall time: 1.54 s
pd.Series(-gym_output.reward).expanding().mean().plot(ylim=(0.09, 0.15))
../_images/07_Online_Time_Series_Prediction_66_0.png
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")

plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
../_images/07_Online_Time_Series_Prediction_67_0.png ../_images/07_Online_Time_Series_Prediction_67_1.png

Hyper parameter tuning¶

First order optimizers¶

We will consider different first order optimizers, namely:

  • SGD

  • ADAM

  • ADAGRAD

For each of them, we will scan the “step_size” parameter \(\eta\).

We will average results over batches of size 40.

We will consider trajectories of size 10.000.

Finally, we will pickup the best parameter based of the minimum averaged loss for the last 5000 time steps.

%%time

STEP_SIZE_idx = pd.Index(onp.logspace(-4, 1, 30), name="step_size")
STEP_SIZE = jax.device_put(STEP_SIZE_idx.values)
OPTIMIZERS = [optax.sgd, optax.adagrad, optax.rmsprop, optax.adam]

res = {}
for optimizer in tqdm(OPTIMIZERS):

    def gym_loop_scan_hparams(eps):
        def scan_params(step_size):
            return GymFeedback(build_agent(opt=optimizer(step_size)), env)(eps)

        res = VMap(scan_params)(STEP_SIZE)
        return res

    sim = unroll_transform_with_state(add_batch(gym_loop_scan_hparams))
    rng = jax.random.PRNGKey(42)
    eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3

    params, state = sim.init(rng, eps)
    _res, state = sim.apply(params, state, rng, eps)
    res[optimizer.__name__] = _res
CPU times: user 11.5 s, sys: 111 ms, total: 11.6 s
Wall time: 11.6 s
ax = None
BEST_STEP_SIZE = {}
BEST_GYM = {}
for name, (gym, info) in res.items():

    loss = pd.DataFrame(-gym.reward, columns=STEP_SIZE).iloc[-5000:].mean()

    BEST_STEP_SIZE[name] = loss.idxmin()
    best_idx = loss.reset_index(drop=True).idxmin()
    BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)

    ax = loss[loss < 0.15].plot(logx=True, logy=False, ax=ax, label=name)
plt.legend()
../_images/07_Online_Time_Series_Prediction_71_0.png
for name, gym in BEST_GYM.items():
    ax = (
        pd.Series(-gym.reward)
        .expanding()
        .mean()
        .plot(
            label=f"{name}    -    $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.15)
        )
    )
ax.legend(bbox_to_anchor=(1.0, 1.0))
../_images/07_Online_Time_Series_Prediction_72_0.png
pd.Series(BEST_STEP_SIZE).plot(kind="bar", logy=True)
../_images/07_Online_Time_Series_Prediction_73_0.png

Newton algorithm¶

Now let’s consider the newton algorithm.

First let’s test it with one set of parameter with average over N_BATCH batches.

%%time


@add_batch
def gym_loop_newton(eps):
    return GymFeedback(build_agent(opt=newton(0.05, eps=20.0)), env)(eps)


sim = unroll_transform_with_state(gym_loop_newton)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3

params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
CPU times: user 1.88 s, sys: 43.5 ms, total: 1.92 s
Wall time: 1.73 s
pd.Series(-gym.reward).expanding().mean().plot()
../_images/07_Online_Time_Series_Prediction_78_0.png
%%time

STEP_SIZE = pd.Index(onp.logspace(-2, 3, 10), name="step_size")
EPS = pd.Index(onp.logspace(-4, 3, 5), name="eps")

HPARAMS_idx = pd.MultiIndex.from_product([STEP_SIZE, EPS])
HPARAMS = jnp.stack(list(map(onp.array, HPARAMS_idx)))


@add_batch
def gym_loop_scan_hparams(eps):
    def scan_params(hparams):
        step_size, newton_eps = hparams
        agent = build_agent(opt=newton(step_size, eps=newton_eps))
        return GymFeedback(agent, env)(eps)

    return VMap(scan_params)(HPARAMS)


sim = unroll_transform_with_state(gym_loop_scan_hparams)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3

params, state = sim.init(rng, eps)
res_newton, state = sim.apply(params, state, rng, eps)
CPU times: user 16.1 s, sys: 118 ms, total: 16.2 s
Wall time: 16 s
gym_newton, info_newton = res_newton
loss_newton = pd.DataFrame(-gym_newton.reward, columns=HPARAMS_idx).mean().unstack()
loss_newton = (
    pd.DataFrame(-gym_newton.reward, columns=HPARAMS_idx).iloc[-5000:].mean().unstack()
)
sns.heatmap(loss_newton[loss_newton < 0.4], annot=True, cmap="YlGnBu")

STEP_SIZE, NEWTON_EPS = loss_newton.stack().idxmin()

x = -gym_newton.reward[-5000:].mean(axis=0)
x = jax.ops.index_update(x, jnp.isnan(x), jnp.inf)
I_BEST_PARAM = jnp.argmin(x)


BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)
Best newton parameters:  0.464158883361278 0.31622776601683794
../_images/07_Online_Time_Series_Prediction_83_1.png
for name, gym in BEST_GYM.items():
    pd.Series(-gym.reward).rolling(5000, min_periods=5000).mean().plot(
        label=f"{name}    -    $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.1)
    )

gym = BEST_NEWTON_GYM
ax = (
    pd.Series(-gym.reward)
    .rolling(5000, min_periods=5000)
    .mean()
    .plot(
        label=f"Newton    -    $\eta$={STEP_SIZE:.2e},    $\epsilon$={NEWTON_EPS:.2e}"
    )
)
ax.legend(bbox_to_anchor=(1.0, 1.0))
plt.title("Rolling mean of loss (5000) time-steps")
../_images/07_Online_Time_Series_Prediction_84_0.png
for name, gym in BEST_GYM.items():
    pd.Series(-gym.reward).expanding().mean().plot(
        label=f"{name}    -    $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.15)
    )
gym = BEST_NEWTON_GYM
ax = (
    pd.Series(-gym.reward)
    .expanding()
    .mean()
    .plot(
        label=f"Newton    -    $\eta$={STEP_SIZE:.2e},    $\epsilon$={NEWTON_EPS:.2e}"
    )
)
ax.legend(bbox_to_anchor=(1.0, 1.0))
../_images/07_Online_Time_Series_Prediction_85_0.png

In agreement with results in [1], we see that Newton’s algorithm performs much better than SGD.

In addition, we note that:

  • ADAGRAD performormance is between newton and sgd.

  • RMSPROPR and ADAM does not perform well in this online setting.