diff --git a/convertinglinear.py b/convertinglinear.py index b5c494f..09bad44 100644 --- a/convertinglinear.py +++ b/convertinglinear.py @@ -2,11 +2,12 @@ import torch class ConvertingLinear(torch.nn.Linear): - def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None): super().__init__(in_features, out_features, bias, device, dtype) + self.output_dtype = output_dtype def forward(self, input: torch.Tensor): - output_dtype = input.dtype + output_dtype = input.dtype if self.output_dtype is None else self.output_dtype output_device = input.device if input.device != self.weight.device: input = input.to(self.weight.device) diff --git a/train_dynamic.py b/train_dynamic.py index 0c92f39..18f5a16 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -7,9 +7,11 @@ from torch.utils import tensorboard import os import shutil import math +import collections from tqdm.auto import tqdm from random import randint -from typing import Tuple +import collections + from arguments import DataArguments, ModelArguments, TrainingArguments from datamodules import create_data_module_s2s, create_data_module @@ -49,9 +51,12 @@ def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing): attn_implementation="flash_attention_2" ) - # for name, module in model.named_modules(): - # if 'norm' in name: - # module = module.to(torch.float32) + modules = dict(model.named_modules()) + keys = find_all_linear_module_names(model) + for key in keys: + modules[key].weight = modules[key].weight.to(dtype) + if modules[key].bias is not None: + modules[key].bias = modules[key].bias.to(dtype) return model @@ -170,12 +175,23 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float, weight_decay: float, eps: float, adam8bit: bool): + all_keys = dynamic_module_names + static_module_names + duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1] + if len(duplicated) > 0: + print("duplicated items:") + for item in duplicated: + print(item) + raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters") + parameters = list() modules = dict(model.named_modules()) for key in dynamic_module_names: parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad) + param_ids = set([id(p['params']) for p in parameters]) for key in static_module_names: - parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad) + parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids) + for p in modules[key].parameters(): + param_ids.add(id(p)) if not adam8bit: optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon) @@ -200,6 +216,7 @@ def compute_dynamic_parameter_ratio(model): def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple: model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device) + tokenizer = get_tokenizer(model, training_args.cache_dir, model_args) if data_args.dataset.endswith("json"):