From 8abea9ef89f913a17b311583c4639e40daf2e127 Mon Sep 17 00:00:00 2001 From: uvos Date: Mon, 8 Apr 2024 00:33:02 +0200 Subject: [PATCH] fix checkDistance --- modules.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/modules.py b/modules.py index 1db0dfd..6e34858 100644 --- a/modules.py +++ b/modules.py @@ -148,12 +148,9 @@ class DynamicQantizedLinear(Linear): def checkDistance(self) -> tuple[float, float]: if self.weight_quantized is None: raise RuntimeError("checkDistance() called without quantized weights avialable") - original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) - quantized_original_weight, quantized_original_state = bnb.functional.quantize_4bit(original_weight, - blocksize=self.block_size, - compress_statistics=True, - quant_type=self.quant_type) - dequantized_original_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype) + original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16) + quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size) + dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype) dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype) distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item() error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()