A three-period consumption-savings model with two regimes:
Working life (ages 25 and 45): The agent chooses whether to work and how much to consume. A simple tax-and-transfer system guarantees a consumption floor. Savings earn interest.
Retirement (age 65): Terminal regime. The agent consumes out of remaining wealth.
Model¶
An agent lives for three periods (ages 25, 45, and 65). In the first two periods (working life), the agent chooses whether to work and how much to consume . In the final period (retirement), the agent consumes out of remaining wealth.
Working life (ages 25 and 45):
subject to
where is wealth, earnings, the wage, a consumption floor guaranteed by transfers, the tax rate, and end-of-period wealth. The transfer only kicks in when the agent’s resources () fall below the consumption floor.
Retirement (age 65, terminal):
from pprint import pprint
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
from lcm import (
AgeGrid,
DiscreteGrid,
LinSpacedGrid,
LogSpacedGrid,
Model,
Regime,
RegimeTransition,
categorical,
)
from lcm.typing import (
BoolND,
ContinuousAction,
ContinuousState,
DiscreteAction,
FloatND,
ScalarInt,
)Categorical Variables¶
@categorical
class Work:
no: int
yes: int
@categorical
class RegimeId:
working_life: int
retirement: intModel Functions¶
# Utility
def utility(
consumption: ContinuousAction,
work: DiscreteAction,
disutility_of_work: float,
risk_aversion: float,
) -> FloatND:
return consumption ** (1 - risk_aversion) / (
1 - risk_aversion
) - disutility_of_work * (work == Work.yes)
def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
return wealth ** (1 - risk_aversion) / (1 - risk_aversion)
# Auxiliary functions
def earnings(work: DiscreteAction, wage: float) -> FloatND:
return jnp.where(work == Work.yes, wage, 0.0)
def taxes_transfers(
earnings: FloatND,
wealth: ContinuousState,
consumption_floor: float,
tax_rate: float,
) -> FloatND:
return jnp.where(
earnings >= consumption_floor,
tax_rate * (earnings - consumption_floor),
jnp.minimum(0.0, wealth + earnings - consumption_floor),
)
def end_of_period_wealth(
wealth: ContinuousState,
earnings: FloatND,
taxes_transfers: FloatND,
consumption: ContinuousAction,
) -> FloatND:
return wealth + earnings - taxes_transfers - consumption
# State transition
def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
return (1 + interest_rate) * end_of_period_wealth
# Constraints
def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
return end_of_period_wealth >= 0
# Regime transition
def next_regime(age: float, last_working_age: float) -> ScalarInt:
return jnp.where(
age >= last_working_age, RegimeId.retirement, RegimeId.working_life
)Regimes and Model¶
age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.precise_values[-1]
working_life = Regime(
transition=RegimeTransition(next_regime),
active=lambda age: age < retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=next_wealth),
},
actions={
"work": DiscreteGrid(Work),
"consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
},
functions={
"utility": utility,
"earnings": earnings,
"taxes_transfers": taxes_transfers,
"end_of_period_wealth": end_of_period_wealth,
},
constraints={
"borrowing_constraint_working": borrowing_constraint_working,
},
)
retirement = Regime(
transition=None,
active=lambda age: age >= retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=None),
},
functions={"utility": utility_retirement},
)
model = Model(
regimes={
"working_life": working_life,
"retirement": retirement,
},
ages=age_grid,
regime_id_class=RegimeId,
description="A tiny three-period consumption-savings model.",
)Parameters¶
Use model.params_template to see what parameters the model expects, organized
by regime and function.
pprint(dict(model.params_template)){'retirement': mappingproxy({'next_wealth': mappingproxy({}),
'utility': mappingproxy({'risk_aversion': <class 'float'>})}),
'working_life': mappingproxy({'H': mappingproxy({'discount_factor': <class 'float'>}),
'borrowing_constraint_working': mappingproxy({}),
'earnings': mappingproxy({'wage': <class 'float'>}),
'end_of_period_wealth': mappingproxy({}),
'next_regime': mappingproxy({'last_working_age': <class 'float'>}),
'next_wealth': mappingproxy({'interest_rate': <class 'float'>}),
'taxes_transfers': mappingproxy({'consumption_floor': <class 'float'>,
'tax_rate': <class 'float'>}),
'utility': mappingproxy({'disutility_of_work': <class 'float'>,
'risk_aversion': <class 'float'>})})}
Parameters shared across regimes (risk_aversion, discount_factor,
interest_rate) can be specified at the model level. Parameters unique to one
regime go under the regime name.
params = {
"discount_factor": 0.95,
"risk_aversion": 1.5,
"interest_rate": 0.03,
"working_life": {
"utility": {"disutility_of_work": 1.0},
"earnings": {"wage": 20.0},
"taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
"next_regime": {"last_working_age": age_grid.precise_values[-2]},
},
}Solve and Simulate¶
n_agents = 100
result = model.solve_and_simulate(
params=params,
initial_regimes=["working_life"] * n_agents,
initial_states={
"age": jnp.full(n_agents, age_grid.values[0]),
"wealth": jnp.linspace(1, 40, n_agents),
},
)INFO:lcm:Starting solution
INFO:lcm:Age: 65.0
INFO:lcm:Age: 45.0
INFO:lcm:Age: 25.0
INFO:lcm:Starting simulation
INFO:lcm:Age: 25.0
INFO:lcm:Age: 45.0
INFO:lcm:Age: 65.0
df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
df["age"] == retirement_age, "wealth"
]
columns = [
"regime",
"work",
"consumption",
"wealth",
"earnings",
"taxes_transfers",
"end_of_period_wealth",
"value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
precision=1,
na_rep="",
)# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.precise_values[0]
last_working_age = age_grid.precise_values[-2]
df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
index="subject_id",
columns="age",
values="work",
aggfunc="first",
)
work_pattern = (
work_by_age[first_working_age].astype(str)
+ ", "
+ work_by_age[last_working_age].astype(str)
)
assert "yes, yes" not in work_pattern.to_numpy(), (
"Plotting assumes that no agent works in both periods of working life."
)
label_map = {
"yes, no": "low", # work early, not later
"no, yes": "medium", # coast early, work later
"no, no": "high", # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")
# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)
df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
index="initial_wealth",
columns="age",
values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]
summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")fig = px.line(
df_mean,
x="age",
y="consumption",
color="initial_wealth",
title="Consumption by Age",
template="plotly_dark",
)
fig.show()fig = px.line(
df_mean,
x="age",
y="wealth",
color="initial_wealth",
title="Wealth by Age",
template="plotly_dark",
)
fig.show()