pylcm models involve two independent layers of transitions that compose during backward induction:
State transitions — how a state variable arrived at its current value (attached to grids)
Regime transitions — which regime the agent enters next period (attached to the regime itself)
The two layers have opposite orientations:
Regime transitions are forward-looking. The
transitionfield on aRegimeanswers “where does the agent go next?” It lives on the source regime and points toward the future.State transitions are backward-looking. The
transitionparameter on a grid answers “how did this state variable reach its current value?” It lives on the grid that receives the value. In multi-regime models, per-boundary mappings are placed on the target regime’s grid.
This notebook explains how each layer works, how per-boundary transitions resolve cross-regime state mismatches, and how everything composes in the Bellman equation.
import jax
import jax.numpy as jnp
from lcm import (
DiscreteGrid,
DiscreteMarkovGrid,
LinSpacedGrid,
Regime,
RegimeTransition,
categorical,
)
from lcm.typing import (
ContinuousState,
DiscreteState,
FloatND,
ScalarInt,
)State Transition Mechanics¶
State transitions are attached directly to grid objects via the transition
parameter. There are four cases:
| Grid configuration | Behavior |
|---|---|
transition=some_func | Deterministic: |
transition=None | Fixed: (identity auto-generated) |
DiscreteMarkovGrid(transition=func) | Stochastic: probability-weighted expectation |
Shock grids (lcm.shocks.*) | Intrinsic transitions with interpolated weights |
Deterministic state transitions¶
A state transition function defines how the current state value was determined from last period’s states, actions, and parameters. The function’s argument names are resolved from the regime’s namespace.
def next_wealth(
wealth: ContinuousState,
consumption: ContinuousState,
interest_rate: float,
) -> ContinuousState:
return (1 + interest_rate) * (wealth - consumption)
# Attached to the grid:
wealth_grid = LinSpacedGrid(start=0, stop=100, n_points=50, transition=next_wealth)Fixed states and identity transitions¶
When transition=None, pylcm auto-generates an _IdentityTransition internally.
This shows up in regime.get_all_functions() under the key "next_<state_name>".
@categorical
class EducationLevel:
low: int
high: int
@categorical
class RegimeId:
working: int
retired: int
def utility(wealth: ContinuousState) -> FloatND:
return jnp.log(wealth + 1.0)
def next_regime() -> ScalarInt:
return RegimeId.retired
regime = Regime(
transition=RegimeTransition(next_regime),
states={
"education": DiscreteGrid(EducationLevel, transition=None),
"wealth": LinSpacedGrid(start=0, stop=50, n_points=10, transition=None),
},
functions={"utility": utility},
)
all_funcs = regime.get_all_functions()
print("Function keys:", list(all_funcs.keys()))
print("next_education type:", type(all_funcs["next_education"]).__name__)
print("next_wealth type: ", type(all_funcs["next_wealth"]).__name__)Function keys: ['utility', 'H', 'next_education', 'next_wealth', 'next_regime']
next_education type: _IdentityTransition
next_wealth type: _IdentityTransition
Both fixed states produce _IdentityTransition objects. These are marked with
_is_auto_identity = True so that validation can distinguish them from
user-provided transitions.
Stochastic state transitions (DiscreteMarkovGrid)¶
For DiscreteMarkovGrid, the transition function returns a probability array over
the categories. During the solve step, pylcm computes a probability-weighted
expectation over next-period states:
@categorical
class Health:
bad: int
good: int
def health_transition(health: DiscreteState) -> FloatND:
return jnp.where(
health == Health.good,
jnp.array([0.1, 0.9]), # good → 90% stay good
jnp.array([0.6, 0.4]), # bad → 40% recover
)
health_grid = DiscreteMarkovGrid(Health, transition=health_transition)
# Inspect the transition probabilities
for state_name, code in [("bad", Health.bad), ("good", Health.good)]:
probs = health_transition(jnp.array(code))
print(f"P(next | {state_name}) = {probs}")P(next | bad) = [0.6 0.4]
P(next | good) = [0.1 0.9]
Shock grids¶
Shock grids (from lcm.shocks.iid and lcm.shocks.ar1) have intrinsic
transitions computed from the distribution. For IID shocks, the transition
probabilities are the same regardless of the current value. For AR(1) shocks,
probabilities depend on the current state.
Shock grids do not accept a transition parameter — their transitions are
built-in.
import lcm.shocks.iid
shock = lcm.shocks.iid.Normal(
n_points=5, gauss_hermite=False, mu=0.0, sigma=1.0, n_std=2.5
)
print("Grid points:", shock.to_jax())Grid points: [-2.5 -1.25 0. 1.25 2.5 ]
Per-Boundary State Transitions¶
When a discrete state has different categories across regimes, a simple callable transition is not enough — you need to map from one category set to another at the regime boundary.
The solution: a mapping transition keyed by (source_regime, target_regime)
pairs, placed on the target regime’s grid.
Example: different health categories¶
Suppose working life has three health states (disabled, bad, good) but retirement only has two (bad, good). The transition from working to retired needs an explicit mapping.
@categorical
class HealthWorking:
disabled: int
bad: int
good: int
@categorical
class HealthRetired:
bad: int
good: int
def map_working_to_retired(health: DiscreteState) -> DiscreteState:
"""Map 3-category working health to 2-category retired health."""
return jnp.where(
health == HealthWorking.good,
HealthRetired.good,
HealthRetired.bad,
)
# Verify the mapping
for name, code in [("disabled", 0), ("bad", 1), ("good", 2)]:
result = map_working_to_retired(jnp.array(code))
print(f"working {name} ({code}) → retired code {int(result)}")working disabled (0) → retired code 0
working bad (1) → retired code 0
working good (2) → retired code 1
The mapping is placed on the target regime’s grid:
health_retired_grid = DiscreteGrid(
HealthRetired,
transition={
("working", "retired"): map_working_to_retired,
},
)Resolution priority¶
When resolving which transition function to use at a regime boundary (source, target), pylcm checks (in order):
Target grid mapping for
(source, target)Source grid mapping for
(source, target)Source grid’s callable transition
Target grid’s callable transition
Auto-generated identity (if categories match)
If the categories differ across regimes and no explicit mapping is found,
ModelInitializationError is raised.
Parameterized per-boundary transitions¶
Per-boundary mapping functions can take parameters beyond the state variable itself. A common use case is a continuous state whose transition law differs across regime boundaries — for example, wealth that grows at a rate specific to the target regime.
When pylcm resolves a per-boundary transition from the target grid’s mapping
(priority 1 above), any parameters in that function are looked up in the
target regime’s parameter template. This means the user specifies the
parameter value under the target regime in the params dict, and pylcm
automatically routes it to the transition function at the boundary.
The rule is simple: whoever owns the mapping owns the parameters. Since per-boundary mappings live on the target regime’s grid, their parameters come from the target regime.
Example: regime-specific growth rate¶
Consider a two-regime model (phase 1 → phase 2) where wealth grows at a rate
that is specific to phase 2. The transition function on phase 2’s wealth grid
takes a growth_rate parameter:
def next_wealth_at_boundary(
wealth: ContinuousState,
growth_rate: float,
) -> ContinuousState:
"""Wealth transition at the phase1 → phase2 boundary.
The growth_rate parameter is resolved from phase2's params template,
because the mapping lives on phase2's grid.
"""
return (1 + growth_rate) * wealth
# Phase 2's wealth grid declares the per-boundary mapping:
phase2_wealth_grid = LinSpacedGrid(
start=0,
stop=100,
n_points=20,
transition={
("phase1", "phase2"): next_wealth_at_boundary,
},
)Because the mapping {("phase1", "phase2"): next_wealth_at_boundary} lives on
phase 2’s grid, the growth_rate parameter appears in phase 2’s parameter
template. The user supplies it under "phase2" in the params dict:
params = {
"phase1": {...},
"phase2": {
"next_wealth": {"growth_rate": 0.05},
...
},
}Internally, pylcm detects that the transition was resolved from the target
grid’s mapping and renames the parameter to a cross-boundary qualified name
(e.g., phase2__next_wealth__growth_rate). At solve and simulation time, the
value is looked up from internal_params["phase2"] — not from "phase1" —
even though the transition is evaluated as part of phase 1’s backward induction
step.
Regime Transition Mechanics¶
Regime transitions determine which regime the agent enters next period. Internally, both deterministic and stochastic transitions are converted to a uniform probability array format.
Deterministic transitions → one-hot encoding¶
A RegimeTransition wraps a function that returns an integer regime ID.
Internally, _wrap_deterministic_regime_transition converts this to a one-hot
probability array using jax.nn.one_hot:
@categorical
class RegimeIdExample:
working: int
retired: int
dead: int
# Deterministic: retire at age 65
def next_regime_det(age: float, retirement_age: float) -> ScalarInt:
return jnp.where(
age >= retirement_age, RegimeIdExample.retired, RegimeIdExample.working
)
# What pylcm does internally:
regime_idx = next_regime_det(age=50.0, retirement_age=65.0)
one_hot = jax.nn.one_hot(regime_idx, num_classes=3)
print(f"Regime index: {int(regime_idx)}")
print(f"One-hot: {one_hot} (= [P(working), P(retired), P(dead)])")Regime index: 0
One-hot: [1. 0. 0.] (= [P(working), P(retired), P(dead)])
Stochastic transitions → probability array¶
A MarkovRegimeTransition wraps a function that directly returns a probability
array. No conversion is needed — the array is used as-is.
def next_regime_stoch(survival_prob: float) -> FloatND:
"""Alive → [P(working), P(retired), P(dead)]."""
return jnp.array([survival_prob, 0.0, 1 - survival_prob])
probs = next_regime_stoch(survival_prob=0.98)
print(f"Probabilities: {probs} (= [P(working), P(retired), P(dead)])")Probabilities: [0.98 0. 0.02] (= [P(working), P(retired), P(dead)])
After wrapping, the probability array is further converted to a dictionary keyed by
regime name (via _wrap_regime_transition_probs), giving a uniform internal
representation regardless of whether the original transition was deterministic or
stochastic.
How Transitions Compose in the Bellman Equation¶
The value function computation depends on the regime type:
Non-terminal with deterministic regime transition¶
The continuation value comes from a single next-period regime:
where is the deterministically chosen next regime and is the next-period state.
Non-terminal with stochastic regime transition¶
The continuation value is an expectation over possible next regimes:
where is the probability of transitioning to regime .
Adding stochastic state transitions¶
When a state has a Markov transition (DiscreteMarkovGrid) or shock grid, an
additional layer of expectation is added inside the max:
The inner sum handles the stochastic state transition; the outer sum handles the stochastic regime transition. When either is deterministic, its corresponding sum collapses to a single term.
Summary¶
| Component | Deterministic | Stochastic |
|---|---|---|
| State transition | ||
| Regime transition | One-hot single | |
| Internal format | Both converted to probability arrays | — |
The uniform probability format means the backward induction algorithm treats all transitions the same way — deterministic transitions are just the special case where one probability is 1 and the rest are 0.