#924 Re-Implementation of BatchRunner

Open Corvince
Showing 1 of 2 files from the diff.

@@ -6,11 +6,120 @@
Loading
6 6
7 7
"""
8 8
import copy
9 -
from itertools import product, count
9 +
import itertools
10 +
import multiprocessing as mp
11 +
import random
12 +
from functools import partial
13 +
from itertools import count, product
14 +
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
15 +
10 16
import pandas as pd
11 17
from tqdm import tqdm
12 18
13 -
import random
19 +
from mesa.datacollection import DataCollector
20 +
from mesa.model import Model
21 +
22 +
ParameterName = str
23 +
ParameterValue = Any
24 +
ModelParameters = Dict[ParameterName, ParameterValue]
25 +
ModelData = Any
26 +
27 +
28 +
def batch_run(
29 +
    model_cls: Type[Model],
30 +
    parameters: Dict[ParameterName, Union[ParameterValue, Iterable[ParameterValue]]],
31 +
    model_reporters: Any = None,
32 +
    agent_reporters: Any = None,
33 +
    nr_processes: Optional[int] = None,
34 +
    iterations: int = 1,
35 +
    max_steps: int = 1000,
36 +
    display_progress: bool = True,
37 +
) -> Dict[Tuple[ParameterValue], ModelData]:
38 +
    """Batch run a model."""
39 +
40 +
    kwargs_list = _make_model_kwargs(parameters)
41 +
    process_func = partial(
42 +
        _model_run_func,
43 +
        model_cls,
44 +
        max_steps=max_steps,
45 +
        model_reporters=model_reporters,
46 +
        agent_reporters=agent_reporters,
47 +
    )
48 +
49 +
    total_iterations = len(kwargs_list) * iterations
50 +
51 +
    results = []
52 +
53 +
    with tqdm(total_iterations, disable=not display_progress) as pbar:
54 +
        if nr_processes == 1:
55 +
            for kwargs in kwargs_list:
56 +
                data = process_func(kwargs)
57 +
                results.extend(data)
58 +
                pbar.update()
59 +
60 +
        else:
61 +
            with mp.Pool(nr_processes) as p:
62 +
                for data in p.imap_unordered(process_func, kwargs_list):
63 +
                    results.extend(data)
64 +
                    pbar.update()
65 +
66 +
    return results
67 +
68 +
69 +
def _make_model_kwargs(
70 +
    parameters: Dict[ParameterName, Union[ParameterValue, Iterable[ParameterValue]]]
71 +
) -> List[ModelParameters]:
72 +
    """Create model kwargs from parameters dictionary."""
73 +
    parameter_list = []
74 +
    for param, values in parameters.items():
75 +
        try:
76 +
            all_values = [(param, value) for value in values]
77 +
        except TypeError:
78 +
            all_values = [(param, values)]
79 +
        parameter_list.append(all_values)
80 +
    all_kwargs = itertools.product(*parameter_list)
81 +
    kwargs_list = [dict(kwargs) for kwargs in all_kwargs]
82 +
    return kwargs_list
83 +
84 +
85 +
def _model_run_func(
86 +
    model_cls: Type[Model],
87 +
    kwargs: ModelParameters,
88 +
    max_steps: int,
89 +
    model_reporters: Any,
90 +
    agent_reporters: Any,
91 +
) -> List[Dict[str, Any]]:
92 +
    """Run a single model run."""
93 +
    model = model_cls(**kwargs)
94 +
    while model.running and model.schedule.steps < max_steps:
95 +
        model.step()
96 +
97 +
    model_data, agent_data = _collect_data(model, model_reporters, agent_reporters)
98 +
99 +
    data = [{**kwargs, **model_data, **agent_datum} for agent_datum in agent_data]
100 +
101 +
    return data
102 +
103 +
104 +
def _collect_data(
105 +
    model: Model,
106 +
    model_reporters: Dict[str, Any] = None,
107 +
    agent_reporters: Dict[str, Any] = None,
108 +
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
109 +
    """Collect model and agent data from a model using mesas datacollector."""
110 +
    dc = DataCollector(model_reporters=model_reporters, agent_reporters=agent_reporters)
111 +
    dc.collect(model)
112 +
113 +
    model_data = {key: value[0] for key, value in dc.model_vars.items()}
114 +
115 +
    agent_data = []
116 +
    raw_agent_data = dc._agent_records.get(model.schedule.steps, [])
117 +
    for data in raw_agent_data:
118 +
        agent_dict = dict(zip(dc.agent_reporters, data[2:]))
119 +
        agent_dict["AgentID"] = data[1]
120 +
        agent_data.append(agent_dict)
121 +
    return model_data, agent_data
122 +
14 123
15 124
try:
16 125
    from pathos.multiprocessing import ProcessPool

Everything is accounted for!

No changes detected that need to be reviewed.
What changes does Codecov check for?
Lines, not adjusted in diff, that have changed coverage data.
Files that introduced coverage data that had none before.
Files that have missing coverage data that once were tracked.
Files Coverage
mesa -2.97% 83.07%
Project Totals (19 files) 83.07%
Loading