1"""Example of using the workbench together with mesa."""
2
3# Import EMA Workbench modules
4# Necessary packages for the model
5import math
6import sys
7from enum import Enum
8
9import mesa
10import networkx as nx
11import numpy as np
12
13from ema_workbench import (
14 ArrayOutcome,
15 Constant,
16 IntegerParameter,
17 RealParameter,
18 ReplicatorModel,
19 ema_logging,
20 perform_experiments,
21)
22
23# MESA demo model "Virus on a Network", from https://github.com/projectmesa/mesa-examples/blob/d16736778fdb500a3e5e05e082b27db78673b562/examples/virus_on_network/virus_on_network/model.py
24
25
26class State(Enum):
27 """Possible states of an agent."""
28
29 SUSCEPTIBLE = 0
30 INFECTED = 1
31 RESISTANT = 2
32
33
34def number_state(model, state):
35 """Helper function."""
36 return sum(1 for a in model.grid.agents if a.state is state)
37
38
39def number_infected(model):
40 """Helper function."""
41 return number_state(model, State.INFECTED)
42
43
44def number_susceptible(model):
45 """Helper function."""
46 return number_state(model, State.SUSCEPTIBLE)
47
48
49def number_resistant(model):
50 """Helper function."""
51 return number_state(model, State.RESISTANT)
52
53
54class VirusOnNetwork(mesa.Model):
55 """A virus model with some number of agents."""
56
57 def __init__(
58 self,
59 num_nodes: int = 10,
60 avg_node_degree: int = 3,
61 initial_outbreak_size: int = 1,
62 virus_spread_chance: float = 0.4,
63 virus_check_frequency: float = 0.4,
64 recovery_chance: float = 0.3,
65 gain_resistance_chance: float = 0.5,
66 rng: int | None = None,
67 ):
68 """Init."""
69 super().__init__(rng=rng)
70 self.num_nodes = num_nodes
71 prob = avg_node_degree / self.num_nodes
72 graph = nx.erdos_renyi_graph(n=self.num_nodes, p=prob)
73
74 self.grid = mesa.discrete_space.Network(graph, capacity=1, random=self.random)
75
76 self.initial_outbreak_size = (
77 initial_outbreak_size if initial_outbreak_size <= num_nodes else num_nodes
78 )
79
80 self.datacollector = mesa.DataCollector(
81 {
82 "Infected": number_infected,
83 "Susceptible": number_susceptible,
84 "Resistant": number_resistant,
85 }
86 )
87
88 # Create agents
89 VirusAgent.create_agents(
90 self,
91 num_nodes,
92 State.SUSCEPTIBLE,
93 virus_spread_chance,
94 virus_check_frequency,
95 recovery_chance,
96 gain_resistance_chance,
97 list(self.grid.all_cells),
98 )
99
100 # Infect some nodes
101 infected_nodes = mesa.discrete_space.CellCollection(
102 self.random.sample(list(self.grid.all_cells), self.initial_outbreak_size),
103 random=self.random,
104 )
105 for a in infected_nodes.agents:
106 a.state = State.INFECTED
107
108 self.running = True
109 self.datacollector.collect(self)
110
111 def resistant_susceptible_ratio(self):
112 """Calculate ratio of resistant to susceptible."""
113 try:
114 return number_state(self, State.RESISTANT) / number_state(
115 self, State.SUSCEPTIBLE
116 )
117 except ZeroDivisionError:
118 return math.inf
119
120 def step(self):
121 """A single step of the model."""
122 # collect data
123 self.agents.shuffle_do("step")
124 self.datacollector.collect(self)
125
126 def run_model(self, n):
127 """Run the model for n steps."""
128 for _ in range(n):
129 self.step()
130
131
132class VirusAgent(mesa.discrete_space.FixedAgent):
133 """A VirusAgent."""
134
135 def __init__(
136 self,
137 model,
138 initial_state,
139 virus_spread_chance,
140 virus_check_frequency,
141 recovery_chance,
142 gain_resistance_chance,
143 cell,
144 ):
145 """Init."""
146 super().__init__(model)
147
148 self.state = initial_state
149
150 self.virus_spread_chance = virus_spread_chance
151 self.virus_check_frequency = virus_check_frequency
152 self.recovery_chance = recovery_chance
153 self.gain_resistance_chance = gain_resistance_chance
154 self.cell = cell
155
156 def try_infect_neighbors(self):
157 """Try to infect neighbors."""
158 for agent in self.cell.neighborhood.agents:
159 if (agent.state is State.SUSCEPTIBLE) and (
160 self.random.random() < self.virus_spread_chance
161 ):
162 agent.state = State.INFECTED
163
164 def try_gain_resistance(self):
165 """Try to gain resistance."""
166 if self.random.random() < self.gain_resistance_chance:
167 self.state = State.RESISTANT
168
169 def try_remove_infection(self):
170 """Try to remove infection."""
171 # Try to remove
172 if self.random.random() < self.recovery_chance:
173 # Success
174 self.state = State.SUSCEPTIBLE
175 self.try_gain_resistance()
176 else:
177 # Failed
178 self.state = State.INFECTED
179
180 def check_situation(self):
181 """Check if the agent is infected and if so, see if she is cured."""
182 if (self.random.random() < self.virus_check_frequency) and (
183 self.state is State.INFECTED
184 ):
185 self.try_remove_infection()
186
187 def step(self):
188 """A single step of the model."""
189 if self.state is State.INFECTED:
190 self.try_infect_neighbors()
191 self.check_situation()
192
193
194# Setting up the model as a function
195def model_virus_on_network(
196 num_nodes=1,
197 avg_node_degree=1,
198 initial_outbreak_size=1,
199 virus_spread_chance=1,
200 virus_check_frequency=1,
201 recovery_chance=1,
202 gain_resistance_chance=1,
203 steps=10,
204 rng=None,
205):
206 """Run the model once."""
207 # Initialising the model
208 virus_on_network = VirusOnNetwork(
209 num_nodes=num_nodes,
210 avg_node_degree=avg_node_degree,
211 initial_outbreak_size=initial_outbreak_size,
212 virus_spread_chance=virus_spread_chance,
213 virus_check_frequency=virus_check_frequency,
214 recovery_chance=recovery_chance,
215 gain_resistance_chance=gain_resistance_chance,
216 rng=rng,
217 )
218
219 # Run the model steps times
220 virus_on_network.run_model(steps)
221
222 # Get model outcomes
223 outcomes = virus_on_network.datacollector.get_model_vars_dataframe()
224
225 # Return model outcomes
226 return {
227 "Infected": outcomes["Infected"].tolist(),
228 "Susceptible": outcomes["Susceptible"].tolist(),
229 "Resistant": outcomes["Resistant"].tolist(),
230 }
231
232
233if __name__ == "__main__":
234 # Initialize logger to keep track of experiments run
235 ema_logging.log_to_stderr(ema_logging.INFO)
236
237 # Instantiate and pass the model
238 model = ReplicatorModel("VirusOnNetwork", function=model_virus_on_network)
239
240 # Define model parameters and their ranges to be sampled
241 model.uncertainties = [
242 IntegerParameter("num_nodes", 10, 100),
243 IntegerParameter("avg_node_degree", 2, 8),
244 RealParameter("virus_spread_chance", 0.1, 1),
245 RealParameter("virus_check_frequency", 0.1, 1),
246 RealParameter("recovery_chance", 0.1, 1),
247 RealParameter("gain_resistance_chance", 0.1, 1),
248 ]
249
250 # Define model parameters that will remain constant
251 model.constants = [Constant("initial_outbreak_size", 1), Constant("steps", 30)]
252
253 # Define model outcomes
254 model.outcomes = [
255 ArrayOutcome("Infected"),
256 ArrayOutcome("Susceptible"),
257 ArrayOutcome("Resistant"),
258 ]
259
260 # Define the number of replications and the seed for each of then
261 n_replications = 10
262 model.replications = [
263 {"rng": i}
264 for i in np.random.default_rng().integers(sys.maxsize, size=n_replications)
265 ]
266
267 # Run experiments with the aforementioned parameters and outputs
268 experiments, outcomes = perform_experiments(models=model, scenarios=20)