forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
35 lines (26 loc) · 1.04 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#!/usr/bin/env python3
# https://github.com/dusty-nv/openvla/blob/main/vla-scripts/extern/verify_openvla.py
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
model="openvla/openvla-7b"
print('loading', model)
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
model,
attn_implementation="flash_attention_2", # [Optional] Requires `flash_attn`
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to("cuda:0")
print(vla.config)
# Grab image input & format prompt
image = Image.open("/data/images/lake.jpg").convert('RGB')
prompt = "In: What action should the robot take to stop?\nOut:"
print('prompt:', prompt)
# Predict Action (7-DoF; un-normalize for BridgeData V2)
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
print('inputs:', list(inputs.keys()))
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
# Execute...
print('action:', action)