Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 authored Jun 7, 2024
1 parent 58a16aa commit 4b9977f
Showing 1 changed file with 56 additions and 47 deletions.
103 changes: 56 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ AI/ML, with a similar logging interface? Try out LogIX that is built upon our cu
[Huggingface Transformers](https://github.com/logix-project/logix/tree/main?tab=readme-ov-file#huggingface-integration) and
[PyTorch Lightning](https://github.com/logix-project/logix/tree/main?tab=readme-ov-file#pytorch-lightning-integration) integrations)!

- **PyPI** (Default)
- **PyPI**
```bash
pip install logix-ai
```
Expand All @@ -42,52 +42,14 @@ pip install -e .
```


## Usage
### Logging
Training log extraction with LogIX is as simple as adding one `with` statement to the existing
training code. LogIX automatically extracts user-specified logs using PyTorch hooks, and stores
it as a tuple of `([data_ids], log[module_name][log_type])`. If needed, LogIX writes these logs
to disk efficiently with memory-mapped files.

```python
import logix
## Easy to Integrate

# Initialze LogIX
run = logix.init(project="my_project")

# Specify modules to be tracked for logging
run.watch(model, name_filter=["mlp"], type_filter=[nn.Linear])

# Specify plugins to be used in logging
run.setup({"grad": ["log", "covariance"]})
run.save(True)

for batch in data_loader:
# Set `data_id` (and optionally `mask`) for the current batch
with run(data_id=batch["input_ids"], mask=batch["attention_mask"]):
model.zero_grad()
loss = model(batch)
loss.backward()
# Synchronize statistics (e.g. covariance) and write logs to disk
run.finalize()
```

### Training Data Attribution
As a part of our initial research, we implemented influence functions using LogIX. We plan to provide more
pre-implemented interpretability algorithms if there is a demand.

```python
# Build PyTorch DataLoader from saved log data
log_loader = run.build_log_dataloader()

with run(data_id=test_batch["input_ids"]):
test_loss = model(test_batch)
test_loss.backward()

test_log = run.get_log()
run.influence.compute_influence_all(test_log, log_loader) # Data attribution
run.influence.compute_self_influence(test_log) # Uncertainty estimation
```
Our software design allows for the seamless integration with popular high-level frameworks including
[HuggingFace Transformer](https://github.com/huggingface/transformers/tree/main) and
[PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning), that conveniently handles
distributed training, data loading, etc. Advanced users, who don't use high-level frameworks, can
still integrate LogIX into their existing training code similarly to any traditional logging software
(See our Tutorial).

### HuggingFace Integration
Our software design allows for the seamless integration with HuggingFace's
Expand Down Expand Up @@ -122,7 +84,7 @@ trainer.self_influence()
```

### PyTorch Lightning Integration
Similarly, we also support the LogIX + PyTorch Lightning integration. The code example
Similarly, we also support the seamless integration with PyTorch Lightning. The code example
is provided below.

```python
Expand Down Expand Up @@ -157,6 +119,53 @@ trainer.extract_log(module, train_loader)
trainer.influence(module, train_loader)
```

## Getting Started
### Logging
Training log extraction with LogIX is as simple as adding one `with` statement to the existing
training code. LogIX automatically extracts user-specified logs using PyTorch hooks, and stores
it as a tuple of `([data_ids], log[module_name][log_type])`. If needed, LogIX writes these logs
to disk efficiently with memory-mapped files.

```python
import logix

# Initialze LogIX
run = logix.init(project="my_project")

# Specify modules to be tracked for logging
run.watch(model, name_filter=["mlp"], type_filter=[nn.Linear])

# Specify plugins to be used in logging
run.setup({"grad": ["log", "covariance"]})
run.save(True)

for batch in data_loader:
# Set `data_id` (and optionally `mask`) for the current batch
with run(data_id=batch["input_ids"], mask=batch["attention_mask"]):
model.zero_grad()
loss = model(batch)
loss.backward()
# Synchronize statistics (e.g. covariance) and write logs to disk
run.finalize()
```

### Training Data Attribution
As a part of our initial research, we implemented influence functions using LogIX. We plan to provide more
pre-implemented interpretability algorithms if there is a demand.

```python
# Build PyTorch DataLoader from saved log data
log_loader = run.build_log_dataloader()

with run(data_id=test_batch["input_ids"]):
test_loss = model(test_batch)
test_loss.backward()

test_log = run.get_log()
run.influence.compute_influence_all(test_log, log_loader) # Data attribution
run.influence.compute_self_influence(test_log) # Uncertainty estimation
```

Please check out [Examples](/examples) for more detailed examples!


Expand Down

0 comments on commit 4b9977f

Please sign in to comment.