Skip to content

Commit

Permalink
Fix workflow (#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Feb 15, 2024
1 parent 171c7af commit be0f9e0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ jobs:
}
python setup.py sdist bdist_wheel
- name: Upload Assets
uses: shogo82148/actions-upload-release-asset@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion scripts/download_wheels.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

# Set variables
AWQ_VERSION="0.1.6"
AWQ_VERSION="0.2.0"
RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ/releases/tags/v${AWQ_VERSION}"

# Create a directory to download the wheels
Expand Down
20 changes: 17 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import torch
import platform
import requests
import importlib_metadata
from pathlib import Path
from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension


def get_latest_kernels_version(repo):
Expand Down Expand Up @@ -88,15 +88,20 @@ def get_kernels_whl_url(
"torch>=2.0.1",
"transformers>=4.35.0",
"tokenizers>=0.12.1",
"typing_extensions>=4.8.0"
"accelerate",
"datasets",
"zstandard",
]

try:
importlib_metadata.version("autoawq-kernels")
if ROCM_VERSION:
import exlv2_ext
else:
import awq_ext

KERNELS_INSTALLED = True
except importlib_metadata.PackageNotFoundError:
except ImportError:
KERNELS_INSTALLED = False

# kernels can be downloaded from pypi for cuda+121 only
Expand Down Expand Up @@ -133,5 +138,14 @@ def get_kernels_whl_url(
"eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
"dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"]
},
# NOTE: We create an empty CUDAExtension because torch helps us with
# creating the right boilerplate to enable correct targeting of
# the autoawq-kernels package
ext_modules=[
CUDAExtension(
name="__build_artifact_for_awq_kernel_targeting",
sources=[],
)
],
**common_setup_kwargs,
)

0 comments on commit be0f9e0

Please sign in to comment.