@.ai-instructions/profiles/tier-a.md @.ai-instructions/modules/jax.md @.ai-instructions/modules/optimagic.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
skillmodels is a Python implementation of estimators for nonlinear dynamic latent factor models, primarily used for skill formation research in economics. It implements Kalman filter-based maximum likelihood estimation following Cunha, Heckman, Schennach (2010).
Used as the core estimation engine by sibling application projects (skane-struct-bw,
health-cognition) in the parent workspace.
# Run tests
pixi run -e tests-cpu tests
# Run tests with coverage
pixi run -e tests-cpu tests-with-cov
# Run a single test file
pixi run -e tests-cpu pytest tests/test_kalman_filters.py
# Run a single test
pixi run -e tests-cpu pytest tests/test_kalman_filters.py::test_function_name
# Type checking
pixi run ty
# Quality checks (linting, formatting)
prek run --all-files
# Build documentation (mystmd, from docs/ directory)
myst buildAlways use these command mappings:
- Python: Use
pixi run pythoninstead ofpythonorpython3 - Type checker: Use
pixi run tyinstead of running ty/mypy/pyright directly - Tests: Use
pixi run -e tests-cpu testsinstead ofpytestdirectly - Linting/formatting: Use
prek run --all-filesinstead ofruffdirectly - All quality checks: Use
prek run --all-files
Before finishing any task that modifies code, always run:
pixi run ty(type checker)pixi run -e tests-cpu tests(tests)prek run --all-files(quality checks)
ModelSpec + Data
↓
process_model() → Validates/extends model specification → ProcessedModel
↓
get_maximization_inputs() → Creates optimization problem (likelihood, gradients,
constraints, params_template)
↓
[optimagic maximize / estimagic estimate_ml with fides algorithm]
↓
get_filtered_states() → Extract estimated latent factors
simulate_dataset() → Simulate states (with optional policy effects)
- model_spec.py: User-facing frozen dataclasses (
ModelSpec,FactorSpec,AnchoringSpec). Re-exportsEstimationOptionsandNormalizationsfromtypes.py.ModelSpecsupports construction via__init__orModelSpec.from_dict(), and fluent builder methods:with_transition_functions(),with_added_factor(),with_added_observed_factors(),with_estimation_options(),with_anchoring(),with_controls(),with_stagemap(). - types.py: Internal frozen dataclasses (
ProcessedModel,Labels,Dimensions,Anchoring,ParsingInfo,ParsedParams,EndogenousFactorsInfo, etc.),EstimationOptions,Normalizations, and immutability utilities. - process_model.py: Model specification validation and preprocessing. Converts
ModelSpecintoProcessedModel. - kalman_filters.py: Core Kalman filter implementation (predict/update steps). Uses square-root form for numerical stability.
- likelihood_function.py / likelihood_function_debug.py: Log-likelihood computation using Kalman filtering. The debug variant is not jitted and returns intermediate results (residuals, contributions, filtered states).
- constraints.py: Generates parameter constraints (bounds, equalities from stagemap,
fixed values) for optimization. Exports
get_constraints(),enforce_fixed_constraints(),add_bounds(),FixedConstraintWithValue. - parse_params.py: Converts flat parameter vectors to structured model parameters.
Exports
create_parsing_info()andparse_params(). - params_index.py: Builds the
pd.MultiIndexfor the params DataFrame viaget_params_index(). - transition_functions.py: Pre-built transition equations:
linear,translog,robust_translog,linear_and_squares,log_ces,log_ces_general,constant. - decorators.py:
register_paramsdecorator for custom transition functions. Tags a callable with__registered_params__so skillmodels knows its parameter names. - process_data.py:
process_data()for internal estimation format,pre_process_data()for reshaping data to long format with period indexing. - simulate_data.py:
simulate_dataset()andsimulate_policy_effect(). - diagnostic_plots.py:
plot_likelihood_contributions()andplot_residual_boxplots()(Plotly-based). - variance_decomposition.py:
decompose_measurement_variance()andsummarize_measurement_reliability(). - process_debug_data.py:
create_state_ranges()andprocess_debug_data()for converting raw debug output into DataFrames. - utilities.py: Model manipulation helpers (
extract_factors,remove_factors,update_parameter_values,switch_translog_to_linear, etc.). - Visualization modules (not in
__all__, imported by module path):correlation_heatmap.py,visualize_factor_distributions.py,visualize_transition_equations.py,utils_plotting.py.
Returns a dict with 6 keys:
"loglike":(params: pd.DataFrame) -> float— jitted scalar log-likelihood"loglikeobs":(params: pd.DataFrame) -> NDArray— jitted per-observation log-likelihood"debug_loglike":(params: pd.DataFrame) -> dict— non-jitted, returns dict with keysvalue,contributions,residuals,residual_sds,filtered_states,state_ranges, etc."loglike_and_gradient":(params: pd.DataFrame) -> tuple[float, NDArray]"constraints": list of optimagic constraint objects"params_template":pd.DataFramewith correct MultiIndex and bounds; fixed constraints pre-applied
Applications frequently access these after calling process_model(model_spec):
processed.labels—.latent_factors,.observed_factors,.all_factors,.controls,.stagemap,.stages,.aug_periods_to_periodsprocessed.dimensions—.n_periods,.n_latent_factorsprocessed.update_info— DataFrame indexed by(aug_period, variable)with factor columns and apurposecolumnprocessed.endogenous_factors_info—.has_endogenous_factors,.aug_periods_from_period(period),.factor_infoprocessed.normalizations— dict of factor name toNormalizationsprocessed.transition_info—.func(vectorized),.individual_functions[factor]
All computation-heavy code uses JAX for automatic differentiation and JIT compilation. The codebase uses:
jax.vmapfor vectorization across observationsjax.jitfor compilation- JAX arrays throughout the estimation pipeline
- Optional GPU support via CUDA or Metal
Model specification classes:
ModelSpec,FactorSpec,AnchoringSpec,EstimationOptions,Normalizations
Core estimation:
get_maximization_inputs(model_spec, data, split_dataset=1)— prepare optimization problemget_filtered_states(model_spec, data, params)— returns nested dict with"anchored_states"and"unanchored_states", each containing"states"(DataFrame) and"state_ranges"
Simulation:
simulate_dataset(model_spec, params, n_obs=None, data=None, policies=None, seed=None)— returns dict with"unanchored_states","anchored_states"simulate_policy_effect(model_spec, params, data, policies, seed=None)— returns DataFrame of factor mean differences between policy and baseline
Diagnostics and visualization:
plot_likelihood_contributions(model_spec, data, params, period=None)plot_residual_boxplots(model_spec, data, params, period=None)decompose_measurement_variance(model_spec, params, data)— returns DataFrame indexed by(period, measurement, factor)with signal/noise columnssummarize_measurement_reliability(variance_decomposition)create_state_ranges(filtered_states, factors, quantile_cutoff=None)
These are not in __all__ but are imported directly by application projects:
skillmodels.process_model.process_model— central to all application codeskillmodels.types.ProcessedModel,EndogenousFactorsInfoskillmodels.decorators.register_params— essential for custom transition functionsskillmodels.constraints.get_constraints,enforce_fixed_constraints,FixedConstraintWithValue,select_by_locskillmodels.utilities.extract_factors,update_parameter_valuesskillmodels.process_data.pre_process_dataskillmodels.correlation_heatmap.get_measurements_corr,get_quasi_scores_corr,get_scores_corr,plot_correlation_heatmapskillmodels.visualize_factor_distributions.univariate_densities,bivariate_density_contours,combine_distribution_plotsskillmodels.visualize_transition_equations.get_transition_plots,combine_transition_plotsskillmodels.parse_params.create_parsing_info,parse_paramsskillmodels.params_index.get_params_index
- Require Python 3.14
- Uses Ruff for linting (target: Python 3.14, line length: 88)
- Google-style docstrings with imperative mood ("Return" not "Returns")
- Use MyST syntax in docstrings (single backticks
like this), not reStructuredText (no double backticks, no:ref:,:func:, etc.) - Dataclass attributes use inline docstrings (docstring on the line after the field):
name: str """Description of name."""
- Pre-commit hooks enforce formatting and linting
- Type checking via
tywith strict rules - Do not use
from __future__ import annotations - Use modern numpy random API:
rng = np.random.default_rng(seed)instead ofnp.random.seed()or legacy functions likenp.random.randn()
- All model configuration and internal data structures use frozen dataclasses
- Dict fields on internal dataclasses use
MappingProxyType(notMapping); wrap at the call site withMappingProxyType(...) - Dict fields on user-facing dataclasses (
AnchoringSpec,Normalizations) useMappingwith__post_init__conversion viaensure_containers_are_immutable() - List fields use
tuple, set fields usefrozenset ensure_containers_are_immutable()recursively converts dict→MappingProxyType, list→tuple, set→frozenset
Models with endogenous factors split each calendar period into multiple augmented
periods (aug_period). The public API uses period (user-facing); aug_period is
strictly internal. All public functions now return period:
ModelSpec— clean, noaug_periodexposure.get_transition_plots()— clean, acceptsperiod/periods.get_filtered_states()— clean, returnsperiodcolumn.simulate_dataset()— clean, returnsperiodin states DataFrames.plot_residual_boxplots()/plot_likelihood_contributions()— clean, accept and returnperiod.decompose_measurement_variance()— clean, indexed by(period, measurement, factor).simulate_policy_effect()/simulate_dataset()policies — accept"period"key.ProcessedModel.labels— exposesaug_periods_to_periodsmapping (acceptable for internal/advanced use).
When writing new public-facing code, always accept and return period. Convert to
aug_period internally using ProcessedModel.labels.aug_periods_to_periods.
- pytest with markers:
wip,unit,integration,end_to_end - Test files mirror source structure in
tests/ - Memory profiling available via pytest-memray (Unix only)