From 68f748e99e00714a95bc77b79c1d8e55d79b7786 Mon Sep 17 00:00:00 2001 From: Carl Philipp Klemm Date: Tue, 7 May 2024 19:48:40 +0200 Subject: [PATCH] Fix mypy warnings --- arguments.py | 13 ++++++------- dyntrainmodel.py | 27 +++++++++++++++------------ modules.py | 8 ++++---- train_dynamic.py | 16 +++++++++------- 4 files changed, 34 insertions(+), 30 deletions(-) diff --git a/arguments.py b/arguments.py index 9c35cca..ef13b1b 100644 --- a/arguments.py +++ b/arguments.py @@ -4,6 +4,9 @@ from typing import Optional @dataclass class DataArguments: + dataset: str = field( + metadata={"help": "A json file (s2s) or text file with the dataset to train on"} + ) eval_dataset_size: int = field( default=512, metadata={"help": "Size of validation dataset."} ) @@ -23,10 +26,6 @@ class DataArguments: default=False, metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"} ) - dataset: str = field( - default=None, - metadata={"help": "A json file (s2s) or text file with the dataset to train on"} - ) block_size: int = field( default=512, metadata={"help": "size of the blocks the text is split into for training"}, @@ -35,7 +34,7 @@ class DataArguments: @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field( + model_name_or_path: str = field( default="EleutherAI/pythia-12b" ) tokenizer: Optional[str] = field( @@ -49,10 +48,10 @@ class ModelArguments: default=False, metadata={"help": "Never resize tokenizer embeddings"} ) - quantize: Optional[bool] = field ( + quantize: bool = field( default=False, metadata={"help": "Quantize parameters not currently be actively trained"} - ) + ) @dataclass diff --git a/dyntrainmodel.py b/dyntrainmodel.py index 3fd6cb7..e6a1638 100644 --- a/dyntrainmodel.py +++ b/dyntrainmodel.py @@ -16,7 +16,7 @@ class LinearGroup: for module in self.modules: assert isinstance(module, Linear) - def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None, output_device: torch.device = None) -> None: + def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None, output_device: torch.device | None = None) -> None: for module in self.modules: module.inplaceTo(dtype, device) self.modules[-1].setOutputDevice(output_device) @@ -67,7 +67,7 @@ class LinearGroup: class DyntrainModel: - def __init__(self, model_name_or_path: str, cache_dir: str, quantize: bool, + def __init__(self, model_name_or_path: str, cache_dir: str | None, quantize: bool, target_active_params: int, reshuffle_fraction: float, gradient_checkpointing: bool, trust_remote_code: bool = False): self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, @@ -80,19 +80,19 @@ class DyntrainModel: self.reshuffle_fraction = reshuffle_fraction if reshuffle_fraction < 0.10 or reshuffle_fraction > 1: raise RuntimeError("reshuffle_percent must be between 0.1 and 1.0") - self.devices = list() + self.devices = list[torch.device]() self.inital_reshufle = True if gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - self.frozen_linear_groups = list() - self.active_linear_groups = list() + self.frozen_linear_groups = list[LinearGroup]() + self.active_linear_groups = list[LinearGroup]() linear_group_names = DyntrainModel._getLinearGroupNames(self.model) for group in linear_group_names: for key in group: - replace_module(self.model, key, self._getModule(key, quantize, "cuda:0", "cpu")) + replace_module(self.model, key, self._getModule(key, quantize, torch.device("cuda:0"), torch.device("cpu"))) self.frozen_linear_groups.append(LinearGroup(self.model, group)) self.model.model.embed_tokens = self.model.model.embed_tokens.to(torch.float16) for group in self.frozen_linear_groups: @@ -106,6 +106,7 @@ class DyntrainModel: else: return DynamicConvertingLinear.fromLinear(modules[key].to(torch.float16), output_dtype=output_dtype) + @staticmethod def _getNonlinearNames(layer: torch.nn.Module): names = list() modules = dict(layer.named_modules()) @@ -115,8 +116,9 @@ class DyntrainModel: names.append(key) return names + @staticmethod def _getLinearGroupNames(layer: torch.nn.Module) -> list[list[str]]: - linear_groups = list() + linear_groups = list[list[str]]() list_counter = 0 in_sequence = False modules = dict(layer.named_modules()) @@ -132,6 +134,7 @@ class DyntrainModel: list_counter = list_counter + 1 return linear_groups + @staticmethod def isModuleIn16bitOutlist(key: str) -> bool: key = key.split('.')[-1] whitelist = set({ @@ -210,7 +213,7 @@ class DyntrainModel: for i, count in enumerate(active_counts): memory = torch.cuda.get_device_properties(self.devices[i]).total_memory if i == 0: - memory = memory * 0.8 + memory = int(memory * 0.8) bits_per_param.append(count / memory) max_index, max_bits_per_param = max(enumerate(active_counts), key=lambda x: x[1]) @@ -220,7 +223,7 @@ class DyntrainModel: if group.getDevice() is self.devices[max_index]: memory = torch.cuda.get_device_properties(self.devices[max_index]).total_memory if max_index == 0: - memory = memory * 0.8 + memory = int(memory * 0.8) swing = group.paramCount() / memory if max_bits_per_param - swing > min_bits_per_param + swing: group.inplaceTo(device=self.devices[min_index]) @@ -230,7 +233,7 @@ class DyntrainModel: assert len(devices) > 0 modules = dict(self.model.named_modules()) total_memory = sum(torch.cuda.get_device_properties(d).total_memory for d in devices) - total_memory -= torch.cuda.get_device_properties(devices[0]).total_memory * 0.2 + total_memory -= int(torch.cuda.get_device_properties(devices[0]).total_memory * 0.2) static_param_count = self.staticParameterCount() total_parameter_count = static_param_count + self.dynamicParameterCount() params_per_byte = total_parameter_count / float(total_memory) @@ -245,9 +248,9 @@ class DyntrainModel: group_index = 0 for i, device in enumerate(devices[:-1]): - memory = torch.cuda.get_device_properties(devices).total_memory + memory = torch.cuda.get_device_properties(device).total_memory if i == 0: - memory = memory * 0.8 + memory = int(memory * 0.8) params_for_device = memory * params_per_byte params = 0 while params_for_device > params and group_index < len(linear_groups): diff --git a/modules.py b/modules.py index e8a4eae..9ff9dee 100644 --- a/modules.py +++ b/modules.py @@ -39,7 +39,7 @@ class Linear(torch.nn.Linear): def isFrozen(self) -> bool: return not self.weight.requires_grad - def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None): + def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None): frozen = self.isFrozen() if dtype is not None: self.weight = torch.nn.Parameter(self.weight.to(dtype)) @@ -77,7 +77,7 @@ class DynamicConvertingLinear(Linear): self.output_device = output_device @classmethod - def fromLinear(cls, in_module: torch.nn.Linear, output_dtype, output_device=None): + def fromLinear(cls, in_module: torch.nn.Linear, output_dtype=torch.float32, output_device=None): new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None, @@ -124,7 +124,7 @@ class DynamicQantizedLinear(Linear): self.weight_start = self.weight.clone().detach() @classmethod - def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device, + def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device = torch.device("cuda:0"), cold_device: torch.device = torch.device("cpu"), output_dtype=None, compute_dtype=torch.float16, output_device=None): new_module = cls(in_features=in_module.in_features, out_features=in_module.out_features, bias=in_module.bias is not None, active_device=active_device, cold_device=cold_device, output_dtype=output_dtype, @@ -193,7 +193,7 @@ class DynamicQantizedLinear(Linear): return out.to(output_device).to(output_dtype) - def inplaceTo(self, dtype: torch.dtype = None, device: torch.device = None): + def inplaceTo(self, dtype: torch.dtype | None = None, device: torch.device | None = None): if dtype is not None: super().inplaceTo(dtype=dtype) if device is not None: diff --git a/train_dynamic.py b/train_dynamic.py index c04e787..96ff497 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -22,22 +22,22 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = model.save_pretrained(output_dir) if max_checkpoints > 0: - files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.starts_with("step_")] + files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.startswith("step_")] def extract_step(filename): tokens = filename.split('_') return int(tokens[1]) if len(files) > max_checkpoints: - min_step = min(map(extract_step, extract_step)) + min_step = min(map(extract_step, files)) delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}") print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}") shutil.rmtree(delete_checkpoit_dir) -def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters: list[torch.nn.parameter], lr: float, static_lr: float, +def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters: list[torch.nn.Parameter] | None, lr: float, static_lr: float, weight_decay: float, eps: float, adam8bit: bool): - parameters = list() + parameters = list[dict]() parameters.extend({'params': p} for p in dyamic_parameters if p.requires_grad) param_ids = set([id(p['params']) for p in parameters]) if static_parameters is not None: @@ -71,6 +71,7 @@ def evaluate(model: DyntrainModel, tokenizer, loss = loss / len(dataloader) log_writer.add_scalar("Loss/Eval", loss, globalstep) print(f"Eval Loss {loss.item()}") + return loss.item() if eval_prompt is not None: input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0]) @@ -84,7 +85,7 @@ def evaluate(model: DyntrainModel, tokenizer, def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments): log_writer = tensorboard.SummaryWriter() - model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=training_args.max_instant_params * 1e6, + model = DyntrainModel(model_args.model_name_or_path, training_args.cache_dir, target_active_params=int(training_args.max_instant_params * 1e6), reshuffle_fraction=training_args.churn_percent / 100.0, gradient_checkpointing=True, trust_remote_code=True, quantize=model_args.quantize) devices = list(torch.device(i) for i in range(0, torch.cuda.device_count())) @@ -95,7 +96,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T paramter_count = sum(p.numel() for p in model.model.parameters()) active_paramter_count = sum(p.numel() for p in model.model.parameters() if p.requires_grad) static_parameter_count = model.staticParameterCount() if training_args.train_non_linear_layers else 0 - print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters of which {static_parameter_count} are static") + print(f"Training model with {paramter_count / 1e6}m parameters and {active_paramter_count / 1e6}m" + f"instantanous active paramters of which {static_parameter_count} are static") tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args) @@ -122,7 +124,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T ) if dataset['eval_dataset'] is not None else None dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() - steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) + steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1 total_steps = steps_per_epoch * training_args.epochs optimizer = get_optimizer(model.dynamicParameters(),