Skip to content

ENH: Add taste shocks #247

@hmgaudecker

Description

@hmgaudecker

Motivation

pylcm currently has no support for taste shocks (additive utility shocks with extreme value distribution). The infrastructure for this feature was partially sketched in two modules that have been removed as dead code in the cleanup (#243):

  • src/lcm/max_Q_over_c.py — maximization of Q over continuous actions
  • src/lcm/max_Qc_over_d.py — maximization of Qc over discrete actions (includes a stub for extreme value shocks via logsumexp)
  • tests/test_max_Qc_over_d.py — tests for the discrete maximization module

What needs to happen

  1. Design how taste shocks integrate with the current Regime-based API
  2. Implement the logsumexp aggregation over discrete actions (the _max_Qc_over_d_extreme_value_shocks function was a starting point)
  3. Add a scale parameter for the shock distribution
  4. Create a test model with taste shocks and verify against an analytical or known-good solution

Archived code

The deleted modules are preserved below for reference. They will need significant adaptation to the current codebase.

src/lcm/max_Q_over_c.py
import functools
from collections.abc import Callable
from types import MappingProxyType

import jax.numpy as jnp
from jax import Array

from lcm.argmax import argmax_and_max
from lcm.dispatchers import productmap
from lcm.typing import (
    ArgmaxQOverCFunction,
    BoolND,
    FloatND,
    IntND,
    MaxQOverCFunction,
    Period,
    RegimeName,
)


def get_max_Q_over_c(
    *,
    Q_and_F: Callable[..., tuple[FloatND, BoolND]],
    continuous_action_names: tuple[str, ...],
    state_and_discrete_action_names: tuple[str, ...],
) -> MaxQOverCFunction:
    r"""Get the function returning the maximum of Q over continuous actions.

    The state-action value function $Q$ is defined as:

    ```{math}
    Q(x, a) =  H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
    ```
    with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
    is pre-implemented in LCM).

    Fixing a state and discrete action, maximizing over the feasible continuous actions,
    we get the $Q^c$ function:

    ```{math}
    Q^{c}(x, a^d) = \max_{a^c} Q(x, a^d, a^c).
    ```

    This last step is handled by the function returned here.

    Args:
        Q_and_F: A function that takes a state-action combination and returns the action
            value of that combination and whether the state-action combination is
            feasible.
        continuous_action_names: Tuple of action variable names that are continuous.
        state_and_discrete_action_names: Tuple of state and discrete action variable
            names.

    Returns:
        Qc, i.e., the function that calculates the maximum of the Q-function over the
        feasible continuous actions.

    """
    if continuous_action_names:
        Q_and_F = productmap(
            func=Q_and_F,
            variables=continuous_action_names,
        )

    @functools.wraps(Q_and_F)
    def max_Q_over_c(
        next_V_arr: MappingProxyType[RegimeName, FloatND],
        period: Period,
        **states_actions_params: Array,
    ) -> FloatND:
        Q_arr, F_arr = Q_and_F(
            next_V_arr=next_V_arr,
            period=period,
            **states_actions_params,
        )
        return Q_arr.max(where=F_arr, initial=-jnp.inf)

    return productmap(func=max_Q_over_c, variables=state_and_discrete_action_names)


def get_argmax_and_max_Q_over_c(
    *,
    Q_and_F: Callable[..., tuple[FloatND, BoolND]],
    continuous_action_names: tuple[str, ...],
) -> ArgmaxQOverCFunction:
    r"""Get the function returning the arguments maximizing Q over continuous actions.

    The state-action value function $Q$ is defined as:

    ```{math}
    Q(x, a) =  H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
    ```
    with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
    is pre-implemented in LCM).

    Fixing a state and discrete action but choosing the feasible continuous actions that
    maximizes Q, we get

    ```{math}
    \pi^{c}(x, a^d) = \argmax_{a^c} Q(x, a^d, a^c).
    ```

    This last step is handled by the function returned here.

    Args:
        Q_and_F: A function that takes a state-action combination and returns the action
            value of that combination and whether the state-action combination is
            feasible.
        continuous_action_names: Tuple of action variable names that are continuous.

    Returns:
        Function that calculates the argument maximizing Q over the feasible continuous
        actions and the maximum itself. The argument maximizing Q is the policy
        function of the continuous actions, conditional on the states and discrete
        actions. The maximum corresponds to the Qc-function.

    """
    if continuous_action_names:
        Q_and_F = productmap(
            func=Q_and_F,
            variables=continuous_action_names,
        )

    @functools.wraps(Q_and_F)
    def argmax_and_max_Q_over_c(
        next_V_arr: MappingProxyType[RegimeName, FloatND],
        period: Period,
        **states_actions_params: Array,
    ) -> tuple[IntND, FloatND]:
        Q_arr, F_arr = Q_and_F(
            next_V_arr=next_V_arr,
            period=period,
            **states_actions_params,
        )
        return argmax_and_max(Q_arr, where=F_arr, initial=-jnp.inf)

    return argmax_and_max_Q_over_c
src/lcm/max_Qc_over_d.py
from functools import partial
from typing import Any

import jax
import pandas as pd

from lcm.argmax import argmax_and_max
from lcm.interfaces import ShockType
from lcm.typing import (
    ArgmaxQcOverDFunction,
    FloatND,
    IntND,
    MaxQcOverDFunction,
)


def get_max_Qc_over_d(
    *,
    random_utility_shock_type: ShockType,
    variable_info: pd.DataFrame,
    is_terminal: bool,
) -> MaxQcOverDFunction:
    r"""Get the function returning the maximum of Qc over discrete actions.

    The state-action value function $Q$ is defined as:

    ```{math}
    Q(x, a) =  H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
    ```
    with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
    is pre-implemented in LCM).

    Fixing a state and discrete action, maximizing over the feasible continuous actions,
    we get the $Q^c$ function:

    ```{math}
    Q^{c}(x, a^d) = \max_{a^c} Q(x, a^d, a^c).
    ```

    And maximizing over the discrete actions, we get the value function:

    ```{math}
    V(x) = \max_{a^d} Q^{c}(x, a^d).
    ```

    This last step is handled by the function returned here.

    Args:
        random_utility_shock_type: Type of action shock. Currently only Shock.NONE is
            supported. Work for "extreme_value" is in progress.
        variable_info: DataFrame with information about the variables.
        is_terminal: Whether the function is created for a terminal regime.

    Returns:
        Function that returns the argument that maximize the Qc-function over the
        discrete actions. The maximizing argument corresponds to the policy function of
        the discrete actions.

    """
    if is_terminal:
        variable_info = variable_info.query("enters_concurrent_valuation")

    discrete_action_axes = _determine_discrete_action_axes_solution(variable_info)

    if random_utility_shock_type == ShockType.NONE:
        func = _max_Qc_over_d_no_shocks
    elif random_utility_shock_type == ShockType.EXTREME_VALUE:
        raise NotImplementedError("Extreme value shocks are not yet implemented.")
    else:
        raise ValueError(f"Invalid shock_type: {random_utility_shock_type}.")

    return partial(func, discrete_action_axes=discrete_action_axes)


def get_argmax_and_max_Qc_over_d(
    *,
    variable_info: pd.DataFrame,
) -> ArgmaxQcOverDFunction:
    r"""Get the function returning the arguments maximizing Qc over discrete actions.

    The state-action value function $Q$ is defined as:

    ```{math}
    Q(x, a) =  H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
    ```
    with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
    is pre-implemented in LCM).

    Fixing a state and discrete action, maximizing over the feasible continuous actions,
    we get the $Q^c$ function:

    ```{math}
    Q^{c}(x, a^d) = \max_{a^c} Q(x, a^d, a^c).
    ```

    Taking the argmax over the discrete actions, we get the policy function of the
    discrete actions:

    ```{math}
    \pi^{d}(x) = \argmax_{a^d} Q^{c}(x, a^d).
    ```

    This last step is handled by the function returned here.

    Args:
        variable_info: DataFrame with information about the variables.

    Returns:
        Function that returns the arguments that maximize the Qc-function over the
        discrete actions and the maximum itself, i.e., policy function of the discrete
        actions. The maximum corresponds to the value function.

    """
    discrete_action_axes = _determine_discrete_action_axes_simulation(variable_info)

    def argmax_and_max_Qc_over_d(
        Qc_arr: FloatND,
        discrete_action_axes: tuple[int, ...],
    ) -> tuple[IntND, FloatND]:
        return argmax_and_max(Qc_arr, axis=discrete_action_axes)

    return partial(argmax_and_max_Qc_over_d, discrete_action_axes=discrete_action_axes)


# ======================================================================================
# Discrete problem with no shocks
# ======================================================================================


def _max_Qc_over_d_no_shocks(
    Qc_arr: FloatND,
    discrete_action_axes: tuple[int, ...],
) -> FloatND:
    """Take the maximum of the Qc-function over the discrete actions.

    Args:
        Qc_arr: The maximum of the state-action value function (Q) over the continuous
            actions, conditional on the discrete action. This has one axis for each
            state and discrete action variable.
        discrete_action_axes: Tuple of indices representing the axes in the value
            function that correspond to discrete actions.
        **kwargs: Flat regime params (including additive_utility_shock__scale).

    Returns:
        The maximum of Qc_arr over the discrete action axes.

    """
    return Qc_arr.max(axis=discrete_action_axes)


# ======================================================================================
# Discrete problem with extreme value shocks
# --------------------------------------------------------------------------------------
# The following is currently *NOT* supported.
# ======================================================================================


def _max_Qc_over_d_extreme_value_shocks(
    Qc_arr: FloatND,
    discrete_action_axes: tuple[int, ...],
    **kwargs: Any,
) -> FloatND:
    """Take the expected maximum of the Qc-function over the discrete actions.

    Args:
        Qc_arr: The maximum of the state-action value function (Q) over the continuous
            actions, conditional on the discrete action. This has one axis for each
            state and discrete action variable.
        discrete_action_axes: Tuple of indices representing the axes in the value
            function that correspond to discrete actions.
        **kwargs: Flat regime params (including additive_utility_shock__scale).

    Returns:
        The expected maximum of Qc_arr over the discrete action axes.

    """
    scale = kwargs["additive_utility_shock__scale"]
    return scale * jax.scipy.special.logsumexp(
        Qc_arr / scale, axis=discrete_action_axes
    )


# ======================================================================================
# Auxiliary functions
# ======================================================================================


def _determine_discrete_action_axes_solution(
    variable_info: pd.DataFrame,
) -> tuple[int, ...]:
    """Get axes of state-action-space that correspond to discrete actions in solution.

    Args:
        variable_info: DataFrame with information about the variables.

    Returns:
        A tuple of indices representing the axes' positions in the value function that
        correspond to discrete actions.

    """
    discrete_action_vars = set(
        variable_info.query("is_action & is_discrete").index.tolist()
    )
    return tuple(
        i for i, ax in enumerate(variable_info.index) if ax in discrete_action_vars
    )


def _determine_discrete_action_axes_simulation(
    variable_info: pd.DataFrame,
) -> tuple[int, ...]:
    """Get axes of state-action-space that correspond to discrete actions in simulation.

    Args:
        variable_info: DataFrame with information about the variables.

    Returns:
        A tuple of indices representing the axes' positions in the value function that
        correspond to discrete actions.

    """
    discrete_action_vars = set(
        variable_info.query("is_action & is_discrete").index.tolist()
    )

    # The first dimension corresponds to the simulated states, so add 1.
    return tuple(1 + i for i in range(len(discrete_action_vars)))
tests/test_max_Qc_over_d.py
import jax.numpy as jnp
import pandas as pd
import pytest
from numpy.testing import assert_array_almost_equal as aaae

from lcm.interfaces import ShockType
from lcm.max_Qc_over_d import (
    _determine_discrete_action_axes_simulation,
    _determine_discrete_action_axes_solution,
    _max_Qc_over_d_extreme_value_shocks,
    _max_Qc_over_d_no_shocks,
    get_max_Qc_over_d,
)

# ======================================================================================
# Illustrative
# ======================================================================================


@pytest.mark.illustrative
def test_get_solve_discrete_problem_illustrative():
    variable_info = pd.DataFrame(
        {
            "is_action": [False, True],
            "is_state": [True, False],
            "is_discrete": [True, True],
            "is_continuous": [False, False],
        },
    )  # leads to discrete_action_axes = [1]

    max_Qc_over_d = get_max_Qc_over_d(
        random_utility_shock_type=ShockType.NONE,
        variable_info=variable_info,
        is_terminal=False,
    )

    Qc_arr = jnp.array(
        [
            [0, 1],
            [2, 3],
            [4, 5],
        ],
    )

    got = max_Qc_over_d(Qc_arr)
    aaae(got, jnp.array([1, 3, 5]))


@pytest.mark.illustrative
def test_solve_discrete_problem_no_shocks_illustrative_single_action_axis():
    Qc_arr = jnp.array(
        [
            [0, 1],
            [2, 3],
            [4, 5],
        ],
    )
    got = _max_Qc_over_d_no_shocks(
        Qc_arr,
        discrete_action_axes=(0,),
    )
    aaae(got, jnp.array([4, 5]))


@pytest.mark.illustrative
def test_solve_discrete_problem_no_shocks_illustrative_multiple_action_axes():
    Qc_arr = jnp.array(
        [
            [0, 1],
            [2, 3],
            [4, 5],
        ],
    )
    got = _max_Qc_over_d_no_shocks(
        Qc_arr,
        discrete_action_axes=(0, 1),
    )
    aaae(got, 5)


@pytest.mark.illustrative
def test_max_Qc_over_d_extreme_value_shocks_illustrative_single_action_axis():
    Qc_arr = jnp.array(
        [
            [0, 1],
            [2, 3],
            [4, 5],
        ],
    )

    got = _max_Qc_over_d_extreme_value_shocks(
        Qc_arr,
        discrete_action_axes=(0,),
        additive_utility_shock__scale=0.1,
    )
    aaae(got, jnp.array([4, 5]), decimal=5)


@pytest.mark.illustrative
def test_max_Qc_over_d_extreme_value_shocks_illustrative_multiple_action_axes():
    Qc_arr = jnp.array(
        [
            [0, 1],
            [2, 3],
            [4, 5],
        ],
    )
    got = _max_Qc_over_d_extreme_value_shocks(
        Qc_arr,
        discrete_action_axes=(0, 1),
        additive_utility_shock__scale=0.1,
    )
    aaae(got, 5, decimal=5)


# ======================================================================================
# Determine discrete action axes
# ======================================================================================


@pytest.mark.illustrative
def test_determine_discrete_action_axes_illustrative_one_var():
    variable_info = pd.DataFrame(
        {
            "is_action": [False, True],
            "is_state": [True, False],
            "is_discrete": [True, True],
            "is_continuous": [False, False],
        },
    )

    assert _determine_discrete_action_axes_solution(variable_info) == (1,)


@pytest.mark.illustrative
def test_determine_discrete_action_axes_illustrative_three_var():
    variable_info = pd.DataFrame(
        {
            "is_action": [False, True, True, True],
            "is_state": [True, False, False, False],
            "is_discrete": [True, True, True, True],
            "is_continuous": [False, False, False, False],
        },
    )

    assert _determine_discrete_action_axes_solution(variable_info) == (1, 2, 3)


def test_determine_discrete_action_axes():
    variable_info = pd.DataFrame(
        {
            "is_state": [True, True, False, True, False, False],
            "is_action": [False, False, True, True, True, True],
            "is_discrete": [True, True, True, True, True, False],
            "is_continuous": [False, True, False, False, False, True],
        },
    )
    got = _determine_discrete_action_axes_simulation(variable_info)
    assert got == (1, 2, 3)

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions