add gpu memory rebalanceing
This commit is contained in:
@ -25,10 +25,16 @@ class Linear(torch.nn.Linear):
|
||||
return not self.weight.requires_grad
|
||||
|
||||
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
|
||||
frozen = self.isFrozen()
|
||||
if dtype is not None:
|
||||
self.weight = torch.nn.Parameter(self.weight.to(dtype))
|
||||
if self.bias is not None:
|
||||
self.bias = torch.nn.Parameter(self.bias.to(dtype))
|
||||
if device is not None:
|
||||
self.weight = torch.nn.Parameter(self.weight.to(device))
|
||||
if self.bias is not None:
|
||||
self.bias = torch.nn.Parameter(self.bias.to(device))
|
||||
self.setFrozen(frozen)
|
||||
|
||||
|
||||
class ConvertingLinear(Linear):
|
||||
@ -63,6 +69,6 @@ class ConvertingLinear(Linear):
|
||||
if input.dtype != self.weight.dtype:
|
||||
input = input.to(self.weight.dtype)
|
||||
output = torch.nn.Linear.forward(self, input)
|
||||
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
|
||||
if torch.isnan(output).any():
|
||||
breakpoint()
|
||||
return output.to(output_device).to(output_dtype)
|
||||
|
Reference in New Issue
Block a user