wip refactor
This commit is contained in:
68
modules.py
Normal file
68
modules.py
Normal file
@ -0,0 +1,68 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
|
||||
@classmethod
|
||||
def fromLinear(cls, in_module: torch.nn.Linear):
|
||||
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
||||
out_features=in_module.out_features,
|
||||
bias=in_module.bias is not None,
|
||||
device=in_module.weight.device,
|
||||
dtype=in_module.weight.dtype)
|
||||
new_module.weight = in_module.weight
|
||||
new_module.bias = in_module.bias
|
||||
return new_module
|
||||
|
||||
def setFrozen(self, frozen: bool):
|
||||
self.weight.requires_grad = not frozen
|
||||
if self.bias is not None:
|
||||
self.bias.requires_grad = not frozen
|
||||
|
||||
def isFrozen(self) -> bool:
|
||||
return not self.weight.requires_grad
|
||||
|
||||
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
|
||||
if dtype is not None:
|
||||
self.weight = torch.nn.Parameter(self.weight.to(dtype))
|
||||
if device is not None:
|
||||
self.weight = torch.nn.Parameter(self.weight.to(device))
|
||||
|
||||
|
||||
class ConvertingLinear(Linear):
|
||||
def __init__(self,
|
||||
in_features, out_features, bias=True, device=None, dtype=None,
|
||||
output_dtype=None, output_device=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.output_dtype = output_dtype
|
||||
self.output_device = output_device
|
||||
|
||||
@classmethod
|
||||
def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None):
|
||||
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
|
||||
out_features=in_module.out_features,
|
||||
bias=in_module.bias is not None,
|
||||
device=in_module.weight.device,
|
||||
dtype=in_module.weight.dtype)
|
||||
new_module.output_dtype = output_dtype
|
||||
new_module.output_device = output_device
|
||||
new_module.weight = in_module.weight
|
||||
new_module.bias = in_module.bias
|
||||
return new_module
|
||||
|
||||
def setOutputDevice(self, output_device: torch.device):
|
||||
self.output_device = output_device
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
|
||||
output_device = input.device if self.output_device is None else self.output_device
|
||||
if input.device != self.weight.device:
|
||||
input = input.to(self.weight.device)
|
||||
if input.dtype != self.weight.dtype:
|
||||
input = input.to(self.weight.dtype)
|
||||
output = torch.nn.Linear.forward(self, input)
|
||||
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
|
||||
breakpoint()
|
||||
return output.to(output_device).to(output_dtype)
|
Reference in New Issue
Block a user