#815 A cleaner Grid implementation

Open Corvince

@@ -1,81 +1,97 @@
Loading
1 -
# -*- coding: utf-8 -*-
2 1
"""
3 2
Mesa Space Module
4 3
=================
5 4
6 -
Objects used to add a spatial component to a model.
5 +
Classes used to add a spatial component to a model.
6 +
7 +
Grids
8 +
-----
7 9
8 10
Grid: base grid, a simple list-of-lists.
9 11
SingleGrid: grid which strictly enforces one object per cell.
10 12
MultiGrid: extension to Grid where each cell is a set of objects.
11 13
14 +
HexGrid: Extends Grid to handle hexagonal neighbors.
15 +
16 +
Other Spaces
17 +
------------
18 +
19 +
ContinuousSpace: Continuous space where each agent has an arbitrary position.
20 +
NetworkGrid: A Network of nodes based on networkx
21 +
12 22
"""
13 23
# Instruction for PyLint to suppress variable name errors, since we have a
14 24
# good reason to use one-character variable names for x and y.
15 25
# pylint: disable=invalid-name
16 26
17 27
import itertools
28 +
import warnings
29 +
from typing import (
30 +
    Any,
31 +
    Callable,
32 +
    Dict,
33 +
    Iterable,
34 +
    Iterator,
35 +
    List,
36 +
    Optional,
37 +
    Set,
38 +
    Tuple,
39 +
    TypeVar,
40 +
    Union,
41 +
    cast,
42 +
)
18 43
19 44
import numpy as np
20 45
21 -
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
22 -
from .agent import Agent
46 +
from mesa.agent import Agent
23 47
24 48
Coordinate = Tuple[int, int]
25 -
GridContent = Union[Optional[Agent], Set[Agent]]
49 +
GridContent = List[Agent]
26 50
# used in ContinuousSpace
27 51
FloatCoordinate = Union[Tuple[float, float], np.ndarray]
28 52
53 +
F = TypeVar("F", bound=Callable[..., Any])
54 +
29 55
30 -
def accept_tuple_argument(wrapped_function):
56 +
def accept_tuple_argument(wrapped_function: F) -> F:
31 57
    """ Decorator to allow grid methods that take a list of (x, y) coord tuples
32 58
    to also handle a single position, by automatically wrapping tuple in
33 59
    single-item list rather than forcing user to do it.
34 60
35 61
    """
36 62
37 -
    def wrapper(*args: Any):
63 +
    def wrapper(*args: Any) -> Any:
38 64
        if isinstance(args[1], tuple) and len(args[1]) == 2:
39 65
            return wrapped_function(args[0], [args[1]])
40 -
        else:
41 -
            return wrapped_function(*args)
66 +
        return wrapped_function(*args)
42 67
43 -
    return wrapper
68 +
    return cast(F, wrapper)
44 69
45 70
46 -
class Grid:
47 -
    """ Base class for a square grid.
71 +
class MultiGrid:
72 +
    """Grid where each cell can contain more than one object.
48 73
49 74
    Grid cells are indexed by [x][y], where [0][0] is assumed to be the
50 75
    bottom-left and [width-1][height-1] is the top-right. If a grid is
51 76
    toroidal, the top and bottom, and left and right, edges wrap to each other
77 +
    Each position of the grid is referred to by a (x, y) coordinate tuple.
78 +
    You may access the content of a single cell by calling Grid[x, y].
52 79
53 80
    Properties:
54 81
        width, height: The grid's width and height.
55 82
        torus: Boolean which determines whether to treat the grid as a torus.
56 -
        grid: Internal list-of-lists which holds the grid cells themselves.
83 +
        empties: List of currently empty cells.
57 84
58 85
    Methods:
59 86
        get_neighbors: Returns the objects surrounding a given cell.
60 87
        get_neighborhood: Returns the cells surrounding a given cell.
61 -
        get_cell_list_contents: Returns the contents of a list of cells
62 -
            ((x,y) tuples)
63 -
        neighbor_iter: Iterates over position neightbors.
88 +
        get_contents: Returns the contents of a list of cells.
64 89
        coord_iter: Returns coordinates as well as cell contents.
65 90
        place_agent: Positions an agent on the grid, and set its pos variable.
