-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #32 from Stranger469/main
Add label model: ibcc, ebcc and their corresponding demo.
- Loading branch information
Showing
6 changed files
with
359 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import copy | ||
import logging | ||
import torch | ||
from wrench.dataset import load_dataset, TextDataset | ||
from wrench._logging import LoggingHandler | ||
from wrench.labelmodel import EBCC, IBCC, Snorkel | ||
|
||
|
||
#### Just some code to print debug information to stdout | ||
logging.basicConfig(format='%(asctime)s - %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S', | ||
level=logging.INFO, | ||
handlers=[LoggingHandler()]) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
device = torch.device('cuda') | ||
|
||
#### Load dataset | ||
dataset_path = '../datasets/' | ||
data = 'agnews' | ||
train_data, valid_data, test_data = load_dataset( | ||
dataset_path, | ||
data, | ||
extract_feature=True, | ||
extract_fn='bert', # extract bert embedding | ||
model_name='bert-base-cased', | ||
cache_name='bert' | ||
) | ||
train_data_c = train_data.get_covered_subset() | ||
|
||
#### Inference for EBCC and IBCC | ||
logger.info('============inference============') | ||
ebcc = EBCC( | ||
num_groups=5, | ||
repeat=100, | ||
inference_iter=100, | ||
empirical_prior=True, | ||
) | ||
ibcc = IBCC() | ||
|
||
ebcc.fit( | ||
dataset_train=train_data_c | ||
) | ||
ibcc.fit( | ||
dataset_train=train_data_c | ||
) | ||
|
||
#### Test for EBCC and IBCC | ||
logger.info('============test============') | ||
ebcc.predict_proba(train_data_c) | ||
acc_ebcc = ebcc.test(train_data_c, 'acc') | ||
acc_test_ebcc = ebcc.test(test_data, 'acc') | ||
|
||
acc_ibcc = ibcc.test(train_data_c, 'acc') | ||
acc_test_ibcc = ibcc.test(test_data, 'acc') | ||
|
||
logger.info(f'label model train/test acc on ebcc: {acc_ebcc}, {acc_test_ebcc}, seed={ebcc.params["seed"]}') | ||
logger.info(f'label model train/test acc on ibcc: {acc_ibcc}, {acc_test_ibcc}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from typing import Optional, Any, Union | ||
import numpy as np | ||
import scipy.sparse as ssp | ||
from scipy.special import digamma, gammaln | ||
from scipy.stats import entropy, dirichlet | ||
|
||
from wrench.basemodel import BaseLabelModel | ||
from wrench.dataset import BaseDataset | ||
from ..utils import create_tuples | ||
|
||
|
||
def ebcc_vb(tuples, | ||
num_items, | ||
num_workers, | ||
num_classes, | ||
a_pi=0.1, | ||
num_groups=10, # M | ||
alpha=1, # alpha_k, it can be 1 or \sum_i gamma_ik | ||
a_v=4, # beta_kk | ||
b_v=1, # beta_kk', k neq k' | ||
eta_km=None, | ||
nu_k=None, | ||
mu_jkml=None, | ||
eval=False, | ||
seed=1234, | ||
inference_iter=500, | ||
empirical_prior=False): | ||
|
||
y_is_one_lij = [] | ||
y_is_one_lji = [] | ||
for k in range(num_classes): | ||
selected = (tuples[:, 2] == k) | ||
coo_ij = ssp.coo_matrix((np.ones(selected.sum()), | ||
tuples[selected, :2].T), | ||
shape=(num_items, num_workers), | ||
dtype=np.bool) | ||
y_is_one_lij.append(coo_ij.tocsr()) | ||
y_is_one_lji.append(coo_ij.T.tocsr()) | ||
|
||
beta_kl = np.eye(num_classes) * (a_v - b_v) + b_v | ||
|
||
# initialize z_ik, zg_ikm, c_ik, gamma_ik, sigma_ik | ||
z_ik = np.zeros((num_items, num_classes)) | ||
for l in range(num_classes): | ||
z_ik[:, [l]] += y_is_one_lij[l].sum(axis=-1) + 1e-8 | ||
z_ik /= z_ik.sum(axis=-1, keepdims=True) | ||
|
||
if empirical_prior: | ||
alpha = z_ik.sum(axis=0) | ||
|
||
np.random.seed(seed) | ||
zg_ikm = np.random.dirichlet(np.ones(num_groups), z_ik.shape) * z_ik[:, :, None] | ||
for it in range(inference_iter): | ||
if eval is False: | ||
eta_km = a_pi / num_groups + zg_ikm.sum(axis=0) | ||
nu_k = alpha + z_ik.sum(axis=0) | ||
mu_jkml = np.zeros((num_workers, num_classes, num_groups, num_classes)) + beta_kl[None, :, None, :] | ||
for l in range(num_classes): | ||
for k in range(num_classes): | ||
mu_jkml[:, k, :, l] += y_is_one_lji[l].dot(zg_ikm[:, k, :]) | ||
|
||
Eq_log_pi_km = digamma(eta_km) - digamma(eta_km.sum(axis=-1, keepdims=True)) | ||
Eq_log_tau_k = digamma(nu_k) - digamma(nu_k.sum()) | ||
Eq_log_v_jkml = digamma(mu_jkml) - digamma(mu_jkml.sum(axis=-1, keepdims=True)) | ||
|
||
zg_ikm[:] = Eq_log_pi_km[None, :, :] + Eq_log_tau_k[None, :, None] | ||
for l in range(num_classes): | ||
for k in range(num_classes): | ||
zg_ikm[:, k, :] += y_is_one_lij[l].dot(Eq_log_v_jkml[:, k, :, l]) | ||
|
||
zg_ikm = np.exp(zg_ikm) | ||
zg_ikm /= zg_ikm.reshape(num_items, -1).sum(axis=-1)[:, None, None] | ||
|
||
last_z_ik = z_ik | ||
z_ik = zg_ikm.sum(axis=-1) | ||
|
||
if np.allclose(last_z_ik, z_ik, atol=1e-3): | ||
break | ||
|
||
ELBO = ((eta_km - 1) * Eq_log_pi_km).sum() + ((nu_k - 1) * Eq_log_tau_k).sum() + ( | ||
(mu_jkml - 1) * Eq_log_v_jkml).sum() | ||
ELBO += dirichlet.entropy(nu_k) | ||
for k in range(num_classes): | ||
ELBO += dirichlet.entropy(eta_km[k]) | ||
ELBO += (gammaln(mu_jkml) - (mu_jkml - 1) * digamma(mu_jkml)).sum() | ||
alpha0_jkm = mu_jkml.sum(axis=-1) | ||
ELBO += ((alpha0_jkm - num_classes) * digamma(alpha0_jkm) - gammaln(alpha0_jkm)).sum() | ||
ELBO += entropy(zg_ikm.reshape(num_items, -1).T).sum() | ||
return z_ik, ELBO, eta_km, nu_k, mu_jkml | ||
|
||
|
||
class EBCC(BaseLabelModel): | ||
def __init__(self, | ||
num_groups: Optional[int] = 10, | ||
a_pi: Optional[float] = 0.1, | ||
alpha: Optional[float] = 1, | ||
a_v: Optional[float] = 4, | ||
b_v: Optional[float] = 1, | ||
repeat: Optional[int] = 1000, | ||
inference_iter: Optional[int] = 500, | ||
empirical_prior=False, | ||
**kwargs: Any): | ||
super().__init__() | ||
self.hyperparas = { | ||
'num_groups': num_groups, | ||
'a_pi': a_pi, | ||
'alpha': alpha, | ||
'a_v': a_v, | ||
'b_v': b_v, | ||
'empirical_prior': empirical_prior, | ||
'inference_iter': inference_iter, | ||
**kwargs | ||
} | ||
self.params = { | ||
'seed': None, | ||
'eta_km': None, | ||
'nu_k': None, | ||
'mu_jkml': None, | ||
} | ||
self.repeat = repeat | ||
|
||
def fit(self, | ||
dataset_train: Union[BaseDataset, np.ndarray], | ||
dataset_valid: Optional[Union[BaseDataset, np.ndarray]] = None, | ||
y_valid: Optional[np.ndarray] = None, | ||
n_class: Optional[int] = None, | ||
verbose: Optional[bool] = False, | ||
*args: Any, | ||
**kwargs: Any): | ||
tuples = create_tuples(dataset_train) | ||
num_items, _, num_classes = tuples.max(axis=0) + 1 | ||
num_workers = len(dataset_train.weak_labels[0]) | ||
max_elbo = float('-inf') | ||
|
||
for infer in range(self.repeat): | ||
seed = np.random.randint(1e8) | ||
prediction, elbo, p1, p2, p3 = ebcc_vb(tuples, | ||
num_items, num_workers, num_classes, | ||
seed=seed, | ||
**self.hyperparas) | ||
if elbo > max_elbo: | ||
print(f'update elbo: new: {elbo}, old: {max_elbo}') | ||
self.params = { | ||
'seed': seed, | ||
'eta_km': p1, | ||
'nu_k': p2, | ||
'mu_jkml': p3 | ||
} | ||
max_elbo = elbo | ||
|
||
def predict_proba(self, | ||
dataset: Union[BaseDataset, np.ndarray], | ||
**kwargs: Any): | ||
tuples = create_tuples(dataset) | ||
num_items, _, num_classes = tuples.max(axis=0) + 1 | ||
num_workers = len(dataset.weak_labels[0]) | ||
pred, elbo, _, _, _ = ebcc_vb(tuples, | ||
num_items, num_workers, num_classes, | ||
eval=True, | ||
**self.hyperparas, **self.params) | ||
return pred |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Optional, Any, Union | ||
import numpy as np | ||
import scipy.sparse as ssp | ||
from scipy.special import digamma | ||
|
||
from wrench.basemodel import BaseLabelModel | ||
from wrench.dataset import BaseDataset | ||
from ..utils import create_tuples | ||
|
||
|
||
def ibcc(tuples, | ||
num_items, | ||
num_workers, | ||
num_classes, | ||
a_v=4, | ||
b_v=1, | ||
alpha=1, | ||
n_jkl=None, | ||
eval=False): | ||
|
||
y_is_one_kij = [] | ||
y_is_one_kji = [] | ||
for k in range(num_classes): | ||
selected = (tuples[:, 2] == k) | ||
coo_ij = ssp.coo_matrix((np.ones(selected.sum()), tuples[selected, :2].T), | ||
shape=(num_items, num_workers), dtype=np.bool) | ||
y_is_one_kij.append(coo_ij.tocsr()) | ||
y_is_one_kji.append(coo_ij.T.tocsr()) | ||
|
||
# initialization | ||
prior_kl = np.eye(num_classes) * (a_v - b_v) + b_v | ||
if n_jkl is None: | ||
n_jkl = np.empty((num_workers, num_classes, num_classes)) | ||
|
||
# MV initialize Z | ||
z_ik = np.zeros((num_items, num_classes)) | ||
for l in range(num_classes): | ||
z_ik[:, [l]] += y_is_one_kij[l].sum(axis=-1) + 1e-8 | ||
z_ik /= z_ik.sum(axis=-1, keepdims=True) | ||
last_z_ik = z_ik.copy() | ||
|
||
for iteration in range(500): | ||
# E step | ||
Eq_log_pi_k = digamma(z_ik.sum(axis=0) + alpha) # - digamma(num_items + num_classes * alpha) | ||
|
||
if eval is False: | ||
for l in range(num_classes): | ||
n_jkl[:, :, l] = y_is_one_kji[l].dot(z_ik) | ||
|
||
Eq_log_v_jkl = digamma(n_jkl + prior_kl[None, :, :]) - \ | ||
digamma(n_jkl.sum(axis=-1) + prior_kl.sum(axis=-1))[:, :, None] | ||
|
||
# M step | ||
last_z_ik[:] = z_ik | ||
|
||
z_ik[:] = Eq_log_pi_k | ||
for l in range(num_classes): | ||
z_ik += y_is_one_kij[l].dot(Eq_log_v_jkl[:, :, l]) | ||
z_ik -= z_ik.max(axis=-1, keepdims=True) | ||
z_ik = np.exp(z_ik) | ||
z_ik /= z_ik.sum(axis=-1, keepdims=True) | ||
|
||
if np.allclose(last_z_ik, z_ik, atol=1e-3): | ||
break | ||
return z_ik, n_jkl | ||
|
||
|
||
class IBCC(BaseLabelModel): | ||
def __init__(self, | ||
alpha: Optional[float] = 1, | ||
a_v: Optional[float] = 4, | ||
b_v: Optional[float] = 1, | ||
**kwargs: Any): | ||
super().__init__() | ||
self.hyperparas = { | ||
'alpha': alpha, | ||
'a_v': a_v, | ||
'b_v': b_v, | ||
**kwargs | ||
} | ||
self.params = { | ||
'n_jkl': None, | ||
} | ||
|
||
def fit(self, | ||
dataset_train: Union[BaseDataset, np.ndarray], | ||
dataset_valid: Optional[Union[BaseDataset, np.ndarray]] = None, | ||
y_valid: Optional[np.ndarray] = None, | ||
n_class: Optional[int] = None, | ||
verbose: Optional[bool] = False, | ||
*args: Any, | ||
**kwargs: Any): | ||
tuples = create_tuples(dataset_train) | ||
num_items, num_workers, num_classes = tuples.max(axis=0) + 1 | ||
_, param = ibcc(tuples, num_items, num_workers, num_classes, **self.hyperparas) | ||
self.params['n_jkl'] = param | ||
|
||
def predict_proba(self, | ||
dataset: Union[BaseDataset, np.ndarray], | ||
**kwargs: Any): | ||
tuples = create_tuples(dataset) | ||
num_items, _, num_classes = tuples.max(axis=0) + 1 | ||
num_workers = len(dataset.weak_labels[0]) | ||
pred, _ = ibcc(tuples, num_items, num_workers, num_classes, | ||
eval=True, | ||
**self.hyperparas, **self.params) | ||
return pred |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters