1"""
2
3This example is a proof of principle for how MESA models can be
4controlled using the ema_workbench.
5
6"""
7
8# Import EMA Workbench modules
9from ema_workbench import (
10 ReplicatorModel,
11 RealParameter,
12 BooleanParameter,
13 IntegerParameter,
14 Constant,
15 ArrayOutcome,
16 perform_experiments,
17 save_results,
18 ema_logging,
19)
20
21# Necessary packages for the model
22import math
23from enum import Enum
24import mesa
25import networkx as nx
26
27# MESA demo model "Virus on a Network", from https://github.com/projectmesa/mesa-examples/blob/d16736778fdb500a3e5e05e082b27db78673b562/examples/virus_on_network/virus_on_network/model.py
28
29
30class State(Enum):
31 SUSCEPTIBLE = 0
32 INFECTED = 1
33 RESISTANT = 2
34
35
36def number_state(model, state):
37 return sum(1 for a in model.grid.get_all_cell_contents() if a.state is state)
38
39
40def number_infected(model):
41 return number_state(model, State.INFECTED)
42
43
44def number_susceptible(model):
45 return number_state(model, State.SUSCEPTIBLE)
46
47
48def number_resistant(model):
49 return number_state(model, State.RESISTANT)
50
51
52class VirusOnNetwork(mesa.Model):
53 """A virus model with some number of agents"""
54
55 def __init__(
56 self,
57 num_nodes=10,
58 avg_node_degree=3,
59 initial_outbreak_size=1,
60 virus_spread_chance=0.4,
61 virus_check_frequency=0.4,
62 recovery_chance=0.3,
63 gain_resistance_chance=0.5,
64 ):
65 self.num_nodes = num_nodes
66 prob = avg_node_degree / self.num_nodes
67 self.G = nx.erdos_renyi_graph(n=self.num_nodes, p=prob)
68 self.grid = mesa.space.NetworkGrid(self.G)
69 self.schedule = mesa.time.RandomActivation(self)
70 self.initial_outbreak_size = (
71 initial_outbreak_size if initial_outbreak_size <= num_nodes else num_nodes
72 )
73 self.virus_spread_chance = virus_spread_chance
74 self.virus_check_frequency = virus_check_frequency
75 self.recovery_chance = recovery_chance
76 self.gain_resistance_chance = gain_resistance_chance
77
78 self.datacollector = mesa.DataCollector(
79 {
80 "Infected": number_infected,
81 "Susceptible": number_susceptible,
82 "Resistant": number_resistant,
83 }
84 )
85
86 # Create agents
87 for i, node in enumerate(self.G.nodes()):
88 a = VirusAgent(
89 i,
90 self,
91 State.SUSCEPTIBLE,
92 self.virus_spread_chance,
93 self.virus_check_frequency,
94 self.recovery_chance,
95 self.gain_resistance_chance,
96 )
97 self.schedule.add(a)
98 # Add the agent to the node
99 self.grid.place_agent(a, node)
100
101 # Infect some nodes
102 infected_nodes = self.random.sample(list(self.G), self.initial_outbreak_size)
103 for a in self.grid.get_cell_list_contents(infected_nodes):
104 a.state = State.INFECTED
105
106 self.running = True
107 self.datacollector.collect(self)
108
109 def resistant_susceptible_ratio(self):
110 try:
111 return number_state(self, State.RESISTANT) / number_state(self, State.SUSCEPTIBLE)
112 except ZeroDivisionError:
113 return math.inf
114
115 def step(self):
116 self.schedule.step()
117 # collect data
118 self.datacollector.collect(self)
119
120 def run_model(self, n):
121 for i in range(n):
122 self.step()
123
124
125class VirusAgent(mesa.Agent):
126 def __init__(
127 self,
128 unique_id,
129 model,
130 initial_state,
131 virus_spread_chance,
132 virus_check_frequency,
133 recovery_chance,
134 gain_resistance_chance,
135 ):
136 super().__init__(unique_id, model)
137
138 self.state = initial_state
139
140 self.virus_spread_chance = virus_spread_chance
141 self.virus_check_frequency = virus_check_frequency
142 self.recovery_chance = recovery_chance
143 self.gain_resistance_chance = gain_resistance_chance
144
145 def try_to_infect_neighbors(self):
146 neighbors_nodes = self.model.grid.get_neighborhood(self.pos, include_center=False)
147 susceptible_neighbors = [
148 agent
149 for agent in self.model.grid.get_cell_list_contents(neighbors_nodes)
150 if agent.state is State.SUSCEPTIBLE
151 ]
152 for a in susceptible_neighbors:
153 if self.random.random() < self.virus_spread_chance:
154 a.state = State.INFECTED
155
156 def try_gain_resistance(self):
157 if self.random.random() < self.gain_resistance_chance:
158 self.state = State.RESISTANT
159
160 def try_remove_infection(self):
161 # Try to remove
162 if self.random.random() < self.recovery_chance:
163 # Success
164 self.state = State.SUSCEPTIBLE
165 self.try_gain_resistance()
166 else:
167 # Failed
168 self.state = State.INFECTED
169
170 def try_check_situation(self):
171 if (self.random.random() < self.virus_check_frequency) and (self.state is State.INFECTED):
172 self.try_remove_infection()
173
174 def step(self):
175 if self.state is State.INFECTED:
176 self.try_to_infect_neighbors()
177 self.try_check_situation()
178
179
180# Setting up the model as a function
181def model_virus_on_network(
182 num_nodes=1,
183 avg_node_degree=1,
184 initial_outbreak_size=1,
185 virus_spread_chance=1,
186 virus_check_frequency=1,
187 recovery_chance=1,
188 gain_resistance_chance=1,
189 steps=10,
190):
191 # Initialising the model
192 virus_on_network = VirusOnNetwork(
193 num_nodes=num_nodes,
194 avg_node_degree=avg_node_degree,
195 initial_outbreak_size=initial_outbreak_size,
196 virus_spread_chance=virus_spread_chance,
197 virus_check_frequency=virus_check_frequency,
198 recovery_chance=recovery_chance,
199 gain_resistance_chance=gain_resistance_chance,
200 )
201
202 # Run the model steps times
203 virus_on_network.run_model(steps)
204
205 # Get model outcomes
206 outcomes = virus_on_network.datacollector.get_model_vars_dataframe()
207
208 # Return model outcomes
209 return {
210 "Infected": outcomes["Infected"].tolist(),
211 "Susceptible": outcomes["Susceptible"].tolist(),
212 "Resistant": outcomes["Resistant"].tolist(),
213 }
214
215
216if __name__ == "__main__":
217 # Initialize logger to keep track of experiments run
218 ema_logging.log_to_stderr(ema_logging.INFO)
219
220 # Instantiate and pass the model
221 model = ReplicatorModel("VirusOnNetwork", function=model_virus_on_network)
222
223 # Define model parameters and their ranges to be sampled
224 model.uncertainties = [
225 IntegerParameter("num_nodes", 10, 100),
226 IntegerParameter("avg_node_degree", 2, 8),
227 RealParameter("virus_spread_chance", 0.1, 1),
228 RealParameter("virus_check_frequency", 0.1, 1),
229 RealParameter("recovery_chance", 0.1, 1),
230 RealParameter("gain_resistance_chance", 0.1, 1),
231 ]
232
233 # Define model parameters that will remain constant
234 model.constants = [Constant("initial_outbreak_size", 1), Constant("steps", 30)]
235
236 # Define model outcomes
237 model.outcomes = [
238 ArrayOutcome("Infected"),
239 ArrayOutcome("Susceptible"),
240 ArrayOutcome("Resistant"),
241 ]
242
243 # Define the number of replications
244 model.replications = 5
245
246 # Run experiments with the aforementioned parameters and outputs
247 results = perform_experiments(models=model, scenarios=20)
248
249 # Get the results
250 experiments, outcomes = results