1
"""
2
Mesa Space Module
3
=================
4

5
Objects used to add a spatial component to a model.
6

7
Grid: base grid, a simple list-of-lists.
8
SingleGrid: grid which strictly enforces one object per cell.
9
MultiGrid: extension to Grid where each cell is a set of objects.
10

11
"""
12
# Instruction for PyLint to suppress variable name errors, since we have a
13
# good reason to use one-character variable names for x and y.
14
# pylint: disable=invalid-name
15

16 3
import itertools
17

18 3
import numpy as np
19

20 3
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
21 3
from mesa.agent import Agent
22

23 3
Coordinate = Tuple[int, int]
24 3
GridContent = Union[Optional[Agent], Set[Agent]]
25
# used in ContinuousSpace
26 3
FloatCoordinate = Union[Tuple[float, float], np.ndarray]
27

28

29 3
def accept_tuple_argument(wrapped_function):
30
    """Decorator to allow grid methods that take a list of (x, y) coord tuples
31
    to also handle a single position, by automatically wrapping tuple in
32
    single-item list rather than forcing user to do it.
33

34
    """
35

36 3
    def wrapper(*args: Any):
37 3
        if isinstance(args[1], tuple) and len(args[1]) == 2:
38 3
            return wrapped_function(args[0], [args[1]])
39
        else:
40 3
            return wrapped_function(*args)
41

42 3
    return wrapper
43

44

45 3
class Grid:
46
    """Base class for a square grid.
47

48
    Grid cells are indexed by [x][y], where [0][0] is assumed to be the
49
    bottom-left and [width-1][height-1] is the top-right. If a grid is
50
    toroidal, the top and bottom, and left and right, edges wrap to each other
51

52
    Properties:
53
        width, height: The grid's width and height.
54
        torus: Boolean which determines whether to treat the grid as a torus.
55
        grid: Internal list-of-lists which holds the grid cells themselves.
56

57
    Methods:
58
        get_neighbors: Returns the objects surrounding a given cell.
59
        get_neighborhood: Returns the cells surrounding a given cell.
60
        get_cell_list_contents: Returns the contents of a list of cells
61
            ((x,y) tuples)
62
        neighbor_iter: Iterates over position neightbors.
63
        coord_iter: Returns coordinates as well as cell contents.
64
        place_agent: Positions an agent on the grid, and set its pos variable.
65
        move_agent: Moves an agent from its current position to a new position.
66
        iter_neighborhood: Returns an iterator over cell coordinates that are
67
        in the neighborhood of a certain point.
68
        torus_adj: Converts coordinate, handles torus looping.
69
        out_of_bounds: Determines whether position is off the grid, returns
70
        the out of bounds coordinate.
71
        iter_cell_list_contents: Returns an iterator of the contents of the
72
        cells identified in cell_list.
73
        get_cell_list_contents: Returns a list of the contents of the cells
74
        identified in cell_list.
75
        remove_agent: Removes an agent from the grid.
76
        is_cell_empty: Returns a bool of the contents of a cell.
77

78
    """
79

80 3
    def __init__(self, width: int, height: int, torus: bool) -> None:
81
        """Create a new grid.
82

83
        Args:
84
            width, height: The width and height of the grid
85
            torus: Boolean whether the grid wraps or not.
86

87
        """
88 3
        self.height = height
89 3
        self.width = width
90 3
        self.torus = torus
91

92 3
        self.grid = []  # type: List[List[GridContent]]
93

94 3
        for x in range(self.width):
95 3
            col = []  # type: List[GridContent]
96 3
            for y in range(self.height):
97 3
                col.append(self.default_val())
98 3
            self.grid.append(col)
99

100
        # Add all cells to the empties list.
101 3
        self.empties = set(itertools.product(*(range(self.width), range(self.height))))
102

103 3
    @staticmethod
104 3
    def default_val() -> None:
105
        """ Default value for new cell elements. """
106 3
        return None
107

108 3
    def __getitem__(self, index: int) -> List[GridContent]:
109 3
        return self.grid[index]
110

111 3
    def __iter__(self) -> Iterator[GridContent]:
112
        """
113
        create an iterator that chains the
114
        rows of grid together as if one list:
115
        """
116 0
        return itertools.chain(*self.grid)
