Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Interpolation and Extrapolation

In dynamic programming, the value function V(x)V(x) is computed on a discrete grid but must be evaluated at arbitrary points during the solution process. State transitions x=g(x,a)x' = g(x, a) can produce next-period states that fall between grid points (requiring interpolation) or outside the grid (requiring extrapolation).

pylcm does not use RegularGridInterpolator or similar pre-built tools because it needs to support grids whose points are only known at runtime — for example, shock grids whose locations depend on distributional parameters supplied via params.

This notebook explains how pylcm handles interpolation and extrapolation, using a CRRA utility function u(w)=w1γ1γu(w) = \frac{w^{1-\gamma}}{1-\gamma} on a coarse wealth grid as a running example.

pylcm’s two-step design

pylcm evaluates functions on arbitrary points in two steps:

  1. Coordinate finder: Convert a physical value (e.g., wealth = 150) to a generalized coordinate — a fractional index into the grid. Values inside the grid produce coordinates in [0,n1][0, n-1]; values outside produce coordinates outside this range.

  2. map_coordinates: Take the generalized coordinates and the pre-computed array of function values. Perform linear interpolation (for coordinates inside [0,n1][0, n-1]) or linear extrapolation (for coordinates outside).

Each grid type provides its own coordinate finder, optimized for its spacing pattern. The map_coordinates function is the same for all grid types.

LinSpacedGrid

For a linearly spaced grid with start, stop, and n_points, the coordinate finder uses the O(1) formula:

coordinate=valuestartstep_length,step_length=stopstartnpoints1\text{coordinate} = \frac{\text{value} - \text{start}}{\text{step\_length}}, \quad \text{step\_length} = \frac{\text{stop} - \text{start}}{n_\text{points} - 1}

Values inside the grid produce coordinates in [0,npoints1][0, n_\text{points}-1]. Values below start produce negative coordinates; values above stop produce coordinates above npoints1n_\text{points}-1. Both cases lead to linear extrapolation in map_coordinates.

import jax.numpy as jnp
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from lcm import LinSpacedGrid, LogSpacedGrid, Piece, PiecewiseLinSpacedGrid
from lcm.ndimage import map_coordinates

blue, orange, green = "#4C78A8", "#F58518", "#54A24B"


def crra(wealth, gamma=1.5):
    """CRRA utility: u(w) = w^(1-gamma) / (1-gamma)."""
    return wealth ** (1 - gamma) / (1 - gamma)


# Coarse linearly spaced wealth grid (10 points)
lin_grid = LinSpacedGrid(start=1, stop=400, n_points=10)
grid_points = lin_grid.to_jax()

# CRRA values on the grid
V = crra(grid_points)

print("Grid points:", grid_points)
print("V on grid:  ", V)
Grid points: [  1.        45.333336  89.66667  134.       178.33334  222.66667
 267.       311.33334  355.6667   400.      ]
V on grid:   [-2.         -0.29704425 -0.21121    -0.17277369 -0.14976616 -0.13403012
 -0.12239801 -0.11334886 -0.10604945 -0.1       ]
# Query points: some inside the grid, some outside
query_inside = jnp.array([25.0, 100.0, 250.0])
query_outside = jnp.array([0.5, 450.0])

coords_inside = lin_grid.get_coordinate(query_inside)
coords_outside = lin_grid.get_coordinate(query_outside)

print("Inside grid:")
for w, c in zip(query_inside, coords_inside, strict=True):
    print(f"  wealth = {w:.1f}  →  coordinate = {c:.4f}  (in [0, 9])")

print("\nOutside grid:")
for w, c in zip(query_outside, coords_outside, strict=True):
    in_range = "< 0" if c < 0 else "> 9"
    print(f"  wealth = {w:.1f}  →  coordinate = {c:.4f}  ({in_range})")
Inside grid:
  wealth = 25.0  →  coordinate = 0.5414  (in [0, 9])
  wealth = 100.0  →  coordinate = 2.2331  (in [0, 9])
  wealth = 250.0  →  coordinate = 5.6165  (in [0, 9])

Outside grid:
  wealth = 0.5  →  coordinate = -0.0113  (< 0)
  wealth = 450.0  →  coordinate = 10.1278  (> 9)
# Interpolate and extrapolate using map_coordinates
query_all = jnp.concatenate([query_inside, query_outside])
coords_all = lin_grid.get_coordinate(query_all)

V_approx = map_coordinates(input=V, coordinates=[coords_all])
V_true = crra(query_all)

