add gpu memory rebalanceing

This commit is contained in:
2024-03-17 22:54:33 +01:00
parent 5acb6809ed
commit 38a7f7cfc4
3 changed files with 78 additions and 39 deletions

View File

@ -25,10 +25,16 @@ class Linear(torch.nn.Linear):
return not self.weight.requires_grad
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None):
frozen = self.isFrozen()
if dtype is not None:
self.weight = torch.nn.Parameter(self.weight.to(dtype))
if self.bias is not None:
self.bias = torch.nn.Parameter(self.bias.to(dtype))
if device is not None:
self.weight = torch.nn.Parameter(self.weight.to(device))
if self.bias is not None:
self.bias = torch.nn.Parameter(self.bias.to(device))
self.setFrozen(frozen)
class ConvertingLinear(Linear):
@ -63,6 +69,6 @@ class ConvertingLinear(Linear):
if input.dtype != self.weight.dtype:
input = input.to(self.weight.dtype)
output = torch.nn.Linear.forward(self, input)
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
if torch.isnan(output).any():
breakpoint()
return output.to(output_device).to(output_dtype)