example_mesa.py

  1"""
  2
  3This example is a proof of principle for how MESA models can be
  4controlled using the ema_workbench.
  5
  6"""
  7
  8# Import EMA Workbench modules
  9from ema_workbench import (
 10    ReplicatorModel,
 11    RealParameter,
 12    BooleanParameter,
 13    IntegerParameter,
 14    Constant,
 15    ArrayOutcome,
 16    perform_experiments,
 17    save_results,
 18    ema_logging,
 19)
 20
 21# Necessary packages for the model
 22import math
 23from enum import Enum
 24import mesa
 25import networkx as nx
 26
 27# MESA demo model "Virus on a Network", from https://github.com/projectmesa/mesa-examples/blob/d16736778fdb500a3e5e05e082b27db78673b562/examples/virus_on_network/virus_on_network/model.py
 28
 29
 30class State(Enum):
 31    SUSCEPTIBLE = 0
 32    INFECTED = 1
 33    RESISTANT = 2
 34
 35
 36def number_state(model, state):
 37    return sum(1 for a in model.grid.get_all_cell_contents() if a.state is state)
 38
 39
 40def number_infected(model):
 41    return number_state(model, State.INFECTED)
 42
 43
 44def number_susceptible(model):
 45    return number_state(model, State.SUSCEPTIBLE)
 46
 47
 48def number_resistant(model):
 49    return number_state(model, State.RESISTANT)
 50
 51
 52class VirusOnNetwork(mesa.Model):
 53    """A virus model with some number of agents"""
 54
 55    def __init__(
 56        self,
 57        num_nodes=10,
 58        avg_node_degree=3,
 59        initial_outbreak_size=1,
 60        virus_spread_chance=0.4,
 61        virus_check_frequency=0.4,
 62        recovery_chance=0.3,
 63        gain_resistance_chance=0.5,
 64    ):
 65        self.num_nodes = num_nodes
 66        prob = avg_node_degree / self.num_nodes
 67        self.G = nx.erdos_renyi_graph(n=self.num_nodes, p=prob)
 68        self.grid = mesa.space.NetworkGrid(self.G)
 69        self.schedule = mesa.time.RandomActivation(self)
 70        self.initial_outbreak_size = (
 71            initial_outbreak_size if initial_outbreak_size <= num_nodes else num_nodes
 72        )
 73        self.virus_spread_chance = virus_spread_chance
 74        self.virus_check_frequency = virus_check_frequency
 75        self.recovery_chance = recovery_chance
 76        self.gain_resistance_chance = gain_resistance_chance
 77
 78        self.datacollector = mesa.DataCollector(
 79            {
 80                "Infected": number_infected,
 81                "Susceptible": number_susceptible,
 82                "Resistant": number_resistant,
 83            }
 84        )
 85
 86        # Create agents
 87        for i, node in enumerate(self.G.nodes()):
 88            a = VirusAgent(
 89                i,
 90                self,
 91                State.SUSCEPTIBLE,
 92                self.virus_spread_chance,
 93                self.virus_check_frequency,
 94                self.recovery_chance,
 95                self.gain_resistance_chance,
 96            )
 97            self.schedule.add(a)
 98            # Add the agent to the node
 99            self.grid.place_agent(a, node)
