diff --git a/README.md b/README.md index 0696423..d839f72 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ Replace `` with the path to your model directory. This repository comes with a simple, composable API, so you can programmatically call the model. You can find a full example [here](demos/api_example.py). But, roughly, it looks like this: ```python -from genmo.mochi_preview.pipelines import ( +from src.genmo.mochi_preview.pipelines import ( DecoderModelFactory, DitModelFactory, MochiSingleGPUPipeline, diff --git a/demos/__init__.py b/demos/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demos/api_example.py b/demos/api_example.py index 27d0d9e..be8deb1 100755 --- a/demos/api_example.py +++ b/demos/api_example.py @@ -3,9 +3,9 @@ from pathlib import Path from textwrap import dedent -from genmo.lib.progress import progress_bar -from genmo.lib.utils import save_video -from genmo.mochi_preview.pipelines import ( +from src.genmo.lib.progress import progress_bar +from src.genmo.lib.utils import save_video +from src.genmo.mochi_preview.pipelines import ( DecoderModelFactory, DitModelFactory, MochiSingleGPUPipeline, diff --git a/demos/cli.py b/demos/cli.py index 76ebd28..29e297e 100755 --- a/demos/cli.py +++ b/demos/cli.py @@ -7,9 +7,9 @@ import numpy as np import torch -from genmo.lib.progress import progress_bar -from genmo.lib.utils import save_video -from genmo.mochi_preview.pipelines import ( +from src.genmo.lib.progress import progress_bar +from src.genmo.lib.utils import save_video +from src.genmo.mochi_preview.pipelines import ( DecoderModelFactory, DitModelFactory, MochiMultiGPUPipeline, diff --git a/demos/gradio_ui.py b/demos/gradio_ui.py index b750e9a..92a0b18 100755 --- a/demos/gradio_ui.py +++ b/demos/gradio_ui.py @@ -1,12 +1,13 @@ #! /usr/bin/env python -import sys +import sys, os import click import gradio as gr -sys.path.append("..") +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from cli import configure_model, generate_video with gr.Blocks() as demo: diff --git a/demos/test_encoder_decoder.py b/demos/test_encoder_decoder.py index 6401908..0b5def8 100644 --- a/demos/test_encoder_decoder.py +++ b/demos/test_encoder_decoder.py @@ -6,9 +6,9 @@ from einops import rearrange from safetensors.torch import load_file -from genmo.lib.utils import save_video -from genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial -from genmo.mochi_preview.vae.models import Encoder, add_fourier_features +from src.genmo.lib.utils import save_video +from src.genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial +from src.genmo.mochi_preview.vae.models import Encoder, add_fourier_features @click.command() diff --git a/requirements.txt b/requirements.txt index 005e447..182d56e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,7 @@ ray>=2.37.0 sentencepiece>=0.2.0 setuptools>=75.2.0 torch>=2.4.1 -transformers>=4.45.2 \ No newline at end of file +transformers>=4.45.2 +click>=8.1.7 +huggingface_hub>=0.26.2 +gradio>=5.4.0 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/genmo/__init__.py b/src/genmo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/genmo/lib/__init__.py b/src/genmo/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/genmo/lib/utils.py b/src/genmo/lib/utils.py index 0567c19..0939237 100644 --- a/src/genmo/lib/utils.py +++ b/src/genmo/lib/utils.py @@ -6,7 +6,7 @@ import numpy as np from PIL import Image -from genmo.lib.progress import get_new_progress_bar +from src.genmo.lib.progress import get_new_progress_bar class Timer: @@ -29,7 +29,6 @@ class TimerContextManager: def __init__(self, outer, name): self.outer = outer # Reference to the Timer instance self.name = name - self.start_time = None def __enter__(self): self.start_time = time.perf_counter() diff --git a/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py b/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py index d75946e..9d16be0 100644 --- a/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -7,23 +7,23 @@ from einops import rearrange from torch.nn.attention import sdpa_kernel -import genmo.mochi_preview.dit.joint_model.context_parallel as cp -from genmo.mochi_preview.dit.joint_model.layers import ( +import src.genmo.mochi_preview.dit.joint_model.context_parallel as cp +from src.genmo.mochi_preview.dit.joint_model.layers import ( FeedForward, PatchEmbed, RMSNorm, TimestepEmbedder, ) -from genmo.mochi_preview.dit.joint_model.mod_rmsnorm import modulated_rmsnorm -from genmo.mochi_preview.dit.joint_model.residual_tanh_gated_rmsnorm import ( +from src.genmo.mochi_preview.dit.joint_model.mod_rmsnorm import modulated_rmsnorm +from src.genmo.mochi_preview.dit.joint_model.residual_tanh_gated_rmsnorm import ( residual_tanh_gated_rmsnorm, ) -from genmo.mochi_preview.dit.joint_model.rope_mixed import ( +from src.genmo.mochi_preview.dit.joint_model.rope_mixed import ( compute_mixed_rotation, create_position_matrix, ) -from genmo.mochi_preview.dit.joint_model.temporal_rope import apply_rotary_emb_qk_real -from genmo.mochi_preview.dit.joint_model.utils import ( +from src.genmo.mochi_preview.dit.joint_model.temporal_rope import apply_rotary_emb_qk_real +from src.genmo.mochi_preview.dit.joint_model.utils import ( AttentionPool, modulate, pad_and_split_xy, @@ -33,7 +33,7 @@ COMPILE_FINAL_LAYER = os.environ.get("COMPILE_DIT") == "1" COMPILE_MMDIT_BLOCK = os.environ.get("COMPILE_DIT") == "1" -from genmo.lib.attn_imports import comfy_attn, flash_varlen_qkvpacked_attn, sage_attn, sdpa_attn_ctx +from src.genmo.lib.attn_imports import comfy_attn, flash_varlen_qkvpacked_attn, sage_attn, sdpa_attn_ctx class AsymmetricAttention(nn.Module): diff --git a/src/genmo/mochi_preview/pipelines.py b/src/genmo/mochi_preview/pipelines.py index c51e926..5aaf0fd 100644 --- a/src/genmo/mochi_preview/pipelines.py +++ b/src/genmo/mochi_preview/pipelines.py @@ -30,17 +30,17 @@ from transformers import T5EncoderModel, T5Tokenizer from transformers.models.t5.modeling_t5 import T5Block -import genmo.mochi_preview.dit.joint_model.context_parallel as cp -import genmo.mochi_preview.vae.cp_conv as cp_conv -from genmo.lib.progress import get_new_progress_bar, progress_bar -from genmo.lib.utils import Timer -from genmo.mochi_preview.vae.models import ( +import src.genmo.mochi_preview.dit.joint_model.context_parallel as cp +import src.genmo.mochi_preview.vae.cp_conv as cp_conv +from src.genmo.lib.progress import get_new_progress_bar, progress_bar +from src.genmo.lib.utils import Timer +from src.genmo.mochi_preview.vae.models import ( Decoder, decode_latents, decode_latents_tiled_full, decode_latents_tiled_spatial, ) -from genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents +from src.genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): @@ -121,7 +121,7 @@ def get_model(self, *, local_rank, device_id, world_size): class DitModelFactory(ModelFactory): def __init__(self, *, model_path: str, model_dtype: str, attention_mode: Optional[str] = None): if attention_mode is None: - from genmo.lib.attn_imports import flash_varlen_qkvpacked_attn # type: ignore + from src.genmo.lib.attn_imports import flash_varlen_qkvpacked_attn # type: ignore attention_mode = "sdpa" if flash_varlen_qkvpacked_attn is None else "flash" print(f"Attention mode: {attention_mode}") @@ -131,7 +131,7 @@ def __init__(self, *, model_path: str, model_dtype: str, attention_mode: Optiona def get_model(self, *, local_rank, device_id, world_size): # TODO(ved): Set flag for torch.compile - from genmo.mochi_preview.dit.joint_model.asymm_models_joint import ( + from src.genmo.mochi_preview.dit.joint_model.asymm_models_joint import ( AsymmDiTJoint, ) diff --git a/src/genmo/mochi_preview/vae/cp_conv.py b/src/genmo/mochi_preview/vae/cp_conv.py index aeab1f5..def0f51 100644 --- a/src/genmo/mochi_preview/vae/cp_conv.py +++ b/src/genmo/mochi_preview/vae/cp_conv.py @@ -4,7 +4,7 @@ import torch.distributed as dist import torch.nn.functional as F -import genmo.mochi_preview.dit.joint_model.context_parallel as cp +import src.genmo.mochi_preview.dit.joint_model.context_parallel as cp def cast_tuple(t, length=1): diff --git a/src/genmo/mochi_preview/vae/models.py b/src/genmo/mochi_preview/vae/models.py index 33681d5..c66e5ee 100644 --- a/src/genmo/mochi_preview/vae/models.py +++ b/src/genmo/mochi_preview/vae/models.py @@ -6,11 +6,11 @@ import torch.nn.functional as F from einops import rearrange -import genmo.mochi_preview.dit.joint_model.context_parallel as cp -from genmo.lib.progress import get_new_progress_bar -from genmo.mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames -from genmo.mochi_preview.vae.latent_dist import LatentDistribution -import genmo.mochi_preview.vae.cp_conv as cp_conv +import src.genmo.mochi_preview.dit.joint_model.context_parallel as cp +from src.genmo.lib.progress import get_new_progress_bar +from src.genmo.mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames +from src.genmo.mochi_preview.vae.latent_dist import LatentDistribution +import src.genmo.mochi_preview.vae.cp_conv as cp_conv def cast_tuple(t, length=1):