Fix mypy warnings

This commit is contained in:
Carl Philipp Klemm
2024-05-07 19:48:40 +02:00
parent a74ef976e4
commit 68f748e99e
4 changed files with 34 additions and 30 deletions

View File

@ -16,7 +16,7 @@ class LinearGroup:
for module in self.modules:
assert isinstance(module, Linear)
def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None:
def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None, output_device: torch.device | None = None) -> None:
for module in self.modules:
module.inplaceTo(dtype, device)
self.modules[-1].setOutputDevice(output_device)
@ -67,7 +67,7 @@ class LinearGroup:
class DyntrainModel:
def __init__(self, model_name_or_path: str, cache_dir: str, quantize: bool,
def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool,
target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False):
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
@ -80,19 +80,19 @@ class DyntrainModel:
self.reshuffle_fraction = reshuffle_fraction
if reshuffle_fraction < 0.10 or reshuffle_fraction > 1:
raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0")
self.devices = list()
self.devices = list[torch.device]()
self.inital_reshufle = True
if gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
self.frozen_linear_groups = list()
self.active_linear_groups = list()
self.frozen_linear_groups = list[LinearGroup]()
self.active_linear_groups = list[LinearGroup]()
linear_group_names = DyntrainModel._getLinearGroupNames(self.model)
for group in linear_group_names:
for key in group:
replace_module(self.model, key, self._getModule(key, quantize, "cuda:0", "cpu"))
replace_module(self.model, key, self._getModule(key, quantize, torch.device("cuda:0"), torch.device("cpu")))
self.frozen_linear_groups.append(LinearGroup(self.model, group))
self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16)
for group in self.frozen_linear_groups:
@ -106,6 +106,7 @@ class DyntrainModel:
else:
return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype)
@staticmethod
def _getNonlinearNames(layer: torch.nn.Module):
names = list()
modules = dict(layer.named_modules())
@ -115,8 +116,9 @@ class DyntrainModel:
names.append(key)
return names
@staticmethod
def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]:
linear_groups = list()
linear_groups = list[list[str]]()
list_counter = 0
in_sequence = False
modules = dict(layer.named_modules())
@ -132,6 +134,7 @@ class DyntrainModel:
list_counter = list_counter + 1
return linear_groups
@staticmethod
def isModuleIn16bitOutlist(key: str) -> bool:
key = key.split('.')[-1]
whitelist = set({
@ -210,7 +213,7 @@ class DyntrainModel:
for i, count in enumerate(active_counts):
memory = torch.cuda.get_device_properties(self.devices[i]).total_memory
if i == 0:
memory = memory * 0.8
memory = int(memory * 0.8)
bits_per_param.append(count / memory)
max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1])
@ -220,7 +223,7 @@ class DyntrainModel:
if group.getDevice() is self.devices[max_index]:
memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory
if max_index == 0:
memory = memory * 0.8
memory = int(memory * 0.8)
swing = group.paramCount() / memory
if max_bits_per_param - swing > min_bits_per_param + swing:
group.inplaceTo(device=self.devices[min_index])
@ -230,7 +233,7 @@ class DyntrainModel:
assert len(devices) > 0
modules = dict(self.model.named_modules())
total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices)
total_memory -= torch.cuda.get_device_properties(devices[0]).total_memory * 0.2
total_memory -= int(torch.cuda.get_device_properties(devices[0]).total_memory * 0.2)
static_param_count = self.staticParameterCount()
total_parameter_count = static_param_count + self.dynamicParameterCount()
params_per_byte = total_parameter_count / float(total_memory)
@ -245,9 +248,9 @@ class DyntrainModel:
group_index = 0
for i, device in enumerate(devices[:-1]):
memory = torch.cuda.get_device_properties(devices).total_memory
memory = torch.cuda.get_device_properties(device).total_memory
if i == 0:
memory = memory * 0.8
memory = int(memory * 0.8)
params_for_device = memory * params_per_byte
params = 0
while params_for_device > params and group_index < len(linear_groups):