print(
    f"{'Wealth':>8}  {'Coordinate':>11}  "
    f"{'Approximated':>12}  {'True':>12}  {'Error':>10}"
)
print("-" * 60)
for w, c, va, vt in zip(query_all, coords_all, V_approx, V_true, strict=True):
    print(f"{w:8.1f}  {c:11.4f}  {va:12.6f}  {vt:12.6f}  {va - vt:10.6f}")
  Wealth   Coordinate  Approximated          True       Error
------------------------------------------------------------
    25.0       0.5414     -1.078099     -0.400000   -0.678099
   100.0       2.2331     -0.202251     -0.200000   -0.002251
   250.0       5.6165     -0.126858     -0.126491   -0.000367
     0.5      -0.0113     -2.019206     -2.828427    0.809221
   450.0      10.1278     -0.093177     -0.094281    0.001104
# Dense points for smooth curves (extending beyond the grid for extrapolation)
x_dense = jnp.linspace(0.5, 450, 500)
coords_dense = lin_grid.get_coordinate(x_dense)
V_interp_dense = map_coordinates(input=V, coordinates=[coords_dense])
V_true_dense = crra(x_dense)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=x_dense,
        y=V_true_dense,
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True CRRA",
    )
)
fig.add_trace(
    go.Scatter(
        x=x_dense,
        y=V_interp_dense,
        mode="lines",
        line={"color": orange, "width": 2},
        name="Interpolation / Extrapolation",
    )
)
fig.add_trace(
    go.Scatter(
        x=grid_points,
        y=V,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Grid points",
    )
)
fig.add_vline(x=1, line={"color": "gray", "dash": "dot", "width": 1})
fig.add_vline(x=400, line={"color": "gray", "dash": "dot", "width": 1})
fig.update_layout(
    title="LinSpacedGrid: Interpolation and Extrapolation",
    xaxis_title="Wealth",
    yaxis_title="u(w)",
)
fig.show()
Loading...

LogSpacedGrid

For a logarithmically spaced grid, points are denser at lower values — ideal for functions with high curvature near zero (like CRRA utility).

The coordinate finder get_logspace_coordinate works by:

  1. Transforming value, start, stop to log space

  2. Finding the bounding grid points (via their ranks in log space)

  3. Linearly interpolating between the ranks in the original (physical) space

This gives a coordinate that accounts for the non-uniform spacing. With the same number of grid points, a log-spaced grid captures the curvature of CRRA much better than a linearly spaced one.

log_grid = LogSpacedGrid(start=1, stop=400, n_points=10)
log_points = log_grid.to_jax()
V_log = crra(log_points)

print("LinSpaced grid:", jnp.round(grid_points, 1))
print("LogSpaced grid:", jnp.round(log_points, 1))
LinSpaced grid: [  1.        45.3       89.700005 134.       178.3      222.7
 267.       311.30002  355.7      400.      ]
LogSpaced grid: [  1.    1.9   3.8   7.4  14.3  27.9  54.3 105.6 205.6 400. ]
x_eval = jnp.linspace(1, 400, 500)
V_true_eval = crra(x_eval)

# LinSpaced interpolation
lin_coords = lin_grid.get_coordinate(x_eval)
V_lin_interp = map_coordinates(input=V, coordinates=[lin_coords])

# LogSpaced interpolation
log_coords = log_grid.get_coordinate(x_eval)
V_log_interp = map_coordinates(input=V_log, coordinates=[log_coords])

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=(
        "Interpolated value functions",
        "Interpolation error",
    ),
)

# Left: Interpolated curves
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=V_true_eval,
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True CRRA",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=V_lin_interp,
        mode="lines",
        line={"color": blue, "width": 2},
        name="LinSpaced (10 pts)",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=V_log_interp,
        mode="lines",
        line={"color": orange, "width": 2},
        name="LogSpaced (10 pts)",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=grid_points,
        y=V,
        mode="markers",
        marker={"color": blue, "size": 6},
        name="LinSpaced grid",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=log_points,
        y=V_log,
        mode="markers",
        marker={"color": orange, "size": 6},
        name="LogSpaced grid",
    ),
    row=1,
    col=1,
)

# Right: Absolute errors
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=jnp.abs(V_lin_interp - V_true_eval),
        mode="lines",
        line={"color": blue},
        name="LinSpaced error",
        showlegend=False,
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        x=x_eval,
        y=jnp.abs(V_log_interp - V_true_eval),
        mode="lines",
        line={"color": orange},
        name="LogSpaced error",
        showlegend=False,
    ),
    row=1,
    col=2,
)

fig.update_xaxes(title_text="Wealth", row=1, col=1)
fig.update_xaxes(title_text="Wealth", row=1, col=2)
fig.update_yaxes(title_text="u(w)", row=1, col=1)
fig.update_yaxes(title_text="|Error|", row=1, col=2)
fig.update_layout(height=450, width=900)
fig.show()
Loading...

