50 lines
		
	
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			50 lines
		
	
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/bin/python3
 | |
| 
 | |
| import torch
 | |
| import safetensors.torch
 | |
| import json
 | |
| from vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
 | |
| import time
 | |
| 
 | |
| 
 | |
| def load_vae(path: str, compile_vae: bool):
 | |
| 	with open("hy_vae_config.json") as f:
 | |
| 		vae_config = json.load(f)
 | |
| 
 | |
| 	vae_st = safetensors.torch.load_file(path, device=torch.device(0).type)
 | |
| 
 | |
| 	vae = AutoencoderKLCausal3D.from_config(vae_config)
 | |
| 	vae.load_state_dict(vae_st)
 | |
| 	vae.requires_grad_(False)
 | |
| 	vae = vae.to(torch.device(0), torch.bfloat16)
 | |
| 	vae.eval()
 | |
| 
 | |
| 	if compile_vae:
 | |
| 		vae = torch.compile(vae)
 | |
| 	return vae
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
| 	latents = torch.randn((1, 16, 19, 120, 68)).to(torch.device(0), torch.bfloat16)
 | |
| 	print(f"Latent dims: {latents.size()}")
 | |
| 
 | |
| 	print("load vae")
 | |
| 	start = time.perf_counter()
 | |
| 	vae = load_vae("hunyuan_video_vae_bf16.safetensors", False)
 | |
| 	print(f"loaded vae in {time.perf_counter() - start} seconds")
 | |
| 
 | |
| 	vae.t_tile_overlap_factor = 0.25
 | |
| 	vae.tile_latent_min_tsize = 16
 | |
| 	vae.tile_sample_min_size = 256
 | |
| 	vae.tile_latent_min_size = 32
 | |
| 	vae.enable_tiling()
 | |
| 
 | |
| 	print("decodeing")
 | |
| 	start = time.perf_counter()
 | |
| 	generator = torch.Generator(device=torch.device("cpu"))
 | |
| 	decoded = vae.decode(
 | |
| 		latents, return_dict=False, generator=generator
 | |
| 	)[0]
 | |
| 	print(f"decoded in {time.perf_counter() - start} seconds")
 | |
| 	print(f"decoded dims: {decoded.size()}")
 | |
| 
 |