diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index ef4b186b5..2041fc0c2 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -8,6 +8,7 @@ import gooey_gui as gui from bots.models import Workflow +from gooey_ui.components.modal import Modal from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.exceptions import UserError @@ -26,9 +27,13 @@ class AnimationModels(TextChoices): class _AnimationPrompt(TypedDict): frame: str prompt: str + second: float AnimationPrompts = list[_AnimationPrompt] +ZoomSettings: dict[int, float] = {0: 1.004} +HPanSettings: dict[int, float] = {0: 0} +VPanSettings: dict[int, float] = {0: 0} CREDITS_PER_FRAME = 1.5 MODEL_ESTIMATED_TIME_PER_FRAME = 2.4 # seconds @@ -38,26 +43,46 @@ def input_prompt_to_animation_prompts(input_prompt: str): animation_prompts = [] for fp in input_prompt.split("|"): split = fp.split(":") - if len(split) == 2: + if len(split) == 3: frame = int(split[0]) prompt = split[1].strip() + second = float(split[2]) else: frame = 0 prompt = fp - animation_prompts.append({"frame": frame, "prompt": prompt}) + second = 0 + animation_prompts.append({"frame": frame, "prompt": prompt, "second": second}) return animation_prompts def animation_prompts_to_st_list(animation_prompts: AnimationPrompts): - return [ - {"frame": fp["frame"], "prompt": fp["prompt"], "key": str(uuid.uuid1())} - for fp in animation_prompts - ] + if "second" in animation_prompts[0]: + return [ + { + "frame": fp["frame"], + "prompt": fp["prompt"], + "second": fp["second"], + "key": str(uuid.uuid1()), + } + for fp in animation_prompts + ] + else: + return [ + { + "frame": fp["frame"], + "prompt": fp["prompt"], + "second": frames_to_seconds( + int(fp["frame"]), st.session_state.get("fps", 12) + ), + "key": str(uuid.uuid1()), + } + for fp in animation_prompts + ] def st_list_to_animation_prompt(prompt_st_list) -> AnimationPrompts: return [ - {"frame": fp["frame"], "prompt": prompt} + {"frame": fp["frame"], "prompt": prompt, "second": fp["second"]} for fp in prompt_st_list if (prompt := fp["prompt"].strip()) ] @@ -86,38 +111,156 @@ def animation_prompts_editor( View the ‘Details’ drop down menu to get started. """ ) + st.write("#### Step 1: Draft & Refine Keyframes") updated_st_list = [] + col1, col2, col3 = st.columns([2, 9, 2], responsive=False) + max_seconds = st.session_state.get("max_seconds", 10) + with col1: + st.write("Second") + with col2: + st.write("Prompt") + with col3: + st.write("Camera") for idx, fp in enumerate(prompt_st_list): fp_key = fp["key"] frame_key = f"{st_list_key}/frame/{fp_key}" + second_key = f"{st_list_key}/seconds/{fp_key}" prompt_key = f"{st_list_key}/prompt/{fp_key}" - if frame_key not in gui.session_state: - gui.session_state[frame_key] = fp["frame"] - if prompt_key not in gui.session_state: - gui.session_state[prompt_key] = fp["prompt"] + if second_key not in st.session_state: + st.session_state[second_key] = fp["second"] + st.session_state[frame_key] = seconds_to_frames( + st.session_state[second_key], st.session_state.get("fps", 12) + ) + if prompt_key not in st.session_state: + st.session_state[prompt_key] = fp["prompt"] - col1, col2 = gui.columns([8, 3], responsive=False) + col1, col2, col3 = st.columns( + [2, 9, 2], responsive=False, style={"text-align": "center;"} + ) + fps = st.session_state.get("fps", 12) + max_seconds = st.session_state.get("max_seconds", 10) + start = fp["second"] + end = ( + prompt_st_list[idx + 1]["second"] + if idx + 1 < len(prompt_st_list) + else max_seconds + ) with col1: - gui.text_area( - label="*Prompt*", - key=prompt_key, - height=100, - ) - with col2: - gui.number_input( - label="*Frame*", - key=frame_key, + st.number_input( + label="", + key=second_key, min_value=0, - step=1, + step=0.1, + className="gui-input-smaller", ) - if gui.button("🗑️", help=f"Remove Frame {idx + 1}"): + if idx != 0 and st.button( + "🗑️", + help=f"Remove Frame {idx + 1}", + type="tertiary", + style={"float": "left;"}, + ): prompt_st_list.pop(idx) - gui.rerun() - + st.experimental_rerun() + if st.button( + '', + help=f"Insert Frame after Frame {idx + 1}", + type="tertiary", + style={"float": "left;"}, + ): + next_second = round((start + end) / 2, 2) + if next_second > max_seconds: + st.error("Please increase Frame Count") + else: + prompt_st_list.insert( + idx + 1, + { + "frame": seconds_to_frames(next_second, fps), + "prompt": prompt_st_list[idx]["prompt"], + "second": next_second, + "key": str(uuid.uuid1()), + }, + ) + st.experimental_rerun() + with col2: + st.text_area( + label="", + key=prompt_key, + height=100, + ) + with col3: + zoom_pan_modal = Modal("Zoom/Pan", key="modal-" + fp_key) + zoom_value = ZoomSettings.get(fp["frame"]) + hpan_value = HPanSettings.get(fp["frame"]) + vpan_value = VPanSettings.get(fp["frame"]) + zoom_pan_description = "" + if zoom_value: + zoom_pan_description = "Out: " if zoom_value > 1 else "In: " + zoom_pan_description += f"{round(zoom_value, 3)}\n" + if hpan_value: + zoom_pan_description += "Right: " if hpan_value > 1 else "Left: " + zoom_pan_description += f"{round(hpan_value, 3)}\n" + if vpan_value: + zoom_pan_description += "Up: " if vpan_value > 1 else "Down: " + zoom_pan_description += f"{round(vpan_value, 3)}" + if not zoom_pan_description: + zoom_pan_description = '' + if st.button( + zoom_pan_description, + key="button-" + fp_key, + type="link", + ): + zoom_pan_modal.open() + if zoom_pan_modal.is_open(): + with zoom_pan_modal.container(): + st.write( + f"#### Keyframe second {start} until {end}", + ) + st.caption( + f"Starting at second {start} and until second {end}, how do you want the camera to move? (Reasonable valuables would be ±0.005)" + ) + zoom_pan_slider = st.slider( + label=""" + #### Zoom + """, + min_value=-1.5, + max_value=1.5, + step=0.001, + value=0, + ) + hpan_slider = st.slider( + label=""" + #### Horizontal Pan + """, + min_value=-1.5, + max_value=1.5, + step=0.001, + value=0, + ) + vpan_slider = st.slider( + label=""" + #### Vertical Pan + """, + min_value=-1.5, + max_value=1.5, + step=0.001, + value=0, + ) + if st.button("Save"): + ZoomSettings.update({fp["frame"]: 1 + zoom_pan_slider}) + HPanSettings.update({fp["frame"]: hpan_slider}) + VPanSettings.update({fp["frame"]: vpan_slider}) + st.session_state["zoom"] = zoom_pan_to_string(ZoomSettings) + st.session_state["translation_x"] = zoom_pan_to_string( + HPanSettings + ) + st.session_state["translation_y"] = zoom_pan_to_string( + VPanSettings + ) updated_st_list.append( { - "frame": gui.session_state.get(frame_key), - "prompt": gui.session_state.get(prompt_key), + "frame": st.session_state.get(frame_key), + "prompt": st.session_state.get(prompt_key), + "second": st.session_state.get(second_key), "key": fp_key, } ) @@ -158,6 +301,18 @@ def get_last_frame(prompt_list: list) -> int: return max(fp["frame"] for fp in prompt_list) +def frames_to_seconds(frames: int, fps: int) -> float: + return round(frames / int(fps), 2) + + +def seconds_to_frames(seconds: float, fps: int) -> int: + return int(seconds * int(fps)) + + +def zoom_pan_to_string(zoom_dict: dict[int, float]) -> str: + return ", ".join([f"{frame}:({zoom})" for frame, zoom in zoom_dict.items()]) + + DEFAULT_ANIMATION_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/cropped_animation_meta.gif" @@ -180,7 +335,9 @@ class DeforumSDPage(BasePage): selected_model=AnimationModels.protogen_2_2.name, ) - class RequestModel(BasePage.RequestModel): + ZoomSettings = {0: 1.004} + + class RequestModel(BaseModel): # input_prompt: str animation_prompts: AnimationPrompts max_frames: int | None @@ -238,14 +395,32 @@ def render_form_v2(self): """ ) + with col2: + st.write("*End of Video*") + + st.write("#### Step 2: Increase Animation Quality") + st.write( + "Once you like your keyframes, increase your frames per second for high quality" + ) + st.custom_radio( + """###### FPS (Frames per second)""", + options=[2, 10, 24], + format_func=lambda x: { + "2": "Draft: 2 FPS", + "10": "Stop-motion: 10 FPS", + "24": "Film: 24 FPS", + }[str(x)], + key="fps", + ) + def get_cost_note(self) -> str | None: - return f"{CREDITS_PER_FRAME} / frame" + return f"{st.session_state.get('max_frames')} frames @ {CREDITS_PER_FRAME} Cr /frame" def additional_notes(self) -> str | None: return "Render Time ≈ 3s / frame" def get_raw_price(self, state: dict) -> float: - max_frames = state.get("max_frames", 100) or 0 + max_frames = state.get("max_frames", 100) return max_frames * CREDITS_PER_FRAME def validate_form_v2(self):