117

118 3
    def coord_iter(self) -> Iterator[Tuple[GridContent, int, int]]:
119
        """ An iterator that returns coordinates as well as cell contents. """
120 3
        for row in range(self.width):
121 3
            for col in range(self.height):
122 3
                yield self.grid[row][col], row, col  # agent, x, y
123

124 3
    def neighbor_iter(
125
        self, pos: Coordinate, moore: bool = True
126
    ) -> Iterator[GridContent]:
127
        """Iterate over position neighbors.
128

129
        Args:
130
            pos: (x,y) coords tuple for the position to get the neighbors of.
131
            moore: Boolean for whether to use Moore neighborhood (including
132
                   diagonals) or Von Neumann (only up/down/left/right).
133

134
        """
135 3
        neighborhood = self.iter_neighborhood(pos, moore=moore)
136 3
        return self.iter_cell_list_contents(neighborhood)
137

138 3
    def iter_neighborhood(
139
        self,
140
        pos: Coordinate,
141
        moore: bool,
142
        include_center: bool = False,
143
        radius: int = 1,
144
    ) -> Iterator[Coordinate]:
145
        """Return an iterator over cell coordinates that are in the
146
        neighborhood of a certain point.
147

148
        Args:
149
            pos: Coordinate tuple for the neighborhood to get.
150
            moore: If True, return Moore neighborhood
151
                        (including diagonals)
152
                   If False, return Von Neumann neighborhood
153
                        (exclude diagonals)
154
            include_center: If True, return the (x, y) cell as well.
155
                            Otherwise, return surrounding cells only.
156
            radius: radius, in cells, of neighborhood to get.
157

158
        Returns:
159
            A list of coordinate tuples representing the neighborhood. For
160
            example with radius 1, it will return list with number of elements
161
            equals at most 9 (8) if Moore, 5 (4) if Von Neumann (if not
162
            including the center).
163

164
        """
165 3
        x, y = pos
166 3
        coordinates = set()  # type: Set[Coordinate]
167 3
        for dy in range(-radius, radius + 1):
168 3
            for dx in range(-radius, radius + 1):
169 3
                if dx == 0 and dy == 0 and not include_center:
170 3
                    continue
171
                # Skip coordinates that are outside manhattan distance
172 3
                if not moore and abs(dx) + abs(dy) > radius:
173 3
                    continue
174
                # Skip if not a torus and new coords out of bounds.
175 3
                if not self.torus and (
176
                    not (0 <= dx + x < self.width) or not (0 <= dy + y < self.height)
177
                ):
178 3
                    continue
179

180 3
                px, py = self.torus_adj((x + dx, y + dy))
181

182
                # Skip if new coords out of bounds.
183 3
                if self.out_of_bounds((px, py)):
184 0
                    continue
185

186 3
                coords = (px, py)
187 3
                if coords not in coordinates:
188 3
                    coordinates.add(coords)
189 3
                    yield coords
190

191 3
    def get_neighborhood(
192
        self,
193
        pos: Coordinate,
194
        moore: bool,
195
        include_center: bool = False,
196
        radius: int = 1,
197
    ) -> List[Coordinate]:
198
        """Return a list of cells that are in the neighborhood of a
199
        certain point.
200

201
        Args:
202
            pos: Coordinate tuple for the neighborhood to get.
203
            moore: If True, return Moore neighborhood
204
                   (including diagonals)
205
                   If False, return Von Neumann neighborhood
206
                   (exclude diagonals)
207
            include_center: If True, return the (x, y) cell as well.
208
                            Otherwise, return surrounding cells only.
209
            radius: radius, in cells, of neighborhood to get.
210

211
        Returns:
212
            A list of coordinate tuples representing the neighborhood;
213
            With radius 1, at most 9 if Moore, 5 if Von Neumann (8 and 4
214
            if not including the center).
215

216
        """
217 3
        return list(self.iter_neighborhood(pos, moore, include_center, radius))
218

219 3
    def iter_neighbors(
220
        self,
221
        pos: Coordinate,
222
        moore: bool,
223
        include_center: bool = False,
224
        radius: int = 1,
225
    ) -> Iterator[GridContent]:
