From 86cd423567e3cc93388424abc58e3f03d7701079 Mon Sep 17 00:00:00 2001 From: Alex Song Date: Wed, 18 May 2022 14:53:51 +0800 Subject: [PATCH 1/3] Add label model: ibcc, ebcc and their corresponding demo. --- examples/vi_demo.py | 93 +++++++++++++++++++++++ wrench/labelmodel/__init__.py | 2 + wrench/labelmodel/ebcc.py | 137 ++++++++++++++++++++++++++++++++++ wrench/labelmodel/ibcc.py | 88 ++++++++++++++++++++++ wrench/utils.py | 15 +++- 5 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 examples/vi_demo.py create mode 100644 wrench/labelmodel/ebcc.py create mode 100644 wrench/labelmodel/ibcc.py diff --git a/examples/vi_demo.py b/examples/vi_demo.py new file mode 100644 index 0000000..477f6b5 --- /dev/null +++ b/examples/vi_demo.py @@ -0,0 +1,93 @@ +import copy +import logging +import torch +from wrench.dataset import load_dataset, TextDataset +from wrench._logging import LoggingHandler +from wrench.labelmodel import EBCC, IBCC, Snorkel + + +def concat(d1: TextDataset, d2: TextDataset) -> TextDataset: + dataset = TextDataset() + dataset.ids = d1.ids + d2.ids + dataset.labels = d1.labels + d2.labels + dataset.examples = d1.examples + d2.examples + dataset.weak_labels = d1.weak_labels + d2.weak_labels + dataset.n_class = d1.n_class + dataset.n_lf = d1.n_lf + + return dataset + + +#### Just some code to print debug information to stdout +logging.basicConfig(format='%(asctime)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + handlers=[LoggingHandler()]) + +logger = logging.getLogger(__name__) + +device = torch.device('cuda') + +#### Load dataset +dataset_path = '../datasets/' +data = 'youtube' +train_data, valid_data, test_data = load_dataset( + dataset_path, + data, + extract_feature=True, + extract_fn='bert', # extract bert embedding + model_name='bert-base-cased', + cache_name='bert' +) +train_data_c = train_data.get_covered_subset() +#### Run label model: Snorkel + +# print('============inductive============') +# label_model_generate = EBCC( +# num_groups=5, +# inference_iter=1, +# max_iter=1, +# empirical_prior=True, +# kernel_function=RBF(length_scale=1) +# ) +# + +# label_model_generate.seed = 12345 +# train_data_c.labels = probs_to_preds(label_model_generate.predict_proba(train_data_c)) + +print('============inference============') +ebcc = EBCC( + num_groups=5, + repeat=10, + inference_iter=100, + empirical_prior=True, +) +ibcc = IBCC() +snorkel = Snorkel( + lr=0.01, + l2=0.0, + n_epochs=10 +) + +ebcc.fit( + dataset_train=train_data_c +) +snorkel.fit( + dataset_train=train_data_c, + dataset_valid=valid_data +) + +print('============test============') +ebcc.predict_proba(train_data_c) +acc_ebcc = ebcc.test(train_data_c, 'acc') +acc_test_ebcc = ebcc.test(test_data, 'acc') + +acc_ibcc = ibcc.test(train_data_c, 'acc') +acc_test_ibcc = ibcc.test(test_data, 'acc') + +acc_s = snorkel.test(train_data_c, 'acc') +acc_test_s = snorkel.test(test_data, 'acc') + +logger.info(f'label model train/test acc on gp-ebcc: {acc_ebcc}, {acc_test_ebcc}, seed={ebcc.seed}') +logger.info(f'label model train/test acc on gp-ibcc: {acc_ibcc}, {acc_test_ibcc}') +logger.info(f'label model train/test acc on snorkel: {acc_s}, {acc_test_s}') diff --git a/wrench/labelmodel/__init__.py b/wrench/labelmodel/__init__.py index f716e10..ca7ac55 100755 --- a/wrench/labelmodel/__init__.py +++ b/wrench/labelmodel/__init__.py @@ -6,3 +6,5 @@ from .metal import MeTaL from .naive_bayes import NaiveBayesModel from .snorkel import Snorkel +from .ebcc import EBCC +from .ibcc import IBCC diff --git a/wrench/labelmodel/ebcc.py b/wrench/labelmodel/ebcc.py new file mode 100644 index 0000000..63e527b --- /dev/null +++ b/wrench/labelmodel/ebcc.py @@ -0,0 +1,137 @@ +from typing import Optional, Any, Union +import numpy as np +import scipy.sparse as ssp +from scipy.special import digamma, gammaln +from scipy.stats import entropy, dirichlet + +from wrench.basemodel import BaseLabelModel +from wrench.dataset import BaseDataset +from ..utils import create_tuples + + +def ebcc_vb(tuples, + a_pi=0.1, + num_groups=10, # M + alpha=1, # alpha_k, it can be 1 or \sum_i gamma_ik + a_v=4, # beta_kk + b_v=1, # beta_kk', k neq k' + seed=1234, + inference_iter=500, + empirical_prior=False): + num_items, num_workers, num_classes = tuples.max(axis=0) + 1 + + y_is_one_lij = [] + y_is_one_lji = [] + for k in range(num_classes): + selected = (tuples[:, 2] == k) + coo_ij = ssp.coo_matrix((np.ones(selected.sum()), + tuples[selected, :2].T), + shape=(num_items, num_workers), + dtype=np.bool) + y_is_one_lij.append(coo_ij.tocsr()) + y_is_one_lji.append(coo_ij.T.tocsr()) + + beta_kl = np.eye(num_classes) * (a_v - b_v) + b_v + + # initialize z_ik, zg_ikm, c_ik, gamma_ik, sigma_ik + z_ik = np.zeros((num_items, num_classes)) + for l in range(num_classes): + z_ik[:, [l]] += y_is_one_lij[l].sum(axis=-1) + 1e-8 + z_ik /= z_ik.sum(axis=-1, keepdims=True) + + if empirical_prior: + alpha = z_ik.sum(axis=0) + + np.random.seed(seed) + zg_ikm = np.random.dirichlet(np.ones(num_groups), z_ik.shape) * z_ik[:, :, None] + for it in range(inference_iter): + eta_km = a_pi / num_groups + zg_ikm.sum(axis=0) + nu_k = alpha + z_ik.sum(axis=0) + + mu_jkml = np.zeros((num_workers, num_classes, num_groups, num_classes)) + beta_kl[None, :, None, :] + for l in range(num_classes): + for k in range(num_classes): + mu_jkml[:, k, :, l] += y_is_one_lji[l].dot(zg_ikm[:, k, :]) + + Eq_log_pi_km = digamma(eta_km) - digamma(eta_km.sum(axis=-1, keepdims=True)) + Eq_log_tau_k = digamma(nu_k) - digamma(nu_k.sum()) + Eq_log_v_jkml = digamma(mu_jkml) - digamma(mu_jkml.sum(axis=-1, keepdims=True)) + + zg_ikm[:] = Eq_log_pi_km[None, :, :] + Eq_log_tau_k[None, :, None] + for l in range(num_classes): + for k in range(num_classes): + zg_ikm[:, k, :] += y_is_one_lij[l].dot(Eq_log_v_jkml[:, k, :, l]) + + zg_ikm = np.exp(zg_ikm) + zg_ikm /= zg_ikm.reshape(num_items, -1).sum(axis=-1)[:, None, None] + + last_z_ik = z_ik + z_ik = zg_ikm.sum(axis=-1) + + if np.allclose(last_z_ik, z_ik, atol=1e-3): + break + + ELBO = ((eta_km - 1) * Eq_log_pi_km).sum() + ((nu_k - 1) * Eq_log_tau_k).sum() + ( + (mu_jkml - 1) * Eq_log_v_jkml).sum() + ELBO += dirichlet.entropy(nu_k) + for k in range(num_classes): + ELBO += dirichlet.entropy(eta_km[k]) + ELBO += (gammaln(mu_jkml) - (mu_jkml - 1) * digamma(mu_jkml)).sum() + alpha0_jkm = mu_jkml.sum(axis=-1) + ELBO += ((alpha0_jkm - num_classes) * digamma(alpha0_jkm) - gammaln(alpha0_jkm)).sum() + ELBO += entropy(zg_ikm.reshape(num_items, -1).T).sum() + return z_ik, ELBO + + +class EBCC(BaseLabelModel): + def __init__(self, + num_groups: Optional[int] = 10, + a_pi: Optional[float] = 0.1, + alpha: Optional[float] = 1, + a_v: Optional[float] = 4, + b_v: Optional[float] = 1, + repeat: Optional[int] = 1000, + inference_iter: Optional[int] = 500, + empirical_prior=False, + **kwargs: Any): + super().__init__() + self.hyperparas = { + 'num_groups': num_groups, + 'a_pi': a_pi, + 'alpha': alpha, + 'a_v': a_v, + 'b_v': b_v, + 'empirical_prior': empirical_prior, + 'inference_iter': inference_iter, + **kwargs + } + self.repeat = repeat + self.seed = None + + def fit(self, + dataset_train: Union[BaseDataset, np.ndarray], + dataset_valid: Optional[Union[BaseDataset, np.ndarray]] = None, + y_valid: Optional[np.ndarray] = None, + n_class: Optional[int] = None, + verbose: Optional[bool] = False, + *args: Any, + **kwargs: Any): + train_tuples = create_tuples(dataset_train) + max_elbo = float('-inf') + + self.seed = None + + for infer in range(self.repeat): + seed = np.random.randint(1e8) + self.seed = seed + prediction, elbo = ebcc_vb(train_tuples, seed=seed, **self.hyperparas) + if elbo > max_elbo: + self.seed = seed + + def predict_proba(self, + dataset: Union[BaseDataset, np.ndarray], + **kwargs: Any): + tuples = create_tuples(dataset) + prediction, elbo = ebcc_vb(tuples, seed=self.seed, **self.hyperparas) + + return prediction diff --git a/wrench/labelmodel/ibcc.py b/wrench/labelmodel/ibcc.py new file mode 100644 index 0000000..c6f8e09 --- /dev/null +++ b/wrench/labelmodel/ibcc.py @@ -0,0 +1,88 @@ +from typing import Optional, Any, Union +import numpy as np +import scipy.sparse as ssp +from scipy.special import digamma + +from wrench.basemodel import BaseLabelModel +from wrench.dataset import BaseDataset +from ..utils import create_tuples + + +def ibcc(tuples, a_v=4, b_v=1, alpha=1): + num_items, num_workers, num_classes = tuples.max(axis=0) + 1 + + y_is_one_kij = [] + y_is_one_kji = [] + for k in range(num_classes): + selected = (tuples[:, 2] == k) + coo_ij = ssp.coo_matrix((np.ones(selected.sum()), tuples[selected, :2].T), shape=(num_items, num_workers), + dtype=np.bool) + y_is_one_kij.append(coo_ij.tocsr()) + y_is_one_kji.append(coo_ij.T.tocsr()) + + # initialization + prior_kl = np.eye(num_classes) * (a_v - b_v) + b_v + n_jkl = np.empty((num_workers, num_classes, num_classes)) + + # MV initialize Z + z_ik = np.zeros((num_items, num_classes)) + for l in range(num_classes): + z_ik[:, [l]] += y_is_one_kij[l].sum(axis=-1) + 1e-8 + z_ik /= z_ik.sum(axis=-1, keepdims=True) + last_z_ik = z_ik.copy() + + for iteration in range(500): + # E step + Eq_log_pi_k = digamma(z_ik.sum(axis=0) + alpha) # - digamma(num_items + num_classes * alpha) + + for l in range(num_classes): + n_jkl[:, :, l] = y_is_one_kji[l].dot(z_ik) + + Eq_log_v_jkl = digamma(n_jkl + prior_kl[None, :, :]) - \ + digamma(n_jkl.sum(axis=-1) + prior_kl.sum(axis=-1))[:, :, None] + + # M step + last_z_ik[:] = z_ik + + z_ik[:] = Eq_log_pi_k + for l in range(num_classes): + z_ik += y_is_one_kij[l].dot(Eq_log_v_jkl[:, :, l]) + z_ik -= z_ik.max(axis=-1, keepdims=True) + z_ik = np.exp(z_ik) + z_ik /= z_ik.sum(axis=-1, keepdims=True) + + if np.allclose(last_z_ik, z_ik, atol=1e-3): + break + return z_ik + + +class IBCC(BaseLabelModel): + def __init__(self, + alpha: Optional[float] = 1, + a_v: Optional[float] = 4, + b_v: Optional[float] = 1, + **kwargs: Any): + super().__init__() + self.hyperparas = { + 'alpha': alpha, + 'a_v': a_v, + 'b_v': b_v, + **kwargs + } + + def fit(self, + dataset_train: Union[BaseDataset, np.ndarray], + dataset_valid: Optional[Union[BaseDataset, np.ndarray]] = None, + y_valid: Optional[np.ndarray] = None, + n_class: Optional[int] = None, + verbose: Optional[bool] = False, + *args: Any, + **kwargs: Any): + pass + + def predict_proba(self, + dataset: Union[BaseDataset, np.ndarray], + **kwargs: Any): + tuples = create_tuples(dataset) + + return ibcc(tuples, **self.hyperparas) diff --git a/wrench/utils.py b/wrench/utils.py index d769167..894de66 100755 --- a/wrench/utils.py +++ b/wrench/utils.py @@ -1,6 +1,6 @@ import random from collections import Counter -from typing import Dict, Optional +from typing import Dict, Optional, Union import numpy as np import torch @@ -181,3 +181,16 @@ def collate_fn_trunc_pad(batch: Dict): return batch return collate_fn_trunc_pad + + +def create_tuples(dataset: Union[BaseDataset, np.ndarray]): + ids = np.repeat(np.array(range(len(dataset))), len(dataset.weak_labels[0])) + workers = np.repeat( + np.array([i for i in range(len(dataset.weak_labels[0]))]), len(dataset.weak_labels) + ).reshape(len(dataset.weak_labels[0]), -1).T.reshape(-1) + classes = np.array(dataset.weak_labels).reshape(-1) + + tuples = np.vstack((ids, workers, classes)) + tuples = tuples[:, tuples[2, :] != -1] + + return tuples.T From c31faff6bdd2dfb09eaddd5bb3e150eefc637c65 Mon Sep 17 00:00:00 2001 From: Alex Song Date: Wed, 18 May 2022 15:38:53 +0800 Subject: [PATCH 2/3] rename; change readme --- README.md | 30 ++++++++++++----------- examples/{vi_demo.py => run_ibcc_ebcc.py} | 0 2 files changed, 16 insertions(+), 14 deletions(-) rename examples/{vi_demo.py => run_ibcc_ebcc.py} (100%) diff --git a/README.md b/README.md index 5f874c8..2d279b7 100755 --- a/README.md +++ b/README.md @@ -137,20 +137,22 @@ The detailed documentation is coming soon. TODO-list: check [this](https://github.com/JieyuZ2/wrench/wiki/TODO-List) out! ### classification: -| Model | Model Type | Reference | Link to Wrench | -|:--------|:---------|:------|:------| -| Majority Voting | Label Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/majority_voting.py#L44) | -| Weighted Majority Voting | Label Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/majority_voting.py#L14) | -| Dawid-Skene | Label Model | [link](https://www.jstor.org/stable/2346806) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/dawid_skene.py#L15) | -| Data Progamming | Label Model | [link](https://arxiv.org/abs/1605.07723) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/generative_model.py#L18) | -| MeTaL | Label Model | [link](https://arxiv.org/abs/1810.02840) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/snorkel.py#L17) | -| FlyingSquid | Label Model | [link](https://arxiv.org/pdf/2002.11955) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/flyingsquid.py#L16) | -| Logistic Regression | End Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/linear_model.py#L52) | -| MLP | End Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/neural_model.py#L21) | -| BERT | End Model | [link](https://huggingface.co/models) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/bert_model.py#L23) | -| COSINE | End Model | [link](https://arxiv.org/abs/2010.07835) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/cosine.py#L68) | -| Denoise | Joint Model | [link](https://arxiv.org/abs/2010.04582) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/classification/denoise.py#L72) | -| WeaSEL | Joint Model | [link](https://arxiv.org/abs/2107.02233) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/classification/weasel.py#L72) | +| Model | Model Type | Reference | Link to Wrench | +|:-------------------------|:------------|:-----------------------------------------------------|:----------------------------------------------------------------------------------------------| +| Majority Voting | Label Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/majority_voting.py#L44) | +| Weighted Majority Voting | Label Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/majority_voting.py#L14) | +| Dawid-Skene | Label Model | [link](https://www.jstor.org/stable/2346806) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/dawid_skene.py#L15) | +| Data Progamming | Label Model | [link](https://arxiv.org/abs/1605.07723) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/generative_model.py#L18) | +| MeTaL | Label Model | [link](https://arxiv.org/abs/1810.02840) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/snorkel.py#L17) | +| FlyingSquid | Label Model | [link](https://arxiv.org/pdf/2002.11955) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/flyingsquid.py#L16) | + | EBCC | Label Model | [link](https://proceedings.mlr.press/v97/li19i.html) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/ebcc.py#L12) | +| IBCC | Label Model | [link](https://proceedings.mlr.press/v97/li19i.html) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/ibcc.py#L11) | + | Logistic Regression | End Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/linear_model.py#L52) | + | MLP | End Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/neural_model.py#L21) | + | BERT | End Model | [link](https://huggingface.co/models) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/bert_model.py#L23) | + | COSINE | End Model | [link](https://arxiv.org/abs/2010.07835) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/cosine.py#L68) | + | Denoise | Joint Model | [link](https://arxiv.org/abs/2010.04582) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/classification/denoise.py#L72) | + | WeaSEL | Joint Model | [link](https://arxiv.org/abs/2107.02233) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/classification/weasel.py#L72) | ### sequence tagging: | Model | Model Type | Reference | Link to Wrench | diff --git a/examples/vi_demo.py b/examples/run_ibcc_ebcc.py similarity index 100% rename from examples/vi_demo.py rename to examples/run_ibcc_ebcc.py From 4ff5d50027944f024b2ce2079dd8881406025d93 Mon Sep 17 00:00:00 2001 From: Alex Song Date: Wed, 18 May 2022 17:50:59 +0800 Subject: [PATCH 3/3] Fit available for ebcc, ibcc --- examples/run_ibcc_ebcc.py | 54 +++++++--------------------------- wrench/labelmodel/ebcc.py | 62 +++++++++++++++++++++++++++------------ wrench/labelmodel/ibcc.py | 41 +++++++++++++++++++------- 3 files changed, 83 insertions(+), 74 deletions(-) diff --git a/examples/run_ibcc_ebcc.py b/examples/run_ibcc_ebcc.py index 477f6b5..4221ada 100644 --- a/examples/run_ibcc_ebcc.py +++ b/examples/run_ibcc_ebcc.py @@ -6,18 +6,6 @@ from wrench.labelmodel import EBCC, IBCC, Snorkel -def concat(d1: TextDataset, d2: TextDataset) -> TextDataset: - dataset = TextDataset() - dataset.ids = d1.ids + d2.ids - dataset.labels = d1.labels + d2.labels - dataset.examples = d1.examples + d2.examples - dataset.weak_labels = d1.weak_labels + d2.weak_labels - dataset.n_class = d1.n_class - dataset.n_lf = d1.n_lf - - return dataset - - #### Just some code to print debug information to stdout logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -30,7 +18,7 @@ def concat(d1: TextDataset, d2: TextDataset) -> TextDataset: #### Load dataset dataset_path = '../datasets/' -data = 'youtube' +data = 'agnews' train_data, valid_data, test_data = load_dataset( dataset_path, data, @@ -40,44 +28,26 @@ def concat(d1: TextDataset, d2: TextDataset) -> TextDataset: cache_name='bert' ) train_data_c = train_data.get_covered_subset() -#### Run label model: Snorkel - -# print('============inductive============') -# label_model_generate = EBCC( -# num_groups=5, -# inference_iter=1, -# max_iter=1, -# empirical_prior=True, -# kernel_function=RBF(length_scale=1) -# ) -# -# label_model_generate.seed = 12345 -# train_data_c.labels = probs_to_preds(label_model_generate.predict_proba(train_data_c)) - -print('============inference============') +#### Inference for EBCC and IBCC +logger.info('============inference============') ebcc = EBCC( num_groups=5, - repeat=10, + repeat=100, inference_iter=100, empirical_prior=True, ) ibcc = IBCC() -snorkel = Snorkel( - lr=0.01, - l2=0.0, - n_epochs=10 -) ebcc.fit( dataset_train=train_data_c ) -snorkel.fit( - dataset_train=train_data_c, - dataset_valid=valid_data +ibcc.fit( + dataset_train=train_data_c ) -print('============test============') +#### Test for EBCC and IBCC +logger.info('============test============') ebcc.predict_proba(train_data_c) acc_ebcc = ebcc.test(train_data_c, 'acc') acc_test_ebcc = ebcc.test(test_data, 'acc') @@ -85,9 +55,5 @@ def concat(d1: TextDataset, d2: TextDataset) -> TextDataset: acc_ibcc = ibcc.test(train_data_c, 'acc') acc_test_ibcc = ibcc.test(test_data, 'acc') -acc_s = snorkel.test(train_data_c, 'acc') -acc_test_s = snorkel.test(test_data, 'acc') - -logger.info(f'label model train/test acc on gp-ebcc: {acc_ebcc}, {acc_test_ebcc}, seed={ebcc.seed}') -logger.info(f'label model train/test acc on gp-ibcc: {acc_ibcc}, {acc_test_ibcc}') -logger.info(f'label model train/test acc on snorkel: {acc_s}, {acc_test_s}') +logger.info(f'label model train/test acc on ebcc: {acc_ebcc}, {acc_test_ebcc}, seed={ebcc.params["seed"]}') +logger.info(f'label model train/test acc on ibcc: {acc_ibcc}, {acc_test_ibcc}') diff --git a/wrench/labelmodel/ebcc.py b/wrench/labelmodel/ebcc.py index 63e527b..61aebed 100644 --- a/wrench/labelmodel/ebcc.py +++ b/wrench/labelmodel/ebcc.py @@ -10,15 +10,21 @@ def ebcc_vb(tuples, + num_items, + num_workers, + num_classes, a_pi=0.1, num_groups=10, # M alpha=1, # alpha_k, it can be 1 or \sum_i gamma_ik a_v=4, # beta_kk b_v=1, # beta_kk', k neq k' + eta_km=None, + nu_k=None, + mu_jkml=None, + eval=False, seed=1234, inference_iter=500, empirical_prior=False): - num_items, num_workers, num_classes = tuples.max(axis=0) + 1 y_is_one_lij = [] y_is_one_lji = [] @@ -45,13 +51,13 @@ def ebcc_vb(tuples, np.random.seed(seed) zg_ikm = np.random.dirichlet(np.ones(num_groups), z_ik.shape) * z_ik[:, :, None] for it in range(inference_iter): - eta_km = a_pi / num_groups + zg_ikm.sum(axis=0) - nu_k = alpha + z_ik.sum(axis=0) - - mu_jkml = np.zeros((num_workers, num_classes, num_groups, num_classes)) + beta_kl[None, :, None, :] - for l in range(num_classes): - for k in range(num_classes): - mu_jkml[:, k, :, l] += y_is_one_lji[l].dot(zg_ikm[:, k, :]) + if eval is False: + eta_km = a_pi / num_groups + zg_ikm.sum(axis=0) + nu_k = alpha + z_ik.sum(axis=0) + mu_jkml = np.zeros((num_workers, num_classes, num_groups, num_classes)) + beta_kl[None, :, None, :] + for l in range(num_classes): + for k in range(num_classes): + mu_jkml[:, k, :, l] += y_is_one_lji[l].dot(zg_ikm[:, k, :]) Eq_log_pi_km = digamma(eta_km) - digamma(eta_km.sum(axis=-1, keepdims=True)) Eq_log_tau_k = digamma(nu_k) - digamma(nu_k.sum()) @@ -80,7 +86,7 @@ def ebcc_vb(tuples, alpha0_jkm = mu_jkml.sum(axis=-1) ELBO += ((alpha0_jkm - num_classes) * digamma(alpha0_jkm) - gammaln(alpha0_jkm)).sum() ELBO += entropy(zg_ikm.reshape(num_items, -1).T).sum() - return z_ik, ELBO + return z_ik, ELBO, eta_km, nu_k, mu_jkml class EBCC(BaseLabelModel): @@ -105,8 +111,13 @@ def __init__(self, 'inference_iter': inference_iter, **kwargs } + self.params = { + 'seed': None, + 'eta_km': None, + 'nu_k': None, + 'mu_jkml': None, + } self.repeat = repeat - self.seed = None def fit(self, dataset_train: Union[BaseDataset, np.ndarray], @@ -116,22 +127,35 @@ def fit(self, verbose: Optional[bool] = False, *args: Any, **kwargs: Any): - train_tuples = create_tuples(dataset_train) + tuples = create_tuples(dataset_train) + num_items, _, num_classes = tuples.max(axis=0) + 1 + num_workers = len(dataset_train.weak_labels[0]) max_elbo = float('-inf') - self.seed = None - for infer in range(self.repeat): seed = np.random.randint(1e8) - self.seed = seed - prediction, elbo = ebcc_vb(train_tuples, seed=seed, **self.hyperparas) + prediction, elbo, p1, p2, p3 = ebcc_vb(tuples, + num_items, num_workers, num_classes, + seed=seed, + **self.hyperparas) if elbo > max_elbo: - self.seed = seed + print(f'update elbo: new: {elbo}, old: {max_elbo}') + self.params = { + 'seed': seed, + 'eta_km': p1, + 'nu_k': p2, + 'mu_jkml': p3 + } + max_elbo = elbo def predict_proba(self, dataset: Union[BaseDataset, np.ndarray], **kwargs: Any): tuples = create_tuples(dataset) - prediction, elbo = ebcc_vb(tuples, seed=self.seed, **self.hyperparas) - - return prediction + num_items, _, num_classes = tuples.max(axis=0) + 1 + num_workers = len(dataset.weak_labels[0]) + pred, elbo, _, _, _ = ebcc_vb(tuples, + num_items, num_workers, num_classes, + eval=True, + **self.hyperparas, **self.params) + return pred diff --git a/wrench/labelmodel/ibcc.py b/wrench/labelmodel/ibcc.py index c6f8e09..4bc2daa 100644 --- a/wrench/labelmodel/ibcc.py +++ b/wrench/labelmodel/ibcc.py @@ -8,21 +8,29 @@ from ..utils import create_tuples -def ibcc(tuples, a_v=4, b_v=1, alpha=1): - num_items, num_workers, num_classes = tuples.max(axis=0) + 1 +def ibcc(tuples, + num_items, + num_workers, + num_classes, + a_v=4, + b_v=1, + alpha=1, + n_jkl=None, + eval=False): y_is_one_kij = [] y_is_one_kji = [] for k in range(num_classes): selected = (tuples[:, 2] == k) - coo_ij = ssp.coo_matrix((np.ones(selected.sum()), tuples[selected, :2].T), shape=(num_items, num_workers), - dtype=np.bool) + coo_ij = ssp.coo_matrix((np.ones(selected.sum()), tuples[selected, :2].T), + shape=(num_items, num_workers), dtype=np.bool) y_is_one_kij.append(coo_ij.tocsr()) y_is_one_kji.append(coo_ij.T.tocsr()) # initialization prior_kl = np.eye(num_classes) * (a_v - b_v) + b_v - n_jkl = np.empty((num_workers, num_classes, num_classes)) + if n_jkl is None: + n_jkl = np.empty((num_workers, num_classes, num_classes)) # MV initialize Z z_ik = np.zeros((num_items, num_classes)) @@ -35,8 +43,9 @@ def ibcc(tuples, a_v=4, b_v=1, alpha=1): # E step Eq_log_pi_k = digamma(z_ik.sum(axis=0) + alpha) # - digamma(num_items + num_classes * alpha) - for l in range(num_classes): - n_jkl[:, :, l] = y_is_one_kji[l].dot(z_ik) + if eval is False: + for l in range(num_classes): + n_jkl[:, :, l] = y_is_one_kji[l].dot(z_ik) Eq_log_v_jkl = digamma(n_jkl + prior_kl[None, :, :]) - \ digamma(n_jkl.sum(axis=-1) + prior_kl.sum(axis=-1))[:, :, None] @@ -53,7 +62,7 @@ def ibcc(tuples, a_v=4, b_v=1, alpha=1): if np.allclose(last_z_ik, z_ik, atol=1e-3): break - return z_ik + return z_ik, n_jkl class IBCC(BaseLabelModel): @@ -69,6 +78,9 @@ def __init__(self, 'b_v': b_v, **kwargs } + self.params = { + 'n_jkl': None, + } def fit(self, dataset_train: Union[BaseDataset, np.ndarray], @@ -78,11 +90,18 @@ def fit(self, verbose: Optional[bool] = False, *args: Any, **kwargs: Any): - pass + tuples = create_tuples(dataset_train) + num_items, num_workers, num_classes = tuples.max(axis=0) + 1 + _, param = ibcc(tuples, num_items, num_workers, num_classes, **self.hyperparas) + self.params['n_jkl'] = param def predict_proba(self, dataset: Union[BaseDataset, np.ndarray], **kwargs: Any): tuples = create_tuples(dataset) - - return ibcc(tuples, **self.hyperparas) + num_items, _, num_classes = tuples.max(axis=0) + 1 + num_workers = len(dataset.weak_labels[0]) + pred, _ = ibcc(tuples, num_items, num_workers, num_classes, + eval=True, + **self.hyperparas, **self.params) + return pred