Inactive parameter quanitzation support

This commit is contained in:
2024-04-07 19:15:42 +02:00
parent 3fa1fc254f
commit c33964371c
4 changed files with 161 additions and 78 deletions

View File

@ -20,10 +20,23 @@ class Linear(torch.nn.Linear):
new_module.bias = in_module.bias
return new_module
def setFrozen(self, frozen: bool):
def compress(self) -> None:
self.inplaceTo(torch.float16)
def decompress(self) -> None:
self.inplaceTo(torch.float32)
def setFrozen(self, frozen: bool, convert: bool = True):
self.weight.requires_grad = not frozen
if self.bias is not None:
self.bias.requires_grad = not frozen
if convert:
if frozen:
breakpoint()
self.compress()
else:
self.decompress()
self.weightStart = torch.Tensor(self.weight).clone().detach()
def isFrozen(self) -> bool:
return not self.weight.requires_grad
@ -38,7 +51,7 @@ class Linear(torch.nn.Linear):
self.weight = torch.nn.Parameter(self.weight.to(device))
if self.bias is not None:
self.bias = torch.nn.Parameter(self.bias.to(device))
Linear.setFrozen(self, frozen)
Linear.setFrozen(self, frozen, False)
def _apply(self, fn, recurse: bool = True):
if fn.__name__ == "convert":
@ -72,17 +85,12 @@ class DynamicConvertingLinear(Linear):
new_module.bias = in_module.bias
return new_module
def setFrozen(self, frozen: bool):
super().setFrozen(frozen)
if frozen:
self.inplaceTo(torch.float16)
else:
self.inplaceTo(torch.float32)
def setOutputDevice(self, output_device: torch.device):
self.output_device = output_device
def checkDistance(self) -> tuple[float, float]:
return (10.0, 0.0)
def forward(self, input: torch.Tensor):
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
output_device = input.device if self.output_device is None else self.output_device
@ -120,7 +128,7 @@ class DynamicQantizedLinear(Linear):
new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
return new_module
def quantize(self):
def compress(self) -> None:
weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
self.weight_quantized, self.weight_state = bnb.functional.quantize_4bit(weight, blocksize=self.block_size,
compress_statistics=False, quant_type=self.quant_type)
@ -132,19 +140,15 @@ class DynamicQantizedLinear(Linear):
weight = torch.nn.Parameter(self.weight.to(self.cold_device))
bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
def dequantize(self):
def decompress(self) -> None:
if self.weight_quantized is None:
raise RuntimeError("forward() called in quantized stated before quantized weights are avialable")
raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable")
dtype = self.weight.dtype
self.weight = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device))
if self.bias_quantized:
self.bias = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device))
self.weight_quantized = None
self.weight_state = None
self.bias_quantized = None
self.bias_state = None
def checkDistance(self) -> float:
def checkDistance(self) -> tuple[float, float]:
if self.weight_quantized is None:
raise RuntimeError("checkDistance() called without quantized weights avialable")
original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
@ -154,22 +158,13 @@ class DynamicQantizedLinear(Linear):
quant_type=self.quant_type)
dequantized_original_weight = bnb.functional.dequantize_fp4(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
dequantized_weight = bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(original_weight.dtype)
return (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight) / dequantized_original_weight.numel()).item()
distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
return (distance, error)
def setOutputDevice(self, output_device: torch.device):
self.output_device = output_device
def setFrozen(self, frozen: bool) -> None:
if frozen == self.isFrozen():
return
super().setFrozen(frozen)
if frozen:
self.quantize()
else:
self.dequantize()
def forward(self, x: torch.Tensor):
output_dtype = x.dtype if self.output_dtype is None else self.output_dtype
output_device = x.device if self.output_device is None else self.output_device
@ -183,9 +178,27 @@ class DynamicQantizedLinear(Linear):
else:
if self.weight_quantized is None:
raise RuntimeError("forward() called in quantized stated before quantized weights are avialable")
if x.device != self.weight_quantized.device:
x = x.to(self.weight_quantized.device)
bias = None
if self.bias_quantized is not None:
bias = bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(x.dtype)
out = bnb.matmul_4bit(x, self.weight_quantized.t(), bias=bias, quant_state=self.weight_state)
return out.to(output_device).to(output_dtype)
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
if dtype is not None:
super().inplaceTo(dtype=dtype)
if device is not None:
frozen = self.isFrozen()
self.active_device = device
if self.weight_quantized is not None:
self.weight_quantized = self.weight_quantized.to(device)
self.weight_state = self.weight_state.to(device)
if self.bias_quantized is not None:
self.bias_quantized = self.bias_quantized.to(device)
self.bias_state = self.bias_state.to(device)
if not frozen:
super().inplaceTo(device=device)
self.setFrozen(frozen, False)