HyDecodeRepo/run.py

61 lines
1.7 KiB
Python

#!/bin/python3
import torch
import safetensors.torch
import json
from vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
import time
from torch.profiler import profile, record_function, ProfilerActivity
def load_vae(path: str, compile_vae: bool):
with open("hy_vae_config.json") as f:
vae_config = json.load(f)
vae_st = safetensors.torch.load_file(path, device=torch.device(0).type)
vae = AutoencoderKLCausal3D.from_config(vae_config)
vae.load_state_dict(vae_st)
vae.requires_grad_(False)
vae = vae.to(torch.device(0), torch.bfloat16)
vae.eval()
if compile_vae:
vae = torch.compile(vae)
return vae
if __name__ == "__main__":
props = torch.cuda.get_device_properties(0)
print(f"Device: {props.name}")
latents = torch.randn((1, 16, 19, 120, 68)).to(torch.device(0), torch.bfloat16)
print(f"Latent dims: {latents.size()}")
print("load vae")
start = time.perf_counter()
vae = load_vae("hunyuan_video_vae_bf16.safetensors", False)
print(f"loaded vae in {time.perf_counter() - start} seconds")
vae.t_tile_overlap_factor = 0.25
vae.tile_latent_min_tsize = 16
vae.tile_sample_min_size = 256
vae.tile_latent_min_size = 32
vae.enable_tiling()
print("decodeing")
with profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], record_shapes=True, with_flops=True) as prof:
start = time.perf_counter()
generator = torch.Generator(device=torch.device("cpu"))
decoded = vae.decode(
latents, return_dict=False, generator=generator
)[0]
print(f"decoded in {time.perf_counter() - start} seconds")
print(f"decoded dims: {decoded.size()}")
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=100))
prof.export_chrome_trace("trace.json")