91 +
        remove_agent: Removes an agent from the grid.
66 92
        move_agent: Moves an agent from its current position to a new position.
67 -
        iter_neighborhood: Returns an iterator over cell coordinates that are
68 -
        in the neighborhood of a certain point.
69 93
        torus_adj: Converts coordinate, handles torus looping.
70 -
        out_of_bounds: Determines whether position is off the grid, returns
71 -
        the out of bounds coordinate.
72 -
        iter_cell_list_contents: Returns an iterator of the contents of the
73 -
        cells identified in cell_list.
74 -
        get_cell_list_contents: Returns a list of the contents of the cells
75 -
        identified in cell_list.
76 -
        remove_agent: Removes an agent from the grid.
77 -
        is_cell_empty: Returns a bool of the contents of a cell.
78 -
94 +
        out_of_bounds: Determines whether position is off the grid
79 95
    """
80 96
81 97
    def __init__(self, width: int, height: int, torus: bool) -> None:
@@ -90,37 +106,74 @@
Loading
90 106
        self.width = width
91 107
        self.torus = torus
92 108
93 -
        self.grid = []  # type: List[List[GridContent]]
109 +
        self._grid = []  # type: List[List[GridContent]]
94 110
95 -
        for x in range(self.width):
111 +
        for _ in range(self.width):
96 112
            col = []  # type: List[GridContent]
97 -
            for y in range(self.height):
98 -
                col.append(self.default_val())
99 -
            self.grid.append(col)
113 +
            for _ in range(self.height):
114 +
                col.append([])
115 +
            self._grid.append(col)
100 116
101 117
        # Add all cells to the empties list.
102 -
        self.empties = set(itertools.product(*(range(self.width), range(self.height))))
118 +
        self._empties = set(itertools.product(range(self.width), range(self.height)))
119 +
        self._all_cells = frozenset(self._empties)
120 +
121 +
        # Neighborhood Cache
122 +
        self._neighborhood_cache = dict()  # type: Dict[Any, List[Coordinate]]
103 123
104 124
    @staticmethod
105 125
    def default_val() -> None:
106 -
        """ Default value for new cell elements. """
126 +
        """Default value for new cell elements. """
127 +
        warnings.warn("Not supported anymore", DeprecationWarning)
107 128
        return None
108 129
109 -
    def __getitem__(self, index: int) -> List[GridContent]:
110 -
        return self.grid[index]
130 +
    def __getitem__(self, pos: Coordinate) -> GridContent:
131 +
        """Access contents of a given position."""
132 +
        if isinstance(pos, int):
133 +
            warnings.warn(
134 +
                """Accesing the grid via `grid[x][y]` is deprecated.
135 +
                Use `grid[x, y]` instead.""",
136 +
                category=DeprecationWarning,
137 +
            )
138 +
            return self._grid[pos]
139 +
        return self._get(*pos)
140 +
141 +
    def __setitem__(self, pos: Coordinate, agent: Agent) -> None:
142 +
        """Add agents to a position."""
143 +
        self._get(*pos).append(agent)
144 +
145 +
    def _get(self, row: int, col: int) -> GridContent:
146 +
        """Access content of a given position.
147 +
148 +
        Since we overwrite __getitem__ for SingleGrid and Grid,
149 +
        we have to use this function to always get the internal list.
150 +
        """
151 +
        return self._grid[row][col]
111 152
112 153
    def __iter__(self) -> Iterator[GridContent]:
113 154
        """
114 155
        create an iterator that chains the
115 156
        rows of grid together as if one list:
