-
Notifications
You must be signed in to change notification settings - Fork 3
ENH: Add taste shocks #247
Copy link
Copy link
Open
Description
Motivation
pylcm currently has no support for taste shocks (additive utility shocks with extreme value distribution). The infrastructure for this feature was partially sketched in two modules that have been removed as dead code in the cleanup (#243):
src/lcm/max_Q_over_c.py— maximization of Q over continuous actionssrc/lcm/max_Qc_over_d.py— maximization of Qc over discrete actions (includes a stub for extreme value shocks vialogsumexp)tests/test_max_Qc_over_d.py— tests for the discrete maximization module
What needs to happen
- Design how taste shocks integrate with the current Regime-based API
- Implement the
logsumexpaggregation over discrete actions (the_max_Qc_over_d_extreme_value_shocksfunction was a starting point) - Add a
scaleparameter for the shock distribution - Create a test model with taste shocks and verify against an analytical or known-good solution
Archived code
The deleted modules are preserved below for reference. They will need significant adaptation to the current codebase.
src/lcm/max_Q_over_c.py
import functools
from collections.abc import Callable
from types import MappingProxyType
import jax.numpy as jnp
from jax import Array
from lcm.argmax import argmax_and_max
from lcm.dispatchers import productmap
from lcm.typing import (
ArgmaxQOverCFunction,
BoolND,
FloatND,
IntND,
MaxQOverCFunction,
Period,
RegimeName,
)
def get_max_Q_over_c(
*,
Q_and_F: Callable[..., tuple[FloatND, BoolND]],
continuous_action_names: tuple[str, ...],
state_and_discrete_action_names: tuple[str, ...],
) -> MaxQOverCFunction:
r"""Get the function returning the maximum of Q over continuous actions.
The state-action value function $Q$ is defined as:
```{math}
Q(x, a) = H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
```
with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
is pre-implemented in LCM).
Fixing a state and discrete action, maximizing over the feasible continuous actions,
we get the $Q^c$ function:
```{math}
Q^{c}(x, a^d) = \max_{a^c} Q(x, a^d, a^c).
```
This last step is handled by the function returned here.
Args:
Q_and_F: A function that takes a state-action combination and returns the action
value of that combination and whether the state-action combination is
feasible.
continuous_action_names: Tuple of action variable names that are continuous.
state_and_discrete_action_names: Tuple of state and discrete action variable
names.
Returns:
Qc, i.e., the function that calculates the maximum of the Q-function over the
feasible continuous actions.
"""
if continuous_action_names:
Q_and_F = productmap(
func=Q_and_F,
variables=continuous_action_names,
)
@functools.wraps(Q_and_F)
def max_Q_over_c(
next_V_arr: MappingProxyType[RegimeName, FloatND],
period: Period,
**states_actions_params: Array,
) -> FloatND:
Q_arr, F_arr = Q_and_F(
next_V_arr=next_V_arr,
period=period,
**states_actions_params,
)
return Q_arr.max(where=F_arr, initial=-jnp.inf)
return productmap(func=max_Q_over_c, variables=state_and_discrete_action_names)
def get_argmax_and_max_Q_over_c(
*,
Q_and_F: Callable[..., tuple[FloatND, BoolND]],
continuous_action_names: tuple[str, ...],
) -> ArgmaxQOverCFunction:
r"""Get the function returning the arguments maximizing Q over continuous actions.
The state-action value function $Q$ is defined as:
```{math}
Q(x, a) = H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
```
with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
is pre-implemented in LCM).
Fixing a state and discrete action but choosing the feasible continuous actions that
maximizes Q, we get
```{math}
\pi^{c}(x, a^d) = \argmax_{a^c} Q(x, a^d, a^c).
```
This last step is handled by the function returned here.
Args:
Q_and_F: A function that takes a state-action combination and returns the action
value of that combination and whether the state-action combination is
feasible.
continuous_action_names: Tuple of action variable names that are continuous.
Returns:
Function that calculates the argument maximizing Q over the feasible continuous
actions and the maximum itself. The argument maximizing Q is the policy
function of the continuous actions, conditional on the states and discrete
actions. The maximum corresponds to the Qc-function.
"""
if continuous_action_names:
Q_and_F = productmap(
func=Q_and_F,
variables=continuous_action_names,
)
@functools.wraps(Q_and_F)
def argmax_and_max_Q_over_c(
next_V_arr: MappingProxyType[RegimeName, FloatND],
period: Period,
**states_actions_params: Array,
) -> tuple[IntND, FloatND]:
Q_arr, F_arr = Q_and_F(
next_V_arr=next_V_arr,
period=period,
**states_actions_params,
)
return argmax_and_max(Q_arr, where=F_arr, initial=-jnp.inf)
return argmax_and_max_Q_over_csrc/lcm/max_Qc_over_d.py
from functools import partial
from typing import Any
import jax
import pandas as pd
from lcm.argmax import argmax_and_max
from lcm.interfaces import ShockType
from lcm.typing import (
ArgmaxQcOverDFunction,
FloatND,
IntND,
MaxQcOverDFunction,
)
def get_max_Qc_over_d(
*,
random_utility_shock_type: ShockType,
variable_info: pd.DataFrame,
is_terminal: bool,
) -> MaxQcOverDFunction:
r"""Get the function returning the maximum of Qc over discrete actions.
The state-action value function $Q$ is defined as:
```{math}
Q(x, a) = H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
```
with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
is pre-implemented in LCM).
Fixing a state and discrete action, maximizing over the feasible continuous actions,
we get the $Q^c$ function:
```{math}
Q^{c}(x, a^d) = \max_{a^c} Q(x, a^d, a^c).
```
And maximizing over the discrete actions, we get the value function:
```{math}
V(x) = \max_{a^d} Q^{c}(x, a^d).
```
This last step is handled by the function returned here.
Args:
random_utility_shock_type: Type of action shock. Currently only Shock.NONE is
supported. Work for "extreme_value" is in progress.
variable_info: DataFrame with information about the variables.
is_terminal: Whether the function is created for a terminal regime.
Returns:
Function that returns the argument that maximize the Qc-function over the
discrete actions. The maximizing argument corresponds to the policy function of
the discrete actions.
"""
if is_terminal:
variable_info = variable_info.query("enters_concurrent_valuation")
discrete_action_axes = _determine_discrete_action_axes_solution(variable_info)
if random_utility_shock_type == ShockType.NONE:
func = _max_Qc_over_d_no_shocks
elif random_utility_shock_type == ShockType.EXTREME_VALUE:
raise NotImplementedError("Extreme value shocks are not yet implemented.")
else:
raise ValueError(f"Invalid shock_type: {random_utility_shock_type}.")
return partial(func, discrete_action_axes=discrete_action_axes)
def get_argmax_and_max_Qc_over_d(
*,
variable_info: pd.DataFrame,
) -> ArgmaxQcOverDFunction:
r"""Get the function returning the arguments maximizing Qc over discrete actions.
The state-action value function $Q$ is defined as:
```{math}
Q(x, a) = H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
```
with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
is pre-implemented in LCM).
Fixing a state and discrete action, maximizing over the feasible continuous actions,
we get the $Q^c$ function:
```{math}
Q^{c}(x, a^d) = \max_{a^c} Q(x, a^d, a^c).
```
Taking the argmax over the discrete actions, we get the policy function of the
discrete actions:
```{math}
\pi^{d}(x) = \argmax_{a^d} Q^{c}(x, a^d).
```
This last step is handled by the function returned here.
Args:
variable_info: DataFrame with information about the variables.
Returns:
Function that returns the arguments that maximize the Qc-function over the
discrete actions and the maximum itself, i.e., policy function of the discrete
actions. The maximum corresponds to the value function.
"""
discrete_action_axes = _determine_discrete_action_axes_simulation(variable_info)
def argmax_and_max_Qc_over_d(
Qc_arr: FloatND,
discrete_action_axes: tuple[int, ...],
) -> tuple[IntND, FloatND]:
return argmax_and_max(Qc_arr, axis=discrete_action_axes)
return partial(argmax_and_max_Qc_over_d, discrete_action_axes=discrete_action_axes)
# ======================================================================================
# Discrete problem with no shocks
# ======================================================================================
def _max_Qc_over_d_no_shocks(
Qc_arr: FloatND,
discrete_action_axes: tuple[int, ...],
) -> FloatND:
"""Take the maximum of the Qc-function over the discrete actions.
Args:
Qc_arr: The maximum of the state-action value function (Q) over the continuous
actions, conditional on the discrete action. This has one axis for each
state and discrete action variable.
discrete_action_axes: Tuple of indices representing the axes in the value
function that correspond to discrete actions.
**kwargs: Flat regime params (including additive_utility_shock__scale).
Returns:
The maximum of Qc_arr over the discrete action axes.
"""
return Qc_arr.max(axis=discrete_action_axes)
# ======================================================================================
# Discrete problem with extreme value shocks
# --------------------------------------------------------------------------------------
# The following is currently *NOT* supported.
# ======================================================================================
def _max_Qc_over_d_extreme_value_shocks(
Qc_arr: FloatND,
discrete_action_axes: tuple[int, ...],
**kwargs: Any,
) -> FloatND:
"""Take the expected maximum of the Qc-function over the discrete actions.
Args:
Qc_arr: The maximum of the state-action value function (Q) over the continuous
actions, conditional on the discrete action. This has one axis for each
state and discrete action variable.
discrete_action_axes: Tuple of indices representing the axes in the value
function that correspond to discrete actions.
**kwargs: Flat regime params (including additive_utility_shock__scale).
Returns:
The expected maximum of Qc_arr over the discrete action axes.
"""
scale = kwargs["additive_utility_shock__scale"]
return scale * jax.scipy.special.logsumexp(
Qc_arr / scale, axis=discrete_action_axes
)
# ======================================================================================
# Auxiliary functions
# ======================================================================================
def _determine_discrete_action_axes_solution(
variable_info: pd.DataFrame,
) -> tuple[int, ...]:
"""Get axes of state-action-space that correspond to discrete actions in solution.
Args:
variable_info: DataFrame with information about the variables.
Returns:
A tuple of indices representing the axes' positions in the value function that
correspond to discrete actions.
"""
discrete_action_vars = set(
variable_info.query("is_action & is_discrete").index.tolist()
)
return tuple(
i for i, ax in enumerate(variable_info.index) if ax in discrete_action_vars
)
def _determine_discrete_action_axes_simulation(
variable_info: pd.DataFrame,
) -> tuple[int, ...]:
"""Get axes of state-action-space that correspond to discrete actions in simulation.
Args:
variable_info: DataFrame with information about the variables.
Returns:
A tuple of indices representing the axes' positions in the value function that
correspond to discrete actions.
"""
discrete_action_vars = set(
variable_info.query("is_action & is_discrete").index.tolist()
)
# The first dimension corresponds to the simulated states, so add 1.
return tuple(1 + i for i in range(len(discrete_action_vars)))tests/test_max_Qc_over_d.py
import jax.numpy as jnp
import pandas as pd
import pytest
from numpy.testing import assert_array_almost_equal as aaae
from lcm.interfaces import ShockType
from lcm.max_Qc_over_d import (
_determine_discrete_action_axes_simulation,
_determine_discrete_action_axes_solution,
_max_Qc_over_d_extreme_value_shocks,
_max_Qc_over_d_no_shocks,
get_max_Qc_over_d,
)
# ======================================================================================
# Illustrative
# ======================================================================================
@pytest.mark.illustrative
def test_get_solve_discrete_problem_illustrative():
variable_info = pd.DataFrame(
{
"is_action": [False, True],
"is_state": [True, False],
"is_discrete": [True, True],
"is_continuous": [False, False],
},
) # leads to discrete_action_axes = [1]
max_Qc_over_d = get_max_Qc_over_d(
random_utility_shock_type=ShockType.NONE,
variable_info=variable_info,
is_terminal=False,
)
Qc_arr = jnp.array(
[
[0, 1],
[2, 3],
[4, 5],
],
)
got = max_Qc_over_d(Qc_arr)
aaae(got, jnp.array([1, 3, 5]))
@pytest.mark.illustrative
def test_solve_discrete_problem_no_shocks_illustrative_single_action_axis():
Qc_arr = jnp.array(
[
[0, 1],
[2, 3],
[4, 5],
],
)
got = _max_Qc_over_d_no_shocks(
Qc_arr,
discrete_action_axes=(0,),
)
aaae(got, jnp.array([4, 5]))
@pytest.mark.illustrative
def test_solve_discrete_problem_no_shocks_illustrative_multiple_action_axes():
Qc_arr = jnp.array(
[
[0, 1],
[2, 3],
[4, 5],
],
)
got = _max_Qc_over_d_no_shocks(
Qc_arr,
discrete_action_axes=(0, 1),
)
aaae(got, 5)
@pytest.mark.illustrative
def test_max_Qc_over_d_extreme_value_shocks_illustrative_single_action_axis():
Qc_arr = jnp.array(
[
[0, 1],
[2, 3],
[4, 5],
],
)
got = _max_Qc_over_d_extreme_value_shocks(
Qc_arr,
discrete_action_axes=(0,),
additive_utility_shock__scale=0.1,
)
aaae(got, jnp.array([4, 5]), decimal=5)
@pytest.mark.illustrative
def test_max_Qc_over_d_extreme_value_shocks_illustrative_multiple_action_axes():
Qc_arr = jnp.array(
[
[0, 1],
[2, 3],
[4, 5],
],
)
got = _max_Qc_over_d_extreme_value_shocks(
Qc_arr,
discrete_action_axes=(0, 1),
additive_utility_shock__scale=0.1,
)
aaae(got, 5, decimal=5)
# ======================================================================================
# Determine discrete action axes
# ======================================================================================
@pytest.mark.illustrative
def test_determine_discrete_action_axes_illustrative_one_var():
variable_info = pd.DataFrame(
{
"is_action": [False, True],
"is_state": [True, False],
"is_discrete": [True, True],
"is_continuous": [False, False],
},
)
assert _determine_discrete_action_axes_solution(variable_info) == (1,)
@pytest.mark.illustrative
def test_determine_discrete_action_axes_illustrative_three_var():
variable_info = pd.DataFrame(
{
"is_action": [False, True, True, True],
"is_state": [True, False, False, False],
"is_discrete": [True, True, True, True],
"is_continuous": [False, False, False, False],
},
)
assert _determine_discrete_action_axes_solution(variable_info) == (1, 2, 3)
def test_determine_discrete_action_axes():
variable_info = pd.DataFrame(
{
"is_state": [True, True, False, True, False, False],
"is_action": [False, False, True, True, True, True],
"is_discrete": [True, True, True, True, True, False],
"is_continuous": [False, True, False, False, False, True],
},
)
got = _determine_discrete_action_axes_simulation(variable_info)
assert got == (1, 2, 3)Related
- Remove dead code? #244 (dead code identification and removal)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels