47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
import torch
|
|
import time
|
|
|
|
configs = [
|
|
[128, 128, 3, 1],
|
|
[256, 256, 3, 1],
|
|
[512, 512, 3, 1],
|
|
[128, 256, 1, 1],
|
|
[512, 512, 3, (2, 2, 2)],
|
|
[256, 256, 3, (2, 2, 2)],
|
|
[128, 3, 3, 1]
|
|
]
|
|
|
|
inputs = [
|
|
[1, 128, 67, 258, 258],
|
|
[1, 256, 35, 130, 130],
|
|
[1, 512, 35, 130, 130],
|
|
[1, 128, 67, 258, 258],
|
|
[1, 512, 35, 130, 130],
|
|
[1, 256, 27, 258, 258],
|
|
[1, 128, 67, 258, 258],
|
|
]
|
|
|
|
|
|
def conv3dbenchmark(configs: list[list[int]], inputs: list[list[int]], repeat: int, dtype: torch.dtype, device: torch.device):
|
|
modules = list()
|
|
assert len(inputs) == len(configs)
|
|
|
|
for config in configs:
|
|
modules.append(torch.nn.Conv3d(config[0], config[1], config[2], stride=config[3]).to(device, dtype))
|
|
|
|
for i in range(len(modules)):
|
|
x = torch.randn(inputs[i]).to(device, dtype)
|
|
print(f"Running Conv3d config: {configs[i]} input: {inputs[i]} type: {dtype}")
|
|
start = time.perf_counter()
|
|
for n in range(repeat):
|
|
modules[i].forward(x)
|
|
torch.cuda.synchronize(device)
|
|
print(f"Time {(time.perf_counter() - start) / repeat} seconds\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
device = torch.device(0)
|
|
|
|
conv3dbenchmark(configs, inputs, 5, torch.bfloat16, device)
|
|
conv3dbenchmark(configs, inputs, 5, torch.float16, device)
|