From 2ee4f294af46a72c81a051429422c6f710069354 Mon Sep 17 00:00:00 2001 From: Carl Philipp Klemm Date: Tue, 7 May 2024 15:10:02 +0200 Subject: [PATCH] fix evaluate() invocation --- train_dynamic.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/train_dynamic.py b/train_dynamic.py index 96ff497..818d730 100644 --- a/train_dynamic.py +++ b/train_dynamic.py @@ -59,7 +59,7 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters def evaluate(model: DyntrainModel, tokenizer, dataloader: torch.utils.data.DataLoader, globalstep: int, - log_writer: tensorboard.SummaryWriter, eval_prompt: str = None): + log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None): print("*** Eval ***") loss = torch.zeros((1), device="cuda:0") model.model.eval() @@ -116,12 +116,13 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T collate_fn=dataset['data_collator'], batch_size=training_args.per_device_train_batch_size ) if dataset['train_dataset'] is not None else None - eval_dataloader = torch.utils.data.DataLoader( - dataset['eval_dataset'], - shuffle=True, - collate_fn=dataset['data_collator'], - batch_size=training_args.per_device_train_batch_size - ) if dataset['eval_dataset'] is not None else None + if training_args.do_eval: + eval_dataloader = torch.utils.data.DataLoader( + dataset['eval_dataset'], + shuffle=True, + collate_fn=dataset['data_collator'], + batch_size=training_args.per_device_train_batch_size + ) dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount() steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1 @@ -195,17 +196,17 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T if global_step > 0: if global_step % training_args.save_steps == 0: save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints) - if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0: - evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt) + if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0: + evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) if training_args.flush_allocator: gc.collect() torch.cuda.empty_cache() if training_args.do_eval and training_args.eval_steps == -1: - evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt) + evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) # Evaluation if training_args.do_eval: - evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt) + evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt) save_model(model.model, global_step, training_args.output_dir)