working full training
This commit is contained in:
parent
7a47fcdcc0
commit
11ea9eeaa7
@ -2,11 +2,12 @@ import torch
|
||||
|
||||
|
||||
class ConvertingLinear(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.output_dtype = output_dtype
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
output_dtype = input.dtype
|
||||
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
|
||||
output_device = input.device
|
||||
if input.device != self.weight.device:
|
||||
input = input.to(self.weight.device)
|
||||
|
@ -7,9 +7,11 @@ from torch.utils import tensorboard
|
||||
import os
|
||||
import shutil
|
||||
import math
|
||||
import collections
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
from typing import Tuple
|
||||
import collections
|
||||
|
||||
|
||||
from arguments import DataArguments, ModelArguments, TrainingArguments
|
||||
from datamodules import create_data_module_s2s, create_data_module
|
||||
@ -49,9 +51,12 @@ def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
# for name, module in model.named_modules():
|
||||
# if 'norm' in name:
|
||||
# module = module.to(torch.float32)
|
||||
modules = dict(model.named_modules())
|
||||
keys = find_all_linear_module_names(model)
|
||||
for key in keys:
|
||||
modules[key].weight = modules[key].weight.to(dtype)
|
||||
if modules[key].bias is not None:
|
||||
modules[key].bias = modules[key].bias.to(dtype)
|
||||
|
||||
return model
|
||||
|
||||
@ -170,12 +175,23 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
|
||||
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
|
||||
weight_decay: float, eps: float, adam8bit: bool):
|
||||
|
||||
all_keys = dynamic_module_names + static_module_names
|
||||
duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
|
||||
if len(duplicated) > 0:
|
||||
print("duplicated items:")
|
||||
for item in duplicated:
|
||||
print(item)
|
||||
raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
|
||||
|
||||
parameters = list()
|
||||
modules = dict(model.named_modules())
|
||||
for key in dynamic_module_names:
|
||||
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
|
||||
param_ids = set([id(p['params']) for p in parameters])
|
||||
for key in static_module_names:
|
||||
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad)
|
||||
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
|
||||
for p in modules[key].parameters():
|
||||
param_ids.add(id(p))
|
||||
|
||||
if not adam8bit:
|
||||
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
|
||||
@ -200,6 +216,7 @@ def compute_dynamic_parameter_ratio(model):
|
||||
|
||||
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
|
||||
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
|
||||
|
||||
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
|
||||
|
||||
if data_args.dataset.endswith("json"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user