CSI-4CAST is a comprehensive framework for generating and evaluating Channel State Information (CSI) prediction models using 3GPP TR 38.901 channel models. The repository provides tools for large-scale dataset generation, model training, and comprehensive evaluation with support for both high-performance computing environments (Phoenix HPC) and direct execution on local machines.
This framework is developed as part of our research paper CSI-4CAST: A Hybrid Deep Learning Model for CSI Prediction with Comprehensive Robustness and Generalization Testing. (A BibTeX entry for citation is provided at the end of this page.) The corresponding datasets are publicly available on our Hugging Face Dataset.
CSI-4CAST/
├── README.md # Project documentation
├── LICENSE # License information
├── env.yml # Conda environment configuration
├── pyproject.toml # Python project configuration and linting rules
├── scripts/ # SLURM job scripts and templates
│ ├── data_gen_template.sh # Template for data generation jobs
│ ├── cp_template.sh # Template for model training jobs
│ ├── testing_template.sh # Template for testing jobs
│ └── outs/ # Job output logs
├── src/ # Source code
│ ├── data/ # Data generation module
│ │ ├── csi_simulator.py # CSI simulation using Sionna
│ │ └── generator.py # Dataset generation pipeline
│ ├── cp/ # Channel Prediction (model training) module
│ │ ├── main.py # Training entry point
│ │ ├── config/ # Training configuration management
│ │ │ └── config.py # Configuration file generator
│ │ ├── dataset/ # PyTorch Lightning data modules
│ │ │ └── data_module.py # Data loading and preprocessing
│ │ ├── models/ # Model architectures
│ │ │ ├── __init__.py # Model registry (PREDICTORS class)
│ │ │ ├── common/ # Shared model components
│ │ │ │ └── base.py # BaseCSIModel class
│ │ │ └── baseline_models/ # Baseline model implementations
│ │ │ ├── np.py # No-prediction baseline
│ │ │ └── rnn.py # RNN-based predictor
│ │ └── loss/ # Loss functions
│ │ └── loss.py # Custom loss implementations
│ ├── noise/ # Noise modeling and testing module
│ │ ├── noise.py # Noise generation functions
│ │ ├── noise_degree.py # Noise parameter calibration
│ │ ├── noise_testing.py # Noise testing utilities
│ │ └── results/ # Noise calibration results
│ │ ├── decide_nd.json # Noise degree mapping
│ │ └── snr.csv # SNR measurement results
│ ├── testing/ # Model evaluation module
│ │ ├── config.py # Testing configuration
│ │ ├── get_models.py # Model loading utilities
│ │ ├── computational_overhead/ # Performance profiling
│ │ │ ├── main.py # Computational overhead testing
│ │ │ └── utils.py # Profiling utilities
│ │ ├── prediction_performance/ # Prediction accuracy testing
│ │ │ ├── main.py # Performance testing entry point
│ │ │ └── test_unit.py # Individual test units
│ │ ├── results/ # Result processing and analysis
│ │ │ ├── main.py # Results processing pipeline
│ │ │ ├── analysis_df.py # Statistical analysis
│ │ │ ├── check_completion.py # Test completion verification
│ │ │ └── gather_results.py # Result aggregation
│ │ └── vis/ # Visualization module
│ │ ├── main.py # Visualization entry point
│ │ ├── line.py # Line plot generation
│ │ ├── radar.py # Radar plot generation
│ │ ├── table.py # Table generation
│ │ └── violin.py # Violin plot generation
│ └── utils/ # Utility functions
│ ├── data_utils.py # Constants and data handling utilities
│ ├── dirs.py # Directory path management
│ ├── norm_utils.py # Data normalization utilities
│ ├── main_utils.py # General utilities
│ ├── model_utils.py # Model-related utilities
│ ├── real_n_complex.py # Complex number handling
│ ├── time_utils.py # Time formatting utilities
│ └── vis_utils.py # Visualization utilities
└── z_artifacts/ # Generated artifacts and outputs
├── config/ # Generated configuration files
│ └── cp/ # Channel prediction configurations
├── data/ # Generated datasets (created during data generation)
├── outputs/ # Training and testing outputs
│ ├── [TDD/FDD]/ # Training outputs by scenario
│ ├── noise/ # Noise calibration results
│ └── testing/ # Testing results and analysis
│ ├── computational_overhead/ # Performance profiling results
│ ├── prediction_performance/ # Accuracy testing results
│ ├── results/ # Processed analysis results
│ └── vis/ # Generated visualizations
└── weights/ # Trained model checkpoints
├── fdd/ # FDD scenario model weights
└── tdd/ # TDD scenario model weights
The data generation module provides a complete pipeline for creating realistic CSI datasets using 3GPP channel models.
-
csi_simulator.py: Configures and implements the CSI simulator based on Sionna's 3GPP TR 38.901 channel model implementation. The simulator generates realistic channel responses for various propagation scenarios including different channel models, delay spreads, and mobility conditions. -
data_utils.py: Defines all simulation parameters and constants following the specifications detailed in the research paper. This includes antenna configurations, OFDM parameters, subcarrier arrangements, and dataset organization structures. -
generator.py: Employs the CSI simulator to generate comprehensive datasets including:- Training datasets for model development
- Regular testing datasets for standard and robustness evaluation
- Generalization testing datasets for generalization evaluation
The generator creates three types of CSI data files for each channel configuration:
H_U_hist.pt: Uplink historical CSI data (model input)H_U_pred.pt: Uplink prediction target CSI dataH_D_pred.pt: Downlink prediction target CSI data (for cross-link scenarios)
Data Dimensions:
- Antennas: 32 (4×4×2 dual-polarized BS antenna array)
- Time slots: 20 total (16 historical + 4 prediction)
- Subcarriers: 300 each for uplink and downlink (750 total with gap)
- Channel models: A, C, D (regular) / A, B, C, D, E (generalization)
- Delay spreads: 30-400 nanoseconds
- Mobility scenarios: 1-45 m/s
The channel prediction module provides a comprehensive framework for training CSI prediction models using PyTorch Lightning.
main.py: Training entry point that orchestrates the entire training processconfig/config.py: Configuration management system for training parameters, model settings, and hyperparametersdataset/data_module.py: PyTorch Lightning data modules for efficient data loading and preprocessingmodels/: Model architectures including:__init__.py: PREDICTORS registry for model selectioncommon/base.py: BaseCSIModel class that all models inherit frombaseline_models/: Implementation of baseline models (NP, RNN)
loss/loss.py: Custom loss functions optimized for CSI prediction tasks
The noise module handles realistic noise modeling and parameter calibration for comprehensive testing scenarios.
noise.py: Core noise generation functions implementing various realistic noise typesnoise_degree.py: Noise parameter calibration system that maps target SNRs to appropriate noise parametersnoise_testing.py: Noise testing utilities and configurationsresults/decide_nd.json: Pre-calibrated noise degree mapping for different noise types
The testing module provides comprehensive evaluation frameworks for CSI prediction models across multiple dimensions.
config.py: Testing configuration including model lists, scenarios, noise types, and job allocation settingsget_models.py: Model loading utilities with checkpoint path managementcomputational_overhead/: Performance profiling for measuring model computational requirementsprediction_performance/: Accuracy evaluation across thousands of testing scenariosresults/: Result processing pipeline including completion checking, data aggregation, and statistical analysisvis/: Comprehensive visualization suite generating line plots, radar charts, violin plots, and tables
The CSI-4CAST framework is designed to be flexible and compatible with various computing environments, from local development machines to large-scale HPC clusters.
module load mamba/[mamba_version]
mamba env create -f env.yml
mamba activate csi-4cast-envThe code related to data generation is in the src/data folder and src/utils/data_utils.py file.
The data_utils.py file defines all constants which configure the Sionna simulator and data generation process. It is critical to understand and adjust these constants based on your setting before running any code.
For high-performance computing, use the template in scripts/data_gen_template.sh:
python3 -m src.data.generator --is_train # Generate training data, typical array size is 1-9
python3 -m src.data.generator # Generate regular test data, typical array size is 1
python3 -m src.data.generator --is_gen # Generate generalization test data, typical array size is 1-20For local/single-node execution, use debug mode for minimal datasets:
python3 -m src.data.generator --debug --is_train # Debug mode: minimal training data
python3 -m src.data.generator --debug # Debug mode: minimal test data
python3 -m src.data.generator --debug --is_gen # Debug mode: minimal generalization dataAfter data generation, compute normalization statistics using src/utils/norm_utils.py:
python3 -m src.utils.norm_utilsThe normalization stats will be saved in z_artifacts/data/stats/[fdd/tdd]/normalization_stats.pkl.
The model training framework is built on PyTorch Lightning and located in the src/cp folder.
Models should be defined under src/cp/models folder, inherit from BaseCSIModel in src/cp/models/common/base.py, and be registered in the PREDICTORS class in src/cp/models/__init__.py. See src/cp/models/baseline_models/rnn.py for an example implementation.
Configure the training process in src/cp/config/config.py, then generate configuration files:
python3 -m src.cp.config.config --model [model_name] --output-dir [output_dir] --is_U2D [True/False] --config-file [yaml/json]Default output directory: z_artifacts/config/cp/[model_name]/
python3 -m src.cp.main --hparams_csi_pred [config_file]For HPC clusters, use scripts/cp_template.sh. Training outputs are saved in z_artifacts/outputs/[TDD/FDD]/[model_name]/[date_time]/ with checkpoints in ckpts/ and TensorBoard logs in tb_logs/.
View training progress:
tensorboard --logdir [output_directory]/tb_logsSince realistic noise types cannot be directly defined by SNRs, calibrate noise parameters first:
python3 -m src.noise.noise_degreeResults are saved in z_artifacts/outputs/noise/noise_degree/[date_time]/decide_nd.json and copied to src/noise/results/decide_nd.json.
The model evaluation framework in src/testing provides comprehensive assessment across multiple dimensions.
Configure models and checkpoint paths in src/testing/config.py. Ensure checkpoints conform to the get_ckpt_path function in src/testing/get_models.py. Default checkpoint path: z_artifacts/weights/[tdd/fdd]/[model_name]/model.ckpt.
python3 -m src.testing.computational_overhead.mainResults saved in z_artifacts/outputs/testing/computational_overhead/[date_time]/ for all configured models.
For HPC clusters using SLURM array jobs (recommended), use scripts/testing.slurm or scripts/testing_template.sh with array size matching JOBS_PER_MODEL in src/testing/config.py.
For local execution:
python3 -m src.testing.prediction_performance.main --model [model_name]Results saved in z_artifacts/outputs/testing/prediction_performance/[model_name]/[full_test/slice_i]/[date_time]/.
Process all testing results with comprehensive analysis:
python3 -m src.testing.results.mainThis performs three steps:
- Check completion status of testing models
- Gather and aggregate all results into CSV files
- Post-process results for scenario-wise distributions based on NMSE and SE metrics
Results saved in:
z_artifacts/outputs/testing/results/completion_reports/[date_time]/z_artifacts/outputs/testing/results/gather/[date_time]/z_artifacts/outputs/testing/results/analysis/[nmse/se]/[date_time]/
Generate comprehensive visualizations (line plots, radar plots, violin plots, tables):
python3 -m src.testing.vis.mainResults saved in z_artifacts/outputs/testing/vis/[date_time]/[line/radar/violin/table]/.
To better illustrate the usage of the framework, sample outputs are provided in the z_artifacts/ directory. These examples demonstrate the complete workflow from configuration to final visualization results.
config/cp/rnn/: Sample configuration files for RNN model trainingfdd_rnn.yaml: FDD scenario RNN configurationtdd_rnn.yaml: TDD scenario RNN configuration
noise/noise_degree/: Noise parameter calibration outputsdecide_nd.json: Calibrated noise degree mappings for different noise typessnr.csv: SNR measurement results across noise parameters
TDD/RNN/: Sample training output for RNN model in TDD scenarioconfig_copy.yaml: Training configuration backuptb_logs/: TensorBoard logs for training monitoring
The testing/ directory contains comprehensive evaluation results for both NP baseline and RNN models:
-
computational_overhead/: Performance profiling resultscomputational_overhead.csv: FLOPs, inference time, and parameter counts
-
prediction_performance/: Prediction accuracy resultsNP/full_test/: NP baseline results obtained via local execution modeRNN/slice_*/: RNN results obtained via SLURM job slices (20 parallel jobs)
results/: Consolidated and analyzed testing datacompletion_reports/: Testing completion status verificationgather/: Consolidated raw results from all models and slicesanalysis/: Statistical analysis with rankings and distributions
vis/: Comprehensive visualization suiteline/: Line plots showing performance across different conditionsgeneralization/: Out-of-distribution performanceregular/: In-distribution performancerobustness/: Performance under noise conditions
radar/: Multi-dimensional performance comparisoncombined_radar_fdd.pdf: FDD scenario radar plotcombined_radar_tdd.pdf: TDD scenario radar plot
table/: Performance summary tables by channel model and delay spreadviolin/: Distribution analysis across scenarios
The provided sample outputs demonstrate:
- Execution Modes: NP baseline uses local full_test mode while RNN uses distributed SLURM slices. The current testing framework supports both modes.
- Comprehensive Evaluation: Testing covers regular, robustness, and generalization scenarios.
- Multi-Metric Analysis: Both NMSE and spectral efficiency (SE) metrics are evaluated.
- Rich Visualizations: Multiple plot types provide different perspectives on model performance.
- Scalable Framework: The structure supports easy extension to additional models and scenarios.
For more comprehensive results and detailed analysis, please refer to the corresponding research paper.
If you use this framework in your research, please cite the corresponding paper:
@misc{cheng2025csi4casthybriddeeplearning,
title={CSI-4CAST: A Hybrid Deep Learning Model for CSI Prediction with Comprehensive Robustness and Generalization Testing},
author={Sikai Cheng and Reza Zandehshahvar and Haoruo Zhao and Daniel A. Garcia-Ulloa and Alejandro Villena-Rodriguez and Carles Navarro Manchón and Pascal Van Hentenryck},
year={2025},
eprint={2510.12996},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2510.12996},
}This project is licensed under the terms specified in the LICENSE file.