Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

A Tiny Example

A three-period consumption-savings model with two regimes:

  • Working life (ages 25 and 45): The agent chooses whether to work and how much to consume. A simple tax-and-transfer system guarantees a consumption floor. Savings earn interest.

  • Retirement (age 65): Terminal regime. The agent consumes out of remaining wealth.

Model

An agent lives for three periods (ages 25, 45, and 65). In the first two periods (working life), the agent chooses whether to work dt{0,1}d_t \in \{0, 1\} and how much to consume ctc_t. In the final period (retirement), the agent consumes out of remaining wealth.

Working life (ages 25 and 45):

Vt(wt)=maxdt,ct{ct1σ1σϕdt+βVt+1(wt+1)}V_t(w_t) = \max_{d_t,\, c_t} \left\{ \frac{c_t^{1-\sigma}}{1-\sigma} - \phi \, d_t + \beta \, V_{t+1}(w_{t+1}) \right\}

subject to

et=dtwˉτ(et,wt)={θ(etc)if etcmin(0,  wt+etc)otherwiseat=wt+etτ(et,wt)ctwt+1=(1+r)atat0\begin{align} e_t &= d_t \cdot \bar{w} \\[4pt] \tau(e_t, w_t) &= \begin{cases} \theta\,(e_t - \underline{c}) & \text{if } e_t \geq \underline{c} \\ \min(0,\; w_t + e_t - \underline{c}) & \text{otherwise} \end{cases} \\[4pt] a_t &= w_t + e_t - \tau(e_t, w_t) - c_t \\[4pt] w_{t+1} &= (1 + r)\, a_t \\[4pt] a_t &\geq 0 \end{align}

where wtw_t is wealth, ete_t earnings, wˉ\bar{w} the wage, c\underline{c} a consumption floor guaranteed by transfers, θ\theta the tax rate, and ata_t end-of-period wealth. The transfer only kicks in when the agent’s resources (wt+etw_t + e_t) fall below the consumption floor.

Retirement (age 65, terminal):

V2(w2)=maxc2w2c21σ1σV_2(w_2) = \max_{c_2 \leq w_2} \frac{c_2^{1-\sigma}}{1-\sigma}
from pprint import pprint

import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from lcm import (
    AgeGrid,
    DiscreteGrid,
    LinSpacedGrid,
    LogSpacedGrid,
    Model,
    Regime,
    RegimeTransition,
    categorical,
)
from lcm.typing import (
    BoolND,
    ContinuousAction,
    ContinuousState,
    DiscreteAction,
    FloatND,
    ScalarInt,
)

Categorical Variables

@categorical
class Work:
    no: int
    yes: int


@categorical
class RegimeId:
    working_life: int
    retirement: int

Model Functions

# Utility


def utility(
    consumption: ContinuousAction,
    work: DiscreteAction,
    disutility_of_work: float,
    risk_aversion: float,
) -> FloatND:
    return consumption ** (1 - risk_aversion) / (
        1 - risk_aversion
    ) - disutility_of_work * (work == Work.yes)


def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
    return wealth ** (1 - risk_aversion) / (1 - risk_aversion)


# Auxiliary functions


def earnings(work: DiscreteAction, wage: float) -> FloatND:
    return jnp.where(work == Work.yes, wage, 0.0)


def taxes_transfers(
    earnings: FloatND,
    wealth: ContinuousState,
    consumption_floor: float,
    tax_rate: float,
) -> FloatND:
    return jnp.where(
        earnings >= consumption_floor,
        tax_rate * (earnings - consumption_floor),
        jnp.minimum(0.0, wealth + earnings - consumption_floor),
    )


def end_of_period_wealth(
    wealth: ContinuousState,
    earnings: FloatND,
    taxes_transfers: FloatND,
    consumption: ContinuousAction,
) -> FloatND:
    return wealth + earnings - taxes_transfers - consumption


# State transition


def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
    return (1 + interest_rate) * end_of_period_wealth


# Constraints


def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
    return end_of_period_wealth >= 0


# Regime transition


def next_regime(age: float, last_working_age: float) -> ScalarInt:
    return jnp.where(
        age >= last_working_age, RegimeId.retirement, RegimeId.working_life
    )

Regimes and Model

age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.precise_values[-1]

working_life = Regime(
    transition=RegimeTransition(next_regime),
    active=lambda age: age < retirement_age,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=next_wealth),
    },
    actions={
        "work": DiscreteGrid(Work),
        "consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
    },
    functions={
        "utility": utility,
        "earnings": earnings,
        "taxes_transfers": taxes_transfers,
        "end_of_period_wealth": end_of_period_wealth,
    },
    constraints={
        "borrowing_constraint_working": borrowing_constraint_working,
    },
)

