Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: client file upload #18

Merged
merged 10 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions apps/demo-nextjs-app-router/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
file_size: number;
};
type Result = {
images: Image[];
image: Image;
};
// @snippet:end

type ErrorProps = {
error: any;

Check warning on line 27 in apps/demo-nextjs-app-router/app/page.tsx

View workflow job for this annotation

GitHub Actions / build

Unexpected any. Specify a different type
};

function Error(props: ErrorProps) {
Expand All @@ -42,12 +42,13 @@
}

const DEFAULT_PROMPT =
'a city landscape of a cyberpunk metropolis, raining, purple, pink and teal neon lights, highly detailed, uhd';
'(masterpiece:1.4), (best quality), (detailed), Medieval village scene with busy streets and castle in the distance';

export default function Home() {
// @snippet:start("client.ui.state")
// Input state
const [prompt, setPrompt] = useState<string>(DEFAULT_PROMPT);
const [imageFile, setImageFile] = useState<File | null>(null);
// Result state
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
Expand All @@ -59,7 +60,10 @@
if (!result) {
return null;
}
return result.images[0];
if (result.image) {
return result.image;
}
return null;
}, [result]);

const reset = () => {
Expand All @@ -76,26 +80,29 @@
setLoading(true);
const start = Date.now();
try {
const result: Result = await fal.subscribe('110602490-lora', {
input: {
prompt,
model_name: 'stabilityai/stable-diffusion-xl-base-1.0',
image_size: 'square_hd',
},
pollInterval: 5000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
if (
update.status === 'IN_PROGRESS' ||
update.status === 'COMPLETED'
) {
setLogs((update.logs || []).map((log) => log.message));
}
},
});
const result: Result = await fal.subscribe(
'54285744-illusion-diffusion',
{
input: {
prompt,
image_url: imageFile,
image_size: 'square_hd',
},
pollInterval: 5000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
if (
update.status === 'IN_PROGRESS' ||
update.status === 'COMPLETED'
) {
setLogs((update.logs || []).map((log) => log.message));
}
},
}
);
setResult(result);
} catch (error: any) {

Check warning on line 105 in apps/demo-nextjs-app-router/app/page.tsx

View workflow job for this annotation

GitHub Actions / build

Unexpected any. Specify a different type
setError(error);
} finally {
setLoading(false);
Expand All @@ -109,6 +116,20 @@
<h1 className="text-4xl font-bold mb-8">
Hello <code className="font-light text-pink-600">fal</code>
</h1>
<div className="text-lg w-full">
<label htmlFor="prompt" className="block mb-2 text-current">
Image
</label>
<input
className="w-full text-lg p-2 rounded bg-black/10 dark:bg-white/5 border border-black/20 dark:border-white/10"
id="image_url"
name="image_url"
type="file"
placeholder="Choose a file"
accept="image/*"
onChange={(e) => setImageFile(e.target.files?.[0] ?? null)}
/>
</div>
<div className="text-lg w-full">
<label htmlFor="prompt" className="block mb-2 text-current">
Prompt
Expand Down
2 changes: 1 addition & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.4.2",
"version": "0.5.0",
"license": "MIT",
"repository": {
"type": "git",
Expand Down
51 changes: 10 additions & 41 deletions libs/client/src/function.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { getConfig } from './config';
import { getUserAgent, isBrowser } from './runtime';
import { storageImpl } from './storage';
import { dispatchRequest } from './request';
import { EnqueueResult, QueueStatus } from './types';
import { isUUIDv4, isValidUrl } from './utils';

Expand Down Expand Up @@ -62,7 +63,6 @@ export function buildUrl<Input>(

/**
* Runs a fal serverless function identified by its `id`.
* TODO: expand documentation and provide examples
*
* @param id the registered function revision id or alias.
* @returns the remote function output
Expand All @@ -71,45 +71,14 @@ export async function run<Input, Output>(
id: string,
options: RunOptions<Input> = {}
): Promise<Output> {
const {
credentials: credentialsValue,
requestMiddleware,
responseHandler,
} = getConfig();
const method = (options.method ?? 'post').toLowerCase();
const userAgent = isBrowser() ? {} : { 'User-Agent': getUserAgent() };
const credentials =
typeof credentialsValue === 'function'
? credentialsValue()
: credentialsValue;

const { url, headers } = await requestMiddleware({
url: buildUrl(id, options),
});
const authHeader = credentials ? { Authorization: `Key ${credentials}` } : {};
if (typeof window !== 'undefined' && credentials) {
console.warn(
"The fal credentials are exposed in the browser's environment. " +
"That's not recommended for production use cases."
);
}
const requestHeaders = {
...authHeader,
Accept: 'application/json',
'Content-Type': 'application/json',
...userAgent,
...(headers ?? {}),
} as HeadersInit;
const response = await fetch(url, {
method,
headers: requestHeaders,
mode: 'cors',
body:
method !== 'get' && options.input
? JSON.stringify(options.input)
: undefined,
});
return await responseHandler(response);
const input = options.input
? await storageImpl.transformInput(options.input)
: options.input;
return dispatchRequest<Input, Output>(
options.method ?? 'post',
buildUrl(id, options),
input as Input
);
}

/**
Expand Down
1 change: 1 addition & 0 deletions libs/client/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export { config, getConfig } from './config';
export { storageImpl as storage } from './storage';
export { queue, run, subscribe } from './function';
export { withMiddleware, withProxy } from './middleware';
export type { RequestMiddleware } from './middleware';
Expand Down
47 changes: 47 additions & 0 deletions libs/client/src/request.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { getConfig } from './config';
import { getUserAgent, isBrowser } from './runtime';

export async function dispatchRequest<Input, Output>(
method: string,
targetUrl: string,
input: Input
): Promise<Output> {
const {
credentials: credentialsValue,
requestMiddleware,
responseHandler,
} = getConfig();
const userAgent = isBrowser() ? {} : { 'User-Agent': getUserAgent() };
const credentials =
typeof credentialsValue === 'function'
? credentialsValue()
: credentialsValue;

const { url, headers } = await requestMiddleware({
url: targetUrl,
});
const authHeader = credentials ? { Authorization: `Key ${credentials}` } : {};
if (typeof window !== 'undefined' && credentials) {
console.warn(
"The fal credentials are exposed in the browser's environment. " +
"That's not recommended for production use cases."
);
}
const requestHeaders = {
...authHeader,
Accept: 'application/json',
'Content-Type': 'application/json',
...userAgent,
...(headers ?? {}),
} as HeadersInit;
const response = await fetch(url, {
method,
headers: requestHeaders,
mode: 'cors',
body:
method.toLowerCase() !== 'get' && input
? JSON.stringify(input)
: undefined,
});
return await responseHandler(response);
}
108 changes: 108 additions & 0 deletions libs/client/src/storage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { getConfig } from './config';
import { dispatchRequest } from './request';

/**
* File support for the client. This interface establishes the contract for
* uploading files to the server and transforming the input to replace file
* objects with URLs.
*/
export interface StorageSupport {
/**
* Upload a file to the server. Returns the URL of the uploaded file.
* @param file the file to upload
* @param options optional parameters, such as custom file name
* @returns the URL of the uploaded file
*/
upload: (file: Blob) => Promise<string>;

/**
* Transform the input to replace file objects with URLs. This is used
* to transform the input before sending it to the server and ensures
* that the server receives URLs instead of file objects.
*
* @param input the input to transform.
* @returns the transformed input.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
transformInput: (input: Record<string, any>) => Promise<Record<string, any>>;
}

function isDataUri(uri: string): boolean {
// avoid uri parsing if it doesn't start with data:
if (!uri.startsWith('data:')) {
return false;
}
try {
const url = new URL(uri);
return url.protocol === 'data:';
} catch (_) {
return false;
}
}

type InitiateUploadResult = {
file_url: string;
upload_url: string;
};

type InitiateUploadData = {
file_name: string;
content_type: string | null;
};

function getRestApiUrl(): string {
const { host } = getConfig();
return host.replace('gateway', 'rest');
}

async function initiateUpload(file: Blob): Promise<InitiateUploadResult> {
return await dispatchRequest<InitiateUploadData, InitiateUploadResult>(
'POST',
`https://${getRestApiUrl()}/storage/upload/initiate`,
{
file_name: file.name,
content_type: file.type || 'application/octet-stream',
}
);
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type KeyValuePair = [string, any];

export const storageImpl: StorageSupport = {
upload: async (file: Blob) => {
const { upload_url: uploadUrl, file_url: url } = await initiateUpload(file);
const response = await fetch(uploadUrl, {
method: 'PUT',
body: file,
headers: {
'Content-Type': file.type || 'application/octet-stream',
},
});
const { responseHandler } = getConfig();
await responseHandler(response);
return url;
},

// eslint-disable-next-line @typescript-eslint/no-explicit-any
transformInput: async (input: Record<string, any>) => {
const promises = Object.entries(input).map(async ([key, value]) => {
if (
value instanceof Blob ||
(typeof value === 'string' && isDataUri(value))
) {
let blob = value;
// if string is a data uri, convert to blob
if (typeof value === 'string' && isDataUri(value)) {
const response = await fetch(value);
blob = await response.blob();
}
const url = await storageImpl.upload(blob as Blob);
return [key, url];
}
return [key, value] as KeyValuePair;
});
const results = await Promise.all(promises);
return Object.fromEntries(results);
},
};
Loading