fix evaluate() invocation

This commit is contained in:
Carl Philipp Klemm 2024-05-07 15:10:02 +02:00
parent ce2ada2617
commit 2ee4f294af

View File

@ -59,7 +59,7 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters
def evaluate(model: DyntrainModel, tokenizer,
dataloader: torch.utils.data.DataLoader, globalstep: int,
log_writer: tensorboard.SummaryWriter, eval_prompt: str = None):
log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
print("*** Eval ***")
loss = torch.zeros((1), device="cuda:0")
model.model.eval()
@ -116,12 +116,13 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
) if dataset['train_dataset'] is not None else None
eval_dataloader = torch.utils.data.DataLoader(
dataset['eval_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
) if dataset['eval_dataset'] is not None else None
if training_args.do_eval:
eval_dataloader = torch.utils.data.DataLoader(
dataset['eval_dataset'],
shuffle=True,
collate_fn=dataset['data_collator'],
batch_size=training_args.per_device_train_batch_size
)
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
@ -195,17 +196,17 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
if global_step > 0:
if global_step % training_args.save_steps == 0:
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0:
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
if training_args.flush_allocator:
gc.collect()
torch.cuda.empty_cache()
if training_args.do_eval and training_args.eval_steps == -1:
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
# Evaluation
if training_args.do_eval:
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
save_model(model.model, global_step, training_args.output_dir)