import torch import bitsandbytes as bnb import torch.multiprocessing as multiprocessing from typing import overload, Optional, Union from functools import wraps class Linear(torch.nn.Linear): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): super().__init__(in_features, out_features, bias, device, dtype) @classmethod def fromLinear(cls, in_module: torch.nn.Linear): new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None, device=in_module.weight.device, dtype=in_module.weight.dtype) new_module.weight = in_module.weight new_module.bias = in_module.bias return new_module def setFrozen(self, frozen: bool): self.weight.requires_grad = not frozen if self.bias is not None: self.bias.requires_grad = not frozen def isFrozen(self) -> bool: return not self.weight.requires_grad def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None): frozen = self.isFrozen() if dtype is not None: self.weight = torch.nn.Parameter(self.weight.to(dtype)) if self.bias is not None: self.bias = torch.nn.Parameter(self.bias.to(dtype)) if device is not None: self.weight = torch.nn.Parameter(self.weight.to(device)) if self.bias is not None: self.bias = torch.nn.Parameter(self.bias.to(device)) Linear.setFrozen(self, frozen) def _apply(self, fn, recurse: bool = True): if fn.__name__ == "convert": return self else: return super()._apply(fn, recurse) @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): breakpoint() return self class DynamicConvertingLinear(Linear): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None, output_device=None): super().__init__(in_features, out_features, bias, device, dtype) self.output_dtype = output_dtype self.output_device = output_device @classmethod def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None): new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None, device=in_module.weight.device, dtype=in_module.weight.dtype) new_module.output_dtype = output_dtype new_module.output_device = output_device new_module.weight = in_module.weight new_module.bias = in_module.bias return new_module def setFrozen(self, frozen: bool): super().setFrozen(frozen) if frozen: self.inplaceTo(torch.float16) else: self.inplaceTo(torch.float32) def setOutputDevice(self, output_device: torch.device): self.output_device = output_device def forward(self, input: torch.Tensor): output_dtype = input.dtype if self.output_dtype is None else self.output_dtype output_device = input.device if self.output_device is None else self.output_device if input.device != self.weight.device: input = input.to(self.weight.device) if input.dtype != self.weight.dtype: input = input.to(self.weight.dtype) output = torch.nn.Linear.forward(self, input) return output.to(output_device).to(output_dtype) class DynamicQantizedLinear(Linear): def __init__(self, in_features: int, out_features: int, bias: bool, active_device: torch.device, cold_device: torch.device, output_dtype=None, compute_dtype=None, output_device=None): super().__init__(in_features, out_features, bias, cold_device, torch.float32) self.active_device = active_device self.cold_device = cold_device self.output_device = output_device self.output_dtype = output_dtype self.compute_dtype = compute_dtype self.weight_quantized = None self.weight_state = None self.bias_quantized = None self.bias_state = None self.block_size = 128 self.quant_type = 'nf4' @classmethod def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device, output_dtype=None, compute_dtype=torch.float16, output_device=None): new_module = cls(in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None, active_device=active_device, cold_device=cold_device, output_dtype=output_dtype, compute_dtype=compute_dtype, output_device=output_device) new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device)) new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None return new_module def quantize(self): weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) self.weight_quantized, self.weight_state = bnb.functional.quantize_4bit(weight, blocksize=self.block_size, compress_statistics=False, quant_type=self.quant_type) if self.bias is not None: bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device) self.bias_quantized, self.bias_state = bnb.functional.quantize_4bit(bias, blocksize=self.block_size, compress_statistics=False, quant_type=self.quant_type) weight = torch.nn.Parameter(self.weight.to(self.cold_device)) bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None def dequantize(self): if self.weight_quantized is None: raise RuntimeError("forward() called in quantized stated before quantized weights are avialable") dtype = self.weight.dtype self.weight = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device)) if self.bias_quantized: self.bias = torch.nn.Parameter(bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device)) self.weight_quantized = None self.weight_state = None self.bias_quantized = None self.bias_state = None def checkDistance(self) -> float: if self.weight_quantized is None: raise RuntimeError("checkDistance() called without quantized weights avialable") original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device) quantized_original_weight, quantized_original_state = bnb.functional.quantize_4bit(original_weight, blocksize=self.block_size, compress_statistics=True, quant_type=self.quant_type) dequantized_original_weight = bnb.functional.dequantize_fp4(quantized_original_weight, quantized_original_state).to(original_weight.dtype) dequantized_weight = bnb.functional.dequantize_fp4(self.weight_quantized, self.weight_state).to(original_weight.dtype) return (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight) / dequantized_original_weight.numel()).item() def setOutputDevice(self, output_device: torch.device): self.output_device = output_device def setFrozen(self, frozen: bool) -> None: if frozen == self.isFrozen(): return super().setFrozen(frozen) if frozen: self.quantize() else: self.dequantize() def forward(self, x: torch.Tensor): output_dtype = x.dtype if self.output_dtype is None else self.output_dtype output_device = x.device if self.output_device is None else self.output_device if not self.isFrozen(): if x.device != self.weight.device: x = x.to(self.weight.device) if x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) return super().forward(x).to(output_device).to(output_dtype) else: if self.weight_quantized is None: raise RuntimeError("forward() called in quantized stated before quantized weights are avialable") bias = None if self.bias_quantized is not None: bias = bnb.functional.dequantize_fp4(self.bias_quantized, self.bias_state).to(x.dtype) out = bnb.matmul_4bit(x, self.weight_quantized.t(), bias=bias, quant_state=self.weight_state) return out.to(output_device).to(output_dtype)