226
        """Return an iterator over neighbors to a certain point.
227

228
        Args:
229
            pos: Coordinates for the neighborhood to get.
230
            moore: If True, return Moore neighborhood
231
                    (including diagonals)
232
                   If False, return Von Neumann neighborhood
233
                     (exclude diagonals)
234
            include_center: If True, return the (x, y) cell as well.
235
                            Otherwise,
236
                            return surrounding cells only.
237
            radius: radius, in cells, of neighborhood to get.
238

239
        Returns:
240
            An iterator of non-None objects in the given neighborhood;
241
            at most 9 if Moore, 5 if Von-Neumann
242
            (8 and 4 if not including the center).
243

244
        """
245 3
        neighborhood = self.iter_neighborhood(pos, moore, include_center, radius)
246 3
        return self.iter_cell_list_contents(neighborhood)
247

248 3
    def get_neighbors(
249
        self,
250
        pos: Coordinate,
251
        moore: bool,
252
        include_center: bool = False,
253
        radius: int = 1,
254
    ) -> List[Coordinate]:
255
        """Return a list of neighbors to a certain point.
256

257
        Args:
258
            pos: Coordinate tuple for the neighborhood to get.
259
            moore: If True, return Moore neighborhood
260
                    (including diagonals)
261
                   If False, return Von Neumann neighborhood
262
                     (exclude diagonals)
263
            include_center: If True, return the (x, y) cell as well.
264
                            Otherwise,
265
                            return surrounding cells only.
266
            radius: radius, in cells, of neighborhood to get.
267

268
        Returns:
269
            A list of non-None objects in the given neighborhood;
270
            at most 9 if Moore, 5 if Von-Neumann
271
            (8 and 4 if not including the center).
272

273
        """
274 3
        return list(self.iter_neighbors(pos, moore, include_center, radius))
275

276 3
    def torus_adj(self, pos: Coordinate) -> Coordinate:
277
        """ Convert coordinate, handling torus looping. """
278 3
        if not self.out_of_bounds(pos):
279 3
            return pos
280 3
        elif not self.torus:
281 3
            raise Exception("Point out of bounds, and space non-toroidal.")
282
        else:
283 3
            x, y = pos[0] % self.width, pos[1] % self.height
284 3
        return x, y
285

286 3
    def out_of_bounds(self, pos: Coordinate) -> bool:
287
        """
288
        Determines whether position is off the grid, returns the out of
289
        bounds coordinate.
290
        """
291 3
        x, y = pos
292 3
        return x < 0 or x >= self.width or y < 0 or y >= self.height
293

294 3
    @accept_tuple_argument
295 3
    def iter_cell_list_contents(
296
        self, cell_list: Iterable[Coordinate]
297
    ) -> Iterator[GridContent]:
298
        """
299
        Args:
300
            cell_list: Array-like of (x, y) tuples, or single tuple.
301

302
        Returns:
303
            An iterator of the contents of the cells identified in cell_list
304

305
        """
306 3
        return (self[x][y] for x, y in cell_list if not self.is_cell_empty((x, y)))
307

308 3
    @accept_tuple_argument
309 3
    def get_cell_list_contents(
310
        self, cell_list: Iterable[Coordinate]
311
    ) -> List[GridContent]:
312
        """
313
        Args:
314
            cell_list: Array-like of (x, y) tuples, or single tuple.
315

316
        Returns:
317
            A list of the contents of the cells identified in cell_list
318

319
        """
320 3
        return list(self.iter_cell_list_contents(cell_list))
321

322 3
    def move_agent(self, agent: Agent, pos: Coordinate) -> None:
323
        """
324
        Move an agent from its current position to a new position.
325

326
        Args:
327
            agent: Agent object to move. Assumed to have its current location
328
                   stored in a 'pos' tuple.
329
            pos: Tuple of new position to move the agent to.
330

331
        """
332 3
        pos = self.torus_adj(pos)
333 3
        self._remove_agent(agent.pos, agent)
334 3
        self._place_agent(pos, agent)
335 3
        agent.pos = pos
336

337 3
    def place_agent(self, agent: Agent, pos: Coordinate) -> None:
338
        """ Position an agent on the grid, and set its pos variable. """
339 3
        self._place_agent(pos, agent)
340 3
        agent.pos = pos
341

