working full training
This commit is contained in:
@ -2,11 +2,12 @@ import torch
|
||||
|
||||
|
||||
class ConvertingLinear(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.output_dtype = output_dtype
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
output_dtype = input.dtype
|
||||
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
|
||||
output_device = input.device
|
||||
if input.device != self.weight.device:
|
||||
input = input.to(self.weight.device)
|
||||
|
Reference in New Issue
Block a user