Inital commit

This commit is contained in:
2024-03-06 17:50:40 +01:00
commit 7a47fcdcc0
5 changed files with 716 additions and 0 deletions

95
arguments.py Normal file
View File

@ -0,0 +1,95 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class DataArguments:
eval_dataset_size: int = field(
default=512, metadata={"help": "Size of validation dataset."}
)
source_max_len: int = field(
default=512,
metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."},
)
train_on_source: Optional[bool] = field(
default=False,
metadata={"help": "Wether to train on the input in addition to the target text when in s2s mode."}
)
target_max_len: int = field(
default=256,
metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
)
dataset: str = field(
default=None,
metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
)
block_size: int = field(
default=512,
metadata={"help": "size of the blocks the text is split into for training"},
)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default="EleutherAI/pythia-12b"
)
tokenizer: Optional[str] = field(
default=None
)
trust_remote_code: Optional[bool] = field(
default=False,
metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
)
max_instant_params: int = field(
default=0,
metadata={"help": "Maximum amount of paramters to optimize per step in millions"}
)
noresize: Optional[bool] = field(
default=False,
metadata={"help": "Never resize tokenizer embeddings"}
)
@dataclass
class TrainingArguments():
cache_dir: Optional[str] = field(
default=None
)
adam8bit: bool = field(
default=False,
metadata={"help": "Use 8-bit adam."}
)
report_to: str = field(
default='none',
metadata={"help": "To use wandb or something else for reporting."}
)
resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'})
weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'})
learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'})
adam_epsilon: float = field(default=1e-7, metadata={"help": 'Adam epsilon'})
remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'})
max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'})
gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'})
fp16: bool = field(default=False, metadata={"help": 'Train in 16 bit mixed precision'})
do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'})
do_eval: bool = field(default=False, metadata={"help": 'To eval or not to eval, that is the question?'})
lr_scheduler_type: str = field(default='constant',
metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'})
logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
group_by_length: bool = field(default=False,
metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
storage_fp16: bool = field(default=False, metadata={"help": 'Store untrained layers in 16bit'})
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'})
save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
primary_device: str = field(default="cuda:0", metadata={"help": 'The primary device to use'})
secondary_device: str = field(default="cuda:0", metadata={"help": 'The secondary device to use'})
train_non_linear_layers: str = field(default=False, metadata={"help": 'train non linear layers'})
flush_allocator: bool = field(default=False, metadata={"help": 'flush torches allocator on eatch iteration'})

29
convertinglinear.py Normal file
View File

@ -0,0 +1,29 @@
import torch
class ConvertingLinear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
def forward(self, input: torch.Tensor):
output_dtype = input.dtype
output_device = input.device
if input.device != self.weight.device:
input = input.to(self.weight.device)
if input.dtype != self.weight.dtype:
input = input.to(self.weight.dtype)
output = torch.nn.Linear.forward(self, input)
if torch.isnan(output).any() or self.weight.dtype != torch.float32:
breakpoint()
return output.to(output_device).to(output_dtype)
@classmethod
def fromLinear(cls, in_module: torch.nn.Linear):
new_module = torch.nn.utils.skip_init(cls, in_features=in_module.in_features,
out_features=in_module.out_features,
bias=in_module.bias is not None,
device=in_module.weight.device,
dtype=in_module.weight.dtype)
new_module.weight = in_module.weight
new_module.bias = in_module.bias
return new_module

192
datamodules.py Normal file
View File

@ -0,0 +1,192 @@
import copy
import torch
import typing
import datasets
import itertools
import transformers
from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
from arguments import DataArguments
IGNORE_INDEX = -100
def group_texts(examples, block_size: int):
# Concatenate all texts.
concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
result["labels"] = result["input_ids"].copy()
return result
@dataclass
class DataCollatorForCausalLM(object):
tokenizer: transformers.PreTrainedTokenizer
source_max_len: int
target_max_len: int
train_on_source: bool
predict_with_generate: bool
def __call__(self, instances: typing.Sequence[typing.Dict]) -> typing.Dict[str, torch.Tensor]:
# Extract elements
sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances]
targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances]
# Tokenize
tokenized_sources_with_prompt = self.tokenizer(
sources,
max_length=self.source_max_len,
truncation=True,
add_special_tokens=False,
)
tokenized_targets = self.tokenizer(
targets,
max_length=self.target_max_len,
truncation=True,
add_special_tokens=False,
)
# Build the input and labels for causal LM
input_ids = []
labels = []
for tokenized_source, tokenized_target in zip(
tokenized_sources_with_prompt['input_ids'],
tokenized_targets['input_ids']
):
if not self.predict_with_generate:
input_ids.append(torch.tensor(tokenized_source + tokenized_target))
if not self.train_on_source:
labels.append(
torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
)
else:
labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)))
else:
input_ids.append(torch.tensor(tokenized_source))
# Apply padding
padding_value = None
if self.tokenizer.pad_token_id is not None:
padding_value = self.tokenizer.pad_token_id
elif self.tokenizer.eos_token_id is not None:
padding_value = self.tokenizer.eos_token_id
else:
raise RuntimeError("Model dose not have a pad or eos token")
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=padding_value)
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None
data_dict = {
'input_ids': input_ids,
'attention_mask': input_ids.ne(padding_value),
}
if labels is not None:
data_dict['labels'] = labels
return data_dict
def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train, do_eval, do_predict) -> typing.Dict:
try:
dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset)
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if do_eval or do_predict:
if 'eval' in dataset:
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset.train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
)
eval_dataset = dataset['test']
eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
if 'train' in dataset:
train_dataset = dataset['train']
else:
train_dataset = dataset
train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
data_collator = DataCollatorForCausalLM(
tokenizer=tokenizer,
source_max_len=data_args.source_max_len,
target_max_len=data_args.target_max_len,
train_on_source=data_args.train_on_source,
predict_with_generate=False # args.predict_with_generate,
)
return dict(
train_dataset=train_dataset if do_train else None,
eval_dataset=eval_dataset if do_eval else None,
predict_dataset=eval_dataset if do_predict else None,
data_collator=data_collator
)
def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
try:
dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
except FileNotFoundError as ex:
raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
if data_args.block_size > tokenizer.model_max_length:
raise ValueError(f"Block size of {data_args.block_size} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
def add_newline_fn(example):
example['text'] = example['text'] + '\n'
return example
dataset = dataset.map(add_newline_fn)
eval_dataset = None
if do_eval or do_predict:
if 'eval' in dataset:
eval_dataset = dataset['eval']
else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset.train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42
)
eval_dataset = dataset['test']
if 'train' in dataset:
train_dataset = dataset['train']
else:
train_dataset = dataset
train_dataset_tokenized = train_dataset.map(
lambda example: tokenizer(example['text']),
batched=True,
remove_columns='text',
num_proc=32,
load_from_cache_file=True)
train_dataset_tokenized = train_dataset_tokenized.map(
lambda example: group_texts(example, data_args.block_size),
batched=True,
num_proc=32,
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {data_args.block_size}")
eval_dataset_tokenized = None
if eval_dataset is not None:
eval_dataset_tokenized = eval_dataset.map(
lambda example: tokenizer(example['text']),
batched=True,
remove_columns='text',
num_proc=32)
eval_dataset_tokenized = eval_dataset_tokenized.map(
lambda example: group_texts(example, data_args.block_size),
batched=True,
num_proc=32,
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {data_args.block_size}")
return dict(
train_dataset=train_dataset_tokenized if do_train else None,
eval_dataset=eval_dataset_tokenized if do_eval else None,
predict_dataset=eval_dataset_tokenized if do_predict else None,
data_collator=transformers.default_data_collator
)

62
tokenizer.py Normal file
View File

@ -0,0 +1,62 @@
import transformers
from arguments import ModelArguments
DEFAULT_PAD_TOKEN = "[PAD]"
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings_data = model.get_input_embeddings().weight.data
output_embeddings_data = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
output_embeddings_data[-num_new_tokens:] = output_embeddings_avg
def get_tokenizer(model, cache_dir, model_args: ModelArguments):
print(f'Tokenizer: {model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path}')
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path,
cache_dir=cache_dir,
padding_side="right",
use_fast=False,
eos_token="[EOS]",
tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None,
trust_remote_code=model_args.trust_remote_code
)
if tokenizer._pad_token is None and not model_args.noresize:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if 'llama' in model_args.model_name_or_path or isinstance(tokenizer, transformers.LlamaTokenizer):
# LLaMA tokenizer may not have correct special tokens set.
# Check and add them if missing to prevent them from being parsed into different tokens.
# Note that these are present in the vocabulary.
# Note also that `model.config.pad_token_id` is 0 which corresponds to `<unk>` token.
print('Adding special tokens.')
tokenizer.add_special_tokens({
"eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
"bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
"unk_token": tokenizer.convert_ids_to_tokens(
model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id
),
})
return tokenizer

338
train_dynamic.py Normal file
View File

@ -0,0 +1,338 @@
import transformers
from transformers import AutoModelForCausalLM, get_scheduler
from peft.utils import _get_submodules
import torch
from torch.utils import tensorboard
import os
import shutil
import math
from tqdm.auto import tqdm
from random import randint
from typing import Tuple
from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module
from convertinglinear import ConvertingLinear
from tokenizer import get_tokenizer
def find_all_linear_module_names(model):
module_names = set()
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear):
module_names.add(name)
if 'lm_head' in module_names: # needed for 16-bit
module_names.remove('lm_head')
return list(module_names)
def find_all_outher_module_names(model):
module_names = set()
for name, module in model.named_modules():
if not (isinstance(module, torch.nn.Linear) or isinstance(module, ConvertingLinear)):
module_names.add(name)
return list(module_names)
def get_model(model_args: ModelArguments, cache_dir, gradient_checkpointing):
dtype = torch.float16 if training_args.fp16 or (training_args.storage_fp16 and model_args.max_instant_params > 0) else torch.float32
print(f'loading base model {model_args.model_name_or_path} in {dtype}...')
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=cache_dir,
torch_dtype=dtype if model_args.max_instant_params > 0 else torch.float32,
trust_remote_code=model_args.trust_remote_code,
device_map=None,
attn_implementation="flash_attention_2"
)
# for name, module in model.named_modules():
# if 'norm' in name:
# module = module.to(torch.float32)
return model
@torch.no_grad()
def recursive_setattr(obj, attr, value):
attr = attr.split('.', 1)
if len(attr) == 1:
setattr(obj, attr[0], value)
else:
recursive_setattr(getattr(obj, attr[0]), attr[1], value)
@torch.no_grad()
def set_linear_module_frozen_simple(module, frozen: bool, dtype: torch.dtype, device: torch.device):
new_module = torch.nn.Linear(module.in_features,
module.out_features,
module.bias is not None,
module.weight.device,
dtype)
new_module.weight = torch.nn.Parameter(module.weight.detach().clone())
new_module.bias = torch.nn.Parameter(module.bias.detach().clone()) if module.bias is not None else None
new_module.weight.requires_grad = not frozen
if new_module.bias is not None:
new_module.bias.requires_grad = not frozen
return new_module
@torch.no_grad()
def set_linear_module_frozen(module, frozen: bool, dtype: torch.dtype, device: torch.device):
if type(module) is torch.nn.Linear:
if frozen:
module.weight.requires_grad = False
if module.bias is not None:
module.bias.requires_grad = False
return module.to(dtype).to(device)
else:
new_module = ConvertingLinear.fromLinear(module).to(dtype)
new_module.weight.requires_grad = True
if new_module.bias is not None:
new_module.bias.requires_grad = True
return new_module.to(device)
elif type(module) is ConvertingLinear:
if not frozen:
module.weight.requires_grad = True
if module.bias is not None:
module.bias.requires_grad = True
assert False
return module.to(dtype).to(device)
else:
new_module = torch.nn.utils.skip_init(torch.nn.Linear, in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=dtype)
new_module.weight = torch.nn.Parameter(module.weight.to(dtype))
new_module.bias = torch.nn.Parameter(module.bias.to(dtype)) if module.bias is not None else None
new_module.weight.requires_grad = False
if new_module.bias is not None:
new_module.bias.requires_grad = False
return new_module.to(device)
else:
assert False
@torch.no_grad()
def freeze_random_modules(model, target_params: int, frozen_dtype: torch.dtype, frozen_device: torch.device, active_device: torch.device):
modules = dict(model.named_modules())
linear_names = find_all_linear_module_names(model)
for key in linear_names:
if modules[key].weight.dtype != frozen_dtype or modules[key].weight.requires_grad or modules[key].weight.requires_grad:
parent, target, target_name = _get_submodules(model, key)
setattr(parent, target_name, set_linear_module_frozen(modules[key], True, frozen_dtype, frozen_device))
modules = dict(model.named_modules())
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
if active_paramter_count > target_params:
raise RuntimeError("Enough paramters must be available to train at least one linear layer")
while active_paramter_count < target_params and len(linear_names) > 0:
i = randint(0, len(linear_names) - 1)
parent, target, target_name = _get_submodules(model, linear_names[i])
new_module = set_linear_module_frozen(modules[linear_names[i]], False, torch.float32, active_device)
setattr(parent, target_name, new_module)
active_paramter_count += modules[linear_names[i]].weight.numel()
if modules[linear_names[i]].bias is not None:
active_paramter_count += modules[linear_names[i]].bias.numel()
linear_names.pop(i)
modules = dict()
assert active_paramter_count == sum(p.numel() for p in model.parameters() if p.requires_grad)
return active_paramter_count
def save_model(model, global_step: int, output_dir: str, max_checkpoints: int = 0):
output_chkpt_dir = f"step_{global_step}" if global_step >= 0 else ""
output_dir = os.path.join(output_dir, output_chkpt_dir)
model.save_pretrained(output_dir)
if max_checkpoints > 0:
files = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f)) and f.starts_with("step_")]
def extract_step(filename):
tokens = filename.split('_')
return int(tokens[1])
if len(files) > max_checkpoints:
min_step = min(map(extract_step, extract_step))
delete_checkpoit_dir = os.path.join(output_dir, f"step_{min_step}")
print(f"there are more than {max_checkpoints} checkpints saved, deleting {delete_checkpoit_dir}")
shutil.rmtree(delete_checkpoit_dir)
def get_optimizer(model, dynamic_module_names: list, static_module_names: list, lr: float, static_lr: float,
weight_decay: float, eps: float, adam8bit: bool):
parameters = list()
modules = dict(model.named_modules())
for key in dynamic_module_names:
parameters.extend({'params': p} for p in modules[key].parameters() if p.requires_grad)
for key in static_module_names:
parameters.extend({'params': p, 'lr': static_lr} for p in modules[key].parameters() if p.requires_grad)
if not adam8bit:
optimizer = torch.optim.AdamW(parameters, weight_decay=weight_decay, lr=lr, eps=training_args.adam_epsilon)
else:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, bitsandbytes must be available")
optimizer = bnb.optim.AdamW8bit(parameters, weight_decay=weight_decay, lr=lr, eps=eps)
return optimizer
def compute_dynamic_parameter_ratio(model):
modules = dict(model.named_modules())
active_linear_parameters = 0
total_linear_parameters = 0
for key in find_all_linear_module_names(model):
active_linear_parameters += sum(p.numel() for p in modules[key].parameters() if p.requires_grad)
total_linear_parameters += sum(p.numel() for p in modules[key].parameters())
return math.ceil(total_linear_parameters / active_linear_parameters)
def prepare(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, primary_device: torch.device, secondary_device: torch.device) -> tuple:
model = get_model(model_args, training_args.cache_dir, training_args.gradient_checkpointing).to(primary_device)
tokenizer = get_tokenizer(model, training_args.cache_dir, model_args)
if data_args.dataset.endswith("json"):
print("Loading dataset in s2s mode")
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
else:
print("Loading dataset in txt mode")
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
train_dataloader = torch.utils.data.DataLoader(
dataset['train_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
) if dataset['train_dataset'] is not None else None
eval_dataloader = torch.utils.data.DataLoader(
dataset['eval_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
) if dataset['eval_dataset'] is not None else None
if model_args.max_instant_params != 0:
print(f"Target params {model_args.max_instant_params}m")
freeze_random_modules(model, model_args.max_instant_params * 1e6,
torch.float16 if training_args.storage_fp16 else torch.float32,
frozen_device=primary_device, active_device=secondary_device)
paramter_count = sum(p.numel() for p in model.parameters())
active_paramter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Training model with {paramter_count/1e6}m parameters and {active_paramter_count/1e6}m instantanous active paramters")
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
print(f"dyanamic parameter ratio: 1/{dynamic_param_ratio}")
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
total_steps = steps_per_epoch * training_args.epochs
optimizer = get_optimizer(model, find_all_linear_module_names(model),
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler = get_scheduler(
name=training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=total_steps
)
return model, optimizer, lr_scheduler, train_dataloader
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
primary_device = torch.device(training_args.primary_device)
secondary_device = torch.device(training_args.secondary_device)
log_writer = tensorboard.SummaryWriter()
model, optimizer, lr_scheduler, train_dataloader = prepare(model_args, data_args, training_args, primary_device, secondary_device)
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
total_steps = steps_per_epoch * training_args.epochs
dynamic_param_ratio = compute_dynamic_parameter_ratio(model)
if training_args.do_train:
progress_bar = tqdm(range(total_steps))
global_step = 0
model.train()
for epoch in range(0, training_args.epochs):
print("*** Train ***")
print(f'Vram used for model before training starts: {torch.cuda.memory_allocated()/(1024.0*1024.0)}')
for step, batch in enumerate(train_dataloader):
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model(**batch)
loss = outputs.loss / training_args.gradient_accumulation_steps
log_writer.add_scalar("Loss/train", loss, global_step)
loss.backward()
if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader):
optimizer.step()
lr_scheduler.step()
model.zero_grad()
if global_step % 10 == 0:
print(loss)
if global_step % 10 == 0 and model_args.max_instant_params != 0:
param_count = freeze_random_modules(model, model_args.max_instant_params * 1e6,
torch.float16 if training_args.storage_fp16 else torch.float32,
frozen_device=primary_device,
active_device=secondary_device)
log_writer.add_scalar("Parameters/train", param_count, global_step)
optimizer = get_optimizer(model, find_all_linear_module_names(model),
find_all_outher_module_names(model) if training_args.train_non_linear_layers else list(),
training_args.learning_rate,
training_args.learning_rate / dynamic_param_ratio,
training_args.weight_decay,
training_args.adam_epsilon,
training_args.adam8bit)
lr_scheduler.optimizer = optimizer
global_step += 1
progress_bar.update()
if global_step % training_args.save_steps == 0:
save_model(model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.flush_allocator:
torch.cuda.empty_cache()
# Evaluation
if training_args.do_eval:
print("*** Evaluate ***")
save_model(model, global_step, training_args.output_dir)
return
if __name__ == "__main__":
hfparser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
print("Model Arguments:")
print(model_args)
print("\nData Arguments:")
print(data_args)
print("\nTraining Arguments:")
print(training_args)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
train(model_args, data_args, training_args)