342 3
    def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
343
        """ Place the agent at the correct location. """
344 3
        x, y = pos
345 3
        self.grid[x][y] = agent
346 3
        self.empties.discard(pos)
347

348 3
    def remove_agent(self, agent: Agent) -> None:
349
        """ Remove the agent from the grid and set its pos variable to None. """
350 3
        pos = agent.pos
351 3
        self._remove_agent(pos, agent)
352 3
        agent.pos = None
353

354 3
    def _remove_agent(self, pos: Coordinate, agent: Agent) -> None:
355
        """ Remove the agent from the given location. """
356 3
        x, y = pos
357 3
        self.grid[x][y] = None
358 3
        self.empties.add(pos)
359

360 3
    def is_cell_empty(self, pos: Coordinate) -> bool:
361
        """ Returns a bool of the contents of a cell. """
362 3
        x, y = pos
363 3
        return self.grid[x][y] == self.default_val()
364

365 3
    def move_to_empty(self, agent: Agent) -> None:
366
        """ Moves agent to a random empty cell, vacating agent's old cell. """
367 3
        pos = agent.pos
368 3
        if len(self.empties) == 0:
369 3
            raise Exception("ERROR: No empty cells")
370 3
        new_pos = agent.random.choice(sorted(self.empties))
371 3
        self._place_agent(new_pos, agent)
372 3
        agent.pos = new_pos
373 3
        self._remove_agent(pos, agent)
374

375 3
    def find_empty(self) -> Optional[Coordinate]:
376
        """ Pick a random empty cell. """
377 3
        from warnings import warn
378 3
        import random
379

380 3
        warn(
381
            (
382
                "`find_empty` is being phased out since it uses the global "
383
                "`random` instead of the model-level random-number generator. "
384
                "Consider replacing it with having a model or agent object "
385
                "explicitly pick one of the grid's list of empty cells."
386
            ),
387
            DeprecationWarning,
388
        )
389

390 3
        if self.exists_empty_cells():
391 3
            pos = random.choice(sorted(self.empties))
392 3
            return pos
393
        else:
394 3
            return None
395

396 3
    def exists_empty_cells(self) -> bool:
397
        """ Return True if any cells empty else False. """
398 3
        return len(self.empties) > 0
399

400

401 3
class SingleGrid(Grid):
402
    """ Grid where each cell contains exactly at most one object. """
403

404 3
    empties = set()  # type: Set[Coordinate]
405

406 3
    def __init__(self, width: int, height: int, torus: bool) -> None:
407
        """Create a new single-item grid.
408

409
        Args:
410
            width, height: The width and width of the grid
411
            torus: Boolean whether the grid wraps or not.
412

413
        """
414 3
        super().__init__(width, height, torus)
415

416 3
    def position_agent(
417
        self, agent: Agent, x: Union[int, str] = "random", y: Union[int, str] = "random"
418
    ) -> None:
419
        """Position an agent on the grid.
420
        This is used when first placing agents! Use 'move_to_empty()'
421
        when you want agents to jump to an empty cell.
422
        Use 'swap_pos()' to swap agents positions.
423
        If x or y are positive, they are used, but if "random",
424
        we get a random position.
425
        Ensure this random position is not occupied (in Grid).
426

427
        """
428 3
        if x == "random" or y == "random":
429 3
            if len(self.empties) == 0:
430 3
                raise Exception("ERROR: Grid full")
431 3
            coords = agent.random.choice(sorted(self.empties))
432
        else:
433 3
            coords = (x, y)
434 3
        agent.pos = coords
435 3
        self._place_agent(coords, agent)
436

437 3
    def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
438 3
        if self.is_cell_empty(pos):
439 3
            super()._place_agent(pos, agent)
440
        else:
441 3
            raise Exception("Cell not empty")
442

443

444 3
class MultiGrid(Grid):
445
    """Grid where each cell can contain more than one object.
446

447
    Grid cells are indexed by [x][y], where [0][0] is assumed to be at
448
    bottom-left and [width-1][height-1] is the top-right. If a grid is
449
    toroidal, the top and bottom, and left and right, edges wrap to each other.
450

451
    Each grid cell holds a set object.
452

453
    Properties:
454
        width, height: The grid's width and height.
455

456
        torus: Boolean which determines whether to treat the grid as a torus.
457

458
        grid: Internal list-of-lists which holds the grid cells themselves.
459

460
    Methods:
461
        get_neighbors: Returns the objects surrounding a given cell.
462
    """
