Skip to content

Commit 0e1448b

Browse files
committed
add tests for MetaEpiModel
1 parent 9d63e16 commit 0e1448b

File tree

3 files changed

+61
-53
lines changed

3 files changed

+61
-53
lines changed

src/epidemik/MetaEpiModel.py

+21-52
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class MetaEpiModel:
2222
2323
Provides a way to implement and numerically integrate
2424
"""
25-
def __init__(self, travel_graph, populations):
25+
def __init__(self, travel_graph, populations, population='Population'):
2626
"""
2727
Initialize the EpiModel object
2828
@@ -35,6 +35,7 @@ def __init__(self, travel_graph, populations):
3535
"""
3636
self.travel_graph = travel_graph
3737
self.populations = populations
38+
self.population = population
3839

3940
models = {}
4041

@@ -122,6 +123,10 @@ def add_vaccination(self, source, target, rate, start):
122123
for state in self.models:
123124
self.models[state].add_vaccination(source, target, rate, start)
124125

126+
def R0(self):
127+
key = list(self.models.keys())[0]
128+
return self.models[key].R0()
129+
125130
def get_state(self, state):
126131
"""
127132
Return a reference to a state EpiModel object
@@ -133,14 +138,16 @@ def get_state(self, state):
133138

134139
return self.models[state]
135140

136-
def _initialize_populations(self, susceptible):
141+
def _initialize_populations(self, susceptible, population=None):
137142
columns = list(self.transitions.nodes())
138143
self.compartments_ = pd.DataFrame(np.zeros((self.travel_graph.shape[0], len(columns)), dtype='int'), columns=columns)
139144
self.compartments_.index = self.populations.index
140145

141-
susceptible = list(self.prototype._get_susceptible())[0]
146+
if population is None:
147+
population = self.population
142148

143-
self.compartments_.loc[:, susceptible] = self.populations['Population']
149+
for state in self.compartments_.index:
150+
self.compartments_.loc[state, susceptible] = self.populations.loc[state, population]
144151

145152
def _run_travel(self, compartments_, travel):
146153
def travel_step(x, populations):
@@ -160,9 +167,9 @@ def travel_step(x, populations):
160167

161168
return new_compartments
162169

163-
def _run_spread(self, models, compartments_, seasonality):
170+
def _run_spread(self):
164171
for state in self.compartments_.index:
165-
pop = dict(self.compartments_.loc[state].to_dict())
172+
pop = self.compartments_.loc[state].to_dict()
166173
self.models[state].single_step(**pop)
167174
self.compartments_.loc[state] = self.models[state].values_.iloc[[-1]].values[0]
168175

