Skip to content

Commit

Permalink
Fixes to script
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieltseng committed Nov 8, 2023
1 parent 758189e commit 99d89d5
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions benchmarks/train_on_subset_of_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,24 @@
from cropharvest.datasets import CropHarvest, Task, CropHarvestLabels
from cropharvest.engineer import TestInstance
from cropharvest.columns import RequiredColumns
from cropharvest.bbox import BBox

from sklearn.ensemble import RandomForestClassifier


def select_points(evaluation_dataset: CropHarvest, all_labels: CropHarvestLabels) -> pd.DataFrame:
def select_points(bounding_box: BBox, all_labels: CropHarvestLabels) -> pd.DataFrame:
"""
This is what participants would implement.
Given an evaluation dataset, they would need
Given an evaluation dataset's bounding box, they would need
to implement some method of selecting points against
which a model will be trained. In this example code, I
show two examples - one uses bounding boxes, the other directly
overwrites the CropHarvestLabels geojson
which a model will be trained
"""
# let's manually select points in the labels to be within the bounding box.
# 1. Make a new geojson. We do it according to the bounding boxes but this could be done
# in any way
filtered_geojson = all_labels.filter_geojson(
all_labels.as_geojson(),
evaluation_dataset.task.bounding_box,
bounding_box,
include_external_contributions=True,
)

Expand Down Expand Up @@ -91,8 +90,8 @@ def main():
results_folder = DATAFOLDER_PATH / "data_centric_test"
results_folder.mkdir(exist_ok=True)

togo_eval = [x for x in evaluation_datasets if "Togo" in x.task.name]
training_points_df = select_points(togo_eval, all_labels)
togo_eval = [x for x in evaluation_datasets if "Togo" in x.task.id][0]
training_points_df = select_points(togo_eval.task.bounding_box, all_labels)
train_and_eval(training_points_df, togo_eval, results_folder)


Expand Down

0 comments on commit 99d89d5

Please sign in to comment.