Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

The Transition System

pylcm models involve two independent layers of transitions that compose during backward induction:

  1. State transitions — how a state variable arrived at its current value (attached to grids)

  2. Regime transitions — which regime the agent enters next period (attached to the regime itself)

The two layers have opposite orientations:

  • Regime transitions are forward-looking. The transition field on a Regime answers “where does the agent go next?” It lives on the source regime and points toward the future.

  • State transitions are backward-looking. The transition parameter on a grid answers “how did this state variable reach its current value?” It lives on the grid that receives the value. In multi-regime models, per-boundary mappings are placed on the target regime’s grid.

This notebook explains how each layer works, how per-boundary transitions resolve cross-regime state mismatches, and how everything composes in the Bellman equation.

import jax
import jax.numpy as jnp

from lcm import (
    DiscreteGrid,
    DiscreteMarkovGrid,
    LinSpacedGrid,
    Regime,
    RegimeTransition,
    categorical,
)
from lcm.typing import (
    ContinuousState,
    DiscreteState,
    FloatND,
    ScalarInt,
)

State Transition Mechanics

State transitions are attached directly to grid objects via the transition parameter. There are four cases:

Grid configurationBehavior
transition=some_funcDeterministic: s=f(s,a,)s' = f(s, a, \ldots)
transition=NoneFixed: s=ss' = s (identity auto-generated)
DiscreteMarkovGrid(transition=func)Stochastic: probability-weighted expectation
Shock grids (lcm.shocks.*)Intrinsic transitions with interpolated weights

Deterministic state transitions

A state transition function defines how the current state value was determined from last period’s states, actions, and parameters. The function’s argument names are resolved from the regime’s namespace.

def next_wealth(
    wealth: ContinuousState,
    consumption: ContinuousState,
    interest_rate: float,
) -> ContinuousState:
    return (1 + interest_rate) * (wealth - consumption)


# Attached to the grid:
wealth_grid = LinSpacedGrid(start=0, stop=100, n_points=50, transition=next_wealth)

Fixed states and identity transitions

When transition=None, pylcm auto-generates an _IdentityTransition internally. This shows up in regime.get_all_functions() under the key "next_<state_name>".

@categorical
class EducationLevel:
    low: int
    high: int


@categorical
class RegimeId:
    working: int
    retired: int


def utility(wealth: ContinuousState) -> FloatND:
    return jnp.log(wealth + 1.0)


def next_regime() -> ScalarInt:
    return RegimeId.retired


regime = Regime(
    transition=RegimeTransition(next_regime),
    states={
        "education": DiscreteGrid(EducationLevel, transition=None),
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=10, transition=None),
    },
    functions={"utility": utility},
)

all_funcs = regime.get_all_functions()
print("Function keys:", list(all_funcs.keys()))
print("next_education type:", type(all_funcs["next_education"]).__name__)
print("next_wealth type:   ", type(all_funcs["next_wealth"]).__name__)
Function keys: ['utility', 'H', 'next_education', 'next_wealth', 'next_regime']
next_education type: _IdentityTransition
next_wealth type:    _IdentityTransition

Both fixed states produce _IdentityTransition objects. These are marked with _is_auto_identity = True so that validation can distinguish them from user-provided transitions.

Stochastic state transitions (DiscreteMarkovGrid)

For DiscreteMarkovGrid, the transition function returns a probability array over the categories. During the solve step, pylcm computes a probability-weighted expectation over next-period states:

