Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SASRec Experiments #4

Merged
merged 5 commits into from
Sep 8, 2024
Merged

SASRec Experiments #4

merged 5 commits into from
Sep 8, 2024

Conversation

haru-256
Copy link
Owner

@haru-256 haru-256 commented Sep 8, 2024

User description

Summary

  • conducted experiment of SASRec
  • There are TODO
    • Eval ranking metrics (e.g. NDCG, MRR) for validation.

PR Type

enhancement, bug fix, tests


Description

  • Enhanced the SASRec model with support for float16 training.
  • Improved validation data handling and added NaN detection in attention layers.
  • Updated README with detailed model architecture and results.
  • Added utility functions and optimized training configurations.

Changes walkthrough 📝

Relevant files
Enhancement
6 files
dataset.py
Update validation data loader                                                       

sasrec/data/dataset.py

  • Changed the data loader to use val_dataset instead of test_dataset.
  • +1/-1     
    sasrec.py
    Add float16 support for model training                                     

    sasrec/models/sasrec.py

  • Added float16 parameter for model optimization.
  • Updated attention mask creation to support float16.
  • +16/-11 
    train.py
    Optimize training configuration                                                   

    sasrec/train.py

  • Increased batch size for training.
  • Configured model to use GPU if available.
  • +15/-4   
    utils.py
    Add utility function for CPU core count                                   

    sasrec/utils/utils.py

    • Added cpu_count function to get the number of CPU cores.
    +31/-3   
    Makefile
    Update Makefile for training command                                         

    sasrec/Makefile

    • Added train target to Makefile for easier execution.
    +4/-1     
    pyproject.toml
    Update project dependencies                                                           

    sasrec/pyproject.toml

    • Updated dependencies for cross-platform support.
    +4/-1     
    Bug fix
    3 files
    transformer_embedding.py
    Fix embedding layer issues                                                             

    sasrec/models/modules/transformer_embedding.py

  • Fixed the embedding lookup to use id_embeddings.
  • Updated position IDs to move to the correct device.
  • +2/-2     
    transformer_encoder_block.py
    Enhance NaN detection in attention block                                 

    sasrec/models/modules/transformer_encoder_block.py

    • Improved NaN detection in MultiheadAttention output.
    +6/-2     
    test_utils.py
    Fix padding mask creation in tests                                             

    sasrec/tests/test_utils/test_utils.py

    • Updated padding mask creation to use system minimum value.
    +3/-2     
    Documentation
    1 files
    README.md
    Improve documentation with detailed model info                     

    sasrec/README.md

  • Expanded dataset description and model architecture.
  • Added results and usage instructions.
  • +108/-1 

    💡 PR-Agent usage:
    Comment /help on the PR to get a list of all available PR-Agent tools and their descriptions

    Copy link

    github-actions bot commented Sep 8, 2024

    PR Reviewer Guide 🔍

    ⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Key issues to review

    Possible Bug
    The assertion for NaN detection in the MultiheadAttention output has been moved to a new location. Ensure that this does not affect the logic of the model, especially if the previous assertion was critical for debugging.

    Code Smell
    The function create_attn_padding_mask has a complex logic that could be simplified. Consider breaking it down into smaller functions for better readability and maintainability.

    Copy link

    github-actions bot commented Sep 8, 2024

    PR Code Suggestions ✨

    CategorySuggestion                                                                                                                                    Score
    Performance
    Use torch.zeros_like for creating the padding mask to match the input tensor's properties

    Consider using torch.zeros_like instead of torch.zeros(batch_size, seq_len) to
    ensure the padding mask has the same data type and device as x.

    sasrec/utils/utils.py [36-39]

    -padding_mask = torch.zeros(batch_size, seq_len).to(device),
    +padding_mask = torch.zeros_like(x, dtype=torch.float32)
     
    Suggestion importance[1-10]: 8

    Why: This suggestion enhances the code's robustness by ensuring that the padding mask matches the input tensor's properties, which is important for maintaining consistency in tensor operations.

    8
    Add a check for empty masks before performing assertions

    Consider adding a check to ensure that attn_mask and key_padding_mask are not empty
    before proceeding with the assertion to avoid unnecessary computations.

    sasrec/models/modules/transformer_encoder_block.py [59-61]

    -assert (
    -+            torch.isnan(mha_out).sum() == 0
    +if attn_mask.numel() > 0 and key_padding_mask.numel() > 0:
    +    assert (
    +        torch.isnan(mha_out).sum() == 0
     
    Suggestion importance[1-10]: 6

    Why: Adding a check for empty masks is a reasonable performance improvement, but it does not address a critical issue. The assertion itself is important for debugging.

    6
    Best practice
    Set float16 based on the availability of a GPU

    Ensure that float16 is only set to True when the model is actually running on a GPU
    to prevent unnecessary precision loss on CPU.

    sasrec/models/sasrec.py [23]

    -float16: bool = False,
    +float16: bool = torch.cuda.is_available(),
     
    Suggestion importance[1-10]: 7

    Why: This suggestion improves the model's efficiency by conditionally setting float16, which is a good practice when dealing with hardware capabilities.

    7
    Possible issue
    Add a check to ensure val_dataset is initialized before use

    Ensure that self.val_dataset is properly initialized before being used in
    val_dataloader to avoid potential runtime errors.

    sasrec/data/dataset.py [318]

     return DataLoader(
    -+            self.val_dataset,
    +        self.val_dataset if self.val_dataset is not None else raise ValueError("val_dataset must be initialized"),
     
    Suggestion importance[1-10]: 3

    Why: While it's good practice to ensure that val_dataset is initialized, the suggestion does not provide a proper implementation for handling the case when it is not initialized, which could lead to runtime errors.

    3

    sasrec/utils/utils.py Outdated Show resolved Hide resolved
    @haru-256 haru-256 merged commit 38e913e into main Sep 8, 2024
    1 check passed
    @haru-256 haru-256 deleted the feat/sasrec-training branch September 8, 2024 08:16
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant