fix checkDistance
This commit is contained in:
		
							parent
							
								
									6b38cfabf8
								
							
						
					
					
						commit
						8abea9ef89
					
				
					 1 changed files with 3 additions and 6 deletions
				
			
		| 
						 | 
					@ -148,12 +148,9 @@ class DynamicQantizedLinear(Linear):
 | 
				
			||||||
    def checkDistance(self) -> tuple[float, float]:
 | 
					    def checkDistance(self) -> tuple[float, float]:
 | 
				
			||||||
        if self.weight_quantized is None:
 | 
					        if self.weight_quantized is None:
 | 
				
			||||||
            raise RuntimeError("checkDistance() called without quantized weights avialable")
 | 
					            raise RuntimeError("checkDistance() called without quantized weights avialable")
 | 
				
			||||||
        original_weight = self.weight.contiguous().to(torch.float16).cuda(self.active_device)
 | 
					        original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
 | 
				
			||||||
        quantized_original_weight, quantized_original_state = bnb.functional.quantize_4bit(original_weight,
 | 
					        quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
 | 
				
			||||||
                                                                                           blocksize=self.block_size,
 | 
					        dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype)
 | 
				
			||||||
                                                                                           compress_statistics=True,
 | 
					 | 
				
			||||||
                                                                                           quant_type=self.quant_type)
 | 
					 | 
				
			||||||
        dequantized_original_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
 | 
					 | 
				
			||||||
        dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
 | 
					        dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
 | 
				
			||||||
        distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
 | 
					        distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
 | 
				
			||||||
        error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
 | 
					        error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue