In dynamic programming, the value function is computed on a discrete grid but must be evaluated at arbitrary points during the solution process. State transitions can produce next-period states that fall between grid points (requiring interpolation) or outside the grid (requiring extrapolation).
pylcm does not use RegularGridInterpolator or similar pre-built tools because it
needs to support grids whose points are only known at runtime — for example, shock
grids whose locations depend on distributional parameters supplied via params.
This notebook explains how pylcm handles interpolation and extrapolation, using a CRRA utility function on a coarse wealth grid as a running example.
pylcm’s two-step design¶
pylcm evaluates functions on arbitrary points in two steps:
Coordinate finder: Convert a physical value (e.g., wealth = 150) to a generalized coordinate — a fractional index into the grid. Values inside the grid produce coordinates in ; values outside produce coordinates outside this range.
map_coordinates: Take the generalized coordinates and the pre-computed array of function values. Perform linear interpolation (for coordinates inside ) or linear extrapolation (for coordinates outside).
Each grid type provides its own coordinate finder, optimized for its spacing pattern.
The map_coordinates function is the same for all grid types.
LinSpacedGrid¶
For a linearly spaced grid with start, stop, and n_points, the coordinate finder
uses the O(1) formula:
Values inside the grid produce coordinates in . Values below
start produce negative coordinates; values above stop produce coordinates above
. Both cases lead to linear extrapolation in map_coordinates.
import jax.numpy as jnp
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from lcm import LinSpacedGrid, LogSpacedGrid, Piece, PiecewiseLinSpacedGrid
from lcm.ndimage import map_coordinates
blue, orange, green = "#4C78A8", "#F58518", "#54A24B"
def crra(wealth, gamma=1.5):
"""CRRA utility: u(w) = w^(1-gamma) / (1-gamma)."""
return wealth ** (1 - gamma) / (1 - gamma)
# Coarse linearly spaced wealth grid (10 points)
lin_grid = LinSpacedGrid(start=1, stop=400, n_points=10)
grid_points = lin_grid.to_jax()
# CRRA values on the grid
V = crra(grid_points)
print("Grid points:", grid_points)
print("V on grid: ", V)Grid points: [ 1. 45.333336 89.66667 134. 178.33334 222.66667
267. 311.33334 355.6667 400. ]
V on grid: [-2. -0.29704425 -0.21121 -0.17277369 -0.14976616 -0.13403012
-0.12239801 -0.11334886 -0.10604945 -0.1 ]
# Query points: some inside the grid, some outside
query_inside = jnp.array([25.0, 100.0, 250.0])
query_outside = jnp.array([0.5, 450.0])
coords_inside = lin_grid.get_coordinate(query_inside)
coords_outside = lin_grid.get_coordinate(query_outside)
print("Inside grid:")
for w, c in zip(query_inside, coords_inside, strict=True):
print(f" wealth = {w:.1f} → coordinate = {c:.4f} (in [0, 9])")
print("\nOutside grid:")
for w, c in zip(query_outside, coords_outside, strict=True):
in_range = "< 0" if c < 0 else "> 9"
print(f" wealth = {w:.1f} → coordinate = {c:.4f} ({in_range})")Inside grid:
wealth = 25.0 → coordinate = 0.5414 (in [0, 9])
wealth = 100.0 → coordinate = 2.2331 (in [0, 9])
wealth = 250.0 → coordinate = 5.6165 (in [0, 9])
Outside grid:
wealth = 0.5 → coordinate = -0.0113 (< 0)
wealth = 450.0 → coordinate = 10.1278 (> 9)
# Interpolate and extrapolate using map_coordinates
query_all = jnp.concatenate([query_inside, query_outside])
coords_all = lin_grid.get_coordinate(query_all)
V_approx = map_coordinates(input=V, coordinates=[coords_all])
V_true = crra(query_all)
print(
f"{'Wealth':>8} {'Coordinate':>11} "
f"{'Approximated':>12} {'True':>12} {'Error':>10}"
)
print("-" * 60)
for w, c, va, vt in zip(query_all, coords_all, V_approx, V_true, strict=True):
print(f"{w:8.1f} {c:11.4f} {va:12.6f} {vt:12.6f} {va - vt:10.6f}") Wealth Coordinate Approximated True Error
------------------------------------------------------------
25.0 0.5414 -1.078099 -0.400000 -0.678099
100.0 2.2331 -0.202251 -0.200000 -0.002251
250.0 5.6165 -0.126858 -0.126491 -0.000367
0.5 -0.0113 -2.019206 -2.828427 0.809221
450.0 10.1278 -0.093177 -0.094281 0.001104
# Dense points for smooth curves (extending beyond the grid for extrapolation)
x_dense = jnp.linspace(0.5, 450, 500)
coords_dense = lin_grid.get_coordinate(x_dense)
V_interp_dense = map_coordinates(input=V, coordinates=[coords_dense])
V_true_dense = crra(x_dense)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x_dense,
y=V_true_dense,
mode="lines",
line={"color": "gray", "width": 1},
name="True CRRA",
)
)
fig.add_trace(
go.Scatter(
x=x_dense,
y=V_interp_dense,
mode="lines",
line={"color": orange, "width": 2},
name="Interpolation / Extrapolation",
)
)
fig.add_trace(
go.Scatter(
x=grid_points,
y=V,
mode="markers",
marker={"color": blue, "size": 8},
name="Grid points",
)
)
fig.add_vline(x=1, line={"color": "gray", "dash": "dot", "width": 1})
fig.add_vline(x=400, line={"color": "gray", "dash": "dot", "width": 1})
fig.update_layout(
title="LinSpacedGrid: Interpolation and Extrapolation",
xaxis_title="Wealth",
yaxis_title="u(w)",
)
fig.show()LogSpacedGrid¶
For a logarithmically spaced grid, points are denser at lower values — ideal for functions with high curvature near zero (like CRRA utility).
The coordinate finder get_logspace_coordinate works by:
Transforming
value,start,stopto log spaceFinding the bounding grid points (via their ranks in log space)
Linearly interpolating between the ranks in the original (physical) space
This gives a coordinate that accounts for the non-uniform spacing. With the same number of grid points, a log-spaced grid captures the curvature of CRRA much better than a linearly spaced one.
log_grid = LogSpacedGrid(start=1, stop=400, n_points=10)
log_points = log_grid.to_jax()
V_log = crra(log_points)
print("LinSpaced grid:", jnp.round(grid_points, 1))
print("LogSpaced grid:", jnp.round(log_points, 1))LinSpaced grid: [ 1. 45.3 89.700005 134. 178.3 222.7
267. 311.30002 355.7 400. ]
LogSpaced grid: [ 1. 1.9 3.8 7.4 14.3 27.9 54.3 105.6 205.6 400. ]
x_eval = jnp.linspace(1, 400, 500)
V_true_eval = crra(x_eval)
# LinSpaced interpolation
lin_coords = lin_grid.get_coordinate(x_eval)
V_lin_interp = map_coordinates(input=V, coordinates=[lin_coords])
# LogSpaced interpolation
log_coords = log_grid.get_coordinate(x_eval)
V_log_interp = map_coordinates(input=V_log, coordinates=[log_coords])
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=(
"Interpolated value functions",
"Interpolation error",
),
)
# Left: Interpolated curves
fig.add_trace(
go.Scatter(
x=x_eval,
y=V_true_eval,
mode="lines",
line={"color": "gray", "width": 1},
name="True CRRA",
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=x_eval,
y=V_lin_interp,
mode="lines",
line={"color": blue, "width": 2},
name="LinSpaced (10 pts)",
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=x_eval,
y=V_log_interp,
mode="lines",
line={"color": orange, "width": 2},
name="LogSpaced (10 pts)",
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=grid_points,
y=V,
mode="markers",
marker={"color": blue, "size": 6},
name="LinSpaced grid",
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=log_points,
y=V_log,
mode="markers",
marker={"color": orange, "size": 6},
name="LogSpaced grid",
),
row=1,
col=1,
)
# Right: Absolute errors
fig.add_trace(
go.Scatter(
x=x_eval,
y=jnp.abs(V_lin_interp - V_true_eval),
mode="lines",
line={"color": blue},
name="LinSpaced error",
showlegend=False,
),
row=1,
col=2,
)
fig.add_trace(
go.Scatter(
x=x_eval,
y=jnp.abs(V_log_interp - V_true_eval),
mode="lines",
line={"color": orange},
name="LogSpaced error",
showlegend=False,
),
row=1,
col=2,
)
fig.update_xaxes(title_text="Wealth", row=1, col=1)
fig.update_xaxes(title_text="Wealth", row=1, col=2)
fig.update_yaxes(title_text="u(w)", row=1, col=1)
fig.update_yaxes(title_text="|Error|", row=1, col=2)
fig.update_layout(height=450, width=900)
fig.show()ShockGrid (Normal)¶
ShockGrids discretize continuous distributions. Their grid points depend on
distributional parameters (e.g., mu, sigma for a Normal shock) which may only be
known at runtime (when supplied via params). Because the points are determined
dynamically, ShockGrids cannot use the O(1) linspace formula — instead, they use
get_irreg_coordinate internally.
get_irreg_coordinate handles arbitrary point sequences:
Use
jnp.searchsortedto find the bounding grid points (O(log n))Linearly interpolate between the bounding points to get the fractional coordinate
For values outside the grid, extrapolate using the slope of the nearest segment
import lcm.shocks.iid
shock = lcm.shocks.iid.Normal(
mu=0.0, sigma=1.0, n_std=2.5, n_points=7, gauss_hermite=False
)
shock_points = shock.to_jax()
print("Shock grid points:", shock_points)
# CRRA of (base wealth + shock)
base_wealth = 100.0
V_shock = crra(base_wealth + shock_points)
# Query points inside and outside the shock grid
shock_query = jnp.array([-3.0, -1.0, 0.5, 2.0, 3.0])
shock_coords = shock.get_coordinate(shock_query)
n_last = len(shock_points) - 1
print("\nShock query points and coordinates:")
for q, c in zip(shock_query, shock_coords, strict=True):
in_out = "inside" if 0 <= c <= n_last else "outside"
print(f" ε = {q:5.1f} → coordinate = {c:6.3f} ({in_out})")Shock grid points: [-2.5000000e+00 -1.6666666e+00 -8.3333313e-01 5.9604645e-08
8.3333349e-01 1.6666669e+00 2.5000000e+00]
Shock query points and coordinates:
ε = -3.0 → coordinate = -0.600 (outside)
ε = -1.0 → coordinate = 1.800 (inside)
ε = 0.5 → coordinate = 3.600 (inside)
ε = 2.0 → coordinate = 5.400 (inside)
ε = 3.0 → coordinate = 6.600 (outside)
eps_dense = jnp.linspace(-3.5, 3.5, 300)
shock_coords_dense = shock.get_coordinate(eps_dense)
V_shock_interp = map_coordinates(input=V_shock, coordinates=[shock_coords_dense])
V_shock_true = crra(base_wealth + eps_dense)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=eps_dense,
y=V_shock_true,
mode="lines",
line={"color": "gray", "width": 1},
name="True CRRA(100 + ε)",
)
)
fig.add_trace(
go.Scatter(
x=eps_dense,
y=V_shock_interp,
mode="lines",
line={"color": orange, "width": 2},
name="Interpolation / Extrapolation",
)
)
fig.add_trace(
go.Scatter(
x=shock_points,
y=V_shock,
mode="markers",
marker={"color": blue, "size": 8},
name="Grid points",
)
)
fig.add_vline(
x=float(shock_points[0]),
line={"color": "gray", "dash": "dot", "width": 1},
)
fig.add_vline(
x=float(shock_points[-1]),
line={"color": "gray", "dash": "dot", "width": 1},
)
fig.update_layout(
title="Normal ShockGrid: Interpolation and Extrapolation",
xaxis_title="Shock (ε)",
yaxis_title="u(100 + ε)",
)
fig.show()PiecewiseLinSpacedGrid¶
Some models feature eligibility thresholds — e.g., a means-tested program that applies
only below a wealth cutoff. The value function may jump at the threshold (eligible
households receive a transfer, ineligible ones do not). To ensure the threshold is a
grid point (avoiding interpolation across the discontinuity),
PiecewiseLinSpacedGrid lets you place a breakpoint at the threshold.
For a grid with pieces and :
Piece selection:
jnp.searchsortedon breakpoints determines which piece a value belongs toLocal coordinate:
get_linspace_coordinatewithin the pieceGlobal coordinate: offset by the cumulative number of points in preceding pieces
The breakpoint is the first point of the second piece, guaranteeing it is a grid
point. Because map_coordinates interpolates linearly between adjacent grid points,
and the two points straddling the breakpoint ( and the last point of the previous
piece) are adjacent in the array, the interpolation never crosses the discontinuity —
this is guaranteed by the implementation.
pw_grid = PiecewiseLinSpacedGrid(
pieces=(
Piece(interval="[1, 50)", n_points=5),
Piece(interval="[50, 400]", n_points=7),
)
)
pw_points = pw_grid.to_jax()
# Value function with a jump: means-tested transfer of 0.5 for wealth below threshold
transfer = 0.5
threshold = 50.0
V_pw = crra(pw_points) + transfer * jnp.where(pw_points < threshold, 1.0, 0.0)
print(f"Total grid points: {pw_grid.n_points}")
print(f"Grid: {jnp.round(pw_points, 1)}")
print(f"\nBreakpoint at wealth = 50 is at index 5: grid[5] = {pw_points[5]:.1f}")
print(f"\nV just below threshold (grid[4]): {V_pw[4]:.4f}")
print(f"V at threshold (grid[5]): {V_pw[5]:.4f}")Total grid points: 12
Grid: [ 1. 13.2 25.5 37.7 50. 50. 108.3
166.7 225. 283.30002 341.7 400. ]
Breakpoint at wealth = 50 is at index 5: grid[5] = 50.0
V just below threshold (grid[4]): 0.2172
V at threshold (grid[5]): -0.2828
x_dense_pw = jnp.linspace(1, 400, 500)
pw_coords_dense = pw_grid.get_coordinate(x_dense_pw)
V_pw_interp = map_coordinates(input=V_pw, coordinates=[pw_coords_dense])
# Split traces at the threshold to avoid connecting lines across the jump
mask_below = x_dense_pw < threshold
x_below = x_dense_pw[mask_below]
x_above = x_dense_pw[~mask_below]
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x_below,
y=crra(x_below) + transfer,
mode="lines",
line={"color": "gray", "width": 1},
name="True function",
)
)
fig.add_trace(
go.Scatter(
x=x_above,
y=crra(x_above),
mode="lines",
line={"color": "gray", "width": 1},
name="True function",
showlegend=False,
)
)
fig.add_trace(
go.Scatter(
x=x_below,
y=V_pw_interp[mask_below],
mode="lines",
line={"color": orange, "width": 2},
name="Piecewise interpolation",
)
)
fig.add_trace(
go.Scatter(
x=x_above,
y=V_pw_interp[~mask_below],
mode="lines",
line={"color": orange, "width": 2},
name="Piecewise interpolation",
showlegend=False,
)
)
fig.add_trace(
go.Scatter(
x=pw_points,
y=V_pw,
mode="markers",
marker={"color": blue, "size": 8},
name="Grid points",
)
)
fig.update_layout(
title="PiecewiseLinSpacedGrid: capturing a discontinuity",
xaxis_title="Wealth",
yaxis_title="V(w)",
)
fig.show()map_coordinates internals¶
pylcm’s map_coordinates (src/lcm/ndimage.py) is a modified version of JAX’s
jax.scipy.ndimage.map_coordinates. The key difference is in how it handles values
outside the grid.
The function _compute_indices_and_weights is where the magic happens:
lower_index = jnp.clip(jnp.floor(coordinate), 0, input_size - 2)
upper_weight = coordinate - lower_index
lower_weight = 1 - upper_weightThe lower index is clipped to , ensuring valid array access
But the weight is not clipped — it is simply
coordinate - lower_index
For coordinates inside , the weight falls in , giving standard linear interpolation. For coordinates outside this range:
| Coordinate | Lower index | Weight | Result |
|---|---|---|---|
| Linear interpolation | |||
| 0 | Linear extrapolation using first segment | ||
| Linear extrapolation using last segment |
JAX’s version, by contrast, clips or fills values outside the grid (depending on
mode), which does not give linear extrapolation.
from lcm.ndimage import _compute_indices_and_weights
# A simple grid with 5 points
n = 5
V_demo = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0])
# Three cases: below grid, inside grid, above grid
cases = [
(-0.5, "below grid"),
(1.7, "inside grid"),
(4.5, "above grid"),
]
for coord, label in cases:
coord_arr = jnp.array(coord)
[(lo_idx, lo_wt), (hi_idx, hi_wt)] = _compute_indices_and_weights(coord_arr, n)
result = float(lo_wt * V_demo[lo_idx] + hi_wt * V_demo[hi_idx])
print(f"coordinate = {coord:5.1f} ({label})")
print(f" lower_index = {int(lo_idx)}, upper_index = {int(hi_idx)}")
print(f" lower_weight = {float(lo_wt):.2f}, upper_weight = {float(hi_wt):.2f}")
print(
f" result = {float(lo_wt):.2f} \u00d7 V[{int(lo_idx)}]"
f" + {float(hi_wt):.2f} \u00d7 V[{int(hi_idx)}]"
)
print(
f" = {float(lo_wt):.2f} \u00d7 {float(V_demo[lo_idx]):.0f}"
f" + {float(hi_wt):.2f} \u00d7 {float(V_demo[hi_idx]):.0f}"
)
print(f" = {result:.1f}\n")coordinate = -0.5 (below grid)
lower_index = 0, upper_index = 1
lower_weight = 1.50, upper_weight = -0.50
result = 1.50 × V[0] + -0.50 × V[1]
= 1.50 × 10 + -0.50 × 20
= 5.0
coordinate = 1.7 (inside grid)
lower_index = 1, upper_index = 2
lower_weight = 0.30, upper_weight = 0.70
result = 0.30 × V[1] + 0.70 × V[2]
= 0.30 × 20 + 0.70 × 30
= 27.0
coordinate = 4.5 (above grid)
lower_index = 3, upper_index = 4
lower_weight = -0.50, upper_weight = 1.50
result = -0.50 × V[3] + 1.50 × V[4]
= -0.50 × 40 + 1.50 × 50
= 55.0
Multi-dimensional interpolation¶
When the value function depends on multiple continuous states, each dimension gets its
own coordinate finder. The map_coordinates function then performs multi-linear
interpolation (or extrapolation) by combining the per-dimension coordinates.
For a 2D case (e.g., wealth income shock), this is bilinear interpolation: the function value at a query point is a weighted average of the 4 nearest grid points, with weights determined by the per-dimension fractional coordinates.
# 2D value function: V(wealth, shock) = CRRA(wealth + shock)
wealth_grid = LinSpacedGrid(start=10, stop=400, n_points=8)
w_points = wealth_grid.to_jax()
shock_2d = lcm.shocks.iid.Normal(
mu=0.0, sigma=1.0, n_std=2.0, n_points=5, gauss_hermite=False
)
s_points = shock_2d.to_jax()
# Evaluate on the 2D grid (shape: 8 x 5)
W, S = jnp.meshgrid(w_points, s_points, indexing="ij")
V_2d = crra(W + S)
print(f"Wealth grid: {jnp.round(w_points, 1)}")
print(f"Shock grid: {jnp.round(s_points, 2)}")
print(f"V shape: {V_2d.shape} (wealth \u00d7 shock)")
# Query point
w_query = jnp.array(150.0)
s_query = jnp.array(0.3)
# Per-dimension coordinates
w_coord = wealth_grid.get_coordinate(w_query)
s_coord = shock_2d.get_coordinate(s_query)
# 2D interpolation
V_interp_2d = map_coordinates(input=V_2d, coordinates=[w_coord, s_coord])
V_true_2d = crra(w_query + s_query)
print(f"\nQuery: wealth = {float(w_query)}, shock = {float(s_query)}")
print(f"Wealth coordinate: {float(w_coord):.4f}")
print(f"Shock coordinate: {float(s_coord):.4f}")
print(f"Interpolated V: {float(V_interp_2d):.6f}")
print(f"True V: {float(V_true_2d):.6f}")
print(f"Error: {float(V_interp_2d - V_true_2d):.6f}")Wealth grid: [ 10. 65.700005 121.4 177.1 232.90001 288.6
344.30002 400. ]
Shock grid: [-2. -1. 0. 1. 2.]
V shape: (8, 5) (wealth × shock)
Query: wealth = 150.0, shock = 0.30000001192092896
Wealth coordinate: 2.5128
Shock coordinate: 2.3000
Interpolated V: -0.165309
True V: -0.163136
Error: -0.002173
Summary¶
| Grid type | Coordinate finder | Complexity | When to use |
|---|---|---|---|
LinSpacedGrid | get_linspace_coordinate | O(1) | Uniformly spaced state variables |
LogSpacedGrid | get_logspace_coordinate | O(1) | States with high curvature at low values (e.g., wealth with CRRA) |
PiecewiseLinSpacedGrid | searchsorted + get_linspace_coordinate | O(log k) + O(1) | States with breakpoints (e.g., eligibility thresholds) |
IrregSpacedGrid | get_irreg_coordinate | O(log n) | Arbitrary point placement |
| ShockGrids | get_irreg_coordinate | O(log n) | Stochastic shocks with runtime-determined points |
All coordinate finders produce generalized coordinates that map_coordinates uses for
linear interpolation (inside the grid) or linear extrapolation (outside the grid). The
two-step design keeps the interpolation logic generic while letting each grid type
optimize its coordinate mapping.