From ba1cd298e793be759fdcb2fe9312c52ef8d8aeba Mon Sep 17 00:00:00 2001 From: matt-bornstein <26883865+matt-bornstein@users.noreply.github.com> Date: Mon, 11 Sep 2023 09:44:45 -0700 Subject: [PATCH] first working version --- .env.example | 15 - .vscode/launch.json | 28 ++ components/pills.js | 1 + lib/promptFormatter.js | 5 + lib/sdxlFineTunes.js | 95 +++++ pages/api/predictions/sdxlfinetunes.js | 207 +++++++++ pages/api/sdxlfinetunes/index.js | 8 + pages/sdxlfinetunes.js | 556 +++++++++++++++++++++++++ yarn.lock | 189 ++++----- 9 files changed, 990 insertions(+), 114 deletions(-) delete mode 100644 .env.example create mode 100644 .vscode/launch.json create mode 100644 lib/promptFormatter.js create mode 100644 lib/sdxlFineTunes.js create mode 100644 pages/api/predictions/sdxlfinetunes.js create mode 100644 pages/api/sdxlfinetunes/index.js create mode 100644 pages/sdxlfinetunes.js diff --git a/.env.example b/.env.example deleted file mode 100644 index 78973de..0000000 --- a/.env.example +++ /dev/null @@ -1,15 +0,0 @@ -# create a https://replicate.com/ account and set the api token -REPLICATE_API_TOKEN= - -# https://platform.openai.com/onboarding -OPENAI_API_KEY= - -# setup a supabase account and create a new project https://supabase.com/docs/guides/getting-started/quickstarts/nextjs -# this is only required if you want your outputs to persist, or if you're using controlnet (input image uploads need to be saved to a DB) -SUPABASE_SERVICE_ROLE= -SUPABASE_URL= -SUPABASE_JWT= -SUPABASE_KEY= - -# install https://ngrok.com/, run it, and set the host here -NGROK_HOST= diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..7a04397 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,28 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Next.js: debug server-side", + "type": "node-terminal", + "request": "launch", + "command": "npm run dev" + }, + { + "name": "Next.js: debug client-side", + "type": "chrome", + "request": "launch", + "url": "http://localhost:3000" + }, + { + "name": "Next.js: debug full stack", + "type": "node-terminal", + "request": "launch", + "command": "npm run dev", + "serverReadyAction": { + "pattern": "started server on .+, url: (https?://.+)", + "uriFormat": "%s", + "action": "debugWithChrome" + } + } + ] +} \ No newline at end of file diff --git a/components/pills.js b/components/pills.js index b5fdbe1..95a0af5 100644 --- a/components/pills.js +++ b/components/pills.js @@ -3,6 +3,7 @@ import { useRouter } from "next/router"; const tabs = [ { name: "Text to Image", href: "/", current: true }, + { name: "SDXL Fine-Tunes", href: "/sdxlfinetunes", current: false }, { name: "ControlNet", href: "/controlnet", current: false }, { name: "X/Y plot", href: "/xyplot", current: false }, ]; diff --git a/lib/promptFormatter.js b/lib/promptFormatter.js new file mode 100644 index 0000000..154aa20 --- /dev/null +++ b/lib/promptFormatter.js @@ -0,0 +1,5 @@ +export const LLAMA2_SYSTEM_PROMPT = + `[INST] <> + You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. + <>` \ No newline at end of file diff --git a/lib/sdxlFineTunes.js b/lib/sdxlFineTunes.js new file mode 100644 index 0000000..c38018e --- /dev/null +++ b/lib/sdxlFineTunes.js @@ -0,0 +1,95 @@ +import Replicate from "replicate"; + +export default async function getFineTunes() { + const replicate = new Replicate({ + auth: process.env.REPLICATE_API_TOKEN, + }); + + const promptTemplates = new Map([ + [ + "sdxl-tron", + "{prompt} in the style of TRN" + ], + [ + "sdxl-barbie", + "{prompt} in the style of TOK" + ], + [ + "sdxl-woolitize", + "{prompt} in the style of TOK, made of wool, focus blur" + ], + [ + "sdxl-sonic-2", + "A screenshot in the style of TOK, pixel art, 2d platform game, sharp, {prompt}" + ], + [ + "sdxl-70s-scifi", + "{prompt}, in the style of TOK" + ], + [ + "sdxl-gta-v", + "video game art, in the style of TOK, {prompt}" + ], + [ + "iwan-baan-sdxl", + "{prompt} in the style of TOK" + ], + [ + "loteria", + "a {prompt} card in the style of TOK" + ], + [ + "sdxl-vision-pro", + "a photo of {prompt} wearing a TOK VR headset, faces visible" + ], + [ + "nammeh", + "Photo in style of NAMMEH, {prompt}" + ], + [ + "sdxl-davinci", + "a drawing of {prompt} in the style of TOK" + ], + [ + "sdxl-cross-section", + "A cross section TOK of {prompt}" + ], + [ + "sdxl-money", + "{prompt} on a bank note in the style of TOK" + ], + [ + "sdxl-illusions", + "{prompt} in the style of TOK" + ] + ]); + + const collectionResponse = await replicate.collections.get("sdxl-fine-tunes") + return collectionResponse.models.map((m, i) => { + const model = { + id: i, + owner: m.owner, + name: m.name, + default_params: m.default_example.input, + prompt_template: promptTemplates.get(m.name), + prompt_example: m.default_example.input.prompt, + version: m.latest_version.id, + checked: false, + source: "replicate", + url: m.url, + description: m.description, + links: [ + { + name: "replicate", + url: m.url + } + ], + // new fields + cover_image_url: m.cover_image_url, + } + + delete model.default_params.prompt + + return model + }) +} \ No newline at end of file diff --git a/pages/api/predictions/sdxlfinetunes.js b/pages/api/predictions/sdxlfinetunes.js new file mode 100644 index 0000000..db71132 --- /dev/null +++ b/pages/api/predictions/sdxlfinetunes.js @@ -0,0 +1,207 @@ +import { Configuration, OpenAIApi } from "openai"; +import upsertPrediction from "../../../lib/upsertPrediction"; +import getFineTunes from "../../../lib/sdxlFineTunes" +import packageData from "../../../package.json"; +import fetch from "node-fetch"; +import Replicate from "replicate"; +import { LLAMA2_SYSTEM_PROMPT } from "../../../lib/promptFormatter"; + +const REPLICATE_API_HOST = "https://api.replicate.com"; +const STABILITY_API_HOST = "https://api.stability.ai"; + +const WEBHOOK_HOST = process.env.VERCEL_URL + ? `https://${process.env.VERCEL_URL}` + : process.env.NGROK_HOST; + +const configuration = new Configuration({ + apiKey: process.env.OPENAI_API_KEY, +}); + +const replicate = new Replicate({ + auth: process.env.REPLICATE_API_TOKEN, +}); + +const openai = new OpenAIApi(configuration); + +export default async function handler(req, res) { + if (!process.env.REPLICATE_API_TOKEN) { + throw new Error( + "The REPLICATE_API_TOKEN environment variable is not set. See README.md for instructions on how to set it." + ); + } + + if (!WEBHOOK_HOST) { + throw new Error( + "WEBHOOK HOST is not set. If you're on local, make sure you set it to an ngrok url. If this doesn't exist, replicate predictions won't save to DB." + ); + } + + const availableModels = await getFineTunes() + + if (req.body.source == "replicate") { + console.log("host", WEBHOOK_HOST); + + const searchParams = new URLSearchParams({ + submission_id: req.body.submission_id, + model: req.body.model, + anon_id: req.body.anon_id, + source: req.body.source, + }); + + let renderedPrompt + if (req.body.prompt_template) { + renderedPrompt = req.body.prompt_template.replace("{prompt}", req.body.prompt_raw) + } + else if (req.body.prompt_example) { + const llamaModel = "meta/llama-2-70b-chat:35042c9a33ac8fd5e29e27fb3197f33aa483f72c2ce3b0b9d201155c7fd2a287" + const llamaInputs = { + input: { + prompt: + `This is a prompt for an image generation model: +${req.body.prompt_example} + +Please generate a new prompt, where the subject is replaced with the new subject "${req.body.prompt_raw}". + +It's important to keep the fine-tuning token identifier. + +Just provide the answer, not any extra commentary! And put the answer between "~" symbols.`, + system_prompt: LLAMA2_SYSTEM_PROMPT, + temperature: 0.01, + max_new_tokens: 500, + top_p: 1 + } + } + try { + const llamaResponse = await replicate.run(llamaModel, llamaInputs) + renderedPrompt = llamaResponse.join('').match(/~([^~]+)~/)[1] + renderedPrompt = renderedPrompt.replace("\\", "").trim() + } + catch (e) { + console.log("error rendering prompt using llama2", e) + renderedPrompt = req.body.prompt_raw + } + } + else { + renderedPrompt = req.body.prompt_raw; + } + + const input = req.body.input; + const body = JSON.stringify({ + input: { + prompt: renderedPrompt, + ...req.body.default_params, + ...input, + }, + version: req.body.version, + webhook: `${WEBHOOK_HOST}/api/replicate-webhook?${searchParams}`, + webhook_events_filter: ["start", "completed"], + }); + + const headers = { + Authorization: `Token ${process.env.REPLICATE_API_TOKEN}`, + "Content-Type": "application/json", + "User-Agent": `${packageData.name}/${packageData.version}`, + }; + + const response = await fetch(`${REPLICATE_API_HOST}/v1/predictions`, { + method: "POST", + headers, + body, + }); + + if (response.status !== 201) { + let error = await response.json(); + res.statusCode = 500; + res.end(JSON.stringify({ detail: error.detail })); + return; + } + + const prediction = await response.json(); + res.statusCode = 201; + res.end(JSON.stringify(prediction)); + } else if (req.body.source == "openai") { + const response = await openai.createImage({ + prompt: req.body.prompt, + ...req.body.default_params, + }); + + const prediction = { + id: req.body.id, + status: "succeeded", + version: "dall-e", + output: [response.data.data[0].url], + input: { prompt: req.body.prompt }, + model: req.body.model, + inserted_at: new Date(), + created_at: new Date(), + submission_id: req.body.submission_id, + source: req.body.source, + anon_id: req.body.anon_id, + }; + + await upsertPrediction(prediction); + + res.statusCode = 201; + res.end(JSON.stringify(prediction)); + } else if (req.body.source == "stability") { + const apiKey = process.env.STABILITY_API_KEY; + if (!apiKey) throw new Error("Missing Stability API key."); + + const engineId = "stable-diffusion-xl-1024-v0-9"; + const seed = Math.floor(Math.random() * 1000000); + const prompt = req.body.prompt; + + const response = await fetch( + `${STABILITY_API_HOST}/v1/generation/${engineId}/text-to-image`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify({ + text_prompts: [ + { + text: prompt, + }, + ], + ...req.body.default_params, + seed, + }), + } + ); + + const responseJSON = await response.json(); + + if (!response.ok) { + throw new Error(`Non-200 response: ${await response.text()}`); + } + + console.log( + `data is ${JSON.stringify(Object.keys(responseJSON.artifacts[0]))}` + ); + + const prediction = { + id: req.body.id, + status: "succeeded", + version: "stability", + output: [responseJSON.artifacts[0].base64], + input: { prompt: req.body.prompt }, + model: req.body.model, + inserted_at: new Date(), + created_at: new Date(), + submission_id: req.body.submission_id, + source: req.body.source, + anon_id: req.body.anon_id, + seed: seed, + }; + await upsertPrediction(prediction); + + // ask charlie or jesse about this + delete prediction.output; + + res.statusCode = 201; + res.end(JSON.stringify(prediction)); + } +} diff --git a/pages/api/sdxlfinetunes/index.js b/pages/api/sdxlfinetunes/index.js new file mode 100644 index 0000000..ebfc124 --- /dev/null +++ b/pages/api/sdxlfinetunes/index.js @@ -0,0 +1,8 @@ +// import Replicate from "replicate"; +import getFineTunes from "../../../lib/sdxlFineTunes" + +export default async function handler(req, res) { + const models = await getFineTunes() + + res.end(JSON.stringify(models)) +} \ No newline at end of file diff --git a/pages/sdxlfinetunes.js b/pages/sdxlfinetunes.js new file mode 100644 index 0000000..b942d89 --- /dev/null +++ b/pages/sdxlfinetunes.js @@ -0,0 +1,556 @@ +import { useState, useEffect } from "react"; +import Prediction from "../components/prediction"; +import Popup from "../components/popup"; +import ZooHead from "../components/zoo-head"; +import ExternalLink from "../components/external-link"; +import promptmaker from "promptmaker"; +import Link from "next/link"; +import { v4 as uuidv4 } from "uuid"; +import { useRouter } from "next/router"; +import Pills from "../components/pills"; + +const HOST = process.env.VERCEL_URL + ? `https://${process.env.VERCEL_URL}` + : "http://localhost:3000"; + +import seeds from "../lib/seeds.js"; + +const sleep = (ms) => new Promise((r) => setTimeout(r, ms)); + +export default function Home({ baseUrl, submissionPredictions, availableModels }) { + const router = useRouter(); + const { id } = router.query; + const [prompt, setPrompt] = useState(""); + const [predictions, setPredictions] = useState([]); + const [error, setError] = useState(null); + const [numOutputs, setNumOutputs] = useState(3); + const [firstTime, setFirstTime] = useState(false); + const [models, setModels] = useState([]); + const [anonId, setAnonId] = useState(null); + const [loading, setLoading] = useState(true); + const [numRuns, setNumRuns] = useState(1); + const [popupOpen, setPopupOpen] = useState(false); + + async function getPredictionsFromSeed(seed) { + const response = await fetch(`/api/submissions/${seed}`, { + method: "GET", + }); + submissionPredictions = await response.json(); + setPredictions(submissionPredictions); + + // get the model names from the predictions, and update which ones are checked + const modelNames = getModelsFromPredictions(submissionPredictions); + updateCheckedModels(modelNames); + + // get the prompt from the predictions, and update the prompt + const submissionPrompt = getPromptFromPredictions(submissionPredictions); + setPrompt(submissionPrompt); + setLoading(false); + } + + function getPromptFromPredictions(predictions) { + if (predictions.length == 0) { + return ""; + } + return predictions[0].input.prompt; + } + + function getModelsFromPredictions(predictions) { + return predictions.map((p) => p.model); + } + + function predictionsStillRunning(predictions) { + return predictions.some((p) => p.status != "succeeded"); + } + + const updateCheckedModels = (modelNames) => { + // Create a new array where each model's `checked` value is updated + const updatedModels = availableModels.map((model) => { + // If the model's name is in the list of names, set `checked` to true, else set it to false + return { + ...model, + checked: modelNames.includes(model.name), + }; + }); + + // Update the state with the new array + setModels(updatedModels); + }; + + function getSelectedModels() { + return models.filter((m) => m.checked); + } + + function getPredictionsByVersion(version) { + return predictions.filter((p) => p.version === version); + } + + const handleCheckboxChange = (e) => { + const modelId = parseInt(e.target.value, 10); + + // Update the checked flag for the model with the matching modelId + const updatedModels = models.map((model) => { + if (model.id === modelId) { + return { + ...model, + checked: e.target.checked, + }; + } + return model; + }); + + // Set the new models array + setModels(updatedModels); + + // save to local storage + localStorage.setItem("models", JSON.stringify(updatedModels)); + }; + + // cmd + enter to submit + const onKeyDown = (e) => { + if (e.metaKey && e.which === 13) { + handleSubmit(e, prompt); + } + }; + + function ogParams() { + return new URLSearchParams({ + done: !predictionsStillRunning(predictions), + prompt: getPromptFromPredictions(submissionPredictions), + ids: submissionPredictions.map((prediction) => prediction.id).join(","), + }); + } + + async function postPrediction(prompt, model, submissionId) { + model.default_params.width = 1024 + model.default_params.height = 1024 + model.default_params.num_outputs = 1 + return fetch("/api/predictions/sdxlfinetunes", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + prompt_raw: prompt, + prompt_template: model.prompt_template, + prompt_example: model.prompt_example, + version: model.version, + source: model.source, + model: model.name, + default_params: model.default_params, + anon_id: anonId, + submission_id: submissionId, + ...(model.source != "replicate" && { id: uuidv4() }), + ...(model.source != "replicate" && { + created_at: new Date().toISOString(), + }), + }), + }); + } + + async function createReplicatePrediction(prompt, model, submissionId) { + const response = await postPrediction(prompt, model, submissionId); + let prediction = await response.json(); + + if (response.status !== 201) { + throw new Error(prediction.detail); + } + + while ( + prediction.status !== "succeeded" && + prediction.status !== "failed" + ) { + await sleep(500); + const response = await fetch("/api/predictions/" + prediction.id); + prediction = await response.json(); + console.log(prediction); + if (response.status !== 200) { + throw new Error(prediction.detail); + } + } + + prediction.model = model.name; + prediction.source = model.source; + + return prediction; + } + + async function createDallePrediction(prompt, model, submissionId) { + const response = await postPrediction(prompt, model, submissionId); + + let prediction = await response.json(); + prediction.source = model.source; + prediction.version = model.version; + + return prediction; + } + + const handleSubmit = async (e, prompt) => { + e.preventDefault(); + setError(null); + setFirstTime(false); + + // update num runs and save to local storage + const newNumRuns = Number(numRuns) + 1; + setNumRuns(newNumRuns); + localStorage.setItem("numRuns", newNumRuns); + + const hasClosedPopup = localStorage.getItem("hasClosedPopup"); + + if (!hasClosedPopup && newNumRuns != 0 && newNumRuns % 10 == 0) { + setPopupOpen(true); + } + + const submissionId = uuidv4(); + + for (const model of getSelectedModels()) { + // Use the model variable to generate predictions with the selected model + for (let i = 0; i < numOutputs; i++) { + let promise = null; + + if (model.source == "replicate") { + promise = createReplicatePrediction(prompt, model, submissionId); + } else if (model.source == "openai") { + promise = createDallePrediction(prompt, model, submissionId); + } else if (model.source == "stability") { + promise = createDallePrediction(prompt, model, submissionId); + } + + promise.model = model.name; + promise.source = model.source; + promise.version = model.version; + + setPredictions((prev) => [...prev, promise]); + + promise + .then((result) => { + setPredictions((prev) => + prev.map((x) => (x === promise ? result : x)) + ); + }) + .catch((error) => setError(error.message)); + } + } + + // push router to new page + router.query.id = submissionId; + router.push(router); + }; + + function checkOrder(list1, list2) { + // Check if both lists are of the same length + if (list1.length !== list2.length) { + return false; + } + + // Check if names are in the same order + for (let i = 0; i < list1.length; i++) { + if (list1[i].name !== list2[i].name) { + return false; + } + } + + // If we made it here, the names are in the same order + return true; + } + + useEffect(() => { + console.log( + submissionPredictions.map((prediction) => prediction.id).join(",") + ); + const anonId = localStorage.getItem("anonId"); + const storedModels = localStorage.getItem("models"); + setLoading(true); + + // if the page has an id set + if (id) { + setPredictions(submissionPredictions); + + // get the model names from the predictions, and update which ones are checked + const modelNames = getModelsFromPredictions(submissionPredictions); + updateCheckedModels(modelNames); + + // get the prompt from the predictions, and update the prompt + const submissionPrompt = getPromptFromPredictions(submissionPredictions); + setPrompt(submissionPrompt); + + setLoading(false); + } else { + // load random seed + if (router.isReady) { + const seed = seeds[Math.floor(Math.random() * seeds.length)]; + + getPredictionsFromSeed(seed); + router.query.id = seed; + router.push(router); + } + } + + // get number of runs from local storage + const storedNumRuns = localStorage.getItem("numRuns"); + if (storedNumRuns) { + setNumRuns(storedNumRuns); + } else { + localStorage.setItem("numRuns", numRuns); + } + + // setup id + if (!anonId) { + const uuid = uuidv4(); + localStorage.setItem("anonId", uuid); + setAnonId(uuid); + setFirstTime(true); + } else { + console.log("returning user: ", anonId); + setAnonId(anonId); + } + }, []); + + console.log("predictions: ", predictions); + + return ( +
+ {/* Header */} + 0 + ? getPromptFromPredictions(submissionPredictions) + : "Compare text-to-image models like Stable Diffusion and DALL-E" + } + ogImage={`${baseUrl}/api/og?${ogParams()}`} + /> + + {/* Tabs */} + + + + {/* Welcome message */} +
+
+
+
+ {firstTime && ( + + Welcome to the Zoo, a playground for text to image models.{" "} + + )} + + What do you want to see? + +
+
+
+
+ + {/* Prompt input */} +
+
handleSubmit(e, prompt)} + > +
+ {" "} +
+