Skip to content

Commit

Permalink
Merge pull request #831 from GabrielKS/fix_confidence_threshold
Browse files Browse the repository at this point in the history
Get confidenceThreshold from config file
  • Loading branch information
shankari authored Aug 11, 2021
2 parents ed2adbb + 9206299 commit 6126201
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _mode_matches_date(mode, trip_date):

def _get_collection_mode_by_schedule(trip):
trip_tz = trip["data"]["end_local_dt"]["timezone"]
trip_date = arrow.get(trip["data"]["end_ts"], tz=trip_tz)
trip_date = arrow.get(trip["data"]["end_ts"]).to(trip_tz)
for mode in _config["modes"]:
if _mode_matches_date(mode, trip_date): return mode
raise ValueError("Trip date does not match any modes; this means the config file lacks a schedule-less mode")
Expand Down
2 changes: 2 additions & 0 deletions emission/analysis/userinput/expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def _process_and_save_trip(user_id, inferred_trip, ts):
if _test_options["preprocess_trip"] is not None: _test_options["preprocess_trip"](expected_trip)

expectation = _get_expectation_for_trip(expected_trip)
confidence_threshold = eace.get_confidence_threshold(expected_trip)
# For now, I don't think it's necessary to save each expectation as its own database entry

expected_trip["data"]["inferred_trip"] = inferred_trip.get_id()
expected_trip["data"]["expectation"] = expectation
expected_trip["data"]["confidence_threshold"] = confidence_threshold
ts.insert(expected_trip)

# This is a placeholder. TODO: implement the real algorithm
Expand Down
1 change: 1 addition & 0 deletions emission/core/wrapper/confirmedtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Confirmedtrip(ecwt.Trip):
"inferred_labels": ecwb.WrapperBase.Access.WORM,
"inferred_trip": ecwb.WrapperBase.Access.WORM,
"expectation": ecwb.WrapperBase.Access.WORM,
"confidence_threshold": ecwb.WrapperBase.Access.WORM,
"expected_trip": ecwb.WrapperBase.Access.WORM,
# the confirmed section that is the "primary"
# https://github.com/e-mission/e-mission-docs/issues/476#issuecomment-738120752
Expand Down
3 changes: 2 additions & 1 deletion emission/core/wrapper/expectedtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class Expectedtrip(ecwt.Trip):
"cleaned_trip": ecwb.WrapperBase.Access.WORM,
"inferred_labels": ecwb.WrapperBase.Access.WORM,
"inferred_trip": ecwb.WrapperBase.Access.WORM,
"expectation": ecwb.WrapperBase.Access.WORM
"expectation": ecwb.WrapperBase.Access.WORM,
"confidence_threshold": ecwb.WrapperBase.Access.WORM,
})

def _populateDependencies(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def setUp(self):
}
eace.reload_config()

self.tz = "America/Chicago"
self.tz = "Etc/GMT-8"
# Note that these depend on certain values in expectations.conf.json.sample, as will other values later
self.test_dates = {
"before_intensive": arrow.get("2021-05-01T20:00:00.000", tzinfo=self.tz), # Ensures schedules can't apply before their start_date
Expand Down
43 changes: 34 additions & 9 deletions emission/tests/pipelineTests/TestExpectationPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,19 @@ class TestExpectationPipeline(unittest.TestCase):
960: arrow.get("2023-02-12T20:00:00.000", tzinfo=tz)
}

@staticmethod
def fingerprint(trip):
# See eacilp.placeholder_predictor_2 for an explanation of the "fingerprint" technique
return trip["data"]["start_local_dt"]["hour"]*60+trip["data"]["start_local_dt"]["minute"]

def setUp(self):
self.test_options_stash = eace._test_options
eace._test_options = {
"use_sample": True,
"override_keylist": None
}
eace.reload_config()

np.random.seed(61297777)
self.reset_all()
etc.setupRealExample(self, "emission/tests/data/real_examples/shankari_2015-07-22")
Expand All @@ -38,6 +50,9 @@ def setUp(self):

def tearDown(self):
self.reset_all()

eace._test_options = self.test_options_stash
eace.reload_config()

def run_pipeline(self, algorithms):
primary_algorithms_stash = eacilp.primary_algorithms
Expand All @@ -49,9 +64,7 @@ def run_pipeline(self, algorithms):
eaue._test_options = test_options_stash

def preprocess(self, trip):
# See eacilp.placeholder_predictor_2 for an explanation of the "fingerprint" technique
fingerprint = trip["data"]["start_local_dt"]["hour"]*60+trip["data"]["start_local_dt"]["minute"]
trip["data"]["end_ts"] = self.contrived_dates[fingerprint].float_timestamp
trip["data"]["end_ts"] = self.contrived_dates[self.fingerprint(trip)].float_timestamp
trip["data"]["end_local_dt"]["timezone"] = self.tz

def reset_all(self):
Expand Down Expand Up @@ -79,8 +92,7 @@ def testRawAgainstAnswers(self):
960: {"type": "randomFraction", "value": 0.05}
}
for trip in self.expected_trips:
fingerprint = trip["data"]["start_local_dt"]["hour"]*60+trip["data"]["start_local_dt"]["minute"]
self.assertEqual(eace.get_expectation(trip), answers[fingerprint])
self.assertEqual(eace.get_expectation(trip), answers[self.fingerprint(trip)])

def testProcessedAgainstAnswers(self):
answers = {
Expand All @@ -92,21 +104,34 @@ def testProcessedAgainstAnswers(self):
960: None
}
for trip in self.expected_trips:
fingerprint = trip["data"]["start_local_dt"]["hour"]*60+trip["data"]["start_local_dt"]["minute"]
if answers[fingerprint] is not None: self.assertEqual(trip["data"]["expectation"]["to_label"], answers[fingerprint])
ans = answers[self.fingerprint(trip)]
if ans is not None: self.assertEqual(trip["data"]["expectation"]["to_label"], ans)

def testProcessedAgainstRaw(self):
for trip in self.expected_trips:
self.assertIn("expectation", trip["data"])
raw_expectation = eace.get_expectation(trip)
if raw_expectation["type"] == "none":
self.assertEqual(trip["data"]["expectation"], {"to_label": False})
self.assertEqual(trip["data"]["expectation"]["to_label"], False)
elif raw_expectation["type"] == "all":
self.assertEqual(trip["data"]["expectation"], {"to_label": True})
self.assertEqual(trip["data"]["expectation"]["to_label"], True)
else:
print("Expectation behavior for "+str(raw_expectation)+" has not been implemented yet; not testing. Value is "+str(trip["data"]["expectation"]))
# TODO: implement tests for the other configurable expectation types once they've been implemented

def testConfidenceThreshold(self):
answers = {
494: 0.55,
565: 0.65,
795: 0.55,
805: 0.65,
880: 0.65,
960: 0.55
}
for trip in self.expected_trips:
self.assertTrue("confidence_threshold" in trip["data"]) # Existence
self.assertAlmostEqual(trip["data"]["confidence_threshold"], answers[self.fingerprint(trip)]) # Correctness

def main():
etc.configLogging()
unittest.main()
Expand Down

0 comments on commit 6126201

Please sign in to comment.