-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Let
find_nearest_nodes()
locate more than 1 nodes
- Loading branch information
Showing
2 changed files
with
91 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,49 @@ | ||
from __future__ import annotations | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from pyposeidon.utils.cpoint import find_nearest_nodes | ||
from pyposeidon.utils.cpoint import get_ball_tree | ||
|
||
EXPECTED_COLUMNS = ["lon", "lat", "id", "mesh_index", "mesh_lon", "mesh_lat", "distance"] | ||
|
||
|
||
def test_find_nearest_nodes(): | ||
mesh_nodes = pd.DataFrame({ | ||
@pytest.fixture(scope="session") | ||
def mesh_nodes(): | ||
return pd.DataFrame({ | ||
"lon": [0, 10, 20], | ||
"lat": [0, 5, 0], | ||
}) | ||
points = pd.DataFrame({ | ||
"lon": [1, 11, 21], | ||
"lat": [1, 4, 1], | ||
"id": ["a", "b", "c"], | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def points(): | ||
return pd.DataFrame({ | ||
"lon": [1, 11, 21, 2], | ||
"lat": [1, 4, 1, 2], | ||
"id": ["a", "b", "c", "d"], | ||
}) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def ball_tree(mesh_nodes): | ||
return get_ball_tree(mesh_nodes) | ||
|
||
|
||
def test_find_nearest_nodes(mesh_nodes, points): | ||
nearest_nodes = find_nearest_nodes(mesh_nodes, points) | ||
assert isinstance(nearest_nodes, pd.DataFrame) | ||
assert len(nearest_nodes) == 3 | ||
assert nearest_nodes.columns.tolist() == ["lon", "lat", "id", "mesh_index", "mesh_lon", "mesh_lat", "distance"] | ||
assert nearest_nodes.mesh_index.tolist() == [0, 1, 2] | ||
assert len(nearest_nodes) == len(points) | ||
assert nearest_nodes.columns.tolist() == EXPECTED_COLUMNS | ||
assert nearest_nodes.mesh_index.tolist() == [0, 1, 2, 0] | ||
assert nearest_nodes.distance.min() > 150_000 | ||
assert nearest_nodes.distance.max() < 160_000 | ||
assert nearest_nodes.distance.max() < 320_000 | ||
|
||
|
||
@pytest.mark.parametrize("k", [pytest.param(2, id='2 points'), pytest.param(3, id='3 points')]) | ||
def test_find_nearest_nodes_multiple_points_and_pass_tree_as_argument(mesh_nodes, points, k, ball_tree): | ||
nearest_nodes = find_nearest_nodes(mesh_nodes, points, k=k, tree=ball_tree) | ||
assert isinstance(nearest_nodes, pd.DataFrame) | ||
assert len(nearest_nodes) == len(points) * k | ||
assert nearest_nodes.columns.tolist() == EXPECTED_COLUMNS |