add support for huggingfacehub datasets and for specificying a prompt for eval

This commit is contained in:
2024-05-07 00:23:12 +02:00
parent 8abea9ef89
commit a74ef976e4
5 changed files with 183 additions and 43 deletions

View File

@ -7,9 +7,10 @@ import os
import shutil
import math
from tqdm.auto import tqdm
import gc
from arguments import DataArguments, ModelArguments, TrainingArguments
from datamodules import create_data_module_s2s, create_data_module
from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub
from tokenizer import get_tokenizer
from dyntrainmodel import DyntrainModel
@ -56,7 +57,9 @@ def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters
return optimizer
def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> float:
def evaluate(model: DyntrainModel, tokenizer,
dataloader: torch.utils.data.DataLoader, globalstep: int,
log_writer: tensorboard.SummaryWriter, eval_prompt: str = None):
print("*** Eval ***")
loss = torch.zeros((1), device="cuda:0")
model.model.eval()
@ -66,8 +69,17 @@ def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> f
outputs = model.model(**batch)
loss += outputs.loss
loss = loss / len(dataloader)
log_writer.add_scalar("Loss/Eval", loss, globalstep)
print(f"Eval Loss {loss.item()}")
if eval_prompt is not None:
input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100)
response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"Eval generation: response_decoded")
log_writer.add_text("Text/Eval", response_decoded, globalstep)
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
log_writer = tensorboard.SummaryWriter()
@ -90,6 +102,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
if data_args.dataset.endswith("json"):
print("Loading dataset in s2s mode")
data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
elif data_args.data_from_hub:
data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
else:
print("Loading dataset in txt mode")
data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
@ -137,12 +151,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
for step, batch in enumerate(train_dataloader):
for key in batch:
batch[key] = batch[key].to("cuda:0")
outputs = model.model(**batch)
loss = outputs.loss / training_args.gradient_accumulation_steps
log_writer.add_scalar("Loss/train", loss, global_step)
loss.backward()
if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader):
if global_step % training_args.logging_steps == 0:
log_writer.add_scalar("Loss/train", loss, global_step)
optimizer.step()
lr_scheduler.step()
@ -151,9 +167,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
if global_step % 5 == 0:
print(f"Train Loss {loss.item()}")
if global_step % 50 == 0 and training_args.max_instant_params != 0:
if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
print("Reshuffleing")
lr_scheduler.optimizer = None
del optimizer
# distance, error = model.getDistanceAndErrorSample()
# log_writer.add_histogram("Distances/Train", distance, max_bins=50)
# log_writer.add_histogram("Errors/Train", error, max_bins=50)
model.reshuffleActive()
model.balanceActive()
log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
@ -173,15 +194,16 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
if global_step % training_args.save_steps == 0:
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0:
evaluate(model, eval_dataloader)
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
if training_args.flush_allocator:
gc.collect()
torch.cuda.empty_cache()
if training_args.do_eval and training_args.eval_steps == -1:
evaluate(model, eval_dataloader)
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
# Evaluation
if training_args.do_eval:
evaluate(model, eval_dataloader)
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
save_model(model.model, global_step, training_args.output_dir)