fix checkDistance
This commit is contained in:
@ -148,12 +148,9 @@ class DynamicQantizedLinear(Linear):
|
|||||||
def checkDistance(self) -> tuple[float, float]:
|
def checkDistance(self) -> tuple[float, float]:
|
||||||
if self.weight_quantized is None:
|
if self.weight_quantized is None:
|
||||||
raise RuntimeError("checkDistance() called without quantized weights avialable")
|
raise RuntimeError("checkDistance() called without quantized weights avialable")
|
||||||
original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
|
original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
|
||||||
quantized_original_weight, quantized_original_state = bnb.functional.quantize_4bit(original_weight,
|
quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
|
||||||
blocksize=self.block_size,
|
dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype)
|
||||||
compress_statistics=True,
|
|
||||||
quant_type=self.quant_type)
|
|
||||||
dequantized_original_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
|
|
||||||
dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
|
dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
|
||||||
distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
|
distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
|
||||||
error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
|
error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
|
||||||
|
Reference in New Issue
Block a user