working full training

This commit is contained in:
2024-03-09 10:03:37 +01:00
parent 7a47fcdcc0
commit 11ea9eeaa7
2 changed files with 25 additions and 7 deletions

View File

@ -2,11 +2,12 @@ import torch
class ConvertingLinear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.output_dtype = output_dtype
def forward(self, input: torch.Tensor):
output_dtype = input.dtype
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
output_device = input.device
if input.device != self.weight.device:
input = input.to(self.weight.device)