diff --git a/modules.py b/modules.py index cf50422..1db0dfd 100644 --- a/modules.py +++ b/modules.py @@ -32,7 +32,6 @@ class Linear(torch.nn.Linear): self.bias.requires_grad = not frozen if convert: if frozen: - breakpoint() self.compress() else: self.decompress() @@ -130,12 +129,10 @@ class DynamicQantizedLinear(Linear): def compress(self) -> None: weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) - self.weight_quantized, self.weight_state = bnb.functional.quantize_4bit(weight, blocksize=self.block_size, - compress_statistics=False, quant_type=self.quant_type) + self.weight_quantized, self.weight_state = bnb.functional.quantize_blockwise(weight, blocksize=self.block_size) if self.bias is not None: bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device) - self.bias_quantized, self.bias_state = bnb.functional.quantize_4bit(bias, blocksize=self.block_size, - compress_statistics=False, quant_type=self.quant_type) + self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size) weight = torch.nn.Parameter(self.weight.to(self.cold_device)) bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None @@ -144,9 +141,9 @@ class DynamicQantizedLinear(Linear): if self.weight_quantized is None: raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable") dtype = self.weight.dtype - self.weight = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device)) + self.weight = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device)) if self.bias_quantized: - self.bias = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device)) + self.bias = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device)) def checkDistance(self) -> tuple[float, float]: if self.weight_quantized is None: @@ -156,8 +153,8 @@ class DynamicQantizedLinear(Linear): blocksize=self.block_size, compress_statistics=True, quant_type=self.quant_type) - dequantized_original_weight = bnb.functional.dequantize_fp4(quantized_original_weight, quantized_original_state).to(original_weight.dtype) - dequantized_weight = bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(original_weight.dtype) + dequantized_original_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype) + dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype) distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item() error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item() return (distance, error) @@ -180,10 +177,14 @@ class DynamicQantizedLinear(Linear): raise RuntimeError("forward() called in quantized stated before quantized weights are avialable") if x.device != self.weight_quantized.device: x = x.to(self.weight_quantized.device) - bias = None + weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(x.dtype) + out = torch.matmul(x, weight.t()) if self.bias_quantized is not None: - bias = bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(x.dtype) - out = bnb.matmul_4bit(x, self.weight_quantized.t(), bias=bias, quant_state=self.weight_state) + bias = bnb.functional.dequantize_blockwise(self.bias_quantized, self.bias_state).to(x.dtype) + out = out + bias + + if torch.isnan(out).sum().item() > 0: + breakpoint() return out.to(output_device).to(output_dtype)