Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,5 @@ coverage.xml
*.log
/weights/
!/weights/.gitkeep
CLAUDE.md
temp/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ If you find tfts project useful in your research, please consider cite:
```
@misc{tfts2020,
author = {Longxing Tan},
title = {Time series prediction},
title = {TFTS: Time series prediction},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def build_model():
```
@misc{tfts2020,
author = {Longxing Tan},
title = {Time series prediction},
title = {TFTS: Time series prediction},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
Expand Down
15 changes: 1 addition & 14 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TFTS: TensorFlow Time Series

<a class="github-button" href="https://github.com/LongxingTan/Time-series-prediction" data-icon="octicon-star" data-size="large" data-show-count="true" aria-label="Star LongxingTan/Time-series-prediction on GitHub">GitHub</a>

Welcome to TFTS (TensorFlow Time Series), a comprehensive Python library for state-of-the-art deep learning time series analysis. TFTS provides production-ready implementations of cutting-edge models for forecasting, classification, and anomaly detection tasks.
Welcome to TFTS (TensorFlow Time Series), a Python library for state-of-the-art deep learning time series analysis. TFTS provides production-ready implementations of cutting-edge models for forecasting, classification, and anomaly detection tasks.

.. image:: https://img.shields.io/badge/License-MIT-blue.svg
:target: https://opensource.org/licenses/MIT
Expand Down Expand Up @@ -341,11 +341,6 @@ Community and Support
**Contributing**
We welcome contributions! See our `Contributing Guide <https://github.com/LongxingTan/Time-series-prediction/blob/master/CONTRIBUTING.md>`_ for details.

**Stay Updated**
- ⭐ Star the `GitHub repository <https://github.com/LongxingTan/Time-series-prediction>`_
- 📰 Check the `changelog <./CHANGELOG.md>`_ for latest updates
- 🐦 Follow updates on social media


Citation
--------
Expand All @@ -368,11 +363,3 @@ License
-------

TFTS is released under the MIT License. See `LICENSE <https://github.com/LongxingTan/Time-series-prediction/blob/master/LICENSE>`_ for details.


Indices and Tables
------------------

* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
1 change: 0 additions & 1 deletion docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,5 @@ Getting Help

If you encounter installation issues:

- 📖 Check the `FAQ <./faq.html>`_
- 💬 Ask in `GitHub Discussions <https://github.com/LongxingTan/Time-series-prediction/discussions>`_
- 🐛 Report bugs in `GitHub Issues <https://github.com/LongxingTan/Time-series-prediction/issues>`_
169 changes: 168 additions & 1 deletion tests/test_data/test_get_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import unittest
from unittest.mock import MagicMock, patch

from tfts.data.get_data import get_air_passengers, get_data, get_sine
import numpy as np
import pandas as pd

from tfts.data.get_data import get_air_passengers, get_ar_data, get_data, get_sine, get_stock_data


class GetDataTest(unittest.TestCase):
Expand Down Expand Up @@ -35,3 +39,166 @@ def test_get_air_passenger_data(self):
self.assertEqual(train[1].shape[1:], (predict_sequence_length, 1))
self.assertEqual(valid[0].shape[1:], (train_length, 1))
self.assertEqual(valid[1].shape[1:], (predict_sequence_length, 1))

def test_get_sine_no_test_split(self):
"""Test get_sine with test_size=0 returns single tuple"""
train_length = 10
predict_sequence_length = 4
n_examples = 50
x, y = get_sine(train_length, predict_sequence_length, test_size=0, n_examples=n_examples)
self.assertEqual(x.shape, (n_examples, train_length, 1))
self.assertEqual(y.shape, (n_examples, predict_sequence_length, 1))
self.assertIsInstance(x, np.ndarray)
self.assertIsInstance(y, np.ndarray)

def test_get_air_passengers_no_test_split(self):
"""Test get_air_passengers with test_size=0"""
train_length = 10
predict_sequence_length = 4
x, y = get_air_passengers(train_length, predict_sequence_length, test_size=0)
self.assertEqual(x.shape[1:], (train_length, 1))
self.assertEqual(y.shape[1:], (predict_sequence_length, 1))

def test_get_data_invalid_name(self):
"""Test get_data raises ValueError for unsupported dataset"""
with self.assertRaises(ValueError) as context:
get_data("invalid_dataset", 10, 4, 0.2)
self.assertIn("unsupported data", str(context.exception))

def test_get_data_test_size_validation(self):
"""Test get_data validates test_size parameter"""
with self.assertRaises(AssertionError):
get_data("sine", 10, 4, test_size=-0.1)

with self.assertRaises(AssertionError):
get_data("sine", 10, 4, test_size=1.5)

def test_get_data_airpassengers(self):
"""Test get_data dispatcher for airpassengers dataset"""
train_length = 12
predict_length = 6
train, valid = get_data("airpassengers", train_length, predict_length, test_size=0.15)
self.assertIsNotNone(train)
self.assertIsNotNone(valid)
self.assertEqual(len(train), 2)
self.assertEqual(len(valid), 2)

def test_get_sine_data_values_in_range(self):
"""Test that sine wave values are in expected range"""
train_length = 20
predict_length = 5
x, y = get_sine(train_length, predict_length, test_size=0, n_examples=10)

# Sine values should be roughly in [-1, 1] range
self.assertTrue(np.all(x >= -1.5))
self.assertTrue(np.all(x <= 1.5))
self.assertTrue(np.all(y >= -1.5))
self.assertTrue(np.all(y <= 1.5))

def test_get_ar_data_basic(self):
"""Test basic AR data generation"""
df = get_ar_data(n_series=5, timesteps=100)

self.assertIsInstance(df, pd.DataFrame)
self.assertIn("series", df.columns)
self.assertIn("time_idx", df.columns)
self.assertIn("value", df.columns)
self.assertEqual(len(df), 5 * 100) # n_series * timesteps

def test_get_ar_data_with_covariates(self):
"""Test AR data generation with covariates"""
df = get_ar_data(n_series=3, timesteps=50, add_covariates=True)

self.assertIn("day_of_week", df.columns)
self.assertIn("month", df.columns)
self.assertIn("category", df.columns)
self.assertIn("special_event", df.columns)

# Check value ranges
self.assertTrue(df["day_of_week"].between(0, 6).all())
self.assertTrue(df["month"].between(1, 13).all())
self.assertTrue(df["special_event"].isin([0, 1]).all())

def test_get_ar_data_with_components(self):
"""Test AR data generation returning components"""
df, components = get_ar_data(n_series=3, timesteps=50, return_components=True)

self.assertIsInstance(components, dict)
self.assertIn("linear_trends", components)
self.assertIn("quadratic_trends", components)
self.assertIn("seasonalities", components)
self.assertIn("levels", components)
self.assertIn("series", components)

def test_get_ar_data_exponential(self):
"""Test AR data with exponential transformation"""
df = get_ar_data(n_series=2, timesteps=50, exp=True)

# Exponential values should all be positive
self.assertTrue((df["value"] > 0).all())

def test_get_ar_data_seeded_reproducibility(self):
"""Test that same seed produces same data"""
df1 = get_ar_data(n_series=3, timesteps=50, seed=42)
df2 = get_ar_data(n_series=3, timesteps=50, seed=42)

pd.testing.assert_frame_equal(df1, df2)

def test_get_ar_data_different_seeds(self):
"""Test that different seeds produce different data"""
df1 = get_ar_data(n_series=3, timesteps=50, seed=42)
df2 = get_ar_data(n_series=3, timesteps=50, seed=123)

# Values should be different
self.assertFalse(df1["value"].equals(df2["value"]))

def test_get_ar_data_invalid_params(self):
"""Test AR data validation for invalid parameters"""
with self.assertRaises(ValueError):
get_ar_data(n_series=0, timesteps=100)

with self.assertRaises(ValueError):
get_ar_data(n_series=5, timesteps=-10)

with self.assertRaises(ValueError):
get_ar_data(n_series=5, timesteps=100, noise=-0.5)

def test_get_ar_data_parameter_effects(self):
"""Test that parameters affect data as expected"""
# High noise should create more variance
df_low_noise = get_ar_data(n_series=5, timesteps=100, noise=0.01, seed=42)
df_high_noise = get_ar_data(n_series=5, timesteps=100, noise=1.0, seed=42)

# Not directly comparing variance due to random effects,
# but shapes should match
self.assertEqual(len(df_low_noise), len(df_high_noise))

def test_get_data_ar_dispatch(self):
"""Test get_data dispatcher for AR data"""
result = get_data("ar", train_length=10, predict_sequence_length=5, test_size=0, n_series=3, timesteps=50)

self.assertIsInstance(result, pd.DataFrame)
self.assertEqual(len(result), 3 * 50)

def test_sine_data_sequence_continuity(self):
"""Test that sine data maintains temporal continuity"""
train_length = 10
predict_length = 5
x, y = get_sine(train_length, predict_length, test_size=0, n_examples=1)

# x and y should form a continuous sequence
# This is a shape test since exact continuity depends on implementation
self.assertEqual(x.shape[1], train_length)
self.assertEqual(y.shape[1], predict_length)

def test_air_passengers_normalization(self):
"""Test that air passengers data is properly normalized"""
x, y = get_air_passengers(10, 4, test_size=0)

# Values should be normalized (roughly between -1 and 1 after normalization)
self.assertTrue(np.all(x >= -1.5))
self.assertTrue(np.all(x <= 1.5))


if __name__ == "__main__":
unittest.main()
Loading
Loading