various optimizations
This commit is contained in:
@ -68,7 +68,9 @@ class LinearGroup:
|
||||
|
||||
class DyntrainModel:
|
||||
def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
|
||||
target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
|
||||
target_active_params: int, train_static_params: bool,
|
||||
reshuffle_fraction: float, gradient_checkpointing: bool,
|
||||
trust_remote_code: bool = False):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
@ -82,6 +84,7 @@ class DyntrainModel:
|
||||
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
|
||||
self.devices = list[torch.device]()
|
||||
self.inital_reshufle = True
|
||||
self.train_static_params = train_static_params
|
||||
|
||||
if gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
@ -167,8 +170,14 @@ class DyntrainModel:
|
||||
def staticParameterCount(self) -> int:
|
||||
return sum(p.numel() for p in self.staticParameters())
|
||||
|
||||
def activeDynamicParameterCount(self) -> int:
|
||||
return sum(p.numel() for p in self.dynamicParameters() if p.requires_grad)
|
||||
|
||||
def activeParameterCount(self) -> int:
|
||||
total_params = self.dynamicParameters() + self.staticParameters()
|
||||
if self.train_static_params:
|
||||
total_params = self.dynamicParameters() + self.staticParameters()
|
||||
else:
|
||||
total_params = self.dynamicParameters()
|
||||
return sum(p.numel() for p in total_params if p.requires_grad)
|
||||
|
||||
def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
|
||||
@ -187,7 +196,7 @@ class DyntrainModel:
|
||||
params = self.activeParameterCount()
|
||||
|
||||
if params >= self.target_active_params:
|
||||
RuntimeError("Insuficant active parameters to suffle active")
|
||||
raise RuntimeError("Insuficant active parameters to suffle active")
|
||||
while params < self.target_active_params and len(self.frozen_linear_groups) > 0:
|
||||
i = randint(0, len(self.frozen_linear_groups) - 1)
|
||||
group = self.frozen_linear_groups.pop(i)
|
||||
@ -199,7 +208,7 @@ class DyntrainModel:
|
||||
|
||||
active_params = self.activeParameterCount()
|
||||
|
||||
assert self.target_active_params * 1.3 > active_params and self.target_active_params * 0.7 < active_params
|
||||
assert self.target_active_params * 1.4 > active_params and self.target_active_params * 0.6 < active_params
|
||||
|
||||
def activeParamtersByDevice(self) -> list[int]:
|
||||
out = [0] * len(self.devices)
|
||||
@ -213,7 +222,7 @@ class DyntrainModel:
|
||||
for i, count in enumerate(active_counts):
|
||||
memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
|
||||
if i == 0:
|
||||
memory = int(memory * 0.8)
|
||||
memory = int(memory * 0.5)
|
||||
bits_per_param.append(count / memory)
|
||||
|
||||
max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
|
||||
@ -223,7 +232,7 @@ class DyntrainModel:
|
||||
if group.getDevice() is self.devices[max_index]:
|
||||
memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
|
||||
if max_index == 0:
|
||||
memory = int(memory * 0.8)
|
||||
memory = int(memory * 0.5)
|
||||
swing = group.paramCount() / memory
|
||||
if max_bits_per_param - swing > min_bits_per_param + swing:
|
||||
group.inplaceTo(device=self.devices[min_index])
|
||||
|
Reference in New Issue
Block a user