fix checkDistance
This commit is contained in:
parent
6b38cfabf8
commit
8abea9ef89
@ -148,12 +148,9 @@ class DynamicQantizedLinear(Linear):
|
||||
def checkDistance(self) -> tuple[float, float]:
|
||||
if self.weight_quantized is None:
|
||||
raise RuntimeError("checkDistance() called without quantized weights avialable")
|
||||
original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
|
||||
quantized_original_weight, quantized_original_state = bnb.functional.quantize_4bit(original_weight,
|
||||
blocksize=self.block_size,
|
||||
compress_statistics=True,
|
||||
quant_type=self.quant_type)
|
||||
dequantized_original_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
|
||||
original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
|
||||
quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
|
||||
dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype)
|
||||
dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
|
||||
distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
|
||||
error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
|
||||
|
Loading…
x
Reference in New Issue
Block a user