Overview¶
import matplotlib as mpl
import matplotlib.pyplot as plt
#from ipywidgets import interact, widgets
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
import altair as alt
import arviz as az
import numpyro
import pandas as pd
from jax import grad, jit
from jax import numpy as jnp
from jax import random, vmap
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS
from lqg import LQG, Actor, Dynamics, System, xcorr
# We have to use DataFrames that a larger than recommended -> turn off the error
alt.data_transformers.disable_max_rows()
DataTransformerRegistry.enable('default')
Quick intro to jax¶
1. NumPy-like API: jax.numpy¶
JAX is a library that enables transformations of array-manipulating programs written with a NumPy-like API. You can think of JAX as differentiable NumPy that runs on accelerators. Many NumPy programs would run just as well in JAX if you substitute np for jnp.
a = jnp.array([[1., 2.],
               [3., 4.]])
key = random.PRNGKey(1)
b = random.normal(key, shape=(2, 3))
a @ b
Array([[ 1.6896877, -0.6240323,  1.5889112],
       [ 4.3366823, -2.2179935,  4.184889 ]], dtype=float32)
2. Automatic differentiation: grad¶
You can think of jax.grad by analogy to the $\nabla$ operator from vector calculus. Given a function $f(x)$, $\nabla f$ represents the function that computes $f$’s gradient, i.e.
$$ (\nabla f)(x)_i = \frac{\partial f}{\partial x_i} (x). $$
Analogously, jax.grad(f) is the function that computes the gradient, so jax.grad(f)(x) is the gradient of f at x.
def f(x):
    return jnp.sin(x)
grad(f)(jnp.pi)
Array(-1., dtype=float32, weak_type=True)
3. Easy vectorization: vmap¶
In JAX, the jax.vmap transformation is designed to generate a vectorized implementation of a function automatically.
x = jnp.linspace(0, 2 * jnp.pi)
plt.plot(x, f(x))
plt.plot(x, vmap(grad(f))(x))
[<matplotlib.lines.Line2D at 0x7feb443178e0>]
4. Compilation: jit¶
You can use the XLA (accelerated linear algebra) compiler to compile your functions with jax.jit.
def f(x):
    return x * x + x * 2.0
x = jnp.ones((5000, 5000))
%timeit f(x)
63 ms ± 1.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
jit_f = jit(f)
%timeit jit_f(x)
22.8 ms ± 1.91 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Putting perception into action: Inverse optimal control for continuous psychophysics¶
Modeling a tracking task with LQG control¶
The LQG control problem is defined by a linear-Gaussian stochastic dynamical system $$ \mathbf x_{t+1} = A \mathbf x_t + B \mathbf u_t + V \mathbf \epsilon_t, \; \mathbf\epsilon_t \sim \mathcal{N}(0, I), $$
a linear-Gaussian observation model $$ \mathbf y_t = C \mathbf x_t + W \mathbf \eta_t, \; \mathbf\eta_t \sim \mathcal{N}(0, I), $$
and a quadratic cost function
$$ J = \sum_t \mathbf x_t^T Q \mathbf x_t + \mathbf u_t^T R \mathbf u_t. $$
We assume that the actor solves the linear-quadratic Gaussian problem, i.e. computes the Kalman filter $K$ and the LQR control law $L$, which are the optimal solution under the quadratic cost function
$$ J(u_{1:T}) = \sum_{t=1}^T \mathbf x_t^T Q \mathbf x_t + \mathbf u_t^T R \mathbf u_t. $$
We start by defining the matrices $A, B, C, V, W, Q, R$ as jax.numpy.arrays. according to our simple model of the continuous psychophysics tracking task:
$$ A = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}, \; B = \begin{bmatrix} 0 \\ dt \end{bmatrix}, \; C = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}, \\ V = \begin{bmatrix} \sigma_\text{rw} & 0 \\ 0 & \sigma_\text{act} \end{bmatrix}, \; W = \begin{bmatrix} \sigma & 0 \\ 0 & \sigma_\text{cursor} \end{bmatrix}, \\ Q = \begin{bmatrix} 1 & -1 \\ -1 & 1 \end{bmatrix}, \; R = \begin{bmatrix} c \end{bmatrix}. $$
action_variability = 0.5
sigma_target = 6.
sigma_cursor = 1.
action_cost = .05
dt = 1. / 60.
# dynamical system
A = jnp.eye(2)
B = jnp.array([[0.], 
              [dt]])
