From d63ce0680fb4da398879252f2ebef9c560eaf2f4 Mon Sep 17 00:00:00 2001 From: rht Date: Mon, 14 Oct 2024 05:59:03 -0400 Subject: [PATCH] refactor: Simplify Schelling code (#222) * refactor: Simplify Schelling code 1. Remove unused model attributes 2. Make `similar` calculation more natural language readable * Remove unused argument doc * Add type hints to agent class * refactor: Simplify self.running expression --- examples/schelling/model.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/examples/schelling/model.py b/examples/schelling/model.py index e995f31e..b7523ef2 100644 --- a/examples/schelling/model.py +++ b/examples/schelling/model.py @@ -6,24 +6,21 @@ class SchellingAgent(mesa.Agent): Schelling segregation agent """ - def __init__(self, model, agent_type): + def __init__(self, model: mesa.Model, agent_type: int) -> None: """ Create a new Schelling agent. Args: - x, y: Agent initial location. agent_type: Indicator for the agent's type (minority=1, majority=0) """ super().__init__(model) self.type = agent_type - def step(self): - similar = 0 - for neighbor in self.model.grid.iter_neighbors( + def step(self) -> None: + neighbors = self.model.grid.iter_neighbors( self.pos, moore=True, radius=self.model.radius - ): - if neighbor.type == self.type: - similar += 1 + ) + similar = sum(1 for neighbor in neighbors if neighbor.type == self.type) # If unhappy, move: if similar < self.model.homophily: @@ -60,10 +57,6 @@ def __init__( """ super().__init__(seed=seed) - self.height = height - self.width = width - self.density = density - self.minority_pc = minority_pc self.homophily = homophily self.radius = radius @@ -79,8 +72,8 @@ def __init__( # the coordinates of a cell as well as # its contents. (coord_iter) for _, pos in self.grid.coord_iter(): - if self.random.random() < self.density: - agent_type = 1 if self.random.random() < self.minority_pc else 0 + if self.random.random() < density: + agent_type = 1 if self.random.random() < minority_pc else 0 agent = SchellingAgent(self, agent_type) self.grid.place_agent(agent, pos) @@ -95,5 +88,4 @@ def step(self): self.datacollector.collect(self) - if self.happy == len(self.agents): - self.running = False + self.running = self.happy != len(self.agents)