-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from nestauk/23-wandb-example-code
23 wandb example code
- Loading branch information
Showing
19 changed files
with
801 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,6 @@ __pycache__/ | |
|
||
*.env | ||
.metaflow | ||
|
||
wandb/wandb/ | ||
wandb/inputs/*.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Weights and biases | ||
|
||
## Setup | ||
|
||
### Wandb set up | ||
|
||
#### Sign up to weights and biases and create a new project | ||
|
||
From [this page](https://wandb.ai/site), click "Sign up" and follow instructions there. **Sign up with your Nesta email.** | ||
|
||
In your *User settings*: **change the default location of new projects to your personal account.** This is so that the Nesta org space doesn't get clogged up with all the projects we're going to create in this tutorial :) | ||
|
||
<img src="screenshots/default_location.jpg" width="500"> | ||
|
||
#### Create a new project | ||
|
||
Create a new project, and let's call it Titanic because that's the data we'll be using. | ||
|
||
<img src="screenshots/create_new_project.jpg" width="200"> | ||
|
||
<img src="screenshots/create_new_project_2.png" width="300"> | ||
|
||
### Clone this repo | ||
|
||
You've probably done this already if you've attended a have-a-go before :) | ||
|
||
### Environment | ||
Create a conda environment: | ||
``` | ||
conda create -n wandb_demo python=3.10 | ||
``` | ||
|
||
Activate the environment: | ||
``` | ||
conda activate wandb_demo | ||
``` | ||
|
||
Install pip: | ||
``` | ||
conda install pip | ||
``` | ||
Install requirements: | ||
``` | ||
pip install -r wandb/requirements.txt | ||
``` | ||
|
||
Create an ipykernel for your conda environment: | ||
``` | ||
python -m ipykernel install --user --name=wandb_demo | ||
``` | ||
### `.env` | ||
|
||
Create a file `dap_tutorials/wandb/.env` with one variable: | ||
|
||
``` | ||
wandb_username = "yourusernamehere" | ||
``` | ||
|
||
Actually this is not super necessary but somehow seems nicer than writing your username in the code :) And the have-a-go code depends on it being set up this way! | ||
|
||
### Download data | ||
Download the `train.csv` Titanic data from Kaggle [here](https://www.kaggle.com/competitions/titanic/data) and store it in `wandb/inputs/`. | ||
|
||
## How to use | ||
|
||
The scripts in `wandb_demo/` show some example Weights and Biases workflows at different levels of complexity. In all of them, the scenario is that based on the available training data, we are trying to predict whether a passenger survived the Titanic disaster. | ||
|
||
1. `baseline_classifier.py`: this script sets up a simple dummy classifier and logs a single run on weights and biases. It guesses whether a passenger survived based on the survival rate calculated from the training data. Depending on the random seed chosen, it should give accuracy of 50-60%. You can see this by navigating to your project on the weights and biases interface, finding the "Titanic" project, and looking under "Runs": | ||
|
||
<img src="screenshots/dummy_classifier.png" width="600"> | ||
|
||
2. `logistic_regression.py`: this script sets up a logistic regression model and again, logs this as a run on weights and biases. | ||
|
||
3. `sweep_log_reg.py`: now that we have progressed from a dummy classifier to an ML model, we might want to improve the performance of that model. This script uses a [sweep](https://docs.wandb.ai/guides/sweeps) to find optimal hyperparameter values for the logistic regression model. There are two important parameters at the top of the script that you can change to control the sweep: | ||
- `sweep_config`: the hyperparameters to sweep over as well as the method to use (random search, grid search, or Bayes). | ||
- `N_RUNS`: we are using Random Search, so once we specify how many runs we would like to execute, the sweep agent will try this many random combinations of hyperparameters from our sweep config. | ||
|
||
Once your sweep is done, navigate to the "Sweeps" view of your "Titanic" project and click on the most recent sweep. You will be able to see how different parameter combinations led to different model performance. Wandb provides handy parallel coordinates interactive plots, where each run within a sweep is a single line, model performance is the yaxis at the right of the graph, and different parameter settings are columns within the graph: | ||
|
||
<img src="screenshots/sweep_viz.png" width="600"> | ||
|
||
4. `sweep_different_classifiers.py`: it's worth knowing that you are not restricted to trying out different hyperparameter values within one sweep - you could also try out different models. In this sweep, we compare Random Forest, SVM and Logistic Regression, and also vary some of the important hyperparameters of those models. You could extend this approach to try out, for example, different embeddings, different feature selection methods, etc. You can customise the sweep config and the `train()` function to account for anything you want to control or vary. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Download the data from [https://www.kaggle.com/competitions/titanic/data](https://www.kaggle.com/competitions/titanic/data). | ||
|
||
You should have the following files: | ||
- train.csv | ||
- test.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
ipykernel==6.29.4 | ||
ipython==8.23.0 | ||
numpy==1.26.4 | ||
pandas==2.2.2 | ||
python-dotenv==1.0.1 | ||
scikit-learn==1.4.2 | ||
wandb==0.16.6 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import logging | ||
from pathlib import Path | ||
|
||
PROJECT_DIR = Path(__file__).resolve().parents[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
Runs a dummy classifier that predicts survival based on the survival rate in the training data. | ||
Logs this as a run on wandb, including accuracy and a confusion matrix. | ||
""" | ||
from dotenv import load_dotenv | ||
import logging | ||
import numpy as np | ||
import os | ||
import pandas as pd | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.metrics import accuracy_score, confusion_matrix | ||
import wandb | ||
|
||
import utils | ||
from utils import PROJECT_DIR, load_data | ||
|
||
load_dotenv() | ||
|
||
os.chdir(PROJECT_DIR) | ||
|
||
WANDB_PROJ = "Titanic" # Change this to your project name! | ||
WANDB_USER = os.environ.get("wandb_username") # Change this to your username! | ||
JOB = "predict survival" | ||
|
||
if __name__ == "__main__": | ||
|
||
# # Split the data | ||
X_train, X_test, y_train, y_test = load_data() | ||
|
||
run = wandb.init( | ||
project=WANDB_PROJ, | ||
job_type=JOB, | ||
save_code=True, | ||
tags=["baseline_model"], | ||
) | ||
|
||
survival_rate = y_train.value_counts(normalize=True)[1] | ||
|
||
y_pred_baseline = np.random.binomial(1, survival_rate, len(y_test)) | ||
|
||
wandb.run.summary["accuracy"] = accuracy_score(y_test, y_pred_baseline) | ||
|
||
cm = confusion_matrix(y_test, y_pred_baseline) | ||
cm = pd.DataFrame(cm) | ||
logging.info(f"Confusion matrix:\n{cm}") | ||
|
||
# Log confusion matrix | ||
wb_confusion_matrix = wandb.Table(data=cm, columns=["0", "1"]) | ||
run.log({"confusion_matrix": wb_confusion_matrix}) | ||
|
||
# End the weights and biases run | ||
wandb.finish() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
Runs a logistic regression to predict survival. | ||
Logs this as a run on wandb, including accuracy and a confusion matrix. | ||
""" | ||
from dotenv import load_dotenv | ||
import logging | ||
import numpy as np | ||
import os | ||
import pandas as pd | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.metrics import accuracy_score, confusion_matrix | ||
import wandb | ||
|
||
import utils | ||
from utils import PROJECT_DIR, load_data, SEED | ||
|
||
load_dotenv() | ||
|
||
os.chdir(PROJECT_DIR) | ||
|
||
WANDB_PROJ = "Titanic" # Change this to your project name! | ||
WANDB_USER = os.environ.get("wandb_username") # Change this to your username! | ||
JOB = "predict survival" | ||
|
||
if __name__ == "__main__": | ||
|
||
# # Split the data | ||
X_train, X_test, y_train, y_test = load_data() | ||
|
||
run = wandb.init( | ||
project=WANDB_PROJ, | ||
job_type=JOB, | ||
save_code=True, | ||
tags=["logistic regression"], | ||
) | ||
|
||
log_reg_config = {'penalty': 'l2', | ||
'C': 1.0, | ||
'random_state': SEED, | ||
'solver':'lbfgs', | ||
'max_iter':100} | ||
|
||
model = LogisticRegression(penalty=log_reg_config['penalty'], | ||
C=log_reg_config['C'], | ||
solver=log_reg_config['solver'], | ||
max_iter=log_reg_config['max_iter'], | ||
random_state=log_reg_config['random_state']) | ||
model.fit(X_train, y_train) | ||
|
||
# Predict and evaluate | ||
preds = model.predict(X_test) | ||
accuracy = accuracy_score(y_test, preds) | ||
logging.info(f"Accuracy: {accuracy}") | ||
|
||
wandb.run.summary["accuracy"] = accuracy | ||
|
||
cm = confusion_matrix(y_test, preds) | ||
cm = pd.DataFrame(cm) | ||
logging.info(f"Confusion matrix:\n{cm}") | ||
|
||
# Log confusion matrix | ||
wb_confusion_matrix = wandb.Table(data=cm, columns=["0", "1"]) | ||
run.log({"confusion_matrix": wb_confusion_matrix}) | ||
|
||
# End the weights and biases run | ||
wandb.finish() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
""" | ||
Runs a sweep across 3 different models and various hyperparameters to find the combination that gives the best accuracy. | ||
The number of hyperparameter combinations tried is controlled by N_RUNS. | ||
The different hyperparameter values to try is controlled by sweep_configuration. | ||
""" | ||
|
||
from dotenv import load_dotenv | ||
import pandas as pd | ||
import os | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.svm import SVC | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.metrics import accuracy_score, confusion_matrix | ||
import wandb | ||
|
||
import utils | ||
from utils import PROJECT_DIR, load_data, SEED | ||
|
||
load_dotenv() | ||
|
||
os.chdir(PROJECT_DIR) | ||
|
||
WANDB_PROJ = "Titanic" # Change this to your project name! | ||
WANDB_USER = os.environ.get("wandb_username") # Change this to your username! | ||
JOB = "predict survival" | ||
|
||
N_RUNS = 5 # How many different hyperparameter combinations to try aka how many different runs | ||
|
||
sweep_configuration = { | ||
'method': 'random', # Choose from 'grid', 'random', or 'bayes' | ||
'metric': { | ||
'name': 'accuracy', | ||
'goal': 'maximize' | ||
}, | ||
'parameters': { | ||
'model_type': { | ||
'values': ['logistic_regression', 'svm', 'random_forest'] | ||
}, | ||
'C': { | ||
'values': [0.01, 0.1, 1] | ||
}, | ||
'max_iter': { | ||
'values': [10, 100, 1000] | ||
}, | ||
'penalty': { | ||
'values': ['l2', None] | ||
}, | ||
'solver': { | ||
'values': ['lbfgs'] | ||
}, | ||
# SVM hyperparams | ||
'kernel': { | ||
'values': ['linear', 'rbf'] | ||
}, | ||
'gamma': { | ||
'values': ['auto'] | ||
}, | ||
# RF hyperparams | ||
'n_estimators': { | ||
'values': [10, 100] | ||
}, | ||
'max_depth': { | ||
'values': [5, 10, 50] | ||
}, | ||
'min_samples_split': { | ||
'values': [2, 5, 10] | ||
}, | ||
'min_samples_leaf': { | ||
'values': [1, 2, 4] | ||
} | ||
} | ||
} | ||
|
||
|
||
def train(config, X_train, X_test, y_train, y_test): | ||
if config.model_type == 'logistic_regression': | ||
model = LogisticRegression(C=config['C'], max_iter=config['max_iter'], | ||
penalty=config['penalty'], solver=config['solver'], random_state=SEED) | ||
elif config.model_type == 'svm': | ||
model = SVC(C=config['C'], kernel=config['kernel'], gamma=config['gamma'], random_state=SEED) | ||
elif config.model_type == 'random_forest': | ||
model = RandomForestClassifier(n_estimators=config['n_estimators'], max_depth=config['max_depth'], | ||
min_samples_split=config['min_samples_split'], | ||
min_samples_leaf=config['min_samples_leaf'], random_state=SEED) | ||
|
||
model.fit(X_train, y_train) | ||
preds = model.predict(X_test) | ||
accuracy = accuracy_score(y_test, preds) | ||
cm = confusion_matrix(y_test, preds) | ||
cm = pd.DataFrame(cm) | ||
|
||
return accuracy, cm | ||
|
||
def main(): | ||
wandb.init(project=WANDB_PROJ, job_type=JOB, save_code=True) | ||
X_train, X_test, y_train, y_test = load_data() | ||
accuracy, cm = train(wandb.config, X_train, X_test, y_train, y_test) | ||
wandb.log({"accuracy": accuracy}) | ||
|
||
# Log confusion matrix | ||
wb_confusion_matrix = wandb.Table(data=cm, columns=["0", "1"]) | ||
wandb.log({"confusion_matrix": wb_confusion_matrix}) | ||
|
||
wandb.finish() | ||
|
||
if __name__ == "__main__": | ||
sweep_id = wandb.sweep(sweep=sweep_configuration, project=WANDB_PROJ) | ||
wandb.agent(sweep_id, entity=WANDB_USER, function=main, count=N_RUNS) |
Oops, something went wrong.