#!/bin/python3 import torch import safetensors.torch import json from vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D import time from torch.profiler import profile, record_function, ProfilerActivity def load_vae(path: str, compile_vae: bool): with open("hy_vae_config.json") as f: vae_config = json.load(f) vae_st = safetensors.torch.load_file(path, device=torch.device(0).type) vae = AutoencoderKLCausal3D.from_config(vae_config) vae.load_state_dict(vae_st) vae.requires_grad_(False) vae = vae.to(torch.device(0), torch.bfloat16) vae.eval() if compile_vae: vae = torch.compile(vae) return vae if __name__ == "__main__": props = torch.cuda.get_device_properties(0) print(f"Device: {props.name}") latents = torch.randn((1, 16, 19, 120, 68)).to(torch.device(0), torch.bfloat16) print(f"Latent dims: {latents.size()}") print("load vae") start = time.perf_counter() vae = load_vae("hunyuan_video_vae_bf16.safetensors", False) print(f"loaded vae in {time.perf_counter() - start} seconds") vae.t_tile_overlap_factor = 0.25 vae.tile_latent_min_tsize = 16 vae.tile_sample_min_size = 256 vae.tile_latent_min_size = 32 vae.enable_tiling() print("decodeing") with profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], record_shapes=True, with_flops=True) as prof: start = time.perf_counter() generator = torch.Generator(device=torch.device("cpu")) decoded = vae.decode( latents, return_dict=False, generator=generator )[0] print(f"decoded in {time.perf_counter() - start} seconds") print(f"decoded dims: {decoded.size()}") print(prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=100)) prof.export_chrome_trace("trace.json")