Skip to content

Commit

Permalink
Merge branch 'dev' into refactor_tox21MolNet
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya0by0 committed Nov 7, 2024
2 parents 29cff11 + 7480783 commit 4102fe9
Show file tree
Hide file tree
Showing 95 changed files with 15,038 additions and 200 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/export_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json

from chebai.preprocessing.reader import (
CLS_TOKEN,
EMBEDDING_OFFSET,
MASK_TOKEN_INDEX,
PADDING_TOKEN_INDEX,
)

# Define the constants you want to export
# Any changes in the key names here should also follow the same change in verify_constants.yml code
constants = {
"EMBEDDING_OFFSET": EMBEDDING_OFFSET,
"CLS_TOKEN": CLS_TOKEN,
"PADDING_TOKEN_INDEX": PADDING_TOKEN_INDEX,
"MASK_TOKEN_INDEX": MASK_TOKEN_INDEX,
}

if __name__ == "__main__":
# Write constants to a JSON file
with open("constants.json", "w") as f:
json.dump(constants, f)
27 changes: 27 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Unittests

on: [pull_request]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
- name: Display Python version
run: python -m unittest discover -s tests/unit
128 changes: 128 additions & 0 deletions .github/workflows/token_consistency.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
name: Check consistency of tokens.txt file

# Define the file paths under `paths` to trigger this check only when specific files are modified.
# This script will then execute checks only on files that have changed, rather than all files listed in `paths`.

# **Note** : To add a new token file for checks, include its path in:
# - `on` -> `push` and `pull_request` sections
# - `jobs` -> `check_tokens` -> `steps` -> Set global variable for multiple tokens.txt paths -> `TOKENS_FILES`

on:
push:
paths:
- "chebai/preprocessing/bin/smiles_token/tokens.txt"
- "chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt"
- "chebai/preprocessing/bin/selfies/tokens.txt"
- "chebai/preprocessing/bin/protein_token/tokens.txt"
- "chebai/preprocessing/bin/graph_properties/tokens.txt"
- "chebai/preprocessing/bin/graph/tokens.txt"
- "chebai/preprocessing/bin/deepsmiles_token/tokens.txt"
- "chebai/preprocessing/bin/protein_token_3_gram/tokens.txt"
pull_request:
paths:
- "chebai/preprocessing/bin/smiles_token/tokens.txt"
- "chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt"
- "chebai/preprocessing/bin/selfies/tokens.txt"
- "chebai/preprocessing/bin/protein_token/tokens.txt"
- "chebai/preprocessing/bin/graph_properties/tokens.txt"
- "chebai/preprocessing/bin/graph/tokens.txt"
- "chebai/preprocessing/bin/deepsmiles_token/tokens.txt"
- "chebai/preprocessing/bin/protein_token_3_gram/tokens.txt"

jobs:
check_tokens:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Get list of changed files
id: changed_files
run: |
git fetch origin dev
# Get the list of changed files compared to origin/dev and save them to a file
git diff --name-only origin/dev > changed_files.txt
# Print the names of changed files on separate lines
echo "Changed files:"
while read -r line; do
echo "Changed File name : $line"
done < changed_files.txt
- name: Set global variable for multiple tokens.txt paths
run: |
# All token files that needs to checked must be included here too, same as in `paths`.
TOKENS_FILES=(
"chebai/preprocessing/bin/smiles_token/tokens.txt"
"chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt"
"chebai/preprocessing/bin/selfies/tokens.txt"
"chebai/preprocessing/bin/protein_token/tokens.txt"
"chebai/preprocessing/bin/graph_properties/tokens.txt"
"chebai/preprocessing/bin/graph/tokens.txt"
"chebai/preprocessing/bin/deepsmiles_token/tokens.txt"
"chebai/preprocessing/bin/protein_token_3_gram/tokens.txt"
)
echo "TOKENS_FILES=${TOKENS_FILES[*]}" >> $GITHUB_ENV
- name: Process only changed tokens.txt files
run: |
# Convert the TOKENS_FILES environment variable into an array
TOKENS_FILES=(${TOKENS_FILES})
# Iterate over each token file path
for TOKENS_FILE_PATH in "${TOKENS_FILES[@]}"; do
# Check if the current token file path is in the list of changed files
if grep -q "$TOKENS_FILE_PATH" changed_files.txt; then
echo "----------------------- Processing $TOKENS_FILE_PATH -----------------------"
# Get previous tokens.txt version
git fetch origin dev
git diff origin/dev -- $TOKENS_FILE_PATH > tokens_diff.txt || echo "No previous tokens.txt found for $TOKENS_FILE_PATH"
# Check for deleted or added lines in tokens.txt
if [ -f tokens_diff.txt ]; then
# Check for deleted lines (lines starting with '-')
deleted_lines=$(grep '^-' tokens_diff.txt | grep -v '^---' | sed 's/^-//' || true)
if [ -n "$deleted_lines" ]; then
echo "Error: Lines have been deleted from $TOKENS_FILE_PATH."
echo -e "Deleted Lines: \n$deleted_lines"
exit 1
fi
# Check for added lines (lines starting with '+')
added_lines=$(grep '^+' tokens_diff.txt | grep -v '^+++' | sed 's/^+//' || true)
if [ -n "$added_lines" ]; then
# Count how many lines have been added
num_added_lines=$(echo "$added_lines" | wc -l)
# Get last `n` lines (equal to num_added_lines) of tokens.txt
last_lines=$(tail -n "$num_added_lines" $TOKENS_FILE_PATH)
# Check if the added lines are at the end of the file
if [ "$added_lines" != "$last_lines" ]; then
# Find lines that were added but not appended at the end of the file
non_appended_lines=$(diff <(echo "$added_lines") <(echo "$last_lines") | grep '^<' | sed 's/^< //')
echo "Error: New lines have been added to $TOKENS_FILE_PATH, but they are not at the end of the file."
echo -e "Added lines that are not at the end of the file: \n$non_appended_lines"
exit 1
fi
fi
if [ "$added_lines" == "" ]; then
echo "$TOKENS_FILE_PATH validation successful: No lines were deleted, and no new lines were added."
else
echo "$TOKENS_FILE_PATH validation successful: No lines were deleted, and new lines were correctly appended at the end."
fi
else
echo "No previous version of $TOKENS_FILE_PATH found."
fi
else
echo "$TOKENS_FILE_PATH was not changed, skipping."
fi
done
116 changes: 116 additions & 0 deletions .github/workflows/verify_constants.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
name: Verify Constants

