Skip to content

Commit

Permalink
Add unit test for TextNumrEncoder and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei-Cheng Chang committed Oct 17, 2024
1 parent e8ed2e7 commit b23c141
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 13 deletions.
6 changes: 3 additions & 3 deletions pecos/xmr/reranker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def __init__(self, config: TextNumrEncoderConfig):
if config.text_config:
text_encoder = AutoModel.from_pretrained(
config.text_config._name_or_path,
attn_implementation=config.text_config._attn_implementation,
trust_remote_code=config.text_config.trust_remote_code,
attn_implementation=getattr(config.text_config, "._attn_implementation", "eager"),
trust_remote_code=getattr(config.text_config, "trust_remote_code", None),
token=getattr(config.text_config, "token", None),
)
text_encoder.config.pad_token_id = (
Expand Down Expand Up @@ -242,7 +242,7 @@ def forward(
text_emb = None
if self.text_encoder:
text_input_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
if token_type_ids:
if token_type_ids is not None:
text_input_dict["token_type_ids"] = token_type_ids
text_outputs = self.text_encoder(**text_input_dict, return_dict=True)
if hasattr(text_outputs, "pooler_output"):
Expand Down
93 changes: 83 additions & 10 deletions test/pecos/xmr/test_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
# and limitations under the License.
import pytest # noqa: F401; pylint: disable=unused-variable
from pytest import approx


def test_importable():
Expand All @@ -19,16 +20,88 @@ def test_importable():
from pecos.xmr.reranker.trainer import RankingTrainer # noqa: F401


def test_model():
def test_numr_encoder():
import torch
from pecos.xmr.reranker.model import NumrMLPEncoderConfig
from pecos.xmr.reranker.model import NumrMLPEncoder

mlp_config = NumrMLPEncoderConfig(
inp_feat_dim=5,
inp_dropout_prob=0.5,
hid_actv_type="gelu",
hid_size_list=[8, 16],
numr_config = NumrMLPEncoderConfig(
inp_feat_dim=2,
inp_dropout_prob=0.0,
hid_dropout_prob=0.0,
hid_actv_type="identity",
hid_size_list=[2],
)
assert mlp_config.inp_feat_dim == 5
assert mlp_config.inp_dropout_prob == 0.5
assert mlp_config.hid_actv_type == "gelu"
assert mlp_config.hid_size_list == [8, 16]
assert numr_config.inp_feat_dim == 2
assert numr_config.inp_dropout_prob == 0.0
assert numr_config.hid_dropout_prob == 0.0
assert numr_config.hid_actv_type == "identity"
assert numr_config.hid_size_list == [2]

numr_encoder = NumrMLPEncoder(numr_config)
linear_layer = numr_encoder.mlp_block.mlp_layers[0]
linear_layer.bias.data.fill_(0.0)
linear_layer.weight.data.fill_(0.0)
linear_layer.weight.data.fill_diagonal_(1.0)
with torch.no_grad():
inp_feat = torch.tensor([-1, 1]).float()
out_feat = numr_encoder(inp_feat)
assert out_feat.numpy() == approx(
out_feat.numpy(),
abs=0.0,
), f"Enc(inp_feat) != inp_feat, given Enc is identity"


def test_textnumr_encoder():
import torch
from transformers import set_seed
from transformers import AutoConfig, AutoTokenizer
from pecos.xmr.reranker.model import TextNumrEncoderConfig
from pecos.xmr.reranker.model import TextNumrEncoder

enc_list = [
"prajjwal1/bert-tiny",
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-small",
]
ans_list = [
0.007879042997956276,
0.0035168465692549944,
-0.0047034271992743015,
]
set_seed(1234)

for idx, enc_name in enumerate(enc_list):
text_config = AutoConfig.from_pretrained(
enc_name,
hidden_dropout_prob=0.0,
)
textnumr_config = TextNumrEncoderConfig(
text_config=text_config,
numr_config=None,
text_pooling_type="cls",
head_actv_type="identity",
head_dropout_prob=0.0,
head_size_list=[1],
)
textnumr_encoder = TextNumrEncoder(textnumr_config)
linear_layer = textnumr_encoder.head_layers.mlp_layers[0]
linear_layer.bias.data.fill_(0.0)
linear_layer.weight.data.fill_(0.0)
linear_layer.weight.data.fill_diagonal_(1.0)
textnumr_encoder.scorer.bias.data.fill_(0.0)
textnumr_encoder.scorer.weight.data.fill_(1.0)

# obtained from bert-tiny tokenizer("I Like coffee")
tokenizer = AutoTokenizer.from_pretrained(enc_name)
input_dict = tokenizer("I Like coffee", return_tensors="pt")
outputs = textnumr_encoder(**input_dict)
assert outputs.text_emb is not None
assert outputs.numr_emb is None

text_emb = outputs.text_emb
mu = torch.mean(text_emb).item()
assert mu == approx(
ans_list[idx],
abs=1e-3,
), f"mu(text_emb)={mu} != {ans_list[idx]}"

0 comments on commit b23c141

Please sign in to comment.