Skip to content

Commit

Permalink
Merge pull request #32 from Stranger469/main
Browse files Browse the repository at this point in the history
Add label model: ibcc, ebcc and their corresponding demo.
  • Loading branch information
JieyuZ2 authored May 21, 2022
2 parents 63a2178 + 4ff5d50 commit ab717ac
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 15 deletions.
30 changes: 16 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,22 @@ The detailed documentation is coming soon.
TODO-list: check [this](https://github.com/JieyuZ2/wrench/wiki/TODO-List) out!

### classification:
| Model | Model Type | Reference | Link to Wrench |
|:--------|:---------|:------|:------|
| Majority Voting | Label Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/majority_voting.py#L44) |
| Weighted Majority Voting | Label Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/majority_voting.py#L14) |
| Dawid-Skene | Label Model | [link](https://www.jstor.org/stable/2346806) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/dawid_skene.py#L15) |
| Data Progamming | Label Model | [link](https://arxiv.org/abs/1605.07723) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/generative_model.py#L18) |
| MeTaL | Label Model | [link](https://arxiv.org/abs/1810.02840) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/snorkel.py#L17) |
| FlyingSquid | Label Model | [link](https://arxiv.org/pdf/2002.11955) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/flyingsquid.py#L16) |
| Logistic Regression | End Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/linear_model.py#L52) |
| MLP | End Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/neural_model.py#L21) |
| BERT | End Model | [link](https://huggingface.co/models) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/bert_model.py#L23) |
| COSINE | End Model | [link](https://arxiv.org/abs/2010.07835) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/cosine.py#L68) |
| Denoise | Joint Model | [link](https://arxiv.org/abs/2010.04582) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/classification/denoise.py#L72) |
| WeaSEL | Joint Model | [link](https://arxiv.org/abs/2107.02233) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/classification/weasel.py#L72) |
| Model | Model Type | Reference | Link to Wrench |
|:-------------------------|:------------|:-----------------------------------------------------|:----------------------------------------------------------------------------------------------|
| Majority Voting | Label Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/majority_voting.py#L44) |
| Weighted Majority Voting | Label Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/majority_voting.py#L14) |
| Dawid-Skene | Label Model | [link](https://www.jstor.org/stable/2346806) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/dawid_skene.py#L15) |
| Data Progamming | Label Model | [link](https://arxiv.org/abs/1605.07723) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/generative_model.py#L18) |
| MeTaL | Label Model | [link](https://arxiv.org/abs/1810.02840) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/snorkel.py#L17) |
| FlyingSquid | Label Model | [link](https://arxiv.org/pdf/2002.11955) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/flyingsquid.py#L16) |
| EBCC | Label Model | [link](https://proceedings.mlr.press/v97/li19i.html) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/ebcc.py#L12) |
| IBCC | Label Model | [link](https://proceedings.mlr.press/v97/li19i.html) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/ibcc.py#L11) |
| Logistic Regression | End Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/linear_model.py#L52) |
| MLP | End Model | -- | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/neural_model.py#L21) |
| BERT | End Model | [link](https://huggingface.co/models) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/bert_model.py#L23) |
| COSINE | End Model | [link](https://arxiv.org/abs/2010.07835) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/endmodel/cosine.py#L68) |
| Denoise | Joint Model | [link](https://arxiv.org/abs/2010.04582) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/classification/denoise.py#L72) |
| WeaSEL | Joint Model | [link](https://arxiv.org/abs/2107.02233) | [link](https://github.com/JieyuZ2/wrench/blob/main/wrench/classification/weasel.py#L72) |

### sequence tagging:
| Model | Model Type | Reference | Link to Wrench |
Expand Down
59 changes: 59 additions & 0 deletions examples/run_ibcc_ebcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import copy
import logging
import torch
from wrench.dataset import load_dataset, TextDataset
from wrench._logging import LoggingHandler
from wrench.labelmodel import EBCC, IBCC, Snorkel


#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])

logger = logging.getLogger(__name__)

device = torch.device('cuda')

#### Load dataset
dataset_path = '../datasets/'
data = 'agnews'
train_data, valid_data, test_data = load_dataset(
dataset_path,
data,
extract_feature=True,
extract_fn='bert', # extract bert embedding
model_name='bert-base-cased',
cache_name='bert'
)
train_data_c = train_data.get_covered_subset()

#### Inference for EBCC and IBCC
logger.info('============inference============')
ebcc = EBCC(
num_groups=5,
repeat=100,
inference_iter=100,
empirical_prior=True,
)
ibcc = IBCC()

ebcc.fit(
dataset_train=train_data_c
)
ibcc.fit(
dataset_train=train_data_c
)

#### Test for EBCC and IBCC
logger.info('============test============')
ebcc.predict_proba(train_data_c)
acc_ebcc = ebcc.test(train_data_c, 'acc')
acc_test_ebcc = ebcc.test(test_data, 'acc')

acc_ibcc = ibcc.test(train_data_c, 'acc')
acc_test_ibcc = ibcc.test(test_data, 'acc')

logger.info(f'label model train/test acc on ebcc: {acc_ebcc}, {acc_test_ebcc}, seed={ebcc.params["seed"]}')
logger.info(f'label model train/test acc on ibcc: {acc_ibcc}, {acc_test_ibcc}')
2 changes: 2 additions & 0 deletions wrench/labelmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from .metal import MeTaL
from .naive_bayes import NaiveBayesModel
from .snorkel import Snorkel
from .ebcc import EBCC
from .ibcc import IBCC
161 changes: 161 additions & 0 deletions wrench/labelmodel/ebcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Optional, Any, Union
import numpy as np
import scipy.sparse as ssp
from scipy.special import digamma, gammaln
from scipy.stats import entropy, dirichlet

from wrench.basemodel import BaseLabelModel
from wrench.dataset import BaseDataset
from ..utils import create_tuples


def ebcc_vb(tuples,
num_items,
num_workers,
num_classes,
a_pi=0.1,
num_groups=10, # M
alpha=1, # alpha_k, it can be 1 or \sum_i gamma_ik
a_v=4, # beta_kk
b_v=1, # beta_kk', k neq k'
eta_km=None,
nu_k=None,
mu_jkml=None,
eval=False,
seed=1234,
inference_iter=500,
empirical_prior=False):

y_is_one_lij = []
y_is_one_lji = []
for k in range(num_classes):
selected = (tuples[:, 2] == k)
coo_ij = ssp.coo_matrix((np.ones(selected.sum()),
tuples[selected, :2].T),
shape=(num_items, num_workers),
dtype=np.bool)
y_is_one_lij.append(coo_ij.tocsr())
y_is_one_lji.append(coo_ij.T.tocsr())

beta_kl = np.eye(num_classes) * (a_v - b_v) + b_v

# initialize z_ik, zg_ikm, c_ik, gamma_ik, sigma_ik
z_ik = np.zeros((num_items, num_classes))
for l in range(num_classes):
z_ik[:, [l]] += y_is_one_lij[l].sum(axis=-1) + 1e-8
z_ik /= z_ik.sum(axis=-1, keepdims=True)

if empirical_prior:
alpha = z_ik.sum(axis=0)

np.random.seed(seed)
zg_ikm = np.random.dirichlet(np.ones(num_groups), z_ik.shape) * z_ik[:, :, None]
for it in range(inference_iter):
if eval is False:
eta_km = a_pi / num_groups + zg_ikm.sum(axis=0)
nu_k = alpha + z_ik.sum(axis=0)
mu_jkml = np.zeros((num_workers, num_classes, num_groups, num_classes)) + beta_kl[None, :, None, :]
for l in range(num_classes):
for k in range(num_classes):
mu_jkml[:, k, :, l] += y_is_one_lji[l].dot(zg_ikm[:, k, :])

Eq_log_pi_km = digamma(eta_km) - digamma(eta_km.sum(axis=-1, keepdims=True))
Eq_log_tau_k = digamma(nu_k) - digamma(nu_k.sum())
Eq_log_v_jkml = digamma(mu_jkml) - digamma(mu_jkml.sum(axis=-1, keepdims=True))

zg_ikm[:] = Eq_log_pi_km[None, :, :] + Eq_log_tau_k[None, :, None]
for l in range(num_classes):
for k in range(num_classes):
zg_ikm[:, k, :] += y_is_one_lij[l].dot(Eq_log_v_jkml[:, k, :, l])

zg_ikm = np.exp(zg_ikm)
zg_ikm /= zg_ikm.reshape(num_items, -1).sum(axis=-1)[:, None, None]

last_z_ik = z_ik
z_ik = zg_ikm.sum(axis=-1)

if np.allclose(last_z_ik, z_ik, atol=1e-3):
break

ELBO = ((eta_km - 1) * Eq_log_pi_km).sum() + ((nu_k - 1) * Eq_log_tau_k).sum() + (
(mu_jkml - 1) * Eq_log_v_jkml).sum()
ELBO += dirichlet.entropy(nu_k)
for k in range(num_classes):
ELBO += dirichlet.entropy(eta_km[k])
ELBO += (gammaln(mu_jkml) - (mu_jkml - 1) * digamma(mu_jkml)).sum()
alpha0_jkm = mu_jkml.sum(axis=-1)
ELBO += ((alpha0_jkm - num_classes) * digamma(alpha0_jkm) - gammaln(alpha0_jkm)).sum()
ELBO += entropy(zg_ikm.reshape(num_items, -1).T).sum()
return z_ik, ELBO, eta_km, nu_k, mu_jkml


class EBCC(BaseLabelModel):
def __init__(self,
num_groups: Optional[int] = 10,
a_pi: Optional[float] = 0.1,
alpha: Optional[float] = 1,
a_v: Optional[float] = 4,
b_v: Optional[float] = 1,
repeat: Optional[int] = 1000,
inference_iter: Optional[int] = 500,
empirical_prior=False,
**kwargs: Any):
super().__init__()
self.hyperparas = {
'num_groups': num_groups,
'a_pi': a_pi,
'alpha': alpha,
'a_v': a_v,
'b_v': b_v,
'empirical_prior': empirical_prior,
'inference_iter': inference_iter,
**kwargs
}
self.params = {
'seed': None,
'eta_km': None,
'nu_k': None,
'mu_jkml': None,
}
self.repeat = repeat

def fit(self,
dataset_train: Union[BaseDataset, np.ndarray],
dataset_valid: Optional[Union[BaseDataset, np.ndarray]] = None,
y_valid: Optional[np.ndarray] = None,
n_class: Optional[int] = None,
verbose: Optional[bool] = False,
*args: Any,
**kwargs: Any):
tuples = create_tuples(dataset_train)
num_items, _, num_classes = tuples.max(axis=0) + 1
num_workers = len(dataset_train.weak_labels[0])
max_elbo = float('-inf')

for infer in range(self.repeat):
seed = np.random.randint(1e8)
prediction, elbo, p1, p2, p3 = ebcc_vb(tuples,
num_items, num_workers, num_classes,
seed=seed,
**self.hyperparas)
if elbo > max_elbo:
print(f'update elbo: new: {elbo}, old: {max_elbo}')
self.params = {
'seed': seed,
'eta_km': p1,
'nu_k': p2,
'mu_jkml': p3
}
max_elbo = elbo

def predict_proba(self,
dataset: Union[BaseDataset, np.ndarray],
**kwargs: Any):
tuples = create_tuples(dataset)
num_items, _, num_classes = tuples.max(axis=0) + 1
num_workers = len(dataset.weak_labels[0])
pred, elbo, _, _, _ = ebcc_vb(tuples,
num_items, num_workers, num_classes,
eval=True,
**self.hyperparas, **self.params)
return pred
107 changes: 107 additions & 0 deletions wrench/labelmodel/ibcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Optional, Any, Union
import numpy as np
import scipy.sparse as ssp
from scipy.special import digamma

from wrench.basemodel import BaseLabelModel
from wrench.dataset import BaseDataset
from ..utils import create_tuples


def ibcc(tuples,
num_items,
num_workers,
num_classes,
a_v=4,
b_v=1,
alpha=1,
n_jkl=None,
eval=False):

y_is_one_kij = []
y_is_one_kji = []
for k in range(num_classes):
selected = (tuples[:, 2] == k)
coo_ij = ssp.coo_matrix((np.ones(selected.sum()), tuples[selected, :2].T),
shape=(num_items, num_workers), dtype=np.bool)
y_is_one_kij.append(coo_ij.tocsr())
y_is_one_kji.append(coo_ij.T.tocsr())

# initialization
prior_kl = np.eye(num_classes) * (a_v - b_v) + b_v
if n_jkl is None:
n_jkl = np.empty((num_workers, num_classes, num_classes))

# MV initialize Z
z_ik = np.zeros((num_items, num_classes))
for l in range(num_classes):
z_ik[:, [l]] += y_is_one_kij[l].sum(axis=-1) + 1e-8
z_ik /= z_ik.sum(axis=-1, keepdims=True)
last_z_ik = z_ik.copy()

for iteration in range(500):
# E step
Eq_log_pi_k = digamma(z_ik.sum(axis=0) + alpha) # - digamma(num_items + num_classes * alpha)

if eval is False:
for l in range(num_classes):
n_jkl[:, :, l] = y_is_one_kji[l].dot(z_ik)

Eq_log_v_jkl = digamma(n_jkl + prior_kl[None, :, :]) - \
digamma(n_jkl.sum(axis=-1) + prior_kl.sum(axis=-1))[:, :, None]

# M step
last_z_ik[:] = z_ik

z_ik[:] = Eq_log_pi_k
for l in range(num_classes):
z_ik += y_is_one_kij[l].dot(Eq_log_v_jkl[:, :, l])
z_ik -= z_ik.max(axis=-1, keepdims=True)
z_ik = np.exp(z_ik)
z_ik /= z_ik.sum(axis=-1, keepdims=True)

if np.allclose(last_z_ik, z_ik, atol=1e-3):
break
return z_ik, n_jkl


class IBCC(BaseLabelModel):
def __init__(self,
alpha: Optional[float] = 1,
a_v: Optional[float] = 4,
b_v: Optional[float] = 1,
**kwargs: Any):
super().__init__()
self.hyperparas = {
'alpha': alpha,
'a_v': a_v,
'b_v': b_v,
**kwargs
}
self.params = {
'n_jkl': None,
}

def fit(self,
dataset_train: Union[BaseDataset, np.ndarray],
dataset_valid: Optional[Union[BaseDataset, np.ndarray]] = None,
y_valid: Optional[np.ndarray] = None,
n_class: Optional[int] = None,
verbose: Optional[bool] = False,
*args: Any,
**kwargs: Any):
tuples = create_tuples(dataset_train)
num_items, num_workers, num_classes = tuples.max(axis=0) + 1
_, param = ibcc(tuples, num_items, num_workers, num_classes, **self.hyperparas)
self.params['n_jkl'] = param

def predict_proba(self,
dataset: Union[BaseDataset, np.ndarray],
**kwargs: Any):
tuples = create_tuples(dataset)
num_items, _, num_classes = tuples.max(axis=0) + 1
num_workers = len(dataset.weak_labels[0])
pred, _ = ibcc(tuples, num_items, num_workers, num_classes,
eval=True,
**self.hyperparas, **self.params)
return pred
15 changes: 14 additions & 1 deletion wrench/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from collections import Counter
from typing import Dict, Optional
from typing import Dict, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -181,3 +181,16 @@ def collate_fn_trunc_pad(batch: Dict):
return batch

return collate_fn_trunc_pad


def create_tuples(dataset: Union[BaseDataset, np.ndarray]):
ids = np.repeat(np.array(range(len(dataset))), len(dataset.weak_labels[0]))
workers = np.repeat(
np.array([i for i in range(len(dataset.weak_labels[0]))]), len(dataset.weak_labels)
).reshape(len(dataset.weak_labels[0]), -1).T.reshape(-1)
classes = np.array(dataset.weak_labels).reshape(-1)

tuples = np.vstack((ids, workers, classes))
tuples = tuples[:, tuples[2, :] != -1]

return tuples.T

0 comments on commit ab717ac

Please sign in to comment.