Skip to content

Commit e12b508

Browse files
committed
Provide a cleaner interface.
1 parent 7867ad1 commit e12b508

File tree

5 files changed

+90
-26
lines changed

5 files changed

+90
-26
lines changed

examples/main.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchqtm.utils.universe import StaticUniverse, IndexComponents
1212
from torchqtm.utils.warnings import catch_warnings
1313
from torchqtm.utils.benchmark import BenchMark
14-
from torchqtm.vbt.backtest import GroupTester01
14+
from torchqtm.vbt.backtest import GroupTester01, GroupTester02
1515
from torchqtm.alphas.alpha101 import *
1616
import torchqtm.op as op
1717
import torchqtm.op.functional as F
@@ -34,7 +34,6 @@ def __init__(self, env):
3434

3535
def forward(self):
3636
self.data = F.divide(1, self.env.PE)
37-
self.data = self.data.astype(np.float64)
3837
self.data = F.winsorize(self.data, 'std', 4)
3938
self.data = F.normalize(self.data)
4039
self.data = F.group_neutralize(self.data, self.env.Sector)
@@ -103,8 +102,8 @@ def aux_func(data_slice):
103102
self.data = F.winsorize(self.data, 'std', 4)
104103
self.data = F.normalize(self.data)
105104
self.data = pd.DataFrame(self.data, index=self.env.Close.index, columns=self.env.Close.columns)
106-
cond = F.geq(F.ts_mean(np.squeeze(Close, -1), 5), F.ts_mean(np.squeeze(Close, -1), 22))
107-
self.data = F.trade_when(cond, self.data, False)
105+
cond = F.geq(F.ts_mean(self.close, 5), F.ts_mean(self.close, 22))
106+
# self.data = F.trade_when(cond, self.data, False)
108107
self.data = F.group_neutralize(self.data, self.env.Sector)
109108
self.data = F.regression_neut(self.data, self.env.MktVal)
110109
return self.data
@@ -134,15 +133,18 @@ def load_data():
134133
# Create alpha
135134
# alphas = Momentum01(env=btEnv0)
136135
# alphas = NeutralizePE(env=btEnv0)
137-
alphas = Alpha060(env=btEnv0)
136+
alphas = Alpha055(env=btEnv0)
138137
# alphas = Ross(env=btEnv0)
139138
# alphas.forward(btEnv.match_env(dfs['PE']))
140139
with Timer():
141140
with catch_warnings():
142141
alphas.forward()
143142
# run backtest
144-
bt = GroupTester01(env=btEnv,
145-
n_groups=5)
143+
bt = GroupTester02(env=btEnv,
144+
n_groups=5,
145+
weighting='equal',
146+
exclude_suspended=False,
147+
exclude_limits=False)
146148

147149
with Timer():
148150
bt.run_backtest(bt.env.match_env(F.purify(alphas.data)))

quant/backtest.py

Whitespace-only changes.

torchqtm/base.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict, Hashable
66
from abc import ABCMeta, abstractmethod
77
from typing import Iterable
8+
from collections import OrderedDict
89

910

