add support for huggingfacehub datasets and for specificying a prompt for eval

This commit is contained in:
2024-05-07 00:23:12 +02:00
parent 8abea9ef89
commit a74ef976e4
5 changed files with 183 additions and 43 deletions

View File

@ -48,14 +48,22 @@ class LinearGroup:
for module in self.modules:
module.decompress()
def checkDistance(self) -> tuple[float, float]:
distance_accum = 0.0
error_accum = 0.0
def getDistanceAndError(self) -> tuple[float, float]:
distance_accum = torch.Tensor()
error_accum = torch.Tensor()
for module in self.modules:
distance, error = module.checkDistance()
distance_accum += distance**2
error_accum += error**2
return (math.sqrt(distance_accum) / math.sqrt(len(self.modules)), math.sqrt(error_accum) / math.sqrt(len(self.modules)))
distance, error = module.getDistanceAndError()
distance = distance.to("cpu")
error = error.to("cpu")
distance_accum = torch.cat((distance_accum, distance.reshape((distance.numel()))))
error_accum = torch.cat((error_accum, error.reshape((error.numel()))))
return (distance_accum, error_accum)
def check(self) -> bool:
for module in self.modules:
if not module.check():
return False
return True
class DyntrainModel:
@ -160,15 +168,18 @@ class DyntrainModel:
total_params = self.dynamicParameters() + self.staticParameters()
return sum(p.numel() for p in total_params if p.requires_grad)
def reshuffleActive(self) -> None:
def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
index = randint(0, len(self.active_linear_groups) - 1)
return self.active_linear_groups[index].getDistanceAndError()
def reshuffleActive(self):
active_count = len(self.active_linear_groups)
index = 0
while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
distance, error = self.active_linear_groups[index].checkDistance()
print(f"linear group has moved {distance} with an error of {error}")
group = self.active_linear_groups.pop(index)
group.setFrozen(True)
self.frozen_linear_groups.append(group)
assert group.check()
params = self.activeParameterCount()
@ -180,6 +191,7 @@ class DyntrainModel:
group.setFrozen(False)
params += group.paramCount()
self.active_linear_groups.append(group)
assert group.check()
print(math.ceil(params / 1e6))
active_params = self.activeParameterCount()
@ -248,4 +260,8 @@ class DyntrainModel:
group_index += 1
for group in tqdm(linear_groups, desc="Perpareing layers"):
group.compress()
if group.isFrozen():
group.compress()
else:
group.decompress()
assert group.check()