Skip to content

Commit

Permalink
update plotting functions for EpiModel
Browse files Browse the repository at this point in the history
  • Loading branch information
bmtgoncalves committed Apr 12, 2024
1 parent c9e705d commit c7afeaf
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 30 deletions.
Binary file modified images/SIR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/epidemik2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 37 additions & 28 deletions src/epidemik/EpiModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,19 @@ def _new_cases(self, population, time, pos):

return diff

def plot(self, title=None, normed=True, show=True, **kwargs):
def plot(self, title=None, normed=True, show=True, ax=None, **kwargs):
"""
Convenience function for plotting
Parameters:
- title: string, optional
Title of the plot
- normed: bool, optional
- normed: bool, default=True
Whether to normalize the values or not
- ax: matplotlib Axes object, default=None
The Axes object to plot to. If None, a new figure is created.
- show: bool, default=True
Whether to call plt.show() or not
- kwargs: keyword arguments
Additional arguments to pass to the plot function
Expand All @@ -207,10 +211,16 @@ def plot(self, title=None, normed=True, show=True, **kwargs):
try:
if normed:
N = self.values_.iloc[0].sum()
ax = (self.values_/N).plot(**kwargs)
else:
ax = self.values_.plot(**kwargs)

N = 1

if ax is None:
ax = plt.gca()

for comp in self.values_.columns:
(self.values_[comp]/N).plot(c=epi_colors[comp[0]], **kwargs)

ax.legend(self.values_.columns)
ax.set_xlabel('Time')
ax.set_ylabel('Population')

Expand All @@ -221,7 +231,8 @@ def plot(self, title=None, normed=True, show=True, **kwargs):
plt.show()

return ax
except:
except Exception as e:
print(e)
raise NotInitialized('You must call integrate() or simulate() first')

def __getattr__(self, name):
Expand Down Expand Up @@ -462,36 +473,24 @@ def _get_infections(self):
return inf

def draw_model(self, ax=None, show=True):
"""
Plot the model structure
- ax: matplotlib Axes object, default=None
The Axes object to plot to. If None, a new figure is created.
- show: bool, default=True
Whether to call plt.show() or not
"""
try:
from networkx.drawing.nx_agraph import graphviz_layout
pos=graphviz_layout(self.transitions, prog='dot', args='-Grankdir="LR"')
except:
pos=nx.layout.spectral_layout(self.transitions)

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

S_color = colors[0]
E_color = colors[4]
I_color = colors[1]
R_color = colors[2]
D_color = colors[7]
default_color = colors[3]

node_colors = []

for node in self.transitions.nodes():
if node[0] == 'S':
node_colors.append(S_color)
elif node[0] == 'E':
node_colors.append(E_color)
elif node[0] == 'I':
node_colors.append(I_color)
elif node[0] == 'R':
node_colors.append(R_color)
elif node[0] == 'D':
node_colors.append(D_color)
else:
node_colors.append(default_color)
node_colors.append(epi_colors[node[0]])

edge_labels = {}

Expand All @@ -508,7 +507,7 @@ def draw_model(self, ax=None, show=True):


if ax is None:
fig, ax = plt.subplots(1)
fig, ax = plt.subplots(1, figsize=(10, 2))

nx.draw(self.transitions, pos, with_labels=True, arrows=True, node_shape='H',
font_color='k', node_color=node_colors, node_size=1000, ax=ax)
Expand All @@ -518,6 +517,16 @@ def draw_model(self, ax=None, show=True):
plt.show()

def R0(self):
"""
Return the value of the basic reproductive ratio, $R_0$, for the model as defined
The calculation is completely generic as it uses the Next-Generation matrix approach
defined in J. R. Soc Interface 7, 873 (2010)
Returns:
R0 - the value of the largest eigenvalue of the next generation matrix
"""

infected = set()

susceptible = self._get_susceptible()
Expand Down
2 changes: 1 addition & 1 deletion src/epidemik/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
from .NetworkEpiModel import NetworkEpiModel
from .MetaEpiModel import MetaEpiModel

__version__ = "0.0.17"
__version__ = "0.0.18"
11 changes: 10 additions & 1 deletion src/epidemik/utils.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
from collections import defaultdict

class NotInitialized(Exception):
pass
pass

epi_colors = defaultdict(lambda :'#f39019')
epi_colors['S'] = '#51a7f9'
epi_colors['E'] = '#f9e351'
epi_colors['I'] = '#cf51f9'
epi_colors['R'] = '#70bf41'
epi_colors['D'] = '#8b8b8b'

0 comments on commit c7afeaf

Please sign in to comment.