Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
skadio committed Jan 24, 2024
1 parent 1adfb48 commit 4830f9b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 39 deletions.
23 changes: 14 additions & 9 deletions jurity/mitigation/equalized_odds.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,24 @@ def fit(self,
# Solve
prob.solve()

# Save fairness probabilities
# Save fairness probabilities (cvxpy value is a numpy array or None)
self.p2p_prob_0 = variables_0["p2p"].value
if isinstance(self.p2p_prob_0,np.ndarray):
self.p2p_prob_0=self.p2p_prob_0.item()
self.n2p_prob_0 = variables_0["n2p"].value
if isinstance(self.n2p_prob_0,np.ndarray):
self.n2p_prob_0=self.n2p_prob_0.item()
self.p2p_prob_1 = variables_1["p2p"].value
if isinstance(self.p2p_prob_1,np.ndarray):
self.p2p_prob_1=self.p2p_prob_1.item()
self.n2p_prob_1 = variables_1["n2p"].value
if isinstance(self.n2p_prob_1,np.ndarray):
self.n2p_prob_1=self.n2p_prob_1.item()

# Get the scalar/primitive value unless it is None
if isinstance(self.p2p_prob_0, np.ndarray):
self.p2p_prob_0 = self.p2p_prob_0[0]

if isinstance(self.n2p_prob_0, np.ndarray):
self.n2p_prob_0 = self.n2p_prob_0[0]

if isinstance(self.p2p_prob_1, np.ndarray):
self.p2p_prob_1 = self.p2p_prob_1[0]

if isinstance(self.n2p_prob_1, np.ndarray):
self.n2p_prob_1 = self.n2p_prob_1[0]

def fit_transform(self,
labels: Union[List, np.ndarray, pd.Series],
Expand Down
38 changes: 8 additions & 30 deletions tests/test_mitigation_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,10 @@ def test_numerical_stability_mixing_rate_small(self):

mitigation.fit(labels, predictions, likelihoods, is_member)

p2p_prob_0 = mitigation.p2p_prob_0
n2p_prob_0 = mitigation.n2p_prob_0
p2p_prob_1 = mitigation.p2p_prob_1
n2p_prob_1 = mitigation.n2p_prob_1

# # Convert types
# p2p_prob_0 = p2p_prob_0
# n2p_prob_0 = n2p_prob_0
# p2p_prob_1 = p2p_prob_1
# n2p_prob_1 = n2p_prob_1

self.assertAlmostEqual(p2p_prob_0, 0.8429378)
self.assertAlmostEqual(n2p_prob_0, 1.)
self.assertAlmostEqual(p2p_prob_1, 1.)
self.assertAlmostEqual(n2p_prob_1, 0.8893096)
self.assertAlmostEqual(mitigation.p2p_prob_0, 0.8429378)
self.assertAlmostEqual(mitigation.n2p_prob_0, 1.)
self.assertAlmostEqual(mitigation.p2p_prob_1, 1.)
self.assertAlmostEqual(mitigation.n2p_prob_1, 0.8893096)

def test_numerical_stability_mixing_rate_large(self):

Expand All @@ -183,21 +172,10 @@ def test_numerical_stability_mixing_rate_large(self):

mitigation.fit(labels, predictions, likelihoods, is_member)

p2p_prob_0 = mitigation.p2p_prob_0
n2p_prob_0 = mitigation.n2p_prob_0
p2p_prob_1 = mitigation.p2p_prob_1
n2p_prob_1 = mitigation.n2p_prob_1

# # Convert types
# p2p_prob_0 = p2p_prob_0
# n2p_prob_0 = n2p_prob_0
# p2p_prob_1 = p2p_prob_1
# n2p_prob_1 = n2p_prob_1

self.assertAlmostEqual(p2p_prob_0, 0.819513)
self.assertAlmostEqual(n2p_prob_0, 1.)
self.assertAlmostEqual(p2p_prob_1, 0.644566)
self.assertAlmostEqual(n2p_prob_1, 1.)
self.assertAlmostEqual(mitigation.p2p_prob_0, 0.819513)
self.assertAlmostEqual(mitigation.n2p_prob_0, 1.)
self.assertAlmostEqual(mitigation.p2p_prob_1, 0.644566)
self.assertAlmostEqual(mitigation.n2p_prob_1, 1.)

def test_numerical_stability_bias_mitigation(self):

Expand Down

0 comments on commit 4830f9b

Please sign in to comment.