# Define the file paths under `paths` to trigger this check only when specific files are modified.
# This script will then execute checks only on files that have changed, rather than all files listed in `paths`.

# **Note** : To add a new file for checks, include its path in:
# - `on` -> `push` and `pull_request` sections
# - `jobs` -> `verify-constants` -> `steps` -> Verify constants -> Add a new if else for your file, with check logic inside it.


on:
push:
paths:
- "chebai/preprocessing/reader.py"
pull_request:
paths:
- "chebai/preprocessing/reader.py"

jobs:
verify-constants:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [
# Only use 3.10 as of now
# "3.9",
"3.10",
# "3.11"
]

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set PYTHONPATH
run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV

- name: Get list of changed files
id: changed_files
run: |
git fetch origin dev
# Get the list of changed files compared to origin/dev and save them to a file
git diff --name-only origin/dev > changed_files.txt
# Print the names of changed files on separate lines
echo "Changed files:"
while read -r line; do
echo "Changed File name : $line"
done < changed_files.txt
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
# Setting a fix version for torch due to an error with latest version (2.5.1)
# ImportError: cannot import name 'T_co' from 'torch.utils.data.dataset'
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
- name: Export constants
run: python .github/workflows/export_constants.py

- name: Load constants into environment variables
id: load_constants
# "E_" is appended as suffix to every constant, to protect overwriting other sys env variables with same name
run: |
constants=$(cat constants.json)
echo "$constants" | jq -r 'to_entries|map("E_\(.key)=\(.value|tostring)")|.[]' >> $GITHUB_ENV
- name: Print all environment variables
run: printenv

- name: Verify constants
run: |
file_name="chebai/preprocessing/reader.py"
if grep -q "$file_name" changed_files.txt; then
echo "----------------------- Checking file : $file_name ----------------------- "
# Define expected values for constants
exp_embedding_offset="10"
exp_cls_token="2"
exp_padding_token_index="0"
exp_mask_token_index="1"
# Debugging output to check environment variables
echo "Current Environment Variables:"
echo "E_EMBEDDING_OFFSET = $E_EMBEDDING_OFFSET"
echo "Expected: $exp_embedding_offset"
# Verify constants match expected values
if [ "$E_EMBEDDING_OFFSET" != "$exp_embedding_offset" ]; then
echo "EMBEDDING_OFFSET ($E_EMBEDDING_OFFSET) does not match expected value ($exp_embedding_offset)!"
exit 1
fi
if [ "$E_CLS_TOKEN" != "$exp_cls_token" ]; then
echo "CLS_TOKEN ($E_CLS_TOKEN) does not match expected value ($exp_cls_token)!"
exit 1
fi
if [ "$E_PADDING_TOKEN_INDEX" != "$exp_padding_token_index" ]; then
echo "PADDING_TOKEN_INDEX ($E_PADDING_TOKEN_INDEX) does not match expected value ($exp_padding_token_index)!"
exit 1
fi
if [ "$E_MASK_TOKEN_INDEX" != "$exp_mask_token_index" ]; then
echo "MASK_TOKEN_INDEX ($E_MASK_TOKEN_INDEX) does not match expected value ($exp_mask_token_index)!"
exit 1
fi
else
echo "$file_name not found in changed_files.txt; skipping check."
fi
12 changes: 7 additions & 5 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
self.beta is not None
and self.data_extractor is not None
and all(
os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file))
for raw_file in self.data_extractor.raw_file_names
os.path.exists(
os.path.join(self.data_extractor.processed_dir_main, file_name)
)
for file_name in self.data_extractor.processed_main_file_names
)
and self.pos_weight is None
):
Expand All @@ -53,13 +55,13 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
pd.read_pickle(
open(
os.path.join(
self.data_extractor.raw_dir,
raw_file_name,
self.data_extractor.processed_dir_main,
file_name,
),
"rb",
)
)
for raw_file_name in self.data_extractor.raw_file_names
for file_name in self.data_extractor.processed_main_file_names
]
)
value_counts = []
Expand Down
8 changes: 6 additions & 2 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def __init__(
# Load pretrained checkpoint if provided
if pretrained_checkpoint:
with open(pretrained_checkpoint, "rb") as fin:
model_dict = torch.load(fin, map_location=self.device)
model_dict = torch.load(
fin, map_location=self.device, weights_only=False
)
if load_prefix:
state_dict = filter_dict(model_dict["state_dict"], load_prefix)
else:
Expand Down Expand Up @@ -414,7 +416,9 @@ def __init__(self, cone_dimensions=20, **kwargs):
model_prefix = kwargs.get("load_prefix", None)
if pretrained_checkpoint:
with open(pretrained_checkpoint, "rb") as fin:
model_dict = torch.load(fin, map_location=self.device)
model_dict = torch.load(
fin, map_location=self.device, weights_only=False
)
if model_prefix:
state_dict = {
str(k)[len(model_prefix) :]: v
Expand Down
Loading

0 comments on commit 4102fe9

Please sign in to comment.