working full training

This commit is contained in:
2024-03-09 10:03:37 +01:00
parent 7a47fcdcc0
commit 11ea9eeaa7
2 changed files with 25 additions and 7 deletions

View File

@ -2,11 +2,12 @@ import torch
class ConvertingLinear(torch.nn.Linear): class ConvertingLinear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
self.output_dtype = output_dtype
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
output_dtype = input.dtype output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
output_device = input.device output_device = input.device
if input.device != self.weight.device: if input.device != self.weight.device:
input = input.to(self.weight.device) input = input.to(self.weight.device)

View File

@ -7,9 +7,11 @@ from torch.utils import tensorboard
import os import os
import shutil import shutil
import math import math
import collections
from tqdm.auto import tqdm from tqdm.auto import tqdm
from random import randint from random import randint
from typing import Tuple import collections
from arguments import DataArguments, ModelArguments, TrainingArguments from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module from datamodules import create_data_module_s2s, create_data_module
@ -49,9 +51,12 @@ def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
attn_implementation="flash_attention_2" attn_implementation="flash_attention_2"
) )
# for name, module in model.named_modules(): modules = dict(model.named_modules())
# if 'norm' in name: keys = find_all_linear_module_names(model)
# module = module.to(torch.float32) for key in keys:
modules[key].weight = modules[key].weight.to(dtype)
if modules[key].bias is not None:
modules[key].bias = modules[key].bias.to(dtype)
return model return model
@ -170,12 +175,23 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float, def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
weight_decay: float, eps: float, adam8bit: bool): weight_decay: float, eps: float, adam8bit: bool):
all_keys = dynamic_module_names + static_module_names
duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
if len(duplicated) > 0:
print("duplicated items:")
for item in duplicated:
print(item)
raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
parameters = list() parameters = list()
modules = dict(model.named_modules()) modules = dict(model.named_modules())
for key in dynamic_module_names: for key in dynamic_module_names:
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad) parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
param_ids = set([id(p['params']) for p in parameters])
for key in static_module_names: for key in static_module_names:
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad) parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
for p in modules[key].parameters():
param_ids.add(id(p))
if not adam8bit: if not adam8bit:
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon) optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
@ -200,6 +216,7 @@ def compute_dynamic_parameter_ratio(model):
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple: def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device) model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args) tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
if data_args.dataset.endswith("json"): if data_args.dataset.endswith("json"):