#!/bin/python3 import torch import safetensors.torch import json from vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D import time def load_vae(path: str, compile_vae: bool): with open("hy_vae_config.json") as f: vae_config = json.load(f) vae_st = safetensors.torch.load_file(path, device=torch.device(0).type) vae = AutoencoderKLCausal3D.from_config(vae_config) vae.load_state_dict(vae_st) vae.requires_grad_(False) vae = vae.to(torch.device(0), torch.bfloat16) vae.eval() if compile_vae: vae = torch.compile(vae) return vae if __name__ == "__main__": latents = torch.randn((1, 16, 19, 120, 68)).to(torch.device(0), torch.bfloat16) print(f"Latent dims: {latents.size()}") print("load vae") start = time.perf_counter() vae = load_vae("hunyuan_video_vae_bf16.safetensors", False) print(f"loaded vae in {time.perf_counter() - start} seconds") vae.t_tile_overlap_factor = 0.25 vae.tile_latent_min_tsize = 16 vae.tile_sample_min_size = 256 vae.tile_latent_min_size = 32 vae.enable_tiling() print("decodeing") start = time.perf_counter() generator = torch.Generator(device=torch.device("cpu")) decoded = vae.decode( latents, return_dict=False, generator=generator )[0] print(f"decoded in {time.perf_counter() - start} seconds") print(f"decoded dims: {decoded.size()}")