1011
class BackTestEnv(object):
@@ -39,9 +40,10 @@ def __init__(self,
3940
self.MktVal = None
4041
self.PE = None
4142
self.Sector = None
42-
self._FutureReturn = None
43+
self.forward_returns = None
4344
self._create_datas()
4445
self._create_features()
46+
self._create_forward_returns()
4547

4648
def _check_dfs(self):
4749
"""
@@ -53,30 +55,32 @@ def _check_dfs(self):
5355
assert 'Sector' in self.dfs
5456

5557
def _create_datas(self):
56-
self.datas = {}
58+
self.data = {}
5759
for key in self.dfs:
5860
if isinstance(self.dfs[key], pd.DataFrame):
59-
self.datas[key] = self.dfs[key].loc[self.dates, self.symbols]
61+
self.data[key] = self.dfs[key].loc[self.dates, self.symbols]
6062
else:
61-
self.datas[key] = self.dfs[key]
62-
self.datas['_FutureReturn'] = self.datas['Close'].pct_change().shift(-1)
63+
self.data[key] = self.dfs[key]
64+
65+
def _create_forward_returns(self, D=1):
66+
forward_returns = self.data['Close'].pct_change().shift(-1)
67+
setattr(self, 'forward_returns', forward_returns)
6368

6469
def _create_features(self):
6570
"""
6671
Create the reference to the dict values
6772
:return:
6873
"""
69-
for key in self.datas.keys():
70-
setattr(self, key, self.datas[key])
71-
setattr(self, '_FutureReturn', self.datas['_FutureReturn'])
74+
for key in self.data.keys():
75+
setattr(self, key, self.data[key])
7276

7377
def __getitem__(self, item):
7478
"""
7579
Keep the operator[]
7680
:param item:
7781
:return:
7882
"""
79-
return self.datas[item]
83+
return self.data[item]
8084

8185
def __setitem__(self, item, value):
8286
"""
@@ -85,15 +89,15 @@ def __setitem__(self, item, value):
8589
:return:
8690
"""
8791
assert isinstance(value, pd.DataFrame)
88-
self.datas[item] = value
92+
self.data[item] = value
8993
setattr(self, item, value)
9094

9195
def __delitem__(self, item):
92-
del self.datas[item]
96+
del self.data[item]
9397
delattr(self, item)
9498

9599
def __contains__(self, item):
96-
return item in self.datas
100+
return item in self.data
97101

98102
def match_env(self, factor):
99103
return factor.loc[self.dates, self.symbols]

torchqtm/op/functional.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,8 @@ def tail(x, lower=0, upper=0.5, newval=np.nan):
812812
return if_else(cond, x, np.nan)
813813

814814

815-
# TODO: 保持变量类型封闭
815+
# TODO: 保持变量类型封闭, 目前有一堆BUG
816+
# BUG:
816817
def trade_when(trigger: np.ndarray[bool],
817818
alpha: np.ndarray[np.float64],
818819
exit_cond: np.ndarray[bool]):

torchqtm/vbt/backtest.py

+63-6
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ class BaseGroupTester(BaseTester, TesterMixin):
4747
def __init__(self,
4848
env: BackTestEnv = None,
4949
n_groups: int = 5,
50+
weighting: str = 'equal',
5051
exclude_suspended: bool = False,
5152
exclude_limits: bool = False):
5253
super().__init__(env)
5354
self.n_groups = n_groups
55+
self.weighting = weighting
5456
self.exclude_limits = exclude_limits
5557
self.exclude_suspended = exclude_suspended
5658
self.returns = None
@@ -105,9 +107,12 @@ def plot(self):
105107

106108
class GroupTester01(BaseGroupTester):
107109
def __init__(self,
108-
env: BackTestEnv,
109-
n_groups: int = 5):
110-
super().__init__(env, n_groups)
110+
env: BackTestEnv = None,
111+
n_groups: int = 5,
112+
weighting: str = 'equal',
113+
exclude_suspended: bool = False,
114+
exclude_limits: bool = False):
115+
super().__init__(env, n_groups, weighting, exclude_suspended, exclude_limits)
111116

112117
def run_backtest(self, modified_factor) -> None:
113118
assert modified_factor.shape == self.env['Close'].shape
@@ -133,9 +138,12 @@ def run_backtest(self, modified_factor) -> None:
133138

134139
def temp(x):
135140
# TODO: develop a weight_scheme class
136-
weight = x['MktVal'] / x['MktVal'].sum()
137-
# weight = 1 / len(x['MktVal'])
138-
# weights.append(weight)
141+
if self.weighting == 'equal':
142+
weight = 1 / len(x['MktVal'])
143+
elif self.weighting == 'market_cap':
144+
weight = x['MktVal'] / x['MktVal'].sum()
145+
else:
146+
raise ValueError('Invalid weight scheme')
139147
ret = x['_FutureReturn']
140148
return (weight * ret).sum()
141149
group_return = temp_data.groupby('group').apply(temp)
@@ -148,6 +156,55 @@ def temp(x):
148156
self.returns.columns.name = "group"
149157

150158

159+
class GroupTester02(BaseGroupTester):
160+
def __init__(self,
161+
env: BackTestEnv = None,
162+
n_groups: int = 5,
163+
weighting: str = 'equal',
164+
exclude_suspended: bool = False,
165+
exclude_limits: bool = False):
166+
super().__init__(env, n_groups, weighting, exclude_suspended, exclude_limits)
167+
168+
def run_backtest(self, modified_factor) -> None:
169+
assert modified_factor.shape == self.env['Close'].shape
170+
self._reset()
171+
labels = ["group_" + str(i + 1) for i in range(self.n_groups)]
172+
returns = []
173+
for i in range(len(modified_factor)-1):
174+
# If you are confused about concat series, you apply use the following way
175+
# 1. series.unsqueeze(1) to generate an additional axes
176+
# 2. concat these series along axis1
177+
temp_data = pd.concat([self.env.forward_returns.iloc[i],
178+
self.env.MktVal.iloc[i],
179+
modified_factor.iloc[i]], axis=1)
180+
temp_data.columns = ['forward_returns', 'MktVal', 'modified_factor']
181+
# na stands for stocks that we you not insterested in
182+
# We can develop a class to better represent this process.
183+
temp_data = temp_data.loc[~np.isnan(temp_data['modified_factor'])]
184+
if len(temp_data) == 0:
185+
group_return = pd.Series(0, index=labels)
186+
else:
187+
temp_data['group'] = pd.qcut(temp_data['modified_factor'], self.n_groups, labels=labels)
188+
189+
def temp(x):
190+
# TODO: develop a weight_scheme class
191+
if self.weighting == 'equal':
192+
weight = 1 / len(x['MktVal'])
193+
elif self.weighting == 'market_cap':
194+
weight = x['MktVal'] / x['MktVal'].sum()
195+
else:
196+
raise ValueError('Invalid weight scheme')
197+
ret = x['forward_returns']
198+
return (weight * ret).sum()
199+
group_return = temp_data.groupby('group').apply(temp)
200+
returns.append(group_return)
201+
returns.append(pd.Series(np.repeat(0, self.n_groups), index=group_return.index))
202+
self.returns = pd.concat(returns, axis=1).T
203+
# Here we need to transpose the return, since the rows are stocks.
204+
self.returns.index = self.rebalance_dates
205+
self.returns.index.name = "trade_date"
206+
self.returns.columns.name = "group"
207+
151208
# class QuickBackTesting02(BaseTester):
152209
# def __init__(self,
153210
# env: BackTestEnv = None,

0 commit comments

Comments
 (0)