fix checkDistance

This commit is contained in:
uvos 2024-04-08 00:33:02 +02:00
parent 6b38cfabf8
commit 8abea9ef89

View File

@ -148,12 +148,9 @@ class DynamicQantizedLinear(Linear):
def checkDistance(self) -> tuple[float, float]:
if self.weight_quantized is None:
raise RuntimeError("checkDistance() called without quantized weights avialable")
original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
quantized_original_weight, quantized_original_state = bnb.functional.quantize_4bit(original_weight,
blocksize=self.block_size,
compress_statistics=True,
quant_type=self.quant_type)
dequantized_original_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype)
dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()