Skip to content

Commit

Permalink
Merge pull request #33 from line/add_test
Browse files Browse the repository at this point in the history
Type annotation & add tests; merged it, but tests need to be implemented for future.
  • Loading branch information
awkrail authored Sep 12, 2024
2 parents 931908b + 32b09a3 commit 985aa5a
Show file tree
Hide file tree
Showing 45 changed files with 1,466 additions and 1,145 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/mypy_ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Run mypy and ruff

on:
push:
branches:
- '**'
pull_request:
branches:
- '**'

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Python Setup
uses: actions/setup-python@v3
with:
python-version: 3.9

- name: Run dependency libraries
run: |
pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 torchtext==0.15.1
pip install easydict pandas tqdm pyyaml scikit-learn ffmpeg-python ftfy regex einops fvcore gradio torchlibrosa librosa
pip install 'clip@git+https://github.com/openai/CLIP.git'
pip install mypy ruff
- name: Run mypy
run: find lighthouse -type f -name "*.py" -not -path 'lighthouse/common/*' | xargs mypy

- name: Run ruff
run: find lighthouse -type f -name "*.py" -not -path 'lighthouse/common/*' | xargs -I {} sh -c 'ruff check "{}"'
8 changes: 4 additions & 4 deletions .github/workflows/mypy.yml → .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Run mypy
name: Run pytest

on:
push:
Expand All @@ -24,7 +24,7 @@ jobs:
pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 torchtext==0.15.1
pip install easydict pandas tqdm pyyaml scikit-learn ffmpeg-python ftfy regex einops fvcore gradio torchlibrosa librosa
pip install 'clip@git+https://github.com/openai/CLIP.git'
pip install mypy
pip install pytest
- name: Run mypy
run: find lighthouse -type f -name "*.py" -not -path 'lighthouse/common/*' | xargs mypy
- name: Run pytest
run: pytest tests/test_models.py
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ venv.bak/
.dmypy.json
dmypy.json

# ruff
.ruff_cache/

# Pyre type checker
.pyre/

Expand Down
46 changes: 19 additions & 27 deletions api_example/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,33 @@
"""
import os
import subprocess
import pprint
import torch
from lighthouse.models import CGDETRPredictor

# use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

weight_dir = 'gradio_demo/weights'
if not os.path.exists(os.path.join(weight_dir, 'clip_cg_detr_qvhighlight.ckpt')):
command = 'wget -P gradio_demo/weights/ https://zenodo.org/records/13363606/files/clip_cg_detr_qvhighlight.ckpt'
subprocess.run(command, shell=True)

if not os.path.exists('SLOWFAST_8x8_R50.pkl'):
subprocess.run('wget https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/SLOWFAST_8x8_R50.pkl', shell=True)
from lighthouse.models import CGDETRPredictor
from typing import Dict, List, Optional

if not os.path.exists('Cnn14_mAP=0.431.pth'):
subprocess.run('wget https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth', shell=True)
def load_weights(weight_dir: str) -> None:
if not os.path.exists(os.path.join(weight_dir, 'clip_slowfast_pann_cg_detr_qvhighlight.ckpt')):
command = 'wget -P gradio_demo/weights/ https://zenodo.org/records/13363606/files/clip_slowfast_pann_cg_detr_qvhighlight.ckpt'
subprocess.run(command, shell=True)

weight_path = os.path.join(weight_dir, 'clip_cg_detr_qvhighlight.ckpt')
model = CGDETRPredictor(weight_path, device=device, feature_name='clip', slowfast_path=None, pann_path=None)
if not os.path.exists('SLOWFAST_8x8_R50.pkl'):
subprocess.run('wget https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/SLOWFAST_8x8_R50.pkl', shell=True)

"""
# slowfast_path is necesary if you use clip_slowfast features
weight_path = os.path.join(weight_dir, 'clip_slowfast_cg_detr_qvhighlight.ckpt')
model = CGDETRPredictor(weight_path, device=device, feature_name='clip_slowfast', slowfast_path='SLOWFAST_8x8_R50.pkl', pann_path=None)
if not os.path.exists('Cnn14_mAP=0.431.pth'):
subprocess.run('wget https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth', shell=True)

# slowfast_path and pann_path are necesary if you use clip_slowfast_pann features
weight_path = os.path.join(weight_dir, 'clip_slowfast_cg_detr_qvhighlight.ckpt')
model = CGDETRPredictor(weight_path, device=device, feature_name='clip_slowfast_pann', slowfast_path='SLOWFAST_8x8_R50.pkl', pann_path='Cnn14_mAP=0.431.pth')
"""
# use GPU if available
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
weight_dir: str = 'gradio_demo/weights'
weight_path: str = os.path.join(weight_dir, 'clip_cg_detr_qvhighlight.ckpt')
model: CGDETRPredictor = CGDETRPredictor(weight_path, device=device, feature_name='clip',
slowfast_path=None, pann_path=None)

# encode video features
model.encode_video('api_example/RoripwjYFp8_60.0_210.0.mp4')

# moment retrieval & highlight detection
query = 'A woman wearing a glass is speaking in front of the camera'
prediction = model.predict(query)
pprint.pp(prediction)
query: str = 'A woman wearing a glass is speaking in front of the camera'
prediction: Optional[Dict[str, List[float]]] = model.predict(query)
print(prediction)
3 changes: 1 addition & 2 deletions lighthouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations
under the License.
"""
from lighthouse import *
"""
40 changes: 0 additions & 40 deletions lighthouse/audio_extractor.py

This file was deleted.

Loading

0 comments on commit 985aa5a

Please sign in to comment.