add support for huggingfacehub datasets and for specificying a prompt for eval
This commit is contained in:
@ -48,14 +48,22 @@ class LinearGroup:
|
||||
for module in self.modules:
|
||||
module.decompress()
|
||||
|
||||
def checkDistance(self) -> tuple[float, float]:
|
||||
distance_accum = 0.0
|
||||
error_accum = 0.0
|
||||
def getDistanceAndError(self) -> tuple[float, float]:
|
||||
distance_accum = torch.Tensor()
|
||||
error_accum = torch.Tensor()
|
||||
for module in self.modules:
|
||||
distance, error = module.checkDistance()
|
||||
distance_accum += distance**2
|
||||
error_accum += error**2
|
||||
return (math.sqrt(distance_accum) / math.sqrt(len(self.modules)), math.sqrt(error_accum) / math.sqrt(len(self.modules)))
|
||||
distance, error = module.getDistanceAndError()
|
||||
distance = distance.to("cpu")
|
||||
error = error.to("cpu")
|
||||
distance_accum = torch.cat((distance_accum, distance.reshape((distance.numel()))))
|
||||
error_accum = torch.cat((error_accum, error.reshape((error.numel()))))
|
||||
return (distance_accum, error_accum)
|
||||
|
||||
def check(self) -> bool:
|
||||
for module in self.modules:
|
||||
if not module.check():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class DyntrainModel:
|
||||
@ -160,15 +168,18 @@ class DyntrainModel:
|
||||
total_params = self.dynamicParameters() + self.staticParameters()
|
||||
return sum(p.numel() for p in total_params if p.requires_grad)
|
||||
|
||||
def reshuffleActive(self) -> None:
|
||||
def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
|
||||
index = randint(0, len(self.active_linear_groups) - 1)
|
||||
return self.active_linear_groups[index].getDistanceAndError()
|
||||
|
||||
def reshuffleActive(self):
|
||||
active_count = len(self.active_linear_groups)
|
||||
index = 0
|
||||
while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
|
||||
distance, error = self.active_linear_groups[index].checkDistance()
|
||||
print(f"linear group has moved {distance} with an error of {error}")
|
||||
group = self.active_linear_groups.pop(index)
|
||||
group.setFrozen(True)
|
||||
self.frozen_linear_groups.append(group)
|
||||
assert group.check()
|
||||
|
||||
params = self.activeParameterCount()
|
||||
|
||||
@ -180,6 +191,7 @@ class DyntrainModel:
|
||||
group.setFrozen(False)
|
||||
params += group.paramCount()
|
||||
self.active_linear_groups.append(group)
|
||||
assert group.check()
|
||||
print(math.ceil(params / 1e6))
|
||||
|
||||
active_params = self.activeParameterCount()
|
||||
@ -248,4 +260,8 @@ class DyntrainModel:
|
||||
group_index += 1
|
||||
|
||||
for group in tqdm(linear_groups, desc="Perpareing layers"):
|
||||
group.compress()
|
||||
if group.isFrozen():
|
||||
group.compress()
|
||||
else:
|
||||
group.decompress()
|
||||
assert group.check()
|
||||
|
Reference in New Issue
Block a user