116 157
        """
117 -
        return itertools.chain(*self.grid)
158 +
        return itertools.chain.from_iterable(self._grid)
118 159
119 160
    def coord_iter(self) -> Iterator[Tuple[GridContent, int, int]]:
120 161
        """ An iterator that returns coordinates as well as cell contents. """
121 162
        for row in range(self.width):
122 163
            for col in range(self.height):
123 -
                yield self.grid[row][col], row, col  # agent, x, y
164 +
                yield self[row, col], row, col  # agent, x, y
165 +
166 +
    @accept_tuple_argument
167 +
    def get_contents(self, cell_list: Iterable[Coordinate]) -> List[GridContent]:
168 +
        """Return a list of the cell contents for a given cell list."""
169 +
        return [self[pos] for pos in cell_list]
170 +
171 +
    @accept_tuple_argument
172 +
    def get_agents(self, cell_list: Iterable[Coordinate]) -> List[Agent]:
173 +
        """Return a list of agents from the given cell list."""
174 +
        contents = self.get_contents(cell_list)
175 +
        agents = itertools.chain.from_iterable(contents)
176 +
        return list(agents)
124 177
125 178
    def neighbor_iter(
126 179
        self, pos: Coordinate, moore: bool = True
@@ -133,8 +186,11 @@
Loading
133 186
                   diagonals) or Von Neumann (only up/down/left/right).
134 187
135 188
        """
136 -
        neighborhood = self.iter_neighborhood(pos, moore=moore)
137 -
        return self.iter_cell_list_contents(neighborhood)
189 +
        warnings.warn(
190 +
            "`neighbor_iter` is deprecated, use `get_neighbors` instead",
191 +
            DeprecationWarning,
192 +
        )
193 +
        yield from self.get_neighbors(pos, moore=moore)
138 194
139 195
    def iter_neighborhood(
140 196
        self,
@@ -163,31 +219,10 @@
Loading
163 219
            including the center).
164 220
165 221
        """
166 -
        x, y = pos
167 -
        coordinates = set()  # type: Set[Coordinate]
168 -
        for dy in range(-radius, radius + 1):
169 -
            for dx in range(-radius, radius + 1):
170 -
                if dx == 0 and dy == 0 and not include_center:
171 -
                    continue
172 -
                # Skip coordinates that are outside manhattan distance
173 -
                if not moore and abs(dx) + abs(dy) > radius:
174 -
                    continue
175 -
                # Skip if not a torus and new coords out of bounds.
176 -
                if not self.torus and (
177 -
                    not (0 <= dx + x < self.width) or not (0 <= dy + y < self.height)
178 -
                ):
179 -
                    continue
180 -
181 -
                px, py = self.torus_adj((x + dx, y + dy))
182 -
183 -
                # Skip if new coords out of bounds.
184 -
                if self.out_of_bounds((px, py)):
185 -
                    continue
186 -
187 -
                coords = (px, py)
188 -
                if coords not in coordinates:
189 -
                    coordinates.add(coords)
190 -
                    yield coords
222 +
        warnings.warn(
223 +
            "`iter_neighborhood` is deprecated, use `get_neighborhood` instead."
224 +
        )
225 +
        yield from self.get_neighborhood(pos, moore, include_center, radius)
191 226
192 227
    def get_neighborhood(
193 228
        self,
@@ -215,15 +250,40 @@
Loading
215 250
            if not including the center).
216 251
217 252
        """
218 -
        return list(self.iter_neighborhood(pos, moore, include_center, radius))
253 +
        cache_key = (pos, moore, include_center, radius)
254 +
        neighborhood = self._neighborhood_cache.get(cache_key, None)
255 +
        if neighborhood is None:
256 +
            x, y = pos
257 +
            coordinates = set()  # type: Set[Coordinate]
258 +
            for dy in range(-radius, radius + 1):
259 +
                for dx in range(-radius, radius + 1):
260 +
                    if dx == 0 and dy == 0 and not include_center:
261 +
                        continue
262 +
                    # Skip coordinates that are outside manhattan distance
263 +
                    if not moore and abs(dx) + abs(dy) > radius:
264 +
                        continue
265 +
                    # Skip if not a torus and new coords out of bounds.
266 +
                    coord = (x + dx, y + dy)
267 +
268 +
                    if self.out_of_bounds(coord):
269 +
                        if not self.torus:
270 +
                            continue
271 +
                        coord = self.torus_adj(coord)
272 +
273 +
                    if coord not in coordinates:
274 +
                        coordinates.add(coord)
275 +
276 +
            neighborhood = sorted(coordinates)
277 +
            self._neighborhood_cache[cache_key] = neighborhood
278 +
        return neighborhood
219 279
220 280
    def iter_neighbors(
221 281
        self,
222 282
        pos: Coordinate,
223 283
        moore: bool,
224 284
        include_center: bool = False,
225 285
        radius: int = 1,
226 -
    ) -> Iterator[GridContent]:
