This project implements a time series forecasting system for stock price prediction using a Segment-based Recurrent Neural Network (SegRNN) architecture. The system includes data preprocessing, model training, evaluation, and a web-based dashboard for interactive visualization of forecasts.
- Overview
- System Architecture
- Project Structure
- Model Architecture
- Installation
- Usage
- Data Format
- Training Process
- Prediction
- Performance Analysis
- Customization
- Troubleshooting
Stock Price Forecasting with SegRNN is designed to predict future stock prices based on historical data using advanced neural network techniques. The project leverages a novel segment-based approach to time series forecasting that divides input sequences into smaller segments for more effective temporal processing.
❗ Note: This project currently runs only on Linux systems.
Key Features:
- Interactive web dashboard for stock selection and prediction visualization
- Support for multiple stocks (currently 30+ stocks from the NIFTY50 index)
- Flexible training parameters and configuration
- Real-time prediction and visualization
- Comparative analysis with historical data
The codebase is organized as follows:
.
├── checkpoints/ # Saved model checkpoints
├── data_provider/ # Data loading and preprocessing
├── dataset/ # Stock data CSV files
├── exp/ # Experiment setup and execution
├── layers/ # Custom neural network layers
├── logs/ # Training and execution logs
├── models/ # Model implementations (including SegRNN)
├── results/ # Prediction results and visualizations
├── scripts/ # Automation scripts for training and evaluation
├── test_results/ # Test evaluation results
├── utils/ # Utility functions for metrics, tools, etc.
├── webapp/ # Flask web application
├── run_longExp.py # Main experiment execution script
└── README.md # This documentation
SegRNN is designed specifically for time series forecasting with the following key components:
-
Segment-based Processing:
- Divides input time series into fixed-length segments
- Processes segments rather than individual time steps
- Reduces sequence length for more efficient training
-
Network Architecture:
- Value Embedding Layer: Transforms segments to embedding space
- RNN Backbone: Processes embedded sequences (supports RNN, GRU, or LSTM)
- Decoding Mechanism: Two methods available:
- RMF (Recurrent Multi-step Forecasting): Autoregressive decoding
- PMF (Parallel Multi-step Forecasting): Single-step decoding with position encoding
-
Hyperparameters:
seq_len: Input sequence length (historical data window)pred_len: Prediction length (forecast horizon)seg_len: Segment length for dividing sequencesd_model: Model dimension for embeddingsrnn_type: Type of RNN cell (rnn, gru, lstm)dec_way: Decoding method (rmf, pmf)channel_id: Enable/disable channel position encoding
- Efficiency: Processing segments reduces computational complexity
- Long-Term Dependencies: Better capture of long-range patterns
- Scalability: Works well with different sequence lengths and prediction horizons
- Flexibility: Adaptable to various time series characteristics
-
Clone the repository:
git clone https://github.com/sagar7162/Stock-Market-Predictions-using-SegRNN.git cd DC -
Install dependencies:
pip install -r requirements.txt
Key dependencies include:
- PyTorch
- Flask
- pandas
- numpy
- plotly
- matplotlib
-
Prepare your environment:
- Set the environment variable to resolve MKL threading issues:
export MKL_THREADING_LAYER=GNU
- Set the environment variable to resolve MKL threading issues:
-
Start the web server:
python webapp/main.py
-
Access the dashboard at
http://127.0.0.1:5000/ -
Using the dashboard:
- Select a stock from the dropdown menu
- Click "Train Model" to train the SegRNN model for the selected stock
- After training completes, select a target date from the prediction range
- Click "Show Prediction" to see the forecast
For batch processing or automated workflows, use the provided shell scripts:
-
Train a model for a specific stock:
sh scripts/SegRNN/stock_predict.sh --stock STOCKNAME --is_training 1
-
Generate predictions using a trained model:
sh scripts/SegRNN/stock_predict.sh --stock STOCKNAME --is_training 0
-
Custom parameter configuration:
sh scripts/SegRNN/stock_predict.sh --stock STOCKNAME --pred_len 192 --seq_len 720 --seg_len 48
-
Direct use of the Python script:
python run_longExp.py --is_training 1 --model SegRNN --data custom --data_path STOCKNAME.csv --seq_len 720 --pred_len 96 --model_id STOCKNAME_720_96
Two CSV files are already provided in the dataset/ directory for illustrative purposes. You are welcome to add more CSV files to this folder, provided they adhere to the correct format described below.
You can download the dataset from the following link:
https://www.kaggle.com/datasets/rohanrao/nifty50-stock-market-data
The system requires stock data in CSV format with the following columns:
- Date: Date of trading
- Open: Opening price
- High: Highest price during the day
- Low: Lowest price during the day
- Close: Closing price (target for prediction)
CSV files should be named as STOCKNAME.csv and placed in the dataset/ directory.
The training process involves the following steps:
-
Data Preprocessing:
- Loading data from CSV files
- Normalization using feature-wise scaling
- Splitting into training, validation, and test sets
- Creating sliding windows for sequence input
-
Model Configuration:
- Setting up hyperparameters based on command-line arguments
- Initializing model architecture
- Setting up optimizer and loss function
-
Training Loop:
- Epoch-based training with early stopping
- Validation after each epoch
- Learning rate adjustment based on validation performance
- Checkpoint saving for best models
-
Evaluation:
- Testing on held-out test set
- Calculating metrics (MAE, MSE, RMSE)
- Generating visualizations of predictions
The prediction process generates forecasts for future time points beyond the available data:
-
Loading Trained Model:
- Retrieving the best checkpoint
- Initializing model with saved weights
-
Generating Predictions:
- Using the latest available data window as input
- Running forward pass through the model
- De-normalizing outputs to get actual price values
-
Saving Results:
- Storing predictions in CSV format
- Generating visualization plots
- Making results available through the web dashboard
The model's performance is evaluated using several metrics:
- MAE (Mean Absolute Error): Average absolute difference between predictions and actual values
- MSE (Mean Squared Error): Average squared difference, penalizing large errors more heavily
- RMSE (Root Mean Squared Error): Square root of MSE, interpretable in the original data scale
Results from various experiments show that SegRNN generally outperforms traditional methods and other deep learning approaches, particularly for longer prediction horizons.
The system is highly customizable:
-
Model Parameters: Adjust in the script or through command-line arguments:
- Sequence length (
--seq_len) - Prediction length (
--pred_len) - Segment length (
--seg_len) - Model dimension (
--d_model) - RNN type (
--rnn_type) - Decoding way (
--dec_way)
- Sequence length (
-
Training Parameters:
- Batch size (
--batch_size) - Learning rate (
--learning_rate) - Training epochs (
--train_epochs) - Patience for early stopping (
--patience)
- Batch size (
-
Feature Selection:
- Choose which columns to use as features (
--features) - Select the target variable (
--target)
- Choose which columns to use as features (
Common issues and solutions:
-
CUDA Out of Memory:
- Reduce batch size
- Decrease sequence length
- Lower model dimension
-
Poor Prediction Quality:
- Try different hyperparameter combinations
- Ensure sufficient training data
- Check for data quality issues
- Experiment with different RNN types (rnn, gru, lstm)
-
Training Too Slow:
- Reduce sequence length or segment length
- Use GPU acceleration if available
- Decrease model complexity (d_model, dropout)
-
Web Dashboard Issues:
- Check logs for error messages
- Ensure model is properly trained before prediction
- Verify file permissions for results directory
For any other issues, please check the logs in the logs/ directory for detailed error messages.
Acknowledgments:
- This project builds upon research in time series forecasting and deep learning
- Utilizes PyTorch for model implementation and training
- Flask for web dashboard development