100
101        # Infect some nodes
102        infected_nodes = self.random.sample(list(self.G), self.initial_outbreak_size)
103        for a in self.grid.get_cell_list_contents(infected_nodes):
104            a.state = State.INFECTED
105
106        self.running = True
107        self.datacollector.collect(self)
108
109    def resistant_susceptible_ratio(self):
110        try:
111            return number_state(self, State.RESISTANT) / number_state(self, State.SUSCEPTIBLE)
112        except ZeroDivisionError:
113            return math.inf
114
115    def step(self):
116        self.schedule.step()
117        # collect data
118        self.datacollector.collect(self)
119
120    def run_model(self, n):
121        for i in range(n):
122            self.step()
123
124
125class VirusAgent(mesa.Agent):
126    def __init__(
127        self,
128        unique_id,
129        model,
130        initial_state,
131        virus_spread_chance,
132        virus_check_frequency,
133        recovery_chance,
134        gain_resistance_chance,
135    ):
136        super().__init__(unique_id, model)
137
138        self.state = initial_state
139
140        self.virus_spread_chance = virus_spread_chance
141        self.virus_check_frequency = virus_check_frequency
142        self.recovery_chance = recovery_chance
143        self.gain_resistance_chance = gain_resistance_chance
144
145    def try_to_infect_neighbors(self):
146        neighbors_nodes = self.model.grid.get_neighborhood(self.pos, include_center=False)
147        susceptible_neighbors = [
148            agent
149            for agent in self.model.grid.get_cell_list_contents(neighbors_nodes)
150            if agent.state is State.SUSCEPTIBLE
151        ]
152        for a in susceptible_neighbors:
153            if self.random.random() < self.virus_spread_chance:
154                a.state = State.INFECTED
155
156    def try_gain_resistance(self):
157        if self.random.random() < self.gain_resistance_chance:
158            self.state = State.RESISTANT
159
160    def try_remove_infection(self):
161        # Try to remove
162        if self.random.random() < self.recovery_chance:
163            # Success
164            self.state = State.SUSCEPTIBLE
165            self.try_gain_resistance()
166        else:
167            # Failed
168            self.state = State.INFECTED
169
170    def try_check_situation(self):
171        if (self.random.random() < self.virus_check_frequency) and (self.state is State.INFECTED):
172            self.try_remove_infection()
173
174    def step(self):
175        if self.state is State.INFECTED:
176            self.try_to_infect_neighbors()
177        self.try_check_situation()
178
179
180# Setting up the model as a function
181def model_virus_on_network(
182    num_nodes=1,
183    avg_node_degree=1,
184    initial_outbreak_size=1,
185    virus_spread_chance=1,
186    virus_check_frequency=1,
187    recovery_chance=1,
188    gain_resistance_chance=1,
189    steps=10,
190):
191    # Initialising the model
192    virus_on_network = VirusOnNetwork(
193        num_nodes=num_nodes,
194        avg_node_degree=avg_node_degree,
195        initial_outbreak_size=initial_outbreak_size,
196        virus_spread_chance=virus_spread_chance,
197        virus_check_frequency=virus_check_frequency,
198        recovery_chance=recovery_chance,
199        gain_resistance_chance=gain_resistance_chance,
200    )
201
202    # Run the model steps times
203    virus_on_network.run_model(steps)
204
205    # Get model outcomes
206    outcomes = virus_on_network.datacollector.get_model_vars_dataframe()
207
208    # Return model outcomes
209    return {
210        "Infected": outcomes["Infected"].tolist(),
211        "Susceptible": outcomes["Susceptible"].tolist(),
212        "Resistant": outcomes["Resistant"].tolist(),
213    }
214
215
216if __name__ == "__main__":
217    # Initialize logger to keep track of experiments run
218    ema_logging.log_to_stderr(ema_logging.INFO)
219
220    # Instantiate and pass the model
221    model = ReplicatorModel("VirusOnNetwork", function=model_virus_on_network)
222
223    # Define model parameters and their ranges to be sampled
224    model.uncertainties = [
225        IntegerParameter("num_nodes", 10, 100),
226        IntegerParameter("avg_node_degree", 2, 8),
227        RealParameter("virus_spread_chance", 0.1, 1),
228        RealParameter("virus_check_frequency", 0.1, 1),
229        RealParameter("recovery_chance", 0.1, 1),
230        RealParameter("gain_resistance_chance", 0.1, 1),
231    ]
232
233    # Define model parameters that will remain constant
234    model.constants = [Constant("initial_outbreak_size", 1), Constant("steps", 30)]
235
236    # Define model outcomes
237    model.outcomes = [
238        ArrayOutcome("Infected"),
239        ArrayOutcome("Susceptible"),
240        ArrayOutcome("Resistant"),
241    ]
242
243    # Define the number of replications
244    model.replications = 5
245
246    # Run experiments with the aforementioned parameters and outputs
247    results = perform_experiments(models=model, scenarios=20)
248
249    # Get the results
250    experiments, outcomes = results