Skip to content

Commit

Permalink
Fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 14, 2024
1 parent 78adca8 commit e69dee8
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 14 deletions.
16 changes: 8 additions & 8 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ Our examples cover the following tasks:

<div align="center">

| Task | Example Datasets |
|----------------------|:------------------------:|
| Regression | UCI |
| Image Classification | CIFAR-10 & ImageNet |
| Text Classification | GLUE |
| Multiple-Choice | SWAG |
| Summarization | DNN/DailyMail |
| Language Modeling | WikiText-2 & OpenWebText |
| Task | Example Datasets |
|----------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| Regression | [UCI](https://github.com/pomonam/kronfluence/tree/main/examples/uci) |
| Image Classification | [CIFAR-10](https://github.com/pomonam/kronfluence/tree/main/examples/cifar) & [ImageNet](https://github.com/pomonam/kronfluence/tree/main/examples/imagenet) |
| Text Classification | [GLUE](https://github.com/pomonam/kronfluence/tree/main/examples/glue) |
| Multiple-Choice | [SWAG](https://github.com/pomonam/kronfluence/tree/main/examples/swag) |
| Summarization | [CNN/DailyMail](https://github.com/pomonam/kronfluence/tree/main/examples/dailymail) |
| Language Modeling | [WikiText-2](https://github.com/pomonam/kronfluence/tree/main/examples/wikitext) & [OpenWebText](https://github.com/pomonam/kronfluence/tree/main/examples/openwebtext) |

</div>

Expand Down
13 changes: 12 additions & 1 deletion examples/openwebtext/compute_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
get_custom_dataset,
get_openwebtext_dataset,
)
from examples.openwebtext.task import LanguageModelingTask
from examples.openwebtext.task import LanguageModelingTask, LanguageModelingWithMarginMeasurementTask
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.utils.common.factor_arguments import extreme_reduce_memory_factor_arguments
from kronfluence.utils.common.score_arguments import (
extreme_reduce_memory_score_arguments,
)
Expand All @@ -28,13 +29,21 @@ def parse_args():
parser.add_argument(
"--factors_name",
type=str,
required=True,
help="Name of the factor.",
)
parser.add_argument(
"--scores_name",
type=str,
required=True,
help="Name of the score.",
)
parser.add_argument(
"--use_margin_for_measurement",
action="store_true",
default=False,
help="Boolean flag whether to use margin for measurement.",
)
parser.add_argument(
"--query_gradient_rank",
type=int,
Expand Down Expand Up @@ -71,6 +80,8 @@ def main():

# Define task and prepare model.
task = LanguageModelingTask()
if args.use_margin_for_measurement:
task = LanguageModelingWithMarginMeasurementTask()
model = prepare_model(model, task)

kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) # 1.5 hours.
Expand Down
1 change: 0 additions & 1 deletion examples/openwebtext/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# prompt = "Machine learning can be defined as"
# prompt = "Using a distributed database has many advantages."
# prompt = "Inflation is typically measured by"
# prompt = "The prime minister of Canada is definitely Justin Bieber. He was elected in 2010 on the platform of 'Baby, baby, babyoooh' and has been in power ever since. Some of Bieber’s key accomplishments as prime minister include:"

outputs = pipeline(prompt)
print("Prompt:")
Expand Down
6 changes: 3 additions & 3 deletions examples/openwebtext/inpsect_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def main():
plt.rcParams.update(markers.with_edge())
plt.rcParams["axes.axisbelow"] = True

layer_num = 18
layer_num = 31
module_name = f"model.layers.{layer_num}.mlp.down_proj"
# module_name = f"model.layers.{layer_num}.mlp.up_proj"
lambda_processed = Analyzer.load_file("influence_results/num_lambda_processed.safetensors")[module_name]
lambda_matrix = Analyzer.load_file("influence_results/lambda_matrix.safetensors")[module_name]
lambda_processed = Analyzer.load_file("num_lambda_processed.safetensors")[module_name]
lambda_matrix = Analyzer.load_file("lambda_matrix.safetensors")[module_name]
lambda_matrix.div_(lambda_processed)
lambda_matrix = lambda_matrix.float()
plt.matshow(lambda_matrix, cmap="PuBu", norm=LogNorm())
Expand Down
24 changes: 24 additions & 0 deletions examples/openwebtext/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,27 @@ def get_influence_tracked_modules(self) -> List[str]:

def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor:
return batch["attention_mask"]


class LanguageModelingWithMarginMeasurementTask(LanguageModelingTask):
def compute_measurement(
self,
batch: BATCH_TYPE,
model: nn.Module,
) -> torch.Tensor:
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits.float()
labels = batch["labels"][..., 1:].contiguous().view(-1)
masks = labels != -100
logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1))

bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False)
logits_correct = logits[bindex, labels]

cloned_logits = logits.clone()
cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins[masks].sum()
2 changes: 1 addition & 1 deletion kronfluence/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.0.1"

0 comments on commit e69dee8

Please sign in to comment.