working full training

This commit is contained in:
uvos 2024-03-09 10:03:37 +01:00
parent 7a47fcdcc0
commit 11ea9eeaa7
2 changed files with 25 additions and 7 deletions

View File

@ -2,11 +2,12 @@ import torch
class ConvertingLinear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.output_dtype = output_dtype
def forward(self, input: torch.Tensor):
output_dtype = input.dtype
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
output_device = input.device
if input.device != self.weight.device:
input = input.to(self.weight.device)

View File

@ -7,9 +7,11 @@ from torch.utils import tensorboard
import os
import shutil
import math
import collections
from tqdm.auto import tqdm
from random import randint
from typing import Tuple
import collections
from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module
@ -49,9 +51,12 @@ def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
attn_implementation="flash_attention_2"
)
# for name, module in model.named_modules():
# if 'norm' in name:
# module = module.to(torch.float32)
modules = dict(model.named_modules())
keys = find_all_linear_module_names(model)
for key in keys:
modules[key].weight = modules[key].weight.to(dtype)
if modules[key].bias is not None:
modules[key].bias = modules[key].bias.to(dtype)
return model
@ -170,12 +175,23 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
weight_decay: float, eps: float, adam8bit: bool):
all_keys = dynamic_module_names + static_module_names
duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
if len(duplicated) > 0:
print("duplicated items:")
for item in duplicated:
print(item)
raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
parameters = list()
modules = dict(model.named_modules())
for key in dynamic_module_names:
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
param_ids = set([id(p['params']) for p in parameters])
for key in static_module_names:
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad)
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
for p in modules[key].parameters():
param_ids.add(id(p))
if not adam8bit:
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
@ -200,6 +216,7 @@ def compute_dynamic_parameter_ratio(model):
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
if data_args.dataset.endswith("json"):