add support for huggingfacehub datasets and for specificying a prompt for eval
This commit is contained in:
61
modules.py
61
modules.py
@ -35,7 +35,6 @@ class Linear(torch.nn.Linear):
|
||||
self.compress()
|
||||
else:
|
||||
self.decompress()
|
||||
self.weightStart = torch.Tensor(self.weight).clone().detach()
|
||||
|
||||
def isFrozen(self) -> bool:
|
||||
return not self.weight.requires_grad
|
||||
@ -60,9 +59,15 @@ class Linear(torch.nn.Linear):
|
||||
|
||||
@wraps(torch.nn.Module.to)
|
||||
def to(self, *args, **kwargs):
|
||||
breakpoint()
|
||||
return self
|
||||
|
||||
def check(self) -> bool:
|
||||
if self.isFrozen() and self.weight.dtype != torch.float16:
|
||||
return False
|
||||
elif not self.isFrozen() and self.weight.dtype != torch.float32:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class DynamicConvertingLinear(Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None,
|
||||
@ -116,6 +121,7 @@ class DynamicQantizedLinear(Linear):
|
||||
self.bias_state = None
|
||||
self.block_size = 128
|
||||
self.quant_type = 'nf4'
|
||||
self.weight_start = self.weight.clone().detach()
|
||||
|
||||
@classmethod
|
||||
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device,
|
||||
@ -125,6 +131,7 @@ class DynamicQantizedLinear(Linear):
|
||||
compute_dtype=compute_dtype, output_device=output_device)
|
||||
new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
|
||||
new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
|
||||
new_module.weight_start = new_module.weight.clone().detach()
|
||||
return new_module
|
||||
|
||||
def compress(self) -> None:
|
||||
@ -134,26 +141,27 @@ class DynamicQantizedLinear(Linear):
|
||||
bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device)
|
||||
self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
|
||||
|
||||
weight = torch.nn.Parameter(self.weight.to(self.cold_device))
|
||||
bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
|
||||
frozen = self.isFrozen()
|
||||
self.weight = torch.nn.Parameter(self.weight.to(self.cold_device))
|
||||
self.bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
|
||||
self.setFrozen(frozen, False)
|
||||
|
||||
def decompress(self) -> None:
|
||||
if self.weight_quantized is None:
|
||||
raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable")
|
||||
dtype = self.weight.dtype
|
||||
self.weight = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device))
|
||||
self.weight_quantized = None
|
||||
self.weight_state = None
|
||||
self.bias_quantized = None
|
||||
self.bias_state = None
|
||||
self.weight_start = self.weight.clone().detach().to(self.cold_device)
|
||||
self.weight = torch.nn.Parameter(self.weight.to(self.active_device))
|
||||
if self.bias_quantized:
|
||||
self.bias = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device))
|
||||
self.bias = torch.nn.Parameter(self.bias.to(self.active_device))
|
||||
|
||||
def checkDistance(self) -> tuple[float, float]:
|
||||
if self.weight_quantized is None:
|
||||
raise RuntimeError("checkDistance() called without quantized weights avialable")
|
||||
def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
|
||||
quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
|
||||
dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype)
|
||||
dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
|
||||
distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
|
||||
error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
|
||||
dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
|
||||
distance = (self.weight_start - self.weight.to(self.cold_device)).to(torch.float32)
|
||||
error = (dequantized_original_weight - original_weight).to(torch.float32)
|
||||
return (distance, error)
|
||||
|
||||
def setOutputDevice(self, output_device: torch.device):
|
||||
@ -200,3 +208,24 @@ class DynamicQantizedLinear(Linear):
|
||||
if not frozen:
|
||||
super().inplaceTo(device=device)
|
||||
self.setFrozen(frozen, False)
|
||||
|
||||
def check(self) -> bool:
|
||||
if self.isFrozen():
|
||||
if torch.device(self.weight.device) != torch.device(self.cold_device):
|
||||
breakpoint()
|
||||
print("Frozen but not cold")
|
||||
return False
|
||||
if self.weight_quantized is None:
|
||||
breakpoint()
|
||||
print("Frozen but not quanted")
|
||||
return False
|
||||
else:
|
||||
if torch.device(self.weight.device) != torch.device(self.active_device):
|
||||
breakpoint()
|
||||
print("Active but not warm")
|
||||
return False
|
||||
if self.weight_quantized is not None:
|
||||
breakpoint()
|
||||
print("Active but still quantized")
|
||||
return False
|
||||
return True
|
||||
|
Reference in New Issue
Block a user