diff --git a/.github/workflows/build-push-images.yml b/.github/workflows/build-push-images.yml index 277a23f..9a14c0d 100644 --- a/.github/workflows/build-push-images.yml +++ b/.github/workflows/build-push-images.yml @@ -18,6 +18,7 @@ jobs: include: - component: chat - component: image-analysis + - component: flux-image-gen permissions: contents: read id-token: write # needed for signing the images with GitHub OIDC Token diff --git a/.gitignore b/.gitignore index 84f43a8..9372765 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ __pycache__/ **/.ruff_cache # Ignore local dev helpers -test-values.y[a]ml +**/dev-values.yml **venv*/ # Helm chart stuff diff --git a/charts/flux-image-gen/.helmignore b/charts/flux-image-gen/.helmignore new file mode 100644 index 0000000..0e8a0eb --- /dev/null +++ b/charts/flux-image-gen/.helmignore @@ -0,0 +1,23 @@ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ diff --git a/charts/flux-image-gen/Chart.yaml b/charts/flux-image-gen/Chart.yaml new file mode 100644 index 0000000..aaa8943 --- /dev/null +++ b/charts/flux-image-gen/Chart.yaml @@ -0,0 +1,9 @@ +apiVersion: v2 +name: flux-image-gen +description: A Helm chart for running Flux image generation models on Kubernetes + +type: application + +# The version and appVersion are updated by the chart build script +version: 0.1.0 +appVersion: local diff --git a/charts/flux-image-gen/ci/test-values.yaml b/charts/flux-image-gen/ci/test-values.yaml new file mode 100644 index 0000000..b1606cb --- /dev/null +++ b/charts/flux-image-gen/ci/test-values.yaml @@ -0,0 +1,13 @@ +models: + - flux-dev + - flux-schnell +api: + # Run in dev mode so that we skip + # the image gen step and can therefore + # test in a kind cluster + commandOverride: + - fastapi + - dev + - api_server.py + - --host + - "0.0.0.0" diff --git a/charts/flux-image-gen/templates/NOTES.txt b/charts/flux-image-gen/templates/NOTES.txt new file mode 100644 index 0000000..e69de29 diff --git a/charts/flux-image-gen/templates/_helpers.tpl b/charts/flux-image-gen/templates/_helpers.tpl new file mode 100644 index 0000000..17221fc --- /dev/null +++ b/charts/flux-image-gen/templates/_helpers.tpl @@ -0,0 +1,77 @@ +{{/* +Expand the name of the chart. +*/}} +{{- define "flux-image-gen.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "flux-image-gen.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "flux-image-gen.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "flux-image-gen.labels" -}} +helm.sh/chart: {{ include "flux-image-gen.chart" . }} +{{ include "flux-image-gen.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "flux-image-gen.selectorLabels" -}} +app.kubernetes.io/name: {{ include "flux-image-gen.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} + +{{/* +Model selector labels +*/}} +{{- define "flux-image-gen.modelLabels" -}} +app.kubernetes.io/component: {{ . }}-api +{{- end }} + +{{/* +UI selector labels +*/}} +{{- define "flux-image-gen.uiLabels" -}} +app.kubernetes.io/component: {{ .Release.Name }}-ui +{{- end }} + + +{{/* +Create the name of the service account to use +*/}} +{{- define "flux-image-gen.serviceAccountName" -}} +{{- if .Values.serviceAccount.create }} +{{- default (include "flux-image-gen.fullname" .) .Values.serviceAccount.name }} +{{- else }} +{{- default "default" .Values.serviceAccount.name }} +{{- end }} +{{- end }} diff --git a/charts/flux-image-gen/templates/api/deployment.yaml b/charts/flux-image-gen/templates/api/deployment.yaml new file mode 100644 index 0000000..5492473 --- /dev/null +++ b/charts/flux-image-gen/templates/api/deployment.yaml @@ -0,0 +1,104 @@ +{{- range $model := .Values.models }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ printf "%s-%s-api" (include "flux-image-gen.fullname" $) $model }} + labels: + {{- include "flux-image-gen.labels" $ | nindent 4 }} + {{- include "flux-image-gen.modelLabels" . | nindent 4 }} +spec: + replicas: {{ $.Values.api.replicaCount }} + {{- with $.Values.api.deploymentStrategy }} + strategy: + {{- toYaml . | nindent 4 }} + {{- end }} + selector: + matchLabels: + {{- include "flux-image-gen.selectorLabels" $ | nindent 6 }} + {{- include "flux-image-gen.modelLabels" . | nindent 6 }} + template: + metadata: + {{- with $.Values.api.podAnnotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + labels: + {{- include "flux-image-gen.labels" $ | nindent 8 }} + {{- include "flux-image-gen.modelLabels" . | nindent 8 }} + {{- with $.Values.api.podLabels }} + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + {{- with $.Values.api.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + securityContext: + {{- toYaml $.Values.api.podSecurityContext | nindent 8 }} + containers: + - name: {{ $.Chart.Name }} + securityContext: + {{- toYaml $.Values.api.securityContext | nindent 12 }} + image: "{{ $.Values.image.repository }}:{{ $.Values.image.tag | default $.Chart.AppVersion }}" + imagePullPolicy: {{ $.Values.image.pullPolicy }} + {{- with $.Values.api.commandOverride }} + {{- if . }} + command: + {{- . | toYaml | nindent 12 }} + {{- end }} + {{- end }} + ports: + - name: http + containerPort: {{ $.Values.api.service.port }} + protocol: TCP + {{- if $.Values.api.startupProbe }} + startupProbe: + {{- toYaml $.Values.api.startupProbe | nindent 12 }} + {{- end }} + {{- if $.Values.api.livenessProbe }} + livenessProbe: + {{- toYaml $.Values.api.livenessProbe | nindent 12 }} + {{- end }} + {{- if $.Values.api.readinessProbe }} + readinessProbe: + {{- toYaml $.Values.api.readinessProbe | nindent 12 }} + {{- end }} + resources: + {{- toYaml $.Values.api.resources | nindent 12 }} + {{- with $.Values.api.volumeMounts }} + volumeMounts: + {{- toYaml . | nindent 12 }} + {{- end }} + # Make stdout from python visible in k8s logs + tty: true + env: + - name: FLUX_MODEL_NAME + value: {{ $model }} + - name: PYTHONUNBUFFERED + value: "1" + {{- if $.Values.api.huggingfaceToken }} + - name: HUGGING_FACE_HUB_TOKEN + value: {{ quote $.Values.api.huggingfaceToken }} + {{- end }} + {{- with $.Values.api.envFrom }} + envFrom: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- with $.Values.api.volumes }} + volumes: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with $.Values.api.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with $.Values.api.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with $.Values.api.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} +{{- end -}} diff --git a/charts/flux-image-gen/templates/api/service.yaml b/charts/flux-image-gen/templates/api/service.yaml new file mode 100644 index 0000000..d81499f --- /dev/null +++ b/charts/flux-image-gen/templates/api/service.yaml @@ -0,0 +1,20 @@ +{{- range $model := .Values.models }} +--- +apiVersion: v1 +kind: Service +metadata: + name: {{ printf "%s-%s-api" (include "flux-image-gen.fullname" $) $model }} + labels: + {{- include "flux-image-gen.labels" $ | nindent 4 }} + {{- include "flux-image-gen.modelLabels" . | nindent 4 }} +spec: + type: {{ $.Values.api.service.type }} + ports: + - port: {{ $.Values.api.service.port }} + targetPort: http + protocol: TCP + name: http + selector: + {{- include "flux-image-gen.selectorLabels" $ | nindent 4 }} + {{- include "flux-image-gen.modelLabels" . | nindent 4 }} +{{- end -}} diff --git a/charts/flux-image-gen/templates/tests/gradio-api.yaml b/charts/flux-image-gen/templates/tests/gradio-api.yaml new file mode 100644 index 0000000..7d90c06 --- /dev/null +++ b/charts/flux-image-gen/templates/tests/gradio-api.yaml @@ -0,0 +1,22 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: gradio-client-test + annotations: + "helm.sh/hook": test +spec: + template: + spec: + containers: + - name: gradio-client + image: "{{ $.Values.image.repository }}:{{ $.Values.image.tag | default $.Chart.AppVersion }}" + command: + - python + - test_client.py + env: + - name: GRADIO_HOST + value: {{ printf "http://%s-ui.%s.svc:%v" (include "flux-image-gen.fullname" .) .Release.Namespace .Values.ui.service.port }} + - name: FLUX_MODEL + value: {{ .Values.models | first }} + restartPolicy: Never + backoffLimit: 3 diff --git a/charts/flux-image-gen/templates/ui/configmap.yaml b/charts/flux-image-gen/templates/ui/configmap.yaml new file mode 100644 index 0000000..4773ef9 --- /dev/null +++ b/charts/flux-image-gen/templates/ui/configmap.yaml @@ -0,0 +1,16 @@ +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ .Release.Name }}-ui-config + labels: + {{- include "flux-image-gen.labels" . | nindent 4 }} +data: + gradio_config.yaml: | + models: + {{- range $model := .Values.models }} + - name: {{ . }} + address: {{ printf "http://%s.%s.svc:%v" ( printf "%s-%s-api" (include "flux-image-gen.fullname" $) . ) $.Release.Namespace $.Values.api.service.port }} + {{- end }} + example_prompt: | + {{- .Values.examplePrompt | nindent 6 -}} diff --git a/charts/flux-image-gen/templates/ui/deployment.yaml b/charts/flux-image-gen/templates/ui/deployment.yaml new file mode 100644 index 0000000..20fa41b --- /dev/null +++ b/charts/flux-image-gen/templates/ui/deployment.yaml @@ -0,0 +1,94 @@ +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "flux-image-gen.fullname" . }}-ui + labels: + {{- include "flux-image-gen.labels" . | nindent 4 }} +spec: + selector: + matchLabels: + {{- include "flux-image-gen.selectorLabels" . | nindent 6 }} + {{- include "flux-image-gen.uiLabels" . | nindent 6 }} + template: + metadata: + {{- with .Values.ui.podAnnotations }} + annotations: + # Recreate pods if settings config map changes + checksum/config: {{ include (print $.Template.BasePath "/ui/configmap.yaml") . | sha256sum }} + {{- toYaml . | nindent 8 }} + {{- end }} + labels: + {{- include "flux-image-gen.labels" . | nindent 8 }} + {{- include "flux-image-gen.uiLabels" . | nindent 8 }} + {{- with .Values.ui.podLabels }} + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + securityContext: + {{- toYaml .Values.podSecurityContext | nindent 8 }} + containers: + - name: {{ $.Chart.Name }} + securityContext: + {{- toYaml .Values.securityContext | nindent 12 }} + image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + command: + - python + - gradio_ui.py + ports: + - name: http + containerPort: {{ .Values.ui.service.port }} + protocol: TCP + {{- if .Values.ui.startupProbe }} + startupProbe: + {{- toYaml .Values.ui.ui.startupProbe | nindent 12 }} + {{- end }} + {{- if .Values.ui.livenessProbe }} + livenessProbe: + {{- toYaml .Values.ui.livenessProbe | nindent 12 }} + {{- end }} + {{- if .Values.ui.readinessProbe }} + readinessProbe: + {{- toYaml .Values.ui.readinessProbe | nindent 12 }} + {{- end }} + volumeMounts: + - name: app-config + mountPath: /etc/gradio-app/ + {{- with .Values.ui.volumeMounts }} + {{- if . -}} + {{- toYaml . | nindent 12 }} + {{- end -}} + {{- end }} + # Make stdout from python visible in k8s logs + tty: true + env: + - name: PYTHONUNBUFFERED + value: "1" + - name: GRADIO_SERVER_NAME + value: 0.0.0.0 + volumes: + - name: app-config + configMap: + name: {{ .Release.Name }}-ui-config + {{- with .Values.ui.volumes }} + {{- if . -}} + {{- toYaml . | nindent 8 }} + {{- end -}} + {{- end }} + {{- with .Values.ui.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.ui.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.ui.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/charts/flux-image-gen/templates/ui/ingress.yaml b/charts/flux-image-gen/templates/ui/ingress.yaml new file mode 100644 index 0000000..0f45b7b --- /dev/null +++ b/charts/flux-image-gen/templates/ui/ingress.yaml @@ -0,0 +1,43 @@ +{{- if .Values.ui.ingress.enabled -}} +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: {{ include "flux-image-gen.fullname" . }} + labels: + {{- include "flux-image-gen.labels" . | nindent 4 }} + {{- with .Values.ui.ingress.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + {{- with .Values.ui.ingress.className }} + ingressClassName: {{ . }} + {{- end }} + {{- if .Values.ui.ingress.tls }} + tls: + {{- range .Values.ui.ingress.tls }} + - hosts: + {{- range .hosts }} + - {{ . | quote }} + {{- end }} + secretName: {{ .secretName }} + {{- end }} + {{- end }} + rules: + {{- range .Values.ui.ingress.hosts }} + - host: {{ .host | quote }} + http: + paths: + {{- range .paths }} + - path: {{ .path }} + {{- with .pathType }} + pathType: {{ . }} + {{- end }} + backend: + service: + name: {{ include "flux-image-gen.fullname" $ }}-ui + port: + number: {{ $.Values.ui.service.port }} + {{- end }} + {{- end }} +{{- end }} diff --git a/charts/flux-image-gen/templates/ui/service.yaml b/charts/flux-image-gen/templates/ui/service.yaml new file mode 100644 index 0000000..dffb900 --- /dev/null +++ b/charts/flux-image-gen/templates/ui/service.yaml @@ -0,0 +1,18 @@ +--- +apiVersion: v1 +kind: Service +metadata: + name: {{ include "flux-image-gen.fullname" . }}-ui + labels: + {{- include "flux-image-gen.labels" $ | nindent 4 }} + {{- include "flux-image-gen.uiLabels" . | nindent 4 }} +spec: + type: {{ .Values.ui.service.type }} + ports: + - port: {{ .Values.ui.service.port }} + targetPort: http + protocol: TCP + name: http + selector: + {{- include "flux-image-gen.selectorLabels" . | nindent 4 }} + {{- include "flux-image-gen.uiLabels" . | nindent 4 }} diff --git a/charts/flux-image-gen/values.yaml b/charts/flux-image-gen/values.yaml new file mode 100644 index 0000000..b5d97b4 --- /dev/null +++ b/charts/flux-image-gen/values.yaml @@ -0,0 +1,176 @@ +# Default values for flux-image-gen. +# This is a YAML-formatted file. +# Declare variables to be passed into your templates. + +models: + - flux-schnell + # - flux-dev + +examplePrompt: | + Yoda riding a tiny unicorn through space. + +# This sets the container image more information can be found here: https://kubernetes.io/docs/concepts/containers/images/ +image: + repository: ghcr.io/stackhpc/azimuth-llm-flux-image-gen-ui + # This sets the pull policy for images. + pullPolicy: IfNotPresent + # Overrides the image tag whose default is the chart appVersion. + tag: "" + +# This is for the secretes for pulling an image from a private repository more information can be found here: https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/ +imagePullSecrets: [] +# This is to override the chart name. +nameOverride: "" +fullnameOverride: "" + +ui: + # This is for setting Kubernetes Annotations to a Pod. + # For more information checkout: https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/ + podAnnotations: {} + # This is for setting Kubernetes Labels to a Pod. + # For more information checkout: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/ + podLabels: {} + + podSecurityContext: {} + # fsGroup: 2000 + + securityContext: {} + # capabilities: + # drop: + # - ALL + # readOnlyRootFilesystem: true + # runAsNonRoot: true + # runAsUser: 1000 + + # This is for setting up a service more information can be found here: https://kubernetes.io/docs/concepts/services-networking/service/ + service: + # This sets the service type more information can be found here: https://kubernetes.io/docs/concepts/services-networking/service/#publishing-services-service-types + type: ClusterIP + # This sets the ports more information can be found here: https://kubernetes.io/docs/concepts/services-networking/service/#field-spec-ports + port: 7860 + + # This block is for setting up the ingress for more information can be found here: https://kubernetes.io/docs/concepts/services-networking/ingress/ + ingress: + enabled: false + annotations: {} + # kubernetes.io/ingress.class: nginx + # kubernetes.io/tls-acme: "true" + hosts: + - host: chart-example.local + paths: + - path: / + pathType: ImplementationSpecific + tls: [] + # - secretName: chart-example-tls + # hosts: + # - chart-example.local + +api: + + resources: {} + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + # limits: + # cpu: 100m + # memory: 128Mi + # requests: + # cpu: 100m + # memory: 128Mi + + # This will set the replicaset count more information can be found here: https://kubernetes.io/docs/concepts/workloads/controllers/replicaset/ + replicaCount: 1 + + deploymentStrategy: + type: Recreate + + # Downloading 100GB+ of model weights can take a long time so + # it's difficult to give these probes sensible default values... + startupProbe: + # httpGet: + # path: / + # port: http + # Is 30 minutes long enough...? + # failureThreshold: 180 + # periodSeconds: 10 + livenessProbe: + # httpGet: + # path: / + # port: http + readinessProbe: + # httpGet: + # path: / + # port: http + + # Additional volumes on the output Deployment definition. + volumes: [] + # - name: foo + # secret: + # secretName: mysecret + # optional: false + + # Additional volumeMounts on the output Deployment definition. + volumeMounts: [] + # - name: foo + # mountPath: "/etc/foo" + # readOnly: true + + + # To supply a HF hub token for flux-dev model either use + huggingfaceToken: + # OR + # envFrom: + # - secretRef: + # name: + + envFrom: {} + + nodeSelector: {} + + tolerations: [] + + affinity: {} + + commandOverride: + + # This is for setting Kubernetes Annotations to a Pod. + # For more information checkout: https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/ + podAnnotations: {} + # This is for setting Kubernetes Labels to a Pod. + # For more information checkout: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/ + podLabels: {} + + podSecurityContext: {} + # fsGroup: 2000 + + securityContext: {} + # capabilities: + # drop: + # - ALL + # readOnlyRootFilesystem: true + # runAsNonRoot: true + # runAsUser: 1000 + + # This is for setting up a service more information can be found here: https://kubernetes.io/docs/concepts/services-networking/service/ + service: + # This sets the service type more information can be found here: https://kubernetes.io/docs/concepts/services-networking/service/#publishing-services-service-types + type: ClusterIP + # This sets the ports more information can be found here: https://kubernetes.io/docs/concepts/services-networking/service/#field-spec-ports + port: 8000 + + # This block is for setting up the ingress for more information can be found here: https://kubernetes.io/docs/concepts/services-networking/ingress/ + ingress: + enabled: false + annotations: {} + # kubernetes.io/ingress.class: nginx + # kubernetes.io/tls-acme: "true" + hosts: + - host: chart-example.local + paths: + - path: / + pathType: ImplementationSpecific + tls: [] + # - secretName: chart-example-tls + # hosts: + # - chart-example.local diff --git a/web-apps/flux-image-gen/.gitignore b/web-apps/flux-image-gen/.gitignore new file mode 100644 index 0000000..ea1472e --- /dev/null +++ b/web-apps/flux-image-gen/.gitignore @@ -0,0 +1 @@ +output/ diff --git a/web-apps/flux-image-gen/Dockerfile b/web-apps/flux-image-gen/Dockerfile new file mode 100644 index 0000000..ba48dbe --- /dev/null +++ b/web-apps/flux-image-gen/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.11 + +# https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y + + +ARG DIR=flux-image-gen + +COPY $DIR/requirements.txt requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +COPY purge-google-fonts.sh . +RUN bash purge-google-fonts.sh + +WORKDIR /app + +COPY $DIR/*.py . + +COPY $DIR/gradio_config.yaml . + +COPY $DIR/test-image.jpg . + +ENTRYPOINT ["fastapi", "run", "api_server.py"] diff --git a/web-apps/flux-image-gen/api_server.py b/web-apps/flux-image-gen/api_server.py new file mode 100644 index 0000000..777857a --- /dev/null +++ b/web-apps/flux-image-gen/api_server.py @@ -0,0 +1,66 @@ +import io +import os +import sys +import torch + +from fastapi import FastAPI +from fastapi.responses import Response, JSONResponse +from PIL import Image +from pydantic import BaseModel + +from image_gen import FluxGenerator + +# Detect if app is run using `fastapi dev ...` +DEV_MODE = sys.argv[1] == "dev" + +app = FastAPI() + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = os.environ.get("FLUX_MODEL_NAME", "flux-schnell") +if not DEV_MODE: + print("Loading model", model) + generator = FluxGenerator(model, device, offload=False) + + +class ImageGenInput(BaseModel): + width: int + height: int + num_steps: int + guidance: float + seed: int + prompt: str + add_sampling_metadata: bool + + +@app.get("/model") +async def get_model(): + return {"model": model} + + +@app.post("/generate") +async def generate_image(input: ImageGenInput): + if DEV_MODE: + # For quicker testing or when GPU hardware not available + fn = "test-image.jpg" + seed = "dev" + image = Image.open(fn) + # Uncomment to test error handling + # return JSONResponse({"error": {"message": "Dev mode error test", "seed": "not-so-random"}}, status_code=400) + else: + # Main image generation functionality + image, seed, msg = generator.generate_image( + input.width, + input.height, + input.num_steps, + input.guidance, + input.seed, + input.prompt, + add_sampling_metadata=input.add_sampling_metadata, + ) + if not image: + return JSONResponse({"error": {"message": msg, "seed": seed}}, status_code=400) + # Convert image to bytes response + buffer = io.BytesIO() + image.save(buffer, format="jpeg") + bytes = buffer.getvalue() + return Response(bytes, media_type="image/jpeg", headers={"x-flux-seed": seed}) diff --git a/web-apps/flux-image-gen/gradio_config.yaml b/web-apps/flux-image-gen/gradio_config.yaml new file mode 100644 index 0000000..e6ebc7c --- /dev/null +++ b/web-apps/flux-image-gen/gradio_config.yaml @@ -0,0 +1,5 @@ +models: + - name: flux-schnell + address: http://localhost:8000 +example_prompt: | + Yoda riding a skateboard. diff --git a/web-apps/flux-image-gen/gradio_ui.py b/web-apps/flux-image-gen/gradio_ui.py new file mode 100644 index 0000000..95f09e4 --- /dev/null +++ b/web-apps/flux-image-gen/gradio_ui.py @@ -0,0 +1,127 @@ +import io +import os +import httpx +import uuid +import pathlib +import yaml + +import gradio as gr +from pydantic import BaseModel, HttpUrl +from PIL import Image, ExifTags +from typing import List +from urllib.parse import urljoin + + +class Model(BaseModel): + name: str + address: HttpUrl + +class AppSettings(BaseModel): + models: List[Model] + example_prompt: str + + +settings_path = pathlib.Path("/etc/gradio-app/gradio_config.yaml") +if not settings_path.exists(): + print("No settings overrides found at", settings_path) + settings_path = "./gradio_config.yaml" +print("Using settings from", settings_path) +with open(settings_path, "r") as file: + settings = AppSettings(**yaml.safe_load(file)) +print("App config:", settings.model_dump()) + +MODELS = {m.name: m.address for m in settings.models} +MODEL_NAMES = list(MODELS.keys()) + +# Disable analytics for GDPR compliance +os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" + +def save_image(model_name: str, prompt: str, seed: int, add_sampling_metadata: bool, image: Image.Image): + filename = f"output/gradio/{uuid.uuid4()}.jpg" + os.makedirs(os.path.dirname(filename), exist_ok=True) + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = model_name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + image.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) + return filename + + +async def generate_image( + model_name: str, + width: int, + height: int, + num_steps: int, + guidance: float, + seed: int, + prompt: str, + add_sampling_metadata: bool, +): + url = urljoin(str(MODELS[model_name]), "/generate") + data = { + "width": width, + "height": height, + "num_steps": num_steps, + "guidance": guidance, + "seed": seed, + "prompt": prompt, + "add_sampling_metadata": add_sampling_metadata, + } + async with httpx.AsyncClient(timeout=60) as client: + try: + response = await client.post(url, json=data) + except httpx.ConnectError: + raise gr.Error("Model backend unavailable") + if response.status_code == 400: + data = response.json() + if "error" in data and "message" in data["error"]: + message = data["error"]["message"] + if "seed" in data["error"]: + message += f" (seed: {data['error']['seed']})" + raise gr.Error(message) + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: + # Raise a generic error message to avoid leaking unwanted details + # Admin should consult API logs for more info + raise gr.Error(f"Backend error (HTTP {err.response.status_code})") + image = Image.open(io.BytesIO(response.content)) + seed = response.headers.get("x-flux-seed", "unknown") + filename = save_image(model_name, prompt, seed, add_sampling_metadata, image) + + return image, seed, filename, None + + +with gr.Blocks() as demo: + gr.Markdown("# Flux Image Generation Demo") + + with gr.Row(): + with gr.Column(): + model = gr.Dropdown(MODEL_NAMES, value=MODEL_NAMES[0], label="Model", interactive=len(MODEL_NAMES) > 1) + prompt = gr.Textbox(label="Prompt", value=settings.example_prompt) + + with gr.Accordion("Advanced Options", open=False): + # TODO: Make min/max slide values configurable + width = gr.Slider(128, 8192, 1360, step=16, label="Width") + height = gr.Slider(128, 8192, 768, step=16, label="Height") + num_steps = gr.Slider(1, 50, 4 if model.value == "flux-schnell" else 50, step=1, label="Number of steps") + guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not model.value == "flux-schnell") + seed = gr.Textbox("-1", label="Seed (-1 for random)") + add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True) + + generate_btn = gr.Button("Generate") + + with gr.Column(): + output_image = gr.Image(label="Generated Image") + seed_output = gr.Textbox(label="Used Seed") + warning_text = gr.Textbox(label="Warning", visible=False) + download_btn = gr.File(label="Download full-resolution") + + generate_btn.click( + fn=generate_image, + inputs=[model, width, height, num_steps, guidance, seed, prompt, add_sampling_metadata], + outputs=[output_image, seed_output, download_btn, warning_text], + ) + demo.launch(enable_monitoring=False) diff --git a/web-apps/flux-image-gen/image_gen.py b/web-apps/flux-image-gen/image_gen.py new file mode 100644 index 0000000..28585d1 --- /dev/null +++ b/web-apps/flux-image-gen/image_gen.py @@ -0,0 +1,155 @@ +##### +# Based on demo_gr.py in repo root +##### + +import time + +import torch +import numpy as np +from einops import rearrange +from PIL import Image, ExifTags +from transformers import pipeline + +from flux.cli import SamplingOptions +from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from flux.util import embed_watermark, load_ae, load_clip, load_flow_model, load_t5 + +NSFW_THRESHOLD = 0.85 + +def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): + t5 = load_t5(device, max_length=256 if is_schnell else 512) + clip = load_clip(device) + model = load_flow_model(name, device="cpu" if offload else device) + ae = load_ae(name, device="cpu" if offload else device) + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + return model, ae, t5, clip, nsfw_classifier + +class FluxGenerator: + def __init__(self, model_name: str, device: str, offload: bool): + self.device = torch.device(device) + self.offload = offload + self.model_name = model_name + self.is_schnell = model_name == "flux-schnell" + self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models( + model_name, + device=self.device, + offload=self.offload, + is_schnell=self.is_schnell, + ) + + @torch.inference_mode() + def generate_image( + self, + width, + height, + num_steps, + guidance, + seed, + prompt, + init_image=None, + image2image_strength=0.0, + add_sampling_metadata=True, + ): + seed = int(seed) + if seed == -1: + seed = None + + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + ) + + if opts.seed is None: + opts.seed = torch.Generator(device="cpu").seed() + t0 = time.perf_counter() + + if init_image is not None: + if isinstance(init_image, np.ndarray): + init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0 + init_image = init_image.unsqueeze(0) + init_image = init_image.to(self.device) + init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width)) + if self.offload: + self.ae.encoder.to(self.device) + init_image = self.ae.encode(init_image.to()) + if self.offload: + self.ae = self.ae.cpu() + torch.cuda.empty_cache() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=self.device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + timesteps = get_schedule( + opts.num_steps, + x.shape[-1] * x.shape[-2] // 4, + shift=(not self.is_schnell), + ) + if init_image is not None: + t_idx = int((1 - image2image_strength) * num_steps) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + x = t * x + (1.0 - t) * init_image.to(x.dtype) + + if self.offload: + self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt) + + # offload TEs to CPU, load model to gpu + if self.offload: + self.t5, self.clip = self.t5.cpu(), self.clip.cpu() + torch.cuda.empty_cache() + self.model = self.model.to(self.device) + + # denoise initial noise + x = denoise(self.model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if self.offload: + self.model.cpu() + torch.cuda.empty_cache() + self.ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): + x = self.ae.decode(x) + + if self.offload: + self.ae.decoder.cpu() + torch.cuda.empty_cache() + + t1 = time.perf_counter() + + print(f"Done in {t1 - t0:.1f}s.") + # bring into PIL format + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + nsfw_score = [x["score"] for x in self.nsfw_classifier(img) if x["label"] == "nsfw"][0] # type: ignore + + if nsfw_score < NSFW_THRESHOLD: + exif_data = Image.Exif() + if init_image is None: + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + else: + exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = self.model_name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + + return img, str(opts.seed), None + else: + return None, str(opts.seed), "Your generated image may contain NSFW content." diff --git a/web-apps/flux-image-gen/requirements.txt b/web-apps/flux-image-gen/requirements.txt new file mode 100644 index 0000000..069c65c --- /dev/null +++ b/web-apps/flux-image-gen/requirements.txt @@ -0,0 +1,4 @@ +flux[gradio] @ git+https://github.com/black-forest-labs/flux@478338d +fastapi[standard] +httpx +# ../utils diff --git a/web-apps/flux-image-gen/test-image.jpg b/web-apps/flux-image-gen/test-image.jpg new file mode 100644 index 0000000..e9d9033 Binary files /dev/null and b/web-apps/flux-image-gen/test-image.jpg differ diff --git a/web-apps/flux-image-gen/test_client.py b/web-apps/flux-image-gen/test_client.py new file mode 100644 index 0000000..21d8a69 --- /dev/null +++ b/web-apps/flux-image-gen/test_client.py @@ -0,0 +1,20 @@ +import os +from gradio_client import Client + +address = os.environ.get("GRADIO_HOST", "http://localhost:7860/") +model = os.environ.get("FLUX_MODEL", "flux-schnell") +client = Client(address) +web_page, seed, file_name, err = client.predict( + model_name=model, + # width=1360, + width=3888, + # height=768, + height=2544, + num_steps=4, + guidance=3.5, + seed="-1", + prompt="Yoda riding a skateboard", + add_sampling_metadata=True, + api_name="/generate_image" +) +print('Result saved to:', file_name) diff --git a/web-apps/utils/pyproject.toml b/web-apps/utils/pyproject.toml new file mode 100644 index 0000000..0924410 --- /dev/null +++ b/web-apps/utils/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "utils" +version = "0.0.1" +dependencies = [ + "pydantic", + "structlog", +] diff --git a/web-apps/utils/setup.py b/web-apps/utils/setup.py index 515d709..6068493 100644 --- a/web-apps/utils/setup.py +++ b/web-apps/utils/setup.py @@ -1,8 +1,3 @@ -from setuptools import setup, find_packages +from setuptools import setup -setup( - name='web-app-utils', - version='0.0.1', - py_modules=["utils"], - requires=["pydantic"] -) +setup()