ShockGrid (Normal)

ShockGrids discretize continuous distributions. Their grid points depend on distributional parameters (e.g., mu, sigma for a Normal shock) which may only be known at runtime (when supplied via params). Because the points are determined dynamically, ShockGrids cannot use the O(1) linspace formula — instead, they use get_irreg_coordinate internally.

get_irreg_coordinate handles arbitrary point sequences:

  1. Use jnp.searchsorted to find the bounding grid points (O(log n))

  2. Linearly interpolate between the bounding points to get the fractional coordinate

  3. For values outside the grid, extrapolate using the slope of the nearest segment

import lcm.shocks.iid

shock = lcm.shocks.iid.Normal(
    mu=0.0, sigma=1.0, n_std=2.5, n_points=7, gauss_hermite=False
)
shock_points = shock.to_jax()

print("Shock grid points:", shock_points)

# CRRA of (base wealth + shock)
base_wealth = 100.0
V_shock = crra(base_wealth + shock_points)

# Query points inside and outside the shock grid
shock_query = jnp.array([-3.0, -1.0, 0.5, 2.0, 3.0])
shock_coords = shock.get_coordinate(shock_query)

n_last = len(shock_points) - 1

print("\nShock query points and coordinates:")
for q, c in zip(shock_query, shock_coords, strict=True):
    in_out = "inside" if 0 <= c <= n_last else "outside"
    print(f"  ε = {q:5.1f}  →  coordinate = {c:6.3f}  ({in_out})")
Shock grid points: [-2.5000000e+00 -1.6666666e+00 -8.3333313e-01  5.9604645e-08
  8.3333349e-01  1.6666669e+00  2.5000000e+00]

Shock query points and coordinates:
  ε =  -3.0  →  coordinate = -0.600  (outside)
  ε =  -1.0  →  coordinate =  1.800  (inside)
  ε =   0.5  →  coordinate =  3.600  (inside)
  ε =   2.0  →  coordinate =  5.400  (inside)
  ε =   3.0  →  coordinate =  6.600  (outside)
eps_dense = jnp.linspace(-3.5, 3.5, 300)
shock_coords_dense = shock.get_coordinate(eps_dense)
V_shock_interp = map_coordinates(input=V_shock, coordinates=[shock_coords_dense])
V_shock_true = crra(base_wealth + eps_dense)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=eps_dense,
        y=V_shock_true,
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True CRRA(100 + ε)",
    )
)
fig.add_trace(
    go.Scatter(
        x=eps_dense,
        y=V_shock_interp,
        mode="lines",
        line={"color": orange, "width": 2},
        name="Interpolation / Extrapolation",
    )
)
fig.add_trace(
    go.Scatter(
        x=shock_points,
        y=V_shock,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Grid points",
    )
)
fig.add_vline(
    x=float(shock_points[0]),
    line={"color": "gray", "dash": "dot", "width": 1},
)
fig.add_vline(
    x=float(shock_points[-1]),
    line={"color": "gray", "dash": "dot", "width": 1},
)
fig.update_layout(
    title="Normal ShockGrid: Interpolation and Extrapolation",
    xaxis_title="Shock (ε)",
    yaxis_title="u(100 + ε)",
)
fig.show()
Loading...

PiecewiseLinSpacedGrid

Some models feature eligibility thresholds — e.g., a means-tested program that applies only below a wealth cutoff. The value function may jump at the threshold (eligible households receive a transfer, ineligible ones do not). To ensure the threshold is a grid point (avoiding interpolation across the discontinuity), PiecewiseLinSpacedGrid lets you place a breakpoint at the threshold.

For a grid with pieces [a,b)[a, b) and [b,c][b, c]:

  1. Piece selection: jnp.searchsorted on breakpoints determines which piece a value belongs to

  2. Local coordinate: get_linspace_coordinate within the piece

  3. Global coordinate: offset by the cumulative number of points in preceding pieces

The breakpoint bb is the first point of the second piece, guaranteeing it is a grid point. Because map_coordinates interpolates linearly between adjacent grid points, and the two points straddling the breakpoint (bb and the last point of the previous piece) are adjacent in the array, the interpolation never crosses the discontinuity — this is guaranteed by the implementation.

pw_grid = PiecewiseLinSpacedGrid(
    pieces=(
        Piece(interval="[1, 50)", n_points=5),
        Piece(interval="[50, 400]", n_points=7),
    )
)
pw_points = pw_grid.to_jax()