# noise
V = jnp.diag(jnp.array([1., action_variability]))
# observation model
C = jnp.eye(2)
W = jnp.diag(jnp.array([sigma_target, sigma_cursor]))
# cost function
Q = jnp.array([[1., -1.],
              [-1., 1]])
R = jnp.eye(1) * action_cost
T = 500
model = LQG(A, B, C, V, W, Q, R, T=T)
The System class, which LQG extends, can be pretty-printed. When calling display(model), the matrices of your model will be displayed as $\LaTeX$ math formula. The same happens when the output of a cell is printed. Standard (i.e., not so pretty) printing can still be achieved through print(model).
model  # equivalent to `display(model)`
Let's simulate some tracking data by applying the Kalman filter and linear-quadratic regular. This is implemented in the method simulate(rng_key, n, T). Since jax does not have a global random number generator state, we need to pass a PRNGKey object. n is the number of trials and T is the number of time steps.
x = model.simulate(random.PRNGKey(0), n=100)
plt.plot(jnp.arange(T) * dt, x[0, :, 0], label="target")
plt.plot(jnp.arange(T) * dt, x[0, :, 1], label="cursor")
plt.legend()
plt.xlabel("Time [s]")
plt.ylabel("Position [arcmin]")
plt.show()
x.shape
(100, 500, 2)
Cross-correlograms¶
We can also look at the correlation between the velocities of the target and the cursor at different time lags. This analysis is known as a cross-correlogram (Mulligan et al., 2013) and computes the average autocorrelation of the velocities of target and response.
vels = jnp.diff(x, axis=1)
lags, correls = xcorr(vels[...,1], vels[...,0], maxlags=120)
plt.plot(lags, correls.mean(axis=0))
plt.xlabel("Lag [s]")
plt.ylabel("Cross-correlation")
Text(0, 0.5, 'Cross-correlation')
Influence of model parameters¶
To look a the influence of the different model parameters, we define a class that inherits from the LQG base class and defines the matrices given the four parameters.
class BoundedActor(LQG):
    def __init__(self, 
               sigma_target, 
               action_variability, 
               action_cost, 
               sigma_cursor):
        dt = 1. / 60.
        A = jnp.eye(2)
        B = jnp.array([[0.], 
                      [dt]])
        V = jnp.diag(jnp.array([1., action_variability]))
        F = jnp.eye(2)
        W = jnp.diag(jnp.array([sigma_target, sigma_cursor]))
        Q = jnp.array([[1., -1.],
                      [-1., 1]])
        R = jnp.eye(1) * action_cost
        super().__init__(A=A, B=B, F=F, V=V, W=W, Q=Q, R=R, T=T)
We can now simulate data from the model given the four parameters. To do this efficiently, we jit-compile the simulation function.
Some observations:
- An increase in action costs leads to an increased lag and decreased maximum correlation.
- An increase in perceptual uncertainty about the target leads to decreased correlation and increased lag, too, but the shape of the curves changes differently compared to the effect of the behavioral cost.
- Action variability does not change the lag, but decreases correlation overall.
- Perceptual uncertainty about the cursor does not change the shape of the CCGs at all, but does increase the mean squared error between target and response.
@jit  # jit-compile the simulation to speed up the data generation
def simulate_trajectories(sigma_target, action_cost, action_variability, sigma_cursor):
    model = BoundedActor(
        sigma_target=sigma_target,
        action_variability=action_variability,
        action_cost=action_cost,
        sigma_cursor=sigma_cursor,
    )
    x = model.simulate(random.PRNGKey(0), n=100)
    return x