286 +
    ) -> Iterator[Agent]:
227 287
        """ Return an iterator over neighbors to a certain point.
228 288
229 289
        Args:
@@ -243,16 +303,20 @@
Loading
243 303
            (8 and 4 if not including the center).
244 304
245 305
        """
246 -
        neighborhood = self.iter_neighborhood(pos, moore, include_center, radius)
247 -
        return self.iter_cell_list_contents(neighborhood)
306 +
        warnings.warn(
307 +
            "`iter_neighbors` is deprecated, use `get_neighbors` instead",
308 +
            DeprecationWarning,
309 +
        )
310 +
        neighborhood = self.get_neighborhood(pos, moore, include_center, radius)
311 +
        yield from self.get_agents(neighborhood)
248 312
249 313
    def get_neighbors(
250 314
        self,
251 315
        pos: Coordinate,
252 -
        moore: bool,
316 +
        moore: bool = True,
253 317
        include_center: bool = False,
254 318
        radius: int = 1,
255 -
    ) -> List[Coordinate]:
319 +
    ) -> List[Agent]:
256 320
        """ Return a list of neighbors to a certain point.
257 321
258 322
        Args:
@@ -272,17 +336,16 @@
Loading
272 336
            (8 and 4 if not including the center).
273 337
274 338
        """
275 -
        return list(self.iter_neighbors(pos, moore, include_center, radius))
339 +
        neighborhood = self.get_neighborhood(pos, moore, include_center, radius)
340 +
        return self.get_agents(neighborhood)
276 341
277 342
    def torus_adj(self, pos: Coordinate) -> Coordinate:
278 343
        """ Convert coordinate, handling torus looping. """
279 344
        if not self.out_of_bounds(pos):
280 345
            return pos
281 -
        elif not self.torus:
346 +
        if not self.torus:
282 347
            raise Exception("Point out of bounds, and space non-toroidal.")
283 -
        else:
284 -
            x, y = pos[0] % self.width, pos[1] % self.height
285 -
        return x, y
348 +
        return pos[0] % self.width, pos[1] % self.height
286 349
287 350
    def out_of_bounds(self, pos: Coordinate) -> bool:
288 351
        """
@@ -304,7 +367,11 @@
Loading
304 367
            An iterator of the contents of the cells identified in cell_list
305 368
306 369
        """
307 -
        return (self[x][y] for x, y in cell_list if not self.is_cell_empty((x, y)))
370 +
        warnings.warn(
371 +
            "`iter_cell_list_contents is deprecated, use `get_agents` instead",
372 +
            DeprecationWarning,
373 +
        )
374 +
        yield from self.get_agents(cell_list)
308 375
309 376
    @accept_tuple_argument
