working full training
This commit is contained in:
@ -2,11 +2,12 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class ConvertingLinear(torch.nn.Linear):
|
class ConvertingLinear(torch.nn.Linear):
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, output_dtype=None):
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
self.output_dtype = output_dtype
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
output_dtype = input.dtype
|
output_dtype = input.dtype if self.output_dtype is None else self.output_dtype
|
||||||
output_device = input.device
|
output_device = input.device
|
||||||
if input.device != self.weight.device:
|
if input.device != self.weight.device:
|
||||||
input = input.to(self.weight.device)
|
input = input.to(self.weight.device)
|
||||||
|
@ -7,9 +7,11 @@ from torch.utils import tensorboard
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import math
|
import math
|
||||||
|
import collections
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from random import randint
|
from random import randint
|
||||||
from typing import Tuple
|
import collections
|
||||||
|
|
||||||
|
|
||||||
from arguments import DataArguments, ModelArguments, TrainingArguments
|
from arguments import DataArguments, ModelArguments, TrainingArguments
|
||||||
from datamodules import create_data_module_s2s, create_data_module
|
from datamodules import create_data_module_s2s, create_data_module
|
||||||
@ -49,9 +51,12 @@ def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
|
|||||||
attn_implementation="flash_attention_2"
|
attn_implementation="flash_attention_2"
|
||||||
)
|
)
|
||||||
|
|
||||||
# for name, module in model.named_modules():
|
modules = dict(model.named_modules())
|
||||||
# if 'norm' in name:
|
keys = find_all_linear_module_names(model)
|
||||||
# module = module.to(torch.float32)
|
for key in keys:
|
||||||
|
modules[key].weight = modules[key].weight.to(dtype)
|
||||||
|
if modules[key].bias is not None:
|
||||||
|
modules[key].bias = modules[key].bias.to(dtype)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -170,12 +175,23 @@ def save_model(model, global_step: int, output_dir: str, max_checkpoints: int =
|
|||||||
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
|
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
|
||||||
weight_decay: float, eps: float, adam8bit: bool):
|
weight_decay: float, eps: float, adam8bit: bool):
|
||||||
|
|
||||||
|
all_keys = dynamic_module_names + static_module_names
|
||||||
|
duplicated = [k for k, v in collections.Counter(all_keys).items() if v > 1]
|
||||||
|
if len(duplicated) > 0:
|
||||||
|
print("duplicated items:")
|
||||||
|
for item in duplicated:
|
||||||
|
print(item)
|
||||||
|
raise ValueError("dynamic_module_names and or static_module_names contain duplicated paramters")
|
||||||
|
|
||||||
parameters = list()
|
parameters = list()
|
||||||
modules = dict(model.named_modules())
|
modules = dict(model.named_modules())
|
||||||
for key in dynamic_module_names:
|
for key in dynamic_module_names:
|
||||||
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
|
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
|
||||||
|
param_ids = set([id(p['params']) for p in parameters])
|
||||||
for key in static_module_names:
|
for key in static_module_names:
|
||||||
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad)
|
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad and id(p) not in param_ids)
|
||||||
|
for p in modules[key].parameters():
|
||||||
|
param_ids.add(id(p))
|
||||||
|
|
||||||
if not adam8bit:
|
if not adam8bit:
|
||||||
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
|
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
|
||||||
@ -200,6 +216,7 @@ def compute_dynamic_parameter_ratio(model):
|
|||||||
|
|
||||||
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
|
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
|
||||||
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
|
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
|
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
|
||||||
|
|
||||||
if data_args.dataset.endswith("json"):
|
if data_args.dataset.endswith("json"):
|
||||||
|
Reference in New Issue
Block a user