Skip to content

Commit

Permalink
Add select_agents method to Model
Browse files Browse the repository at this point in the history
This commit introduces a new feature to the Mesa Agent-Based Modeling framework: a select_agents method. This enhancement is inspired by NetLogo's n-of functions and a smart idea from @rht as discussed in projectmesa#1894. It aims to provide a flexible and efficient way to select and filter agents in a model, based on a wide range of criteria.

#### Features
The `select_agents` method offers several key functionalities:

1. **Selective Quantity (`n`):** Specify the number of agents to select. If `n` is omitted, all matching agents are chosen.

2. **Criteria-Based Sorting (`sort` and `direction`):** Sort agents based on one or more attributes. The sorting order can be set individually for each attribute.

3. **Functional Filtering (`filter`):** Use a lambda function or a callable to apply complex filter conditions.

4. **Type-Based Filtering (`agent_type`):** Filter agents by their class type, allowing for selection among specific subclasses.

5. **Flexible Size Handling (`up_to`):** When `True`, the method returns up to `n` agents, which is useful when the available agent count is less than `n`.
  • Loading branch information
EwoutH committed Dec 18, 2023
1 parent 9495a5a commit 3ce1afe
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 1 deletion.
73 changes: 72 additions & 1 deletion mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections import defaultdict

# mypy
from typing import Any
from typing import Any, Callable

from mesa.datacollection import DataCollector

Expand Down Expand Up @@ -106,3 +106,74 @@ def initialize_data_collector(
)
# Collect data for the first time during initialization.
self.datacollector.collect(self)

def select_agents(
self,
n: int | None = None,
sort: list[str] | None = None,
direction: list[str] | None = None,
filter_func: Callable[[Any], bool] | None = None,
agent_type: type[Any] | list[type[Any]] | None = None,
up_to: bool = True,
) -> list[Any]:
"""
Select agents based on various criteria including type, attributes, and custom filters.
Args:
n: Number of agents to select.
sort: Attributes to sort by.
direction: Sort direction for each attribute in `sort`.
filter_func: A callable to further filter agents.
agent_type: Type(s) of agents to include.
up_to: If True, allows returning up to `n` agents.
Returns:
A list of selected agents.
"""

# If agent_type is specified, fetch only those agents; otherwise, fetch all
if agent_type:
if not isinstance(agent_type, list):
agent_type = [agent_type]
agent_type_set = set(agent_type)
agents_iter = (
agent
for type_key, agents in self.agents.items()
if type_key in agent_type_set
for agent in agents
)
else:
agents_iter = (agent for agents in self.agents.values() for agent in agents)

# Apply functional filter if provided
if filter_func:
agents_iter = filter(filter_func, agents_iter)

# Convert to list if sorting is needed or n is specified
if sort and direction or n is not None:
agents_iter = list(agents_iter)

# If only a specific number of agents is needed without sorting, limit early
if n is not None and not (sort and direction):
agents_iter = (
agents_iter[: min(n, len(agents_iter))] if up_to else agents_iter[:n]
)

# Sort agents if needed
if sort and direction:

def sort_key(agent):
return tuple(
getattr(agent, attr)
if dir.lower() == "lowest"
else -getattr(agent, attr)
for attr, dir in zip(sort, direction)
)

agents_iter.sort(key=sort_key)

# Select the desired number of agents after sorting
if n is not None and sort and direction:
return agents_iter[: min(n, len(agents_iter))] if up_to else agents_iter[:n]

return list(agents_iter)
106 changes: 106 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from mesa.agent import Agent
from mesa.model import Model

Expand Down Expand Up @@ -51,3 +53,107 @@ class TestAgent(Agent):
test_agent = TestAgent(model.next_id(), model)
assert test_agent in model.agents[type(test_agent)]
assert type(test_agent) in model.agent_types


class TestSelectAgents:
class MockAgent(Agent):
def __init__(self, unique_id, model, type_id, age, wealth):
super().__init__(unique_id, model)
self.type_id = type_id
self.age = age
self.wealth = wealth

@pytest.fixture
def model_with_agents(self):
model = Model()
for i in range(20):
self.MockAgent(i, model, type_id=i % 2, age=i + 20, wealth=100 - i * 2)
return model

def test_basic_selection(self, model_with_agents):
selected_agents = model_with_agents.select_agents()
assert len(selected_agents) == 20

def test_selection_with_n(self, model_with_agents):
selected_agents = model_with_agents.select_agents(n=5)
assert len(selected_agents) == 5

def test_sorting_and_direction(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
n=3, sort=["wealth"], direction=["highest"]
)
assert [agent.wealth for agent in selected_agents] == [100, 98, 96]

def test_functional_filtering(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
filter_func=lambda agent: agent.age > 30
)
assert all(agent.age > 30 for agent in selected_agents)

def test_type_filtering(self, model_with_agents):
selected_agents = model_with_agents.select_agents(agent_type=self.MockAgent)
assert all(isinstance(agent, self.MockAgent) for agent in selected_agents)

def test_up_to_flag(self, model_with_agents):
selected_agents = model_with_agents.select_agents(n=50, up_to=True)
assert len(selected_agents) == 20

def test_edge_case_empty_model(self):
empty_model = Model()
selected_agents = empty_model.select_agents()
assert len(selected_agents) == 0

def test_error_handling_invalid_sort(self, model_with_agents):
with pytest.raises(AttributeError):
model_with_agents.select_agents(
n=3, sort=["nonexistent_attribute"], direction=["highest"]
)

def test_sorting_with_multiple_criteria(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
n=3, sort=["type_id", "age"], direction=["lowest", "highest"]
)
assert [(agent.type_id, agent.age) for agent in selected_agents] == [
(0, 38),
(0, 36),
(0, 34),
]

def test_direction_with_multiple_criteria(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
n=3, sort=["type_id", "wealth"], direction=["highest", "lowest"]
)
assert [(agent.type_id, agent.wealth) for agent in selected_agents] == [
(1, 62),
(1, 66),
(1, 70),
]

def test_type_filtering_with_multiple_types(self, model_with_agents):
class AnotherMockAgent(Agent):
pass

# Adding different type agents to the model
for i in range(20, 25):
AnotherMockAgent(i, model_with_agents)

selected_agents = model_with_agents.select_agents(
agent_type=[self.MockAgent, AnotherMockAgent]
)
assert len(selected_agents) == 25

def test_selection_when_n_exceeds_agent_count(self, model_with_agents):
selected_agents = model_with_agents.select_agents(n=50)
assert len(selected_agents) == 20

def test_inverse_functional_filtering(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
filter_func=lambda agent: agent.age < 25
)
assert all(agent.age < 25 for agent in selected_agents)

def test_complex_lambda_in_filter(self, model_with_agents):
selected_agents = model_with_agents.select_agents(
filter_func=lambda agent: agent.age > 25 and agent.wealth > 70
)
assert all(agent.age > 25 and agent.wealth > 70 for agent in selected_agents)

0 comments on commit 3ce1afe

Please sign in to comment.