463

464 3
    @staticmethod
465 3
    def default_val() -> Set[Agent]:
466
        """ Default value for new cell elements. """
467 3
        return []
468

469 3
    def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
470
        """ Place the agent at the correct location. """
471 3
        x, y = pos
472 3
        if agent not in self.grid[x][y]:
473 3
            self.grid[x][y].append(agent)
474 3
        self.empties.discard(pos)
475

476 3
    def _remove_agent(self, pos: Coordinate, agent: Agent) -> None:
477
        """ Remove the agent from the given location. """
478 3
        x, y = pos
479 3
        self.grid[x][y].remove(agent)
480 3
        if self.is_cell_empty(pos):
481 3
            self.empties.add(pos)
482

483 3
    @accept_tuple_argument
484 3
    def iter_cell_list_contents(
485
        self, cell_list: Iterable[Coordinate]
486
    ) -> Iterator[GridContent]:
487
        """
488
        Args:
489
            cell_list: Array-like of (x, y) tuples, or single tuple.
490

491
        Returns:
492
            A iterator of the contents of the cells identified in cell_list
493

494
        """
495 3
        return itertools.chain.from_iterable(
496
            self[x][y] for x, y in cell_list if not self.is_cell_empty((x, y))
497
        )
498

499

500 3
class HexGrid(Grid):
501
    """Hexagonal Grid: Extends Grid to handle hexagonal neighbors.
502

503
    Functions according to odd-q rules.
504
    See http://www.redblobgames.com/grids/hexagons/#coordinates for more.
505

506
    Properties:
507
        width, height: The grid's width and height.
508
        torus: Boolean which determines whether to treat the grid as a torus.
509

510
    Methods:
511
        get_neighbors: Returns the objects surrounding a given cell.
512
        get_neighborhood: Returns the cells surrounding a given cell.
513
        neighbor_iter: Iterates over position neightbors.
514
        iter_neighborhood: Returns an iterator over cell coordinates that are
515
            in the neighborhood of a certain point.
516

517
    """
518

519 3
    def iter_neighborhood(
520
        self, pos: Coordinate, include_center: bool = False, radius: int = 1
521
    ) -> Iterator[Coordinate]:
522
        """Return an iterator over cell coordinates that are in the
523
        neighborhood of a certain point.
524

525
        Args:
526
            pos: Coordinate tuple for the neighborhood to get.
527
            include_center: If True, return the (x, y) cell as well.
528
                            Otherwise, return surrounding cells only.
529
            radius: radius, in cells, of neighborhood to get.
530

531
        Returns:
532
            A list of coordinate tuples representing the neighborhood. For
533
            example with radius 1, it will return list with number of elements
534
            equals at most 9 (8) if Moore, 5 (4) if Von Neumann (if not
535
            including the center).
536

537
        """
538

539 3
        def torus_adj_2d(pos: Coordinate) -> Coordinate:
540 3
            return (pos[0] % self.width, pos[1] % self.height)
541

542 3
        coordinates = set()
543

544 3
        def find_neighbors(pos: Coordinate, radius: int) -> None:
545 3
            x, y = pos
546

547
            """
548
            Both: (0,-), (0,+)
549

550
            Even: (-,+), (-,0), (+,+), (+,0)
551
            Odd:  (-,0), (-,-), (+,0), (+,-)
552
            """
553 3
            adjacent = [(x, y - 1), (x, y + 1)]
554

555 3
            if include_center:
556 3
                adjacent.append(pos)
557

558 3
            if x % 2 == 0:
559 3
                adjacent += [(x - 1, y + 1), (x - 1, y), (x + 1, y + 1), (x + 1, y)]
560
            else:
561 3
                adjacent += [(x - 1, y), (x - 1, y - 1), (x + 1, y), (x + 1, y - 1)]
562

563 3
            if self.torus is False:
564 3
                adjacent = list(
565
                    filter(lambda coords: not self.out_of_bounds(coords), adjacent)
566
                )
567
            else:
