Skip to content

Commit

Permalink
[Idefics -> BaseMultiModalModel] [Vilt => BaseMultiModalModel]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 13, 2023
1 parent 4bef09a commit 79d8f14
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 10 deletions.
8 changes: 6 additions & 2 deletions swarms/models/idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ def __init__(
model_name, torch_dtype=torch_dtype, *args, **kwargs
).to(self.device)

self.processor = AutoProcessor.from_pretrained(model_name)
self.processor = AutoProcessor.from_pretrained(
model_name, *args, **kwargs
)

def run(self, task: str, *args, **kwargs) -> str:
def run(
self, task: str = None, img: str = None, *args, **kwargs
) -> str:
"""
Generates text based on the provided prompts.
Expand Down
2 changes: 2 additions & 0 deletions swarms/models/openai_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(
chunk_size=1024 * 1024,
autosave: bool = False,
saved_filepath: str = None,
*args,
**kwargs,
):
super().__init__()
self.model_name = model_name
Expand Down
24 changes: 16 additions & 8 deletions swarms/models/vilt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from transformers import ViltProcessor, ViltForQuestionAnswering
import requests
from PIL import Image
from transformers import ViltForQuestionAnswering, ViltProcessor

from swarms.models.base_multimodal_model import BaseMultiModalModel

class Vilt:

class Vilt(BaseMultiModalModel):
"""
Vision-and-Language Transformer (ViLT) model fine-tuned on VQAv2.
It was introduced in the paper ViLT: Vision-and-Language Transformer Without
Expand All @@ -21,15 +23,21 @@ class Vilt:
"""

def __init__(self):
def __init__(
self,
model_name: str = "dandelin/vilt-b32-finetuned-vqa",
*args,
**kwargs,
):
super().__init__(model_name, *args, **kwargs)
self.processor = ViltProcessor.from_pretrained(
"dandelin/vilt-b32-finetuned-vqa"
model_name, *args, **kwargs
)
self.model = ViltForQuestionAnswering.from_pretrained(
"dandelin/vilt-b32-finetuned-vqa"
model_name, *args, **kwargs
)

def __call__(self, text: str, image_url: str):
def run(self, task: str = None, img: str = None, *args, **kwargs):
"""
Run the model
Expand All @@ -38,9 +46,9 @@ def __call__(self, text: str, image_url: str):
"""
# Download the image
image = Image.open(requests.get(image_url, stream=True).raw)
image = Image.open(requests.get(img, stream=True).raw)

encoding = self.processor(image, text, return_tensors="pt")
encoding = self.processor(image, task, return_tensors="pt")

# Forward pass
outputs = self.model(**encoding)
Expand Down
46 changes: 46 additions & 0 deletions swarms/structs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,52 @@


class BaseStructure(ABC):
"""Base structure.
Attributes:
name (Optional[str]): _description_
description (Optional[str]): _description_
save_metadata (bool): _description_
save_artifact_path (Optional[str]): _description_
save_metadata_path (Optional[str]): _description_
save_error_path (Optional[str]): _description_
Methods:
run: _description_
save_to_file: _description_
load_from_file: _description_
save_metadata: _description_
load_metadata: _description_
log_error: _description_
save_artifact: _description_
load_artifact: _description_
log_event: _description_
run_async: _description_
save_metadata_async: _description_
load_metadata_async: _description_
log_error_async: _description_
save_artifact_async: _description_
load_artifact_async: _description_
log_event_async: _description_
asave_to_file: _description_
aload_from_file: _description_
run_in_thread: _description_
save_metadata_in_thread: _description_
run_concurrent: _description_
compress_data: _description_
decompres_data: _description_
run_batched: _description_
load_config: _description_
backup_data: _description_
monitor_resources: _description_
run_with_resources: _description_
run_with_resources_batched: _description_
Examples:
"""

def __init__(
self,
name: Optional[str] = None,
Expand Down

0 comments on commit 79d8f14

Please sign in to comment.