Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add libs click and huggingface_hub #61

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Replace `<path_to_downloaded_directory>` with the path to your model directory.
This repository comes with a simple, composable API, so you can programmatically call the model. You can find a full example [here](demos/api_example.py). But, roughly, it looks like this:

```python
from genmo.mochi_preview.pipelines import (
from src.genmo.mochi_preview.pipelines import (
DecoderModelFactory,
DitModelFactory,
MochiSingleGPUPipeline,
Expand Down
Empty file added demos/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions demos/api_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from pathlib import Path
from textwrap import dedent

from genmo.lib.progress import progress_bar
from genmo.lib.utils import save_video
from genmo.mochi_preview.pipelines import (
from src.genmo.lib.progress import progress_bar
from src.genmo.lib.utils import save_video
from src.genmo.mochi_preview.pipelines import (
DecoderModelFactory,
DitModelFactory,
MochiSingleGPUPipeline,
Expand Down
6 changes: 3 additions & 3 deletions demos/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import numpy as np
import torch

from genmo.lib.progress import progress_bar
from genmo.lib.utils import save_video
from genmo.mochi_preview.pipelines import (
from src.genmo.lib.progress import progress_bar
from src.genmo.lib.utils import save_video
from src.genmo.mochi_preview.pipelines import (
DecoderModelFactory,
DitModelFactory,
MochiMultiGPUPipeline,
Expand Down
5 changes: 3 additions & 2 deletions demos/gradio_ui.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#! /usr/bin/env python


import sys
import sys, os

import click
import gradio as gr

sys.path.append("..")
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from cli import configure_model, generate_video

with gr.Blocks() as demo:
Expand Down
6 changes: 3 additions & 3 deletions demos/test_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from einops import rearrange
from safetensors.torch import load_file

from genmo.lib.utils import save_video
from genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial
from genmo.mochi_preview.vae.models import Encoder, add_fourier_features
from src.genmo.lib.utils import save_video
from src.genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial
from src.genmo.mochi_preview.vae.models import Encoder, add_fourier_features


@click.command()
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ ray>=2.37.0
sentencepiece>=0.2.0
setuptools>=75.2.0
torch>=2.4.1
transformers>=4.45.2
transformers>=4.45.2
click>=8.1.7
huggingface_hub>=0.26.2
gradio>=5.4.0
Empty file added src/__init__.py
Empty file.
Empty file added src/genmo/__init__.py
Empty file.
Empty file added src/genmo/lib/__init__.py
Empty file.
3 changes: 1 addition & 2 deletions src/genmo/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from PIL import Image

from genmo.lib.progress import get_new_progress_bar
from src.genmo.lib.progress import get_new_progress_bar


class Timer:
Expand All @@ -29,7 +29,6 @@ class TimerContextManager:
def __init__(self, outer, name):
self.outer = outer # Reference to the Timer instance
self.name = name
self.start_time = None

def __enter__(self):
self.start_time = time.perf_counter()
Expand Down
16 changes: 8 additions & 8 deletions src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
from einops import rearrange
from torch.nn.attention import sdpa_kernel

import genmo.mochi_preview.dit.joint_model.context_parallel as cp
from genmo.mochi_preview.dit.joint_model.layers import (
import src.genmo.mochi_preview.dit.joint_model.context_parallel as cp
from src.genmo.mochi_preview.dit.joint_model.layers import (
FeedForward,
PatchEmbed,
RMSNorm,
TimestepEmbedder,
)
from genmo.mochi_preview.dit.joint_model.mod_rmsnorm import modulated_rmsnorm
from genmo.mochi_preview.dit.joint_model.residual_tanh_gated_rmsnorm import (
from src.genmo.mochi_preview.dit.joint_model.mod_rmsnorm import modulated_rmsnorm
from src.genmo.mochi_preview.dit.joint_model.residual_tanh_gated_rmsnorm import (
residual_tanh_gated_rmsnorm,
)
from genmo.mochi_preview.dit.joint_model.rope_mixed import (
from src.genmo.mochi_preview.dit.joint_model.rope_mixed import (
compute_mixed_rotation,
create_position_matrix,
)
from genmo.mochi_preview.dit.joint_model.temporal_rope import apply_rotary_emb_qk_real
from genmo.mochi_preview.dit.joint_model.utils import (
from src.genmo.mochi_preview.dit.joint_model.temporal_rope import apply_rotary_emb_qk_real
from src.genmo.mochi_preview.dit.joint_model.utils import (
AttentionPool,
modulate,
pad_and_split_xy,
Expand All @@ -33,7 +33,7 @@
COMPILE_FINAL_LAYER = os.environ.get("COMPILE_DIT") == "1"
COMPILE_MMDIT_BLOCK = os.environ.get("COMPILE_DIT") == "1"

from genmo.lib.attn_imports import comfy_attn, flash_varlen_qkvpacked_attn, sage_attn, sdpa_attn_ctx
from src.genmo.lib.attn_imports import comfy_attn, flash_varlen_qkvpacked_attn, sage_attn, sdpa_attn_ctx


class AsymmetricAttention(nn.Module):
Expand Down
16 changes: 8 additions & 8 deletions src/genmo/mochi_preview/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.t5.modeling_t5 import T5Block

import genmo.mochi_preview.dit.joint_model.context_parallel as cp
import genmo.mochi_preview.vae.cp_conv as cp_conv
from genmo.lib.progress import get_new_progress_bar, progress_bar
from genmo.lib.utils import Timer
from genmo.mochi_preview.vae.models import (
import src.genmo.mochi_preview.dit.joint_model.context_parallel as cp
import src.genmo.mochi_preview.vae.cp_conv as cp_conv
from src.genmo.lib.progress import get_new_progress_bar, progress_bar
from src.genmo.lib.utils import Timer
from src.genmo.mochi_preview.vae.models import (
Decoder,
decode_latents,
decode_latents_tiled_full,
decode_latents_tiled_spatial,
)
from genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents
from src.genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents


def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
Expand Down Expand Up @@ -121,7 +121,7 @@ def get_model(self, *, local_rank, device_id, world_size):
class DitModelFactory(ModelFactory):
def __init__(self, *, model_path: str, model_dtype: str, attention_mode: Optional[str] = None):
if attention_mode is None:
from genmo.lib.attn_imports import flash_varlen_qkvpacked_attn # type: ignore
from src.genmo.lib.attn_imports import flash_varlen_qkvpacked_attn # type: ignore

attention_mode = "sdpa" if flash_varlen_qkvpacked_attn is None else "flash"
print(f"Attention mode: {attention_mode}")
Expand All @@ -131,7 +131,7 @@ def __init__(self, *, model_path: str, model_dtype: str, attention_mode: Optiona

def get_model(self, *, local_rank, device_id, world_size):
# TODO(ved): Set flag for torch.compile
from genmo.mochi_preview.dit.joint_model.asymm_models_joint import (
from src.genmo.mochi_preview.dit.joint_model.asymm_models_joint import (
AsymmDiTJoint,
)

Expand Down
2 changes: 1 addition & 1 deletion src/genmo/mochi_preview/vae/cp_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.distributed as dist
import torch.nn.functional as F

import genmo.mochi_preview.dit.joint_model.context_parallel as cp
import src.genmo.mochi_preview.dit.joint_model.context_parallel as cp


def cast_tuple(t, length=1):
Expand Down
10 changes: 5 additions & 5 deletions src/genmo/mochi_preview/vae/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import torch.nn.functional as F
from einops import rearrange

import genmo.mochi_preview.dit.joint_model.context_parallel as cp
from genmo.lib.progress import get_new_progress_bar
from genmo.mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
from genmo.mochi_preview.vae.latent_dist import LatentDistribution
import genmo.mochi_preview.vae.cp_conv as cp_conv
import src.genmo.mochi_preview.dit.joint_model.context_parallel as cp
from src.genmo.lib.progress import get_new_progress_bar
from src.genmo.mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
from src.genmo.mochi_preview.vae.latent_dist import LatentDistribution
import src.genmo.mochi_preview.vae.cp_conv as cp_conv


def cast_tuple(t, length=1):
Expand Down