Fix mypy warnings

This commit is contained in:
Carl Philipp Klemm
2024-05-07 19:48:40 +02:00
parent a74ef976e4
commit 68f748e99e
4 changed files with 34 additions and 30 deletions

View File

@ -39,7 +39,7 @@ class Linear(torch.nn.Linear):
def isFrozen(self) -> bool:
return not self.weight.requires_grad
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None):
frozen = self.isFrozen()
if dtype is not None:
self.weight = torch.nn.Parameter(self.weight.to(dtype))
@ -77,7 +77,7 @@ class DynamicConvertingLinear(Linear):
self.output_device = output_device
@classmethod
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None):
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype=torch.float32, output_device=None):
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
out_features=in_module.out_features,
bias=in_module.bias is not None,
@ -124,7 +124,7 @@ class DynamicQantizedLinear(Linear):
self.weight_start = self.weight.clone().detach()
@classmethod
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device,
def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"),
output_dtype=None, compute_dtype=torch.float16, output_device=None):
new_module = cls(in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None,
active_device=active_device, cold_device=cold_device, output_dtype=output_dtype,
@ -193,7 +193,7 @@ class DynamicQantizedLinear(Linear):
return out.to(output_device).to(output_dtype)
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None):
if dtype is not None:
super().inplaceTo(dtype=dtype)
if device is not None: