import subprocess
import modal

image = (  # build up a Modal Image to run ComfyUI, step by step
    modal.Image.debian_slim(python_version="3.11")  # start from basic Linux with Python
    .apt_install("git")  # install git to clone ComfyUI
    .pip_install("fastapi[standard]==0.115.4")  # install web dependencies
    .pip_install("comfy-cli==1.4.1")  # install comfy-cli
    .run_commands(  # use comfy-cli to install ComfyUI and its dependencies
        "comfy --skip-prompt install --fast-deps --nvidia --version 0.3.47"
    )
)

## Downloading custom nodes
image = image.run_commands(
    "comfy node install --fast-deps [email protected]",
    "git clone <https://github.com/ChenDarYen/ComfyUI-NAG.git> /root/comfy/ComfyUI/custom_nodes/ComfyUI-NAG",
    "git clone <https://github.com/kijai/ComfyUI-KJNodes.git> /root/comfy/ComfyUI/custom_nodes/ComfyUI-KJNodes",
    "git clone <https://github.com/cubiq/ComfyUI_essentials.git> /root/comfy/ComfyUI/custom_nodes/ComfyUI_essentials",
)

def hf_download():
    from huggingface_hub import hf_hub_download

    wan_model = hf_hub_download(
        repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
        filename="split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors",
        cache_dir="/cache",
    )

    subprocess.run(
        f"ln -s {wan_model} /root/comfy/ComfyUI/models/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors",
        shell=True,
        check=True,
    )

    vae_model = hf_hub_download(
        repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
        filename="split_files/vae/wan2.2_vae.safetensors",
        cache_dir="/cache",
    )

    subprocess.run(
        f"ln -s {vae_model} /root/comfy/ComfyUI/models/vae/wan2.2_vae.safetensors",
        shell=True,
        check=True,
    )

    # <https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors>
    text_encoder_model = hf_hub_download(
        repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged",
        filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
        cache_dir="/cache",
    )

    subprocess.run(
        f"ln -s {text_encoder_model} /root/comfy/ComfyUI/models/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
        shell=True,
        check=True,
    )

vol = modal.Volume.from_name("hf-hub-cache", create_if_missing=True)

image = (
    # install huggingface_hub with hf_transfer support to speed up downloads
    image.pip_install("huggingface_hub[hf_transfer]>=0.34.0,<1.0")
    .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
    .run_function(
        hf_download,
        # persist the HF cache to a Modal Volume so future runs don't re-download models
        volumes={"/cache": vol},
    )
)

# ## Running ComfyUI interactively

# Spin up an interactive ComfyUI server by wrapping the `comfy launch` command in a Modal Function
# and serving it as a [web server](<https://modal.com/docs/guide/webhooks#non-asgi-web-servers>).

app = modal.App(name="example-comfyui", image=image)

@app.function(
    max_containers=1,  # limit interactive session to 1 container
    gpu="L40S",  # good starter GPU for inference
    volumes={"/cache": vol},  # mounts our cached models
)
@modal.concurrent(
    max_inputs=10
)  # required for UI startup process which runs several API calls concurrently
@modal.web_server(8000, startup_timeout=60)
def ui():
    subprocess.Popen("comfy launch -- --listen 0.0.0.0 --port 8000", shell=True)