-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcomponents.py
61 lines (50 loc) · 1.78 KB
/
components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import abc
import gradio
from lightning import LightningWork
from functools import partial
from types import ModuleType
from typing import Any, List, Optional
from lightning import LightningWork
class MyServeGradioComponent(LightningWork, abc.ABC):
inputs: Any
outputs: Any
examples: Optional[List] = None
enable_queue: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.inputs
assert self.outputs
self._preprocessor = None
self._model = None
@property
def model(self):
return self._model
@property
def preprocessor(self):
return self._preprocessor
@abc.abstractmethod
def predict(self, *args, **kwargs):
"""Override with your logic to make a prediction."""
@abc.abstractmethod
def build_preprocessor(self) -> Any:
"""Override to instantiate and return your preprocessing pipeline.
The model would be accessible under self.preprocessor
"""
@abc.abstractmethod
def build_model(self) -> Any:
"""Override to instantiate and return your model.
The model would be accessible under self.model
"""
def run(self, *args, **kwargs):
if self._preprocessor is None:
self._preprocessor = self.build_preprocessor()
if self._model is None:
self._model = self.build_model()
fn = partial(self.predict, *args, **kwargs)
fn.__name__ = self.predict.__name__
#output_size = gradio.outputs.Textbox(label="Predicted age:")
gradio.Interface(fn=fn, inputs=self.inputs, outputs=self.outputs, examples=self.examples).launch(
server_name=self.host,
server_port=self.port,
enable_queue=self.enable_queue,
)