Fix mypy warnings
This commit is contained in:
@ -39,7 +39,7 @@ class Linear(torch.nn.Linear):
|
||||
def isFrozen(self) -> bool:
|
||||
return not self.weight.requires_grad
|
||||
|
||||
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
|
||||
def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None):
|
||||
frozen = self.isFrozen()
|
||||
if dtype is not None:
|
||||
self.weight = torch.nn.Parameter(self.weight.to(dtype))
|
||||
@ -77,7 +77,7 @@ class DynamicConvertingLinear(Linear):
|
||||
self.output_device = output_device
|
||||
|
||||
@classmethod
|
||||
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None):
|
||||
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype=torch.float32, output_device=None):
|
||||
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
||||
out_features=in_module.out_features,
|
||||
bias=in_module.bias is not None,
|
||||
@ -124,7 +124,7 @@ class DynamicQantizedLinear(Linear):
|
||||
self.weight_start = self.weight.clone().detach()
|
||||
|
||||
@classmethod
|
||||
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device,
|
||||
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"),
|
||||
output_dtype=None, compute_dtype=torch.float16, output_device=None):
|
||||
new_module = cls(in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None,
|
||||
active_device=active_device, cold_device=cold_device, output_dtype=output_dtype,
|
||||
@ -193,7 +193,7 @@ class DynamicQantizedLinear(Linear):
|
||||
|
||||
return out.to(output_device).to(output_dtype)
|
||||
|
||||
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
|
||||
def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None):
|
||||
if dtype is not None:
|
||||
super().inplaceTo(dtype=dtype)
|
||||
if device is not None:
|
||||
|
Reference in New Issue
Block a user