-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsdxl.py
105 lines (91 loc) · 3.5 KB
/
sdxl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from datetime import datetime
import os
from pathlib import Path
from typing import Optional
from PIL import Image
from diffusers import AutoPipelineForText2Image
import torch
# Constants
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_STEPS = 1
DEFAULT_GUIDANCE = 0.0
OUTPUT_DIR = "output/generated_images/sdxl-turbo"
MODEL_CACHE_DIR = "models/sdxl-turbo"
# Global pipeline variable
pipeline = None
def get_best_device():
try:
# Check for rocBLAS/HIP environment
if "ROCM_PATH" in os.environ or "HIP_PATH" in os.environ or hasattr(torch, 'hip'):
print("ROCm/HIP environment detected, forcing CPU usage")
return "cpu", torch.float32
elif torch.cuda.is_available():
print("CUDA is available")
return "cuda", torch.float16
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
print("MPS is available")
return "mps", torch.float32
except:
pass
print("Using CPU backend")
return "cpu", torch.float32
def initialize_pipeline() -> "StableDiffusionPipeline":
"""Initialize the SDXL Turbo pipeline."""
device, dtype = get_best_device()
try:
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None,
cache_dir=MODEL_CACHE_DIR
)
pipeline.to(device)
print(f"Pipeline initialized on {device}")
return pipeline
except Exception as e:
print(f"Failed to initialize on {device}, falling back to CPU: {str(e)}")
# CPU fallback
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float32,
cache_dir=MODEL_CACHE_DIR
)
pipeline.to("cpu")
return pipeline
def sanitize_filename(prompt: str, max_length: int = 50) -> str:
"""Create a safe filename from the prompt."""
safe_prompt = "".join(x for x in prompt if x.isalnum() or x.isspace())
return safe_prompt[:max_length].strip()
def generate_image(prompt: str, output_dir: str = OUTPUT_DIR) -> str:
"""Generate an image from a text prompt."""
global pipeline
try:
if pipeline is None:
pipeline = initialize_pipeline()
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Generate the image
output = pipeline(prompt, num_inference_steps=DEFAULT_STEPS, guidance_scale=DEFAULT_GUIDANCE)
# Save the image
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
safe_prompt = sanitize_filename(prompt)
filename = f"{safe_prompt}_{timestamp}.png"
filepath = os.path.join(output_dir, filename)
output.images[0].save(filepath)
print(f"\nImage saved to: {filepath}")
return filepath
except Exception as e:
if "cuda" in str(e).lower() and pipeline is not None:
print("CUDA error detected, falling back to CPU")
pipeline.to("cpu")
return generate_image(prompt, output_dir) # Retry on CPU
raise RuntimeError(f"Failed to generate image: {str(e)}")
if __name__ == "__main__":
try:
prompt = input("Enter prompt: ")
generate_image(prompt)
except KeyboardInterrupt:
print("\nGeneration cancelled by user")
except Exception as e:
print(f"Error: {str(e)}")