# Value function with a jump: means-tested transfer of 0.5 for wealth below threshold
transfer = 0.5
threshold = 50.0

V_pw = crra(pw_points) + transfer * jnp.where(pw_points < threshold, 1.0, 0.0)

print(f"Total grid points: {pw_grid.n_points}")
print(f"Grid: {jnp.round(pw_points, 1)}")
print(f"\nBreakpoint at wealth = 50 is at index 5: grid[5] = {pw_points[5]:.1f}")
print(f"\nV just below threshold (grid[4]): {V_pw[4]:.4f}")
print(f"V at threshold        (grid[5]): {V_pw[5]:.4f}")
Total grid points: 12
Grid: [  1.       13.2      25.5      37.7      50.       50.      108.3
 166.7     225.      283.30002 341.7     400.     ]

Breakpoint at wealth = 50 is at index 5: grid[5] = 50.0

V just below threshold (grid[4]): 0.2172
V at threshold        (grid[5]): -0.2828
x_dense_pw = jnp.linspace(1, 400, 500)
pw_coords_dense = pw_grid.get_coordinate(x_dense_pw)
V_pw_interp = map_coordinates(input=V_pw, coordinates=[pw_coords_dense])

# Split traces at the threshold to avoid connecting lines across the jump
mask_below = x_dense_pw < threshold
x_below = x_dense_pw[mask_below]
x_above = x_dense_pw[~mask_below]

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=x_below,
        y=crra(x_below) + transfer,
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True function",
    )
)
fig.add_trace(
    go.Scatter(
        x=x_above,
        y=crra(x_above),
        mode="lines",
        line={"color": "gray", "width": 1},
        name="True function",
        showlegend=False,
    )
)
fig.add_trace(
    go.Scatter(
        x=x_below,
        y=V_pw_interp[mask_below],
        mode="lines",
        line={"color": orange, "width": 2},
        name="Piecewise interpolation",
    )
)
fig.add_trace(
    go.Scatter(
        x=x_above,
        y=V_pw_interp[~mask_below],
        mode="lines",
        line={"color": orange, "width": 2},
        name="Piecewise interpolation",
        showlegend=False,
    )
)
fig.add_trace(
    go.Scatter(
        x=pw_points,
        y=V_pw,
        mode="markers",
        marker={"color": blue, "size": 8},
        name="Grid points",
    )
)
fig.update_layout(
    title="PiecewiseLinSpacedGrid: capturing a discontinuity",
    xaxis_title="Wealth",
    yaxis_title="V(w)",
)
fig.show()
Loading...

map_coordinates internals

pylcm’s map_coordinates (src/lcm/ndimage.py) is a modified version of JAX’s jax.scipy.ndimage.map_coordinates. The key difference is in how it handles values outside the grid.

The function _compute_indices_and_weights is where the magic happens:

lower_index = jnp.clip(jnp.floor(coordinate), 0, input_size - 2)
upper_weight = coordinate - lower_index
lower_weight = 1 - upper_weight
  • The lower index is clipped to [0,n2][0, n-2], ensuring valid array access

  • But the weight is not clipped — it is simply coordinate - lower_index

For coordinates inside [0,n1][0, n-1], the weight falls in [0,1][0, 1], giving standard linear interpolation. For coordinates outside this range:

CoordinateLower indexWeightResult
c[0,n1]c \in [0, n-1]c\lfloor c \rfloorcc[0,1]c - \lfloor c \rfloor \in [0,1]Linear interpolation
c<0c < 00c<0c < 0Linear extrapolation using first segment
c>n1c > n-1n2n-2c(n2)>1c - (n-2) > 1Linear extrapolation using last segment

JAX’s version, by contrast, clips or fills values outside the grid (depending on mode), which does not give linear extrapolation.

from lcm.ndimage import _compute_indices_and_weights

# A simple grid with 5 points
n = 5
V_demo = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0])

# Three cases: below grid, inside grid, above grid
cases = [
    (-0.5, "below grid"),
    (1.7, "inside grid"),
    (4.5, "above grid"),
]

for coord, label in cases:
    coord_arr = jnp.array(coord)
    [(lo_idx, lo_wt), (hi_idx, hi_wt)] = _compute_indices_and_weights(coord_arr, n)
    result = float(lo_wt * V_demo[lo_idx] + hi_wt * V_demo[hi_idx])
    print(f"coordinate = {coord:5.1f} ({label})")
    print(f"  lower_index = {int(lo_idx)}, upper_index = {int(hi_idx)}")
    print(f"  lower_weight = {float(lo_wt):.2f}, upper_weight = {float(hi_wt):.2f}")
    print(
        f"  result = {float(lo_wt):.2f} \u00d7 V[{int(lo_idx)}]"
        f" + {float(hi_wt):.2f} \u00d7 V[{int(hi_idx)}]"
    )
    print(
        f"         = {float(lo_wt):.2f} \u00d7 {float(V_demo[lo_idx]):.0f}"
        f" + {float(hi_wt):.2f} \u00d7 {float(V_demo[hi_idx]):.0f}"
    )
    print(f"         = {result:.1f}\n")
