51 lines
1.3 KiB
Python
51 lines
1.3 KiB
Python
#!/bin/python3
|
|
|
|
import torch
|
|
import safetensors.torch
|
|
import json
|
|
from vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
|
import time
|
|
|
|
|
|
def load_vae(path: str, compile_vae: bool):
|
|
with open("hy_vae_config.json") as f:
|
|
vae_config = json.load(f)
|
|
|
|
vae_st = safetensors.torch.load_file(path, device=torch.device(0).type)
|
|
|
|
vae = AutoencoderKLCausal3D.from_config(vae_config)
|
|
vae.load_state_dict(vae_st)
|
|
vae.requires_grad_(False)
|
|
vae = vae.to(torch.device(0), torch.bfloat16)
|
|
vae.eval()
|
|
|
|
if compile_vae:
|
|
vae = torch.compile(vae)
|
|
return vae
|
|
|
|
|
|
if __name__ == "__main__":
|
|
latents = torch.randn((1, 16, 19, 120, 68)).to(torch.device(0), torch.bfloat16)
|
|
print(f"Latent dims: {latents.size()}")
|
|
|
|
print("load vae")
|
|
start = time.perf_counter()
|
|
vae = load_vae("hunyuan_video_vae_bf16.safetensors", False)
|
|
print(f"loaded vae in {time.perf_counter() - start} seconds")
|
|
|
|
vae.t_tile_overlap_factor = 0.25
|
|
vae.tile_latent_min_tsize = 16
|
|
vae.tile_sample_min_size = 256
|
|
vae.tile_latent_min_size = 32
|
|
vae.enable_tiling()
|
|
|
|
print("decodeing")
|
|
start = time.perf_counter()
|
|
generator = torch.Generator(device=torch.device("cpu"))
|
|
decoded = vae.decode(
|
|
latents, return_dict=False, generator=generator
|
|
)[0]
|
|
print(f"decoded in {time.perf_counter() - start} seconds")
|
|
print(f"decoded dims: {decoded.size()}")
|
|
|