retirement = Regime(
    transition=None,
    active=lambda age: age >= retirement_age,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=None),
    },
    functions={"utility": utility_retirement},
)

model = Model(
    regimes={
        "working_life": working_life,
        "retirement": retirement,
    },
    ages=age_grid,
    regime_id_class=RegimeId,
    description="A tiny three-period consumption-savings model.",
)

Parameters

Use model.params_template to see what parameters the model expects, organized by regime and function.

pprint(dict(model.params_template))
{'retirement': mappingproxy({'next_wealth': mappingproxy({}),
                             'utility': mappingproxy({'risk_aversion': <class 'float'>})}),
 'working_life': mappingproxy({'H': mappingproxy({'discount_factor': <class 'float'>}),
                               'borrowing_constraint_working': mappingproxy({}),
                               'earnings': mappingproxy({'wage': <class 'float'>}),
                               'end_of_period_wealth': mappingproxy({}),
                               'next_regime': mappingproxy({'last_working_age': <class 'float'>}),
                               'next_wealth': mappingproxy({'interest_rate': <class 'float'>}),
                               'taxes_transfers': mappingproxy({'consumption_floor': <class 'float'>,
                                                                'tax_rate': <class 'float'>}),
                               'utility': mappingproxy({'disutility_of_work': <class 'float'>,
                                                        'risk_aversion': <class 'float'>})})}

Parameters shared across regimes (risk_aversion, discount_factor, interest_rate) can be specified at the model level. Parameters unique to one regime go under the regime name.

params = {
    "discount_factor": 0.95,
    "risk_aversion": 1.5,
    "interest_rate": 0.03,
    "working_life": {
        "utility": {"disutility_of_work": 1.0},
        "earnings": {"wage": 20.0},
        "taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
        "next_regime": {"last_working_age": age_grid.precise_values[-2]},
    },
}

Solve and Simulate

n_agents = 100

result = model.solve_and_simulate(
    params=params,
    initial_regimes=["working_life"] * n_agents,
    initial_states={
        "age": jnp.full(n_agents, age_grid.values[0]),
        "wealth": jnp.linspace(1, 40, n_agents),
    },
)
INFO:lcm:Starting solution
INFO:lcm:Age: 65.0
INFO:lcm:Age: 45.0
INFO:lcm:Age: 25.0
INFO:lcm:Starting simulation
INFO:lcm:Age: 25.0
INFO:lcm:Age: 45.0
INFO:lcm:Age: 65.0
df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
    df["age"] == retirement_age, "wealth"
]
columns = [
    "regime",
    "work",
    "consumption",
    "wealth",
    "earnings",
    "taxes_transfers",
    "end_of_period_wealth",
    "value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
    precision=1,
    na_rep="",
)
Loading...
# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.precise_values[0]
last_working_age = age_grid.precise_values[-2]

df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
    index="subject_id",
    columns="age",
    values="work",
    aggfunc="first",
)
work_pattern = (
    work_by_age[first_working_age].astype(str)
    + ", "
    + work_by_age[last_working_age].astype(str)
)
assert "yes, yes" not in work_pattern.to_numpy(), (
    "Plotting assumes that no agent works in both periods of working life."
)

label_map = {
    "yes, no": "low",  # work early, not later
    "no, yes": "medium",  # coast early, work later
    "no, no": "high",  # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")

# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)

df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
    numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
    index="initial_wealth",
    columns="age",
    values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]

summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")
Loading...
fig = px.line(
    df_mean,
    x="age",
    y="consumption",
    color="initial_wealth",
    title="Consumption by Age",
    template="plotly_dark",
)
fig.show()
Loading...
fig = px.line(
    df_mean,
    x="age",
    y="wealth",
    color="initial_wealth",
    title="Wealth by Age",
    template="plotly_dark",
)
fig.show()
Loading...