imlement 8bit quantization

This commit is contained in:
2024-04-07 20:52:15 +02:00
parent c33964371c
commit 6b38cfabf8

View File

@ -32,7 +32,6 @@ class Linear(torch.nn.Linear):
self.bias.requires_grad = not frozen self.bias.requires_grad = not frozen
if convert: if convert:
if frozen: if frozen:
breakpoint()
self.compress() self.compress()
else: else:
self.decompress() self.decompress()
@ -130,12 +129,10 @@ class DynamicQantizedLinear(Linear):
def compress(self) -> None: def compress(self) -> None:
weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
self.weight_quantized, self.weight_state = bnb.functional.quantize_4bit(weight, blocksize=self.block_size, self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size)
compress_statistics=False, quant_type=self.quant_type)
if self.bias is not None: if self.bias is not None:
bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device) bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device)
self.bias_quantized, self.bias_state = bnb.functional.quantize_4bit(bias, blocksize=self.block_size, self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
compress_statistics=False, quant_type=self.quant_type)
weight = torch.nn.Parameter(self.weight.to(self.cold_device)) weight = torch.nn.Parameter(self.weight.to(self.cold_device))
bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
@ -144,9 +141,9 @@ class DynamicQantizedLinear(Linear):
if self.weight_quantized is None: if self.weight_quantized is None:
raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable") raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable")
dtype = self.weight.dtype dtype = self.weight.dtype
self.weight = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device)) self.weight = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device))
if self.bias_quantized: if self.bias_quantized:
self.bias = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device)) self.bias = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device))
def checkDistance(self) -> tuple[float, float]: def checkDistance(self) -> tuple[float, float]:
if self.weight_quantized is None: if self.weight_quantized is None:
@ -156,8 +153,8 @@ class DynamicQantizedLinear(Linear):
blocksize=self.block_size, blocksize=self.block_size,
compress_statistics=True, compress_statistics=True,
quant_type=self.quant_type) quant_type=self.quant_type)
dequantized_original_weight = bnb.functional.dequantize_fp4(quantized_original_weight, quantized_original_state).to(original_weight.dtype) dequantized_original_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
dequantized_weight = bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(original_weight.dtype) dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item() distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item() error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
return (distance, error) return (distance, error)
@ -180,10 +177,14 @@ class DynamicQantizedLinear(Linear):
raise RuntimeError("forward() called in quantized stated before quantized weights are avialable") raise RuntimeError("forward() called in quantized stated before quantized weights are avialable")
if x.device != self.weight_quantized.device: if x.device != self.weight_quantized.device:
x = x.to(self.weight_quantized.device) x = x.to(self.weight_quantized.device)
bias = None weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(x.dtype)
out = torch.matmul(x, weight.t())
if self.bias_quantized is not None: if self.bias_quantized is not None:
bias = bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(x.dtype) bias = bnb.functional.dequantize_blockwise(self.bias_quantized, self.bias_state).to(x.dtype)
out = bnb.matmul_4bit(x, self.weight_quantized.t(), bias=bias, quant_state=self.weight_state) out = out + bias
if torch.isnan(out).sum().item() > 0:
breakpoint()
return out.to(output_device).to(output_dtype) return out.to(output_device).to(output_dtype)