@@ -47,10 +47,12 @@ class BaseGroupTester(BaseTester, TesterMixin):
47
47
def __init__ (self ,
48
48
env : BackTestEnv = None ,
49
49
n_groups : int = 5 ,
50
+ weighting : str = 'equal' ,
50
51
exclude_suspended : bool = False ,
51
52
exclude_limits : bool = False ):
52
53
super ().__init__ (env )
53
54
self .n_groups = n_groups
55
+ self .weighting = weighting
54
56
self .exclude_limits = exclude_limits
55
57
self .exclude_suspended = exclude_suspended
56
58
self .returns = None
@@ -105,9 +107,12 @@ def plot(self):
105
107
106
108
class GroupTester01 (BaseGroupTester ):
107
109
def __init__ (self ,
108
- env : BackTestEnv ,
109
- n_groups : int = 5 ):
110
- super ().__init__ (env , n_groups )
110
+ env : BackTestEnv = None ,
111
+ n_groups : int = 5 ,
112
+ weighting : str = 'equal' ,
113
+ exclude_suspended : bool = False ,
114
+ exclude_limits : bool = False ):
115
+ super ().__init__ (env , n_groups , weighting , exclude_suspended , exclude_limits )
111
116
112
117
def run_backtest (self , modified_factor ) -> None :
113
118
assert modified_factor .shape == self .env ['Close' ].shape
@@ -133,9 +138,12 @@ def run_backtest(self, modified_factor) -> None:
133
138
134
139
def temp (x ):
135
140
# TODO: develop a weight_scheme class
136
- weight = x ['MktVal' ] / x ['MktVal' ].sum ()
137
- # weight = 1 / len(x['MktVal'])
138
- # weights.append(weight)
141
+ if self .weighting == 'equal' :
142
+ weight = 1 / len (x ['MktVal' ])
143
+ elif self .weighting == 'market_cap' :
144
+ weight = x ['MktVal' ] / x ['MktVal' ].sum ()
145
+ else :
146
+ raise ValueError ('Invalid weight scheme' )
139
147
ret = x ['_FutureReturn' ]
140
148
return (weight * ret ).sum ()
141
149
group_return = temp_data .groupby ('group' ).apply (temp )
@@ -148,6 +156,55 @@ def temp(x):
148
156
self .returns .columns .name = "group"
149
157
150
158
159
+ class GroupTester02 (BaseGroupTester ):
160
+ def __init__ (self ,
161
+ env : BackTestEnv = None ,
162
+ n_groups : int = 5 ,
163
+ weighting : str = 'equal' ,
164
+ exclude_suspended : bool = False ,
165
+ exclude_limits : bool = False ):
166
+ super ().__init__ (env , n_groups , weighting , exclude_suspended , exclude_limits )
167
+
168
+ def run_backtest (self , modified_factor ) -> None :
169
+ assert modified_factor .shape == self .env ['Close' ].shape
170
+ self ._reset ()
171
+ labels = ["group_" + str (i + 1 ) for i in range (self .n_groups )]
172
+ returns = []
173
+ for i in range (len (modified_factor )- 1 ):
174
+ # If you are confused about concat series, you apply use the following way
175
+ # 1. series.unsqueeze(1) to generate an additional axes
176
+ # 2. concat these series along axis1
177
+ temp_data = pd .concat ([self .env .forward_returns .iloc [i ],
178
+ self .env .MktVal .iloc [i ],
179
+ modified_factor .iloc [i ]], axis = 1 )
180
+ temp_data .columns = ['forward_returns' , 'MktVal' , 'modified_factor' ]
181
+ # na stands for stocks that we you not insterested in
182
+ # We can develop a class to better represent this process.
183
+ temp_data = temp_data .loc [~ np .isnan (temp_data ['modified_factor' ])]
184
+ if len (temp_data ) == 0 :
185
+ group_return = pd .Series (0 , index = labels )
186
+ else :
187
+ temp_data ['group' ] = pd .qcut (temp_data ['modified_factor' ], self .n_groups , labels = labels )
188
+
189
+ def temp (x ):
190
+ # TODO: develop a weight_scheme class
191
+ if self .weighting == 'equal' :
192
+ weight = 1 / len (x ['MktVal' ])
193
+ elif self .weighting == 'market_cap' :
194
+ weight = x ['MktVal' ] / x ['MktVal' ].sum ()
195
+ else :
196
+ raise ValueError ('Invalid weight scheme' )
197
+ ret = x ['forward_returns' ]
198
+ return (weight * ret ).sum ()
199
+ group_return = temp_data .groupby ('group' ).apply (temp )
200
+ returns .append (group_return )
201
+ returns .append (pd .Series (np .repeat (0 , self .n_groups ), index = group_return .index ))
202
+ self .returns = pd .concat (returns , axis = 1 ).T
203
+ # Here we need to transpose the return, since the rows are stocks.
204
+ self .returns .index = self .rebalance_dates
205
+ self .returns .index .name = "trade_date"
206
+ self .returns .columns .name = "group"
207
+
151
208
# class QuickBackTesting02(BaseTester):
152
209
# def __init__(self,
153
210
# env: BackTestEnv = None,
0 commit comments