568 3
                adjacent = [torus_adj_2d(coord) for coord in adjacent]
569

570 3
            coordinates.update(adjacent)
571

572 3
            if radius > 1:
573 3
                [find_neighbors(coords, radius - 1) for coords in adjacent]
574

575 3
        find_neighbors(pos, radius)
576

577 3
        if not include_center and pos in coordinates:
578 0
            coordinates.remove(pos)
579

580 3
        for i in coordinates:
581 3
            yield i
582

583 3
    def neighbor_iter(self, pos: Coordinate) -> Iterator[GridContent]:
584
        """Iterate over position neighbors.
585

586
        Args:
587
            pos: (x,y) coords tuple for the position to get the neighbors of.
588

589
        """
590 3
        neighborhood = self.iter_neighborhood(pos)
591 3
        return self.iter_cell_list_contents(neighborhood)
592

593 3
    def get_neighborhood(
594
        self, pos: Coordinate, include_center: bool = False, radius: int = 1
595
    ) -> List[Coordinate]:
596
        """Return a list of cells that are in the neighborhood of a
597
        certain point.
598

599
        Args:
600
            pos: Coordinate tuple for the neighborhood to get.
601
            include_center: If True, return the (x, y) cell as well.
602
                            Otherwise, return surrounding cells only.
603
            radius: radius, in cells, of neighborhood to get.
604

605
        Returns:
606
            A list of coordinate tuples representing the neighborhood;
607
            With radius 1
608

609
        """
610 3
        return list(self.iter_neighborhood(pos, include_center, radius))
611

612 3
    def iter_neighbors(
613
        self, pos: Coordinate, include_center: bool = False, radius: int = 1
614
    ) -> Iterator[GridContent]:
615
        """Return an iterator over neighbors to a certain point.
616

617
        Args:
618
            pos: Coordinates for the neighborhood to get.
619
            include_center: If True, return the (x, y) cell as well.
620
                            Otherwise,
621
                            return surrounding cells only.
622
            radius: radius, in cells, of neighborhood to get.
623

624
        Returns:
625
            An iterator of non-None objects in the given neighborhood
626

627
        """
628 0
        neighborhood = self.iter_neighborhood(pos, include_center, radius)
629 0
        return self.iter_cell_list_contents(neighborhood)
630

631 3
    def get_neighbors(
632
        self, pos: Coordinate, include_center: bool = False, radius: int = 1
633
    ) -> List[Coordinate]:
634
        """Return a list of neighbors to a certain point.
635

636
        Args:
637
            pos: Coordinate tuple for the neighborhood to get.
638
            include_center: If True, return the (x, y) cell as well.
639
                            Otherwise,
640
                            return surrounding cells only.
641
            radius: radius, in cells, of neighborhood to get.
642

643
        Returns:
644
            A list of non-None objects in the given neighborhood
645

646
        """
647 0
        return list(self.iter_neighbors(pos, include_center, radius))
648

649

650 3
class ContinuousSpace:
651
    """Continuous space where each agent can have an arbitrary position.
652

653
    Assumes that all agents are point objects, and have a pos property storing
654
    their position as an (x, y) tuple. This class uses a numpy array internally
655
    to store agent objects, to speed up neighborhood lookups.
656

657
    """
658

659 3
    _grid = None
660

661 3
    def __init__(
662
        self,
663
        x_max: float,
664
        y_max: float,
665
        torus: bool,
666
        x_min: float = 0,
667
        y_min: float = 0,
668
    ) -> None:
669
        """Create a new continuous space.
670

671
        Args:
672
            x_max, y_max: Maximum x and y coordinates for the space.
673
            torus: Boolean for whether the edges loop around.
674
            x_min, y_min: (default 0) If provided, set the minimum x and y
675
                          coordinates for the space. Below them, values loop to
676
                          the other edge (if torus=True) or raise an exception.
677

678
        """
679 3
        self.x_min = x_min
680 3
        self.x_max = x_max
681 3
        self.width = x_max - x_min
682 3
        self.y_min = y_min
683 3
        self.y_max = y_max
684 3
        self.height = y_max - y_min
685 3
        self.center = np.array(((x_max + x_min) / 2, (y_max + y_min) / 2))
686 3
        self.size = np.array((self.width, self.height))