coordinate =  -0.5 (below grid)
  lower_index = 0, upper_index = 1
  lower_weight = 1.50, upper_weight = -0.50
  result = 1.50 × V[0] + -0.50 × V[1]
         = 1.50 × 10 + -0.50 × 20
         = 5.0

coordinate =   1.7 (inside grid)
  lower_index = 1, upper_index = 2
  lower_weight = 0.30, upper_weight = 0.70
  result = 0.30 × V[1] + 0.70 × V[2]
         = 0.30 × 20 + 0.70 × 30
         = 27.0

coordinate =   4.5 (above grid)
  lower_index = 3, upper_index = 4
  lower_weight = -0.50, upper_weight = 1.50
  result = -0.50 × V[3] + 1.50 × V[4]
         = -0.50 × 40 + 1.50 × 50
         = 55.0

Multi-dimensional interpolation

When the value function depends on multiple continuous states, each dimension gets its own coordinate finder. The map_coordinates function then performs multi-linear interpolation (or extrapolation) by combining the per-dimension coordinates.

For a 2D case (e.g., wealth ×\times income shock), this is bilinear interpolation: the function value at a query point is a weighted average of the 4 nearest grid points, with weights determined by the per-dimension fractional coordinates.

# 2D value function: V(wealth, shock) = CRRA(wealth + shock)
wealth_grid = LinSpacedGrid(start=10, stop=400, n_points=8)
w_points = wealth_grid.to_jax()

shock_2d = lcm.shocks.iid.Normal(
    mu=0.0, sigma=1.0, n_std=2.0, n_points=5, gauss_hermite=False
)
s_points = shock_2d.to_jax()

# Evaluate on the 2D grid (shape: 8 x 5)
W, S = jnp.meshgrid(w_points, s_points, indexing="ij")
V_2d = crra(W + S)

print(f"Wealth grid: {jnp.round(w_points, 1)}")
print(f"Shock grid:  {jnp.round(s_points, 2)}")
print(f"V shape:     {V_2d.shape} (wealth \u00d7 shock)")

# Query point
w_query = jnp.array(150.0)
s_query = jnp.array(0.3)

# Per-dimension coordinates
w_coord = wealth_grid.get_coordinate(w_query)
s_coord = shock_2d.get_coordinate(s_query)

# 2D interpolation
V_interp_2d = map_coordinates(input=V_2d, coordinates=[w_coord, s_coord])
V_true_2d = crra(w_query + s_query)

print(f"\nQuery: wealth = {float(w_query)}, shock = {float(s_query)}")
print(f"Wealth coordinate: {float(w_coord):.4f}")
print(f"Shock coordinate:  {float(s_coord):.4f}")
print(f"Interpolated V:    {float(V_interp_2d):.6f}")
print(f"True V:            {float(V_true_2d):.6f}")
print(f"Error:             {float(V_interp_2d - V_true_2d):.6f}")
Wealth grid: [ 10.        65.700005 121.4      177.1      232.90001  288.6
 344.30002  400.      ]
Shock grid:  [-2. -1.  0.  1.  2.]
V shape:     (8, 5) (wealth × shock)

Query: wealth = 150.0, shock = 0.30000001192092896
Wealth coordinate: 2.5128
Shock coordinate:  2.3000
Interpolated V:    -0.165309
True V:            -0.163136
Error:             -0.002173

Summary

Grid typeCoordinate finderComplexityWhen to use
LinSpacedGridget_linspace_coordinateO(1)Uniformly spaced state variables
LogSpacedGridget_logspace_coordinateO(1)States with high curvature at low values (e.g., wealth with CRRA)
PiecewiseLinSpacedGridsearchsorted + get_linspace_coordinateO(log k) + O(1)States with breakpoints (e.g., eligibility thresholds)
IrregSpacedGridget_irreg_coordinateO(log n)Arbitrary point placement
ShockGridsget_irreg_coordinateO(log n)Stochastic shocks with runtime-determined points

All coordinate finders produce generalized coordinates that map_coordinates uses for linear interpolation (inside the grid) or linear extrapolation (outside the grid). The two-step design keeps the interpolation logic generic while letting each grid type optimize its coordinate mapping.