@@ -180,9 +187,15 @@ def simulate(self, timestamp, t_min=1, seasonality=None, seed_state=None, suscep
180187
self.compartments_.loc[seed_state, susceptible] -= kwargs[comp]
181188

182189
for t in tqdm(range(t_min, timestamp+1), total=timestamp):
183-
self._run_spread(self.models, self.compartments_, self.seasonality)
190+
self._run_spread()
184191
self.compartments_ = self._run_travel(self.compartments_, self.travel_graph)
185192

193+
def integrate(self, **kwargs):
194+
raise NotImplementedError("MetaEpiModel doesn't support direct integration of the ODE")
195+
196+
def draw_model(self):
197+
return self.models.iloc[0].draw_model()
198+
186199
def plot(self, title=None, normed=True, layout=None, **kwargs):
187200
if layout is None:
188201
n_pop = self.travel_graph.shape[0]
@@ -288,48 +301,4 @@ def plot_peaks(self):
288301
ax.set_xticks(np.arange(0, peaks.shape[1], 3))
289302
ax.set_xticklabels(np.arange(0, peaks.shape[1], 3), fontsize=10)
290303
# ax.set_aspect(1)
291-
fig.patch.set_facecolor('#FFFFFF')
292-
293-
if __name__ == '__main__':
294-
295-
Nk_uk = pd.read_csv("data/United Kingdom-2020.csv", index_col=0)
296-
Nk_ke = pd.read_csv("data/Kenya-2020.csv", index_col=0)
297-
298-
contacts_uk = pd.read_excel("data/MUestimates_all_locations_2.xlsx", sheet_name="United Kingdom of Great Britain", header=None)
299-
contacts_ke = pd.read_excel("data/MUestimates_all_locations_1.xlsx", sheet_name="Kenya")
300-
301-
beta = 0.05
302-
mu = 0.1
303-
304-
SIR_uk = EpiModel()
305-
SIR_uk.add_interaction('S', 'I', 'I', beta)
306-
SIR_uk.add_spontaneous('I', 'R', mu)
307-
308-
309-
SIR_ke = EpiModel()
310-
SIR_ke.add_interaction('S', 'I', 'I', beta)
311-
SIR_ke.add_spontaneous('I', 'R', mu)
312-
313-
N_uk = int(Nk_uk.sum())
314-
N_ke = int(Nk_ke.sum())
315-
316-
317-
SIR_uk.add_age_structure(contacts_uk, Nk_uk)
318-
SIR_ke.add_age_structure(contacts_ke, Nk_ke)
319-
320-
SIR_uk.integrate(100, S=N_uk*.99, I=N_uk*.01, R=0)
321-
SIR_ke.integrate(100, S=N_ke*.99, I=N_ke*.01, R=0)
322-
323-
fig, ax = plt.subplots(1)
324-
SIR_uk.draw_model(ax)
325-
fig.savefig('SIR_model.png', dpi=300, facecolor='white')
326-
327-
fig, ax = plt.subplots(1)
328-
329-
(SIR_uk['I']*100/N_uk).plot(ax=ax)
330-
(SIR_ke['I']*100/N_ke).plot(ax=ax)
331-
ax.legend(['UK', 'Kenya'])
332-
ax.set_xlabel('Time')
333-
ax.set_ylabel('Population (%)')
334-
335-
fig.savefig('SIR_age.png', dpi=300, facecolor='white')
304+
fig.patch.set_facecolor('#FFFFFF')

src/epidemik/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
from .NetworkEpiModel import NetworkEpiModel
1313
from .MetaEpiModel import MetaEpiModel
1414

15-
__version__ = "0.0.18"
15+
__version__ = "0.0.19"

tests/tests_MetaEpiModel.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
import pandas as pd
3+
from epidemik import MetaEpiModel
4+
from epidemik.utils import NotInitialized
5+
6+
class MetaEpiModelTestCase(unittest.TestCase):
7+
def setUp(self):
8+
self.travel = pd.DataFrame({'A': [0.9, 0.1], 'B':[0.1, 0.9]}, index=["A", "B"])
9+
self.population = pd.DataFrame({'Population':[100000, 10000]}, index=["A", "B"])
10+
11+
self.SIR = MetaEpiModel(self.travel, self.population)
12+
self.beta = 0.3
13+
self.mu = 0.1
14+
self.SIR.add_interaction('S', 'I', 'I', self.beta)
15+
self.SIR.add_spontaneous('I', 'R', self.mu)
16+
17+
def test_number_populations(self):
18+
self.assertEqual(self.SIR.travel_graph.shape[0], 2)
19+
20+
def test_R0(self):
21+
self.assertEqual(self.SIR.R0(), 3.0, 'incorrect R0')
22+
23+
def test_simulate(self):
24+
with self.assertRaises(NotInitialized) as _:
25+
self.SIR.simulate(10)
26+
27+
def test_initialize_populations(self):
28+
self.SIR._initialize_populations('S')
29+
self.assertEqual(self.SIR.compartments_['S'].sum(), self.population['Population'].sum())
30+
31+
def test_travel(self):
32+
self.SIR._initialize_populations('S')
33+
new_compartments = self.SIR._run_travel(self.SIR.compartments_, self.travel)
34+
35+
self.assertEqual(new_compartments.sum().sum(), self.population['Population'].sum())
36+
37+
def test_integrate(self):
38+
with self.assertRaises(NotImplementedError) as _:
39+
self.SIR.integrate()

0 commit comments

Comments
 (0)