687 3
        self.torus = torus
688

689 3
        self._agent_points = None
690 3
        self._index_to_agent = {}  # type: Dict[int, Agent]
691 3
        self._agent_to_index = {}  # type: Dict[Agent, int]
692

693 3
    def place_agent(self, agent: Agent, pos: FloatCoordinate) -> None:
694
        """Place a new agent in the space.
695

696
        Args:
697
            agent: Agent object to place.
698
            pos: Coordinate tuple for where to place the agent.
699

700
        """
701 3
        pos = self.torus_adj(pos)
702 3
        if self._agent_points is None:
703 3
            self._agent_points = np.array([pos])
704
        else:
705 3
            self._agent_points = np.append(self._agent_points, np.array([pos]), axis=0)
706 3
        self._index_to_agent[self._agent_points.shape[0] - 1] = agent
707 3
        self._agent_to_index[agent] = self._agent_points.shape[0] - 1
708 3
        agent.pos = pos
709

710 3
    def move_agent(self, agent: Agent, pos: FloatCoordinate) -> None:
711
        """Move an agent from its current position to a new position.
712

713
        Args:
714
            agent: The agent object to move.
715
            pos: Coordinate tuple to move the agent to.
716

717
        """
718 3
        pos = self.torus_adj(pos)
719 3
        idx = self._agent_to_index[agent]
720 3
        self._agent_points[idx, 0] = pos[0]
721 3
        self._agent_points[idx, 1] = pos[1]
722 3
        agent.pos = pos
723

724 3
    def remove_agent(self, agent: Agent) -> None:
725
        """Remove an agent from the simulation.
726

727
        Args:
728
            agent: The agent object to remove
729
        """
730 3
        if agent not in self._agent_to_index:
731 3
            raise Exception("Agent does not exist in the space")
732 3
        idx = self._agent_to_index[agent]
733 3
        del self._agent_to_index[agent]
734 3
        max_idx = max(self._index_to_agent.keys())
735
        # Delete the agent's position and decrement the index/agent mapping
736 3
        self._agent_points = np.delete(self._agent_points, idx, axis=0)
737 3
        for a, index in self._agent_to_index.items():
738 3
            if index > idx:
739 3
                self._agent_to_index[a] = index - 1
740 3
                self._index_to_agent[index - 1] = a
741
        # The largest index is now redundant
742 3
        del self._index_to_agent[max_idx]
743 3
        agent.pos = None
744

745 3
    def get_neighbors(
746
        self, pos: FloatCoordinate, radius: float, include_center: bool = True
747
    ) -> List[GridContent]:
748
        """Get all objects within a certain radius.
749

750
        Args:
751
            pos: (x,y) coordinate tuple to center the search at.
752
            radius: Get all the objects within this distance of the center.
753
            include_center: If True, include an object at the *exact* provided
754
                            coordinates. i.e. if you are searching for the
755
                            neighbors of a given agent, True will include that
756
                            agent in the results.
757

758
        """
759 3
        deltas = np.abs(self._agent_points - np.array(pos))
760 3
        if self.torus:
761 3
            deltas = np.minimum(deltas, self.size - deltas)
762 3
        dists = deltas[:, 0] ** 2 + deltas[:, 1] ** 2
763

764 3
        (idxs,) = np.where(dists <= radius ** 2)
765 3
        neighbors = [
766
            self._index_to_agent[x] for x in idxs if include_center or dists[x] > 0
767
        ]
768 3
        return neighbors
769

770 3
    def get_heading(
771
        self, pos_1: FloatCoordinate, pos_2: FloatCoordinate
772
    ) -> FloatCoordinate:
773
        """Get the heading angle between two points, accounting for toroidal space.
774

775
        Args:
776
            pos_1, pos_2: Coordinate tuples for both points.
777
        """
778 3
        one = np.array(pos_1)
779 3
        two = np.array(pos_2)
780 3
        if self.torus:
781 3
            one = (one - self.center) % self.size
782 3
            two = (two - self.center) % self.size
783 3
        heading = two - one
784 3
        if isinstance(pos_1, tuple):
785 3
            heading = tuple(heading)
786 3
        return heading
787

788 3
    def get_distance(self, pos_1: FloatCoordinate, pos_2: FloatCoordinate) -> float:
789
        """Get the distance between two point, accounting for toroidal space.
790

791
        Args:
792
            pos_1, pos_2: Coordinate tuples for both points.
793

794
        """
795 3
        x1, y1 = pos_1
796 3
        x2, y2 = pos_2
797

798 3
        dx = np.abs(x1 - x2)
799 3
        dy = np.abs(y1 - y2)
800 3
        if self.torus:
801 3
            dx = min(dx, self.width - dx)
802 3
            dy = min(dy, self.height - dy)
803 3
        return np.sqrt(dx * dx + dy * dy)
804

805 3
    def torus_adj(self, pos: FloatCoordinate) -> FloatCoordinate:
806
        """Adjust coordinates to handle torus looping.
807

808
        If the coordinate is out-of-bounds and the space is toroidal, return
809
        the corresponding point within the space. If the space is not toroidal,
810
        raise an exception.
811

812
        Args:
813
            pos: Coordinate tuple to convert.
814

815
        """
816 3
        if not self.out_of_bounds(pos):
817 3
            return pos
818 3
        elif not self.torus:
819 3
            raise Exception("Point out of bounds, and space non-toroidal.")
820
        else:
821 3
            x = self.x_min + ((pos[0] - self.x_min) % self.width)
822 3
            y = self.y_min + ((pos[1] - self.y_min) % self.height)
823 3
            if isinstance(pos, tuple):
824 3
                return (x, y)
825
            else:
826 3
                return np.array((x, y))
827

828 3
    def out_of_bounds(self, pos: FloatCoordinate) -> bool:
829
        """ Check if a point is out of bounds. """
830 3
        x, y = pos
831 3
        return x < self.x_min or x >= self.x_max or y < self.y_min or y >= self.y_max
832

833

834 3
class NetworkGrid:
835
    """ Network Grid where each node contains zero or more agents. """
836

837 3
    def __init__(self, G: Any) -> None:
838 3
        self.G = G
839 3
        for node_id in self.G.nodes:
840 3
            G.nodes[node_id]["agent"] = list()
841

842 3
    def place_agent(self, agent: Agent, node_id: int) -> None:
843
        """ Place a agent in a node. """
844

845 3
        self._place_agent(agent, node_id)
846 3
        agent.pos = node_id
847

848 3
    def get_neighbors(self, node_id: int, include_center: bool = False) -> List[int]:
849
        """ Get all adjacent nodes """
850

851 3
        neighbors = list(self.G.neighbors(node_id))
852 3
        if include_center:
853 3
            neighbors.append(node_id)
854

855 3
        return neighbors
856

857 3
    def move_agent(self, agent: Agent, node_id: int) -> None:
858
        """ Move an agent from its current node to a new node. """
859

860 3
        self._remove_agent(agent, agent.pos)
861 3
        self._place_agent(agent, node_id)
862 3
        agent.pos = node_id
863

864 3
    def _place_agent(self, agent: Agent, node_id: int) -> None:
865
        """ Place the agent at the correct node. """
866

867 3
        self.G.nodes[node_id]["agent"].append(agent)
868

869 3
    def _remove_agent(self, agent: Agent, node_id: int) -> None:
870
        """ Remove an agent from a node. """
871

872 3
        self.G.nodes[node_id]["agent"].remove(agent)
873

874 3
    def is_cell_empty(self, node_id: int) -> bool:
875
        """ Returns a bool of the contents of a cell. """
876 3
        return not self.G.nodes[node_id]["agent"]
877

878 3
    def get_cell_list_contents(self, cell_list: List[int]) -> List[GridContent]:
879 3
        return list(self.iter_cell_list_contents(cell_list))
880

881 3
    def get_all_cell_contents(self) -> List[GridContent]:
882 3
        return list(self.iter_cell_list_contents(self.G))
883

884 3
    def iter_cell_list_contents(self, cell_list: List[int]) -> List[GridContent]:
885 3
        list_of_lists = [
886
            self.G.nodes[node_id]["agent"]
887
            for node_id in cell_list
888
            if not self.is_cell_empty(node_id)
889
        ]
890 3
        return [item for sublist in list_of_lists for item in sublist]

Read our documentation on viewing source code .

Loading