310 377
    def get_cell_list_contents(
@@ -318,9 +385,13 @@
Loading
318 385
            A list of the contents of the cells identified in cell_list
319 386
320 387
        """
321 -
        return list(self.iter_cell_list_contents(cell_list))
388 +
        warnings.warn(
389 +
            "`iter_cell_list_contents is deprecated, use `get_agents` instead",
390 +
            DeprecationWarning,
391 +
        )
392 +
        return self.get_agents(cell_list)
322 393
323 -
    def move_agent(self, agent: Agent, pos: Coordinate) -> None:
394 +
    def move_agent(self, agent: Agent, pos: Coordinate) -> Agent:
324 395
        """
325 396
        Move an agent from its current position to a new position.
326 397
@@ -331,47 +402,37 @@
Loading
331 402
332 403
        """
333 404
        pos = self.torus_adj(pos)
334 -
        self._remove_agent(agent.pos, agent)
335 -
        self._place_agent(pos, agent)
336 -
        agent.pos = pos
405 +
        self.remove_agent(agent)
406 +
        self.place_agent(agent, pos)
407 +
        return agent
337 408
338 409
    def place_agent(self, agent: Agent, pos: Coordinate) -> None:
339 410
        """ Position an agent on the grid, and set its pos variable. """
340 -
        self._place_agent(pos, agent)
341 -
        agent.pos = pos
342 -
343 -
    def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
344 -
        """ Place the agent at the correct location. """
345 -
        x, y = pos
346 -
        self.grid[x][y] = agent
347 -
        self.empties.discard(pos)
411 +
        self._get(*pos).append(agent)
412 +
        self._empties.discard(pos)
413 +
        setattr(agent, "pos", pos)
414 +
        return agent
348 415
349 416
    def remove_agent(self, agent: Agent) -> None:
350 417
        """ Remove the agent from the grid and set its pos variable to None. """
351 -
        pos = agent.pos
352 -
        self._remove_agent(pos, agent)
353 -
        agent.pos = None
354 -
355 -
    def _remove_agent(self, pos: Coordinate, agent: Agent) -> None:
356 -
        """ Remove the agent from the given location. """
357 -
        x, y = pos
358 -
        self.grid[x][y] = None
359 -
        self.empties.add(pos)
418 +
        pos = getattr(agent, "pos")
419 +
        content = self._get(*pos)
420 +
        content.remove(agent)
421 +
        if not content:
422 +
            self._empties.add(pos)
423 +
        setattr(agent, "pos", None)
424 +
        return agent
360 425
361 426
    def is_cell_empty(self, pos: Coordinate) -> bool:
362 427
        """ Returns a bool of the contents of a cell. """
363 -
        x, y = pos
364 -
        return self.grid[x][y] == self.default_val()
428 +
        return not bool(self._get(*pos))
365 429
366 430
    def move_to_empty(self, agent: Agent) -> None:
367 431
        """ Moves agent to a random empty cell, vacating agent's old cell. """
368 -
        pos = agent.pos
369 -
        if len(self.empties) == 0:
432 +
        if len(self._empties) == 0:
370 433
            raise Exception("ERROR: No empty cells")
371 -
        new_pos = agent.random.choice(sorted(self.empties))
372 -
        self._place_agent(new_pos, agent)
373 -
        agent.pos = new_pos
374 -
        self._remove_agent(pos, agent)
434 +
        new_pos = agent.random.choice(self.empties)
435 +
        self.move_agent(agent, new_pos)
375 436
376 437
    def find_empty(self) -> Optional[Coordinate]:
377 438
        """ Pick a random empty cell. """
@@ -389,7 +450,7 @@
Loading
389 450
        )
390 451
391 452
        if self.exists_empty_cells():
392 -
            pos = random.choice(sorted(self.empties))
453 +
            pos = random.choice(self.empties)
393 454
            return pos
394 455
        else:
395 456
            return None
@@ -398,11 +459,17 @@
Loading
398 459
        """ Return True if any cells empty else False. """
399 460
        return len(self.empties) > 0
400 461
462 +
    @property
463 +
    def empties(self) -> List[Coordinate]:
464 +
        return sorted(self._empties)
401 465
402 -
class SingleGrid(Grid):
403 -
    """ Grid where each cell contains exactly at most one object. """
466 +
    @property
467 +
    def all_cells(self) -> List[Coordinate]:
468 +
        return sorted(self._all_cells)
404 469
405 -
    empties = set()  # type: Set[Coordinate]
470 +
471 +
class SingleGrid(MultiGrid):
472 +
    """ Grid where each cell contains exactly at most one object. """
406 473
407 474
    def __init__(self, width: int, height: int, torus: bool) -> None:
408 475
        """ Create a new single-item grid.
@@ -414,87 +481,51 @@
Loading
414 481
        """
415 482
        super().__init__(width, height, torus)
416 483
484 +
    def __getitem__(self, pos: Coordinate) -> Optional[Agent]:
485 +
        if isinstance(pos, int):
486 +
            warnings.warn(
487 +
                """Accesing the grid via `grid[x][y]` is deprecated.
488 +
                Use `grid[x, y]` instead.""",
489 +
                category=DeprecationWarning,
490 +
            )
491 +
        content = self._get(*pos)
492 +
        return content[0] if content else None
493 +
417 494
    def position_agent(
418 -
        self, agent: Agent, x: Union[int, str] = "random", y: Union[int, str] = "random"
495 +
        self, agent: Agent, x: Union[str, int] = "random", y: Union[str, int] = "random"
419 496
    ) -> None:
420 497
        """ Position an agent on the grid.
421 498
        This is used when first placing agents! Use 'move_to_empty()'
422 499
        when you want agents to jump to an empty cell.
423 -
        Use 'swap_pos()' to swap agents positions.
424 500
        If x or y are positive, they are used, but if "random",
425 501
        we get a random position.
426 502
        Ensure this random position is not occupied (in Grid).
427 503
428 504
        """
505 +
        # TODO: Allow to use only one random value
429 506
        if x == "random" or y == "random":
430 -
            if len(self.empties) == 0:
507 +
            if len(self._empties) == 0:
431 508
                raise Exception("ERROR: Grid full")
432 -
            coords = agent.random.choice(sorted(self.empties))
509 +
            coords = agent.random.choice(self.empties)  # type: Tuple[int, int]
433 510
        else:
434 -
            coords = (x, y)
511 +
            coords = (int(x), int(y))
435 512
        agent.pos = coords
436 -
        self._place_agent(coords, agent)
437 -
438 -
    def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
439 -
        if self.is_cell_empty(pos):
440 -
            super()._place_agent(pos, agent)
441 -
        else:
442 -
            raise Exception("Cell not empty")
443 -
444 -
445 -
class MultiGrid(Grid):
446 -
    """ Grid where each cell can contain more than one object.
447 -
448 -
    Grid cells are indexed by [x][y], where [0][0] is assumed to be at
449 -
    bottom-left and [width-1][height-1] is the top-right. If a grid is
450 -
    toroidal, the top and bottom, and left and right, edges wrap to each other.
451 -
452 -
    Each grid cell holds a set object.
453 -
454 -
    Properties:
455 -
        width, height: The grid's width and height.
456 -
457 -
        torus: Boolean which determines whether to treat the grid as a torus.
458 -
459 -
        grid: Internal list-of-lists which holds the grid cells themselves.
460 -
461 -
    Methods:
462 -
        get_neighbors: Returns the objects surrounding a given cell.
463 -
    """
464 -
465 -
    @staticmethod
466 -
    def default_val() -> Set[Agent]:
467 -
        """ Default value for new cell elements. """
468 -
        return set()
469 -
470 -
    def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
471 -
        """ Place the agent at the correct location. """
472 -
        x, y = pos
473 -
        self.grid[x][y].add(agent)
474 -
        self.empties.discard(pos)
475 -
476 -
    def _remove_agent(self, pos: Coordinate, agent: Agent) -> None:
477 -
        """ Remove the agent from the given location. """
478 -
        x, y = pos
479 -
        self.grid[x][y].remove(agent)
480 -
        if self.is_cell_empty(pos):
481 -
            self.empties.add(pos)
513 +
        self.place_agent(agent, coords)
482 514
483 515
    @accept_tuple_argument
484 -
    def iter_cell_list_contents(
485 -
        self, cell_list: Iterable[Coordinate]
486 -
    ) -> Iterator[GridContent]:
487 -
        """
488 -
        Args:
489 -
            cell_list: Array-like of (x, y) tuples, or single tuple.
516 +
    def get_agents(self, cell_list: Iterable[Coordinate]) -> List[Agent]:
517 +
        """Return a list of agents from the given cell list."""
518 +
        return list(filter(None, self.get_contents(cell_list)))
490 519
491 -
        Returns:
492 -
            A iterator of the contents of the cells identified in cell_list
493 520
494 -
        """
495 -
        return itertools.chain.from_iterable(
496 -
            self[x][y] for x, y in cell_list if not self.is_cell_empty((x, y))
497 -
        )
521 +
class Grid(SingleGrid):
522 +
    """ Grid where each cell contains exactly at most one object."""
523 +
524 +
    def place_agent(self, agent: Agent, pos: Coordinate) -> Agent:
525 +
        if not self.is_cell_empty(pos):
526 +
            self._get(*pos).clear()
527 +
            self._empties.add(pos)
528 +
        return super().place_agent(agent, pos)
498 529
499 530
500 531
class HexGrid(Grid):

Everything is accounted for!

No changes detected that need to be reviewed.
What changes does Codecov check for?
Lines, not adjusted in diff, that have changed coverage data.
Files that introduced coverage data that had none before.
Files that have missing coverage data that once were tracked.
Files Coverage
mesa -0.47% 84.25%
Project Totals (17 files) 84.25%
Loading