From 54a585f2b9f48913b2cc618ae1e8e2e699ff80a0 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 12 Feb 2025 16:29:43 +0100 Subject: [PATCH] add profileing and small conv3d repoducer --- conv3d.py | 46 +++++++++++++++++++++++++++++++++ run.py | 24 ++++++++++++----- vae/autoencoder_kl_causal_3d.py | 3 +++ vae/unet_causal_3d_blocks.py | 1 + 4 files changed, 67 insertions(+), 7 deletions(-) create mode 100644 conv3d.py diff --git a/conv3d.py b/conv3d.py new file mode 100644 index 0000000..1315467 --- /dev/null +++ b/conv3d.py @@ -0,0 +1,46 @@ +import torch +import time + +configs = [ + [128, 128, 3, 1], + [256, 256, 3, 1], + [512, 512, 3, 1], + [128, 256, 1, 1], + [512, 512, 3, (2, 2, 2)], + [256, 256, 3, (2, 2, 2)], + [128, 3, 3, 1] +] + +inputs = [ + [1, 128, 67, 258, 258], + [1, 256, 35, 130, 130], + [1, 512, 35, 130, 130], + [1, 128, 67, 258, 258], + [1, 512, 35, 130, 130], + [1, 256, 27, 258, 258], + [1, 128, 67, 258, 258], +] + + +def conv3dbenchmark(configs: list[list[int]], inputs: list[list[int]], repeat: int, dtype: torch.dtype, device: torch.device): + modules = list() + assert len(inputs) == len(configs) + + for config in configs: + modules.append(torch.nn.Conv3d(config[0], config[1], config[2], stride=config[3]).to(device, dtype)) + + for i in range(len(modules)): + x = torch.randn(inputs[i]).to(device, dtype) + print(f"Running Conv3d config: {configs[i]} input: {inputs[i]} type: {dtype}") + start = time.perf_counter() + for n in range(repeat): + modules[i].forward(x) + torch.cuda.synchronize(device) + print(f"Time {(time.perf_counter() - start) / repeat} seconds\n") + + +if __name__ == "__main__": + device = torch.device(0) + + conv3dbenchmark(configs, inputs, 5, torch.bfloat16, device) + conv3dbenchmark(configs, inputs, 5, torch.float16, device) diff --git a/run.py b/run.py index 60850dd..9d56cd2 100644 --- a/run.py +++ b/run.py @@ -6,6 +6,8 @@ import json from vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D import time +from torch.profiler import profile, record_function, ProfilerActivity + def load_vae(path: str, compile_vae: bool): with open("hy_vae_config.json") as f: @@ -25,6 +27,10 @@ def load_vae(path: str, compile_vae: bool): if __name__ == "__main__": + props = torch.cuda.get_device_properties(0) + + print(f"Device: {props.name}") + latents = torch.randn((1, 16, 19, 120, 68)).to(torch.device(0), torch.bfloat16) print(f"Latent dims: {latents.size()}") @@ -40,11 +46,15 @@ if __name__ == "__main__": vae.enable_tiling() print("decodeing") - start = time.perf_counter() - generator = torch.Generator(device=torch.device("cpu")) - decoded = vae.decode( - latents, return_dict=False, generator=generator - )[0] - print(f"decoded in {time.perf_counter() - start} seconds") - print(f"decoded dims: {decoded.size()}") + with profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], record_shapes=True, with_flops=True) as prof: + start = time.perf_counter() + generator = torch.Generator(device=torch.device("cpu")) + decoded = vae.decode( + latents, return_dict=False, generator=generator + )[0] + print(f"decoded in {time.perf_counter() - start} seconds") + print(f"decoded dims: {decoded.size()}") + + print(prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=100)) + prof.export_chrome_trace("trace.json") diff --git a/vae/autoencoder_kl_causal_3d.py b/vae/autoencoder_kl_causal_3d.py index 045be1b..f0d30c9 100644 --- a/vae/autoencoder_kl_causal_3d.py +++ b/vae/autoencoder_kl_causal_3d.py @@ -115,6 +115,9 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): self.post_quant_conv = nn.Conv3d( latent_channels, latent_channels, kernel_size=1) + print(f"Conv3d: {2 * latent_channels}, {2 * latent_channels}, 1") + print(f"Conv3d: {latent_channels}, {latent_channels}, 1") + self.use_slicing = False self.use_spatial_tiling = False self.use_temporal_tiling = False diff --git a/vae/unet_causal_3d_blocks.py b/vae/unet_causal_3d_blocks.py index 9c5cdd4..9163813 100644 --- a/vae/unet_causal_3d_blocks.py +++ b/vae/unet_causal_3d_blocks.py @@ -72,6 +72,7 @@ class CausalConv3d(nn.Module): self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + print(f"Conv3d: {chan_in}, {chan_out}, {kernel_size}, stride={stride}, dilation={dilation}") def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)