From 4628c2894ccec69d83b1b8e9793e4ec5a2cee951 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 20 Nov 2024 08:08:48 -0800 Subject: [PATCH] add linting and fix score mod function in flex attention pathway --- .github/workflows/lint.yaml | 21 +++++++++++++++++++++ pi_zero_pytorch/pi_zero.py | 5 ++--- pyproject.toml | 16 +++++++++++++++- 3 files changed, 38 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/lint.yaml diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..ee2afc6 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,21 @@ +name: Ruff +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install uv + python -m uv pip install ruff + - name: Lint with Ruff + run: | + ruff check pi_zero_pytorch/ diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index ed7e442..5e27f8c 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -810,7 +810,7 @@ def forward( images, inverse_pack_image_frames = pack_with_inverse([images], '* c h w') with torch.no_grad(): - self.vit.eval() + self.vit.eval() visual_tokens = self.vit(images) if is_multiple_images: @@ -880,7 +880,6 @@ def forward( else: state_length = state_tokens.shape[-2] - total_seq_length = action_with_registers_length + state_length mask = F.pad(language_mask, (state_length - command_length - 1, 1 + action_with_registers_length), value = True) # assume fixed number of images for now, but address variable length modality states later # rotary embeddings @@ -911,7 +910,7 @@ def forward( flex_attn_fn = partial( flex_attention, block_mask = block_mask, - score_mod = score_mod + score_mod = score_mod_fn ) # state keys and values for caching during inference diff --git a/pyproject.toml b/pyproject.toml index 064ca3c..0177303 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.26" +version = "0.0.27" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -45,6 +45,7 @@ Repository = "https://github.com/lucidrains/pi-zero-pytorch" examples = [] test = [ "pytest", + "ruff>=0.4.2", "vit-pytorch>=1.8.7" ] @@ -53,6 +54,19 @@ pythonpath = [ "." ] +[tool.ruff] +line-length = 1000 + +lint.ignore = [ + "F722", # for jaxtyping shape annotation + "F401", + "F821" +] + +lint.extend-select = [ + "W291" +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build"