Skip to content

Commit

Permalink
Rolling VersionedDataHandler back to what's in develop
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Dec 12, 2024
1 parent f2ca6cf commit 8ba8b37
Showing 1 changed file with 22 additions and 34 deletions.
56 changes: 22 additions & 34 deletions src/elexmodel/handlers/data/VersionedData.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import warnings
from datetime import datetime

import numpy as np
import pandas as pd
from dateutil import tz

from elexmodel.handlers import s3
from elexmodel.handlers.data.BaseDataHandler import BaseDataHandler
from elexmodel.handlers.data.Estimandizer import Estimandizer
from elexmodel.logger import getModelLogger
from elexmodel.utils.file_utils import S3_FILE_PATH, TARGET_BUCKET

LOG = getModelLogger()


class VersionedDataHandler(BaseDataHandler):
class VersionedDataHandler:
def __init__(
self,
election_id,
Expand All @@ -24,28 +23,29 @@ def __init__(
end_date=None,
sample=2,
tzinfo="America/New_York",
data=None,
):
self.election_id = election_id
self.office_id = office_id
self.geographic_unit_type = geographic_unit_type
self.estimands = estimands
self.start_date = start_date # in EST
self.end_date = end_date # in EST

if election_id.startswith("2020-11-03_USA_G"):
if self.election_id.startswith("2020-11-03_USA_G"):
target_bucket = "elex-models-prod"
else:
target_bucket = TARGET_BUCKET
start_date = datetime.fromisoformat(start_date).astimezone(tz=tz.gettz("UTC")) if start_date else None
end_date = datetime.fromisoformat(end_date).astimezone(tz=tz.gettz("UTC")) if end_date else None
# versioned results natively are in UTC but we'll convert it back to timezone in tzinfo
s3_client = s3.S3VersionUtil(target_bucket, start_date, end_date, tzinfo)
self.s3_client = s3.S3VersionUtil(target_bucket, start_date, end_date, tzinfo)

# Sample lets us skip every nth version, by default 2.
self.sample = sample

# This handles timezone conversion for us, by default to EST.
self.tz = tzinfo

super().__init__(election_id, office_id, geographic_unit_type, estimands, s3_client=s3_client, data=data)

def get_versioned_results(self, filepath=None):
if filepath is not None:
versioned_results_np = np.load(f"{filepath}/versioned_results.npy")
Expand All @@ -64,25 +64,21 @@ def get_versioned_results(self, filepath=None):

if self.election_id.startswith("2020-11-03_USA_G"):
path = "elex-models-prod/2020-general/results/pres/current.csv"
elif self.election_id.startswith("2024-11-05_USA_G"):
path = f"{S3_FILE_PATH}/{self.election_id}/results/{self.office_id}/{self.geographic_unit_type}/current_counties.csv"
else:
base_dir = f"{S3_FILE_PATH}/{self.election_id}/results/{self.office_id}/{self.geographic_unit_type}"
if self.election_id.startswith("2024-11-05_USA_G"):
path = base_dir + "/current_counties.csv"
else:
path = base_dir + "/current.csv"
path = f"{S3_FILE_PATH}/{self.election_id}/results/{self.office_id}/{self.geographic_unit_type}/current.csv"

data = self.s3_client.get(path, self.sample)
LOG.info("Loaded versioned results from S3")
if data is None:
self.data = data
return data
data, _ = self.estimandizer.add_estimand_results(data, self.estimands, False)
estimandizer = Estimandizer()
data, _ = estimandizer.add_estimand_results(data, self.estimands, False)
self.data = data.sort_values("last_modified")
return self.data

def get_data(self):
return self.get_versioned_results()

def compute_versioned_margin_estimate(self, data=None):
"""
This function imputes the margin at each percent reporting for a versioned dataset.
Expand Down Expand Up @@ -128,8 +124,7 @@ def compute_estimated_margin(df):
casting="unsafe",
)

# check if perc_expected_vote_corr is monotone increasing
# (if not, give up and don't try to estimate a margin)
# check if perc_expected_vote_corr is monotone increasing (if not, give up and don't try to estimate a margin)
if not np.all(np.diff(perc_expected_vote_corr) >= 0):
return pd.DataFrame(
{
Expand All @@ -148,18 +143,15 @@ def compute_estimated_margin(df):
# Compute batch_margin using NumPy
# this is the difference in dem_votes - the difference in gop_votes divided by the difference in total votes
# that is, this is the normalized margin in the batch of votes recorded between versions
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
batch_margin = (
np.diff(results_dem, append=results_dem[-1]) - np.diff(results_gop, append=results_gop[-1])
) / np.diff(results_weights, append=results_weights[-1])
batch_margin = (
np.diff(results_dem, append=results_dem[-1]) - np.diff(results_gop, append=results_gop[-1])
) / np.diff(results_weights, append=results_weights[-1])

# nan values in batch_margin are due to div-by-zero since there's no change in votes
batch_margin[np.isnan(batch_margin)] = 0 # Set NaN margins to 0
df["batch_margin"] = batch_margin

# batch_margins should be between -1 and 1
# (otherwise, there was a data entry issue and we will not use this unit)
# batch_margins should be between -1 and 1 (otherwise, there was a data entry issue and we will not use this unit)
if np.abs(batch_margin).max() > 1:
return pd.DataFrame(
{
Expand Down Expand Up @@ -216,9 +208,7 @@ def compute_estimated_margin(df):
}
)

results = (
results.groupby("geographic_unit_fips").apply(compute_estimated_margin, include_groups=False).reset_index()
)
results = results.groupby("geographic_unit_fips").apply(compute_estimated_margin).reset_index()

for error_type in sorted(set(results["error_type"])):
if error_type == "none":
Expand All @@ -227,16 +217,14 @@ def compute_estimated_margin(df):
LOG.info(f"# of versioned units with {error_type} error: {len(category_error_type)}")
return results

def load_data(self, data):
return self.compute_versioned_margin_estimate(data=data)

def get_versioned_predictions(self, filepath=None):
if filepath is not None:
return pd.read_csv(filepath)

if self.election_id.startswith("2020-11-03_USA_G"):
path = "elex-models-prod/2020-general/prediction/pres/current.csv"
raise ValueError("No versioned predictions available for this election.")

path = f"{S3_FILE_PATH}/{self.election_id}/predictions/{self.office_id}/{self.geographic_unit_type}/current.csv"
else:
path = f"{S3_FILE_PATH}/{self.election_id}/predictions/{self.office_id}/{self.geographic_unit_type}/current.csv"

return self.s3_client.get(path, self.sample)

0 comments on commit 8ba8b37

Please sign in to comment.