# Simulate data and store it in a DataFrame
data_trajectory = []
data_ccg = []
sigma_target_list = [1.0, 10.0, 100.0]
action_cost_list = [0.2, 1.0, 5.0]
action_variability_list = [0.25, 0.5, 1.0]
sigma_cursor_list = [1.0, 10.0, 100.0]
time_max = 500
time = jnp.arange(time_max) * dt
for sigma_target in sigma_target_list:
    for action_cost in action_cost_list:
        for action_variability in action_variability_list:
            for sigma_cursor in sigma_cursor_list:
                x = simulate_trajectories(
                    sigma_target, action_cost, action_variability, sigma_cursor
                )
                for i, step in enumerate(x[0]):
                    data_trajectory.append(
                        [
                            sigma_target,
                            action_cost,
                            action_variability,
                            sigma_cursor,
                            "target",
                            time[i].item(),
                            step[0].item(),
                        ]
                    )
                    data_trajectory.append(
                        [
                            sigma_target,
                            action_cost,
                            action_variability,
                            sigma_cursor,
                            "cursor",
                            time[i].item(),
                            step[1].item(),
                        ]
                    )
                vels = jnp.diff(x, axis=1)
                lags, correls = xcorr(vels[..., 1], vels[..., 0], maxlags=120)
                lags = lags / 60
                correls = correls.mean(axis=0)
                for lag, correl in zip(lags, correls):
                    data_ccg.append(
                        [
                            sigma_target,
                            action_cost,
                            action_variability,
                            sigma_cursor,
                            lag.item(),
                            correl.item(),
                        ]
                    )
df_trajectory = pd.DataFrame(
    data_trajectory,
    columns=[
        "sigma_target",
        "action_cost",
        "action_variability",
        "sigma_cursor",
        "value",
        "time",
        "position",
    ],
)
df_ccg = pd.DataFrame(
    data_ccg,
    columns=["sigma_target", "action_cost", "action_variability", "sigma_cursor", "lag", "correl"],
)
# Plot the data as interactive Altair chart
radio1 = alt.binding_radio(
    options=sigma_target_list,
    name="Perceptual uncertainty (target): ",
)
selection1 = alt.selection_point(
    value=10.0,
    fields=["sigma_target"],
    bind=radio1,
)
radio2 = alt.binding_radio(options=action_cost_list, name="Behavioral costs: ")
selection2 = alt.selection_point(
    value=1.0,
    fields=["action_cost"],
    bind=radio2,
)
radio3 = alt.binding_radio(options=action_variability_list, name="Action variability: ")
selection3 = alt.selection_point(
    value=0.5,
    fields=["action_variability"],
    bind=radio3,
)
radio4 = alt.binding_radio(options=sigma_cursor_list, name="Perceptual uncertainty (cursor): ")
selection4 = alt.selection_point(
    value=10.0,
    fields=["sigma_cursor"],
    bind=radio4,
)
lines_trajectory = (
    alt.Chart(df_trajectory)
    .mark_line()
    .encode(
        x="time:Q",
        y=alt.Y("position").scale(domain=(-30, 30)),
        color=alt.Color("value").sort("descending"),
        tooltip=["time", "position"],
    )
    .add_params(selection4, selection3, selection2, selection1)
    .transform_filter(selection1 & selection2 & selection3 & selection4)
    .properties(title="Trajectory")
)
lines_ccg = (
    alt.Chart(df_ccg)
    .mark_line()
    .encode(
        x="lag:Q",
        y=alt.Y("correl:Q").scale(domain=(-0.02, 0.1)),
        color=alt.value("#2CA02C"),
        tooltip=["lag", "correl"],
    )
    .add_params(selection4, selection3, selection2, selection1)
    .transform_filter(selection1 & selection2 & selection3 & selection4)
    .properties(title="Cross-correlogram")
)
chart = lines_trajectory | lines_ccg
display(chart)