E[V(s)]=sP(ss)V(s)\mathbb{E}[V(s')] = \sum_{s'} P(s' \mid s) \, V(s')
@categorical
class Health:
    bad: int
    good: int


def health_transition(health: DiscreteState) -> FloatND:
    return jnp.where(
        health == Health.good,
        jnp.array([0.1, 0.9]),  # good → 90% stay good
        jnp.array([0.6, 0.4]),  # bad  → 40% recover
    )


health_grid = DiscreteMarkovGrid(Health, transition=health_transition)

# Inspect the transition probabilities
for state_name, code in [("bad", Health.bad), ("good", Health.good)]:
    probs = health_transition(jnp.array(code))
    print(f"P(next | {state_name}) = {probs}")
P(next | bad) = [0.6 0.4]
P(next | good) = [0.1 0.9]

Shock grids

Shock grids (from lcm.shocks.iid and lcm.shocks.ar1) have intrinsic transitions computed from the distribution. For IID shocks, the transition probabilities are the same regardless of the current value. For AR(1) shocks, probabilities depend on the current state.

Shock grids do not accept a transition parameter — their transitions are built-in.

import lcm.shocks.iid

shock = lcm.shocks.iid.Normal(
    n_points=5, gauss_hermite=False, mu=0.0, sigma=1.0, n_std=2.5
)
print("Grid points:", shock.to_jax())
Grid points: [-2.5  -1.25  0.    1.25  2.5 ]

Per-Boundary State Transitions

When a discrete state has different categories across regimes, a simple callable transition is not enough — you need to map from one category set to another at the regime boundary.

The solution: a mapping transition keyed by (source_regime, target_regime) pairs, placed on the target regime’s grid.

Example: different health categories

Suppose working life has three health states (disabled, bad, good) but retirement only has two (bad, good). The transition from working to retired needs an explicit mapping.

@categorical
class HealthWorking:
    disabled: int
    bad: int
    good: int


@categorical
class HealthRetired:
    bad: int
    good: int


def map_working_to_retired(health: DiscreteState) -> DiscreteState:
    """Map 3-category working health to 2-category retired health."""
    return jnp.where(
        health == HealthWorking.good,
        HealthRetired.good,
        HealthRetired.bad,
    )


# Verify the mapping
for name, code in [("disabled", 0), ("bad", 1), ("good", 2)]:
    result = map_working_to_retired(jnp.array(code))
    print(f"working {name} ({code}) → retired code {int(result)}")
working disabled (0) → retired code 0
working bad (1) → retired code 0
working good (2) → retired code 1

The mapping is placed on the target regime’s grid:

health_retired_grid = DiscreteGrid(
    HealthRetired,
    transition={
        ("working", "retired"): map_working_to_retired,
    },
)

Resolution priority

When resolving which transition function to use at a regime boundary (source, target), pylcm checks (in order):

  1. Target grid mapping for (source, target)

  2. Source grid mapping for (source, target)

  3. Source grid’s callable transition

  4. Target grid’s callable transition

  5. Auto-generated identity (if categories match)

If the categories differ across regimes and no explicit mapping is found, ModelInitializationError is raised.

Parameterized per-boundary transitions

Per-boundary mapping functions can take parameters beyond the state variable itself. A common use case is a continuous state whose transition law differs across regime boundaries — for example, wealth that grows at a rate specific to the target regime.

When pylcm resolves a per-boundary transition from the target grid’s mapping (priority 1 above), any parameters in that function are looked up in the target regime’s parameter template. This means the user specifies the parameter value under the target regime in the params dict, and pylcm automatically routes it to the transition function at the boundary.

The rule is simple: whoever owns the mapping owns the parameters. Since per-boundary mappings live on the target regime’s grid, their parameters come from the target regime.

Example: regime-specific growth rate

Consider a two-regime model (phase 1 → phase 2) where wealth grows at a rate that is specific to phase 2. The transition function on phase 2’s wealth grid takes a growth_rate parameter:

def next_wealth_at_boundary(
    wealth: ContinuousState,
    growth_rate: float,
) -> ContinuousState:
    """Wealth transition at the phase1 → phase2 boundary.

    The growth_rate parameter is resolved from phase2's params template,
    because the mapping lives on phase2's grid.
    """
    return (1 + growth_rate) * wealth


# Phase 2's wealth grid declares the per-boundary mapping:
phase2_wealth_grid = LinSpacedGrid(
    start=0,
    stop=100,
    n_points=20,
    transition={
        ("phase1", "phase2"): next_wealth_at_boundary,
    },
)

Because the mapping {("phase1", "phase2"): next_wealth_at_boundary} lives on phase 2’s grid, the growth_rate parameter appears in phase 2’s parameter template. The user supplies it under "phase2" in the params dict:

params = {
    "phase1": {...},
    "phase2": {
        "next_wealth": {"growth_rate": 0.05},
        ...
    },
}

Internally, pylcm detects that the transition was resolved from the target grid’s mapping and renames the parameter to a cross-boundary qualified name (e.g., phase2__next_wealth__growth_rate). At solve and simulation time, the value is looked up from internal_params["phase2"] — not from "phase1" — even though the transition is evaluated as part of phase 1’s backward induction step.

Regime Transition Mechanics

Regime transitions determine which regime the agent enters next period. Internally, both deterministic and stochastic transitions are converted to a uniform probability array format.

Deterministic transitions → one-hot encoding

A RegimeTransition wraps a function that returns an integer regime ID. Internally, _wrap_deterministic_regime_transition converts this to a one-hot probability array using jax.nn.one_hot:

@categorical
class RegimeIdExample:
    working: int
    retired: int
    dead: int


# Deterministic: retire at age 65
def next_regime_det(age: float, retirement_age: float) -> ScalarInt:
    return jnp.where(
        age >= retirement_age, RegimeIdExample.retired, RegimeIdExample.working
    )


# What pylcm does internally:
regime_idx = next_regime_det(age=50.0, retirement_age=65.0)
one_hot = jax.nn.one_hot(regime_idx, num_classes=3)
print(f"Regime index: {int(regime_idx)}")
print(f"One-hot:      {one_hot}  (= [P(working), P(retired), P(dead)])")
Regime index: 0
One-hot:      [1. 0. 0.]  (= [P(working), P(retired), P(dead)])

Stochastic transitions → probability array

A MarkovRegimeTransition wraps a function that directly returns a probability array. No conversion is needed — the array is used as-is.

def next_regime_stoch(survival_prob: float) -> FloatND:
    """Alive → [P(working), P(retired), P(dead)]."""
    return jnp.array([survival_prob, 0.0, 1 - survival_prob])


probs = next_regime_stoch(survival_prob=0.98)
print(f"Probabilities: {probs}  (= [P(working), P(retired), P(dead)])")
Probabilities: [0.98 0.   0.02]  (= [P(working), P(retired), P(dead)])

After wrapping, the probability array is further converted to a dictionary keyed by regime name (via _wrap_regime_transition_probs), giving a uniform internal representation regardless of whether the original transition was deterministic or stochastic.

How Transitions Compose in the Bellman Equation

The value function computation depends on the regime type:

Terminal regimes

No continuation value. The value function equals the utility directly:

VT(s)=U(s)V_T(s) = U(s)

Non-terminal with deterministic regime transition

The continuation value comes from a single next-period regime:

Vt(s)=maxa{U(s,a)+βVt+1r(s)}V_t(s) = \max_a \left\{ U(s, a) + \beta \, V_{t+1}^{r'}(s') \right\}

where rr' is the deterministically chosen next regime and s=g(s,a)s' = g(s, a) is the next-period state.

Non-terminal with stochastic regime transition

The continuation value is an expectation over possible next regimes:

Vt(s)=maxa{U(s,a)+βrprVt+1r(s)}V_t(s) = \max_a \left\{ U(s, a) + \beta \sum_r p_r \, V_{t+1}^{r}(s') \right\}

where prp_r is the probability of transitioning to regime rr.

Adding stochastic state transitions

When a state has a Markov transition (DiscreteMarkovGrid) or shock grid, an additional layer of expectation is added inside the max:

Vt(s)=maxa{U(s,a)+βrprsP(ss)Vt+1r(s)}V_t(s) = \max_a \left\{ U(s, a) + \beta \sum_r p_r \sum_{s'} P(s' \mid s) \, V_{t+1}^{r}(s') \right\}

The inner sum handles the stochastic state transition; the outer sum handles the stochastic regime transition. When either is deterministic, its corresponding sum collapses to a single term.

Summary

ComponentDeterministicStochastic
State transitions=g(s,a)s' = g(s, a)sP(ss)V(s)\sum_{s'} P(s' \mid s) \, V(s')
Regime transitionOne-hot \rightarrow single VrV^{r'}rprVr\sum_r p_r \, V^r
Internal formatBoth converted to probability arrays

The uniform probability format means the backward induction algorithm treats all transitions the same way — deterministic transitions are just the special case where one probability is 1 and the rest are 0.