fix evaluate() invocation
This commit is contained in:
parent
ce2ada2617
commit
2ee4f294af
@ -59,7 +59,7 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters
|
||||
|
||||
def evaluate(model: DyntrainModel, tokenizer,
|
||||
dataloader: torch.utils.data.DataLoader, globalstep: int,
|
||||
log_writer: tensorboard.SummaryWriter, eval_prompt: str = None):
|
||||
log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
|
||||
print("*** Eval ***")
|
||||
loss = torch.zeros((1), device="cuda:0")
|
||||
model.model.eval()
|
||||
@ -116,12 +116,13 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
||||
collate_fn=dataset['data_collator'],
|
||||
batch_size=training_args.per_device_train_batch_size
|
||||
) if dataset['train_dataset'] is not None else None
|
||||
eval_dataloader = torch.utils.data.DataLoader(
|
||||
dataset['eval_dataset'],
|
||||
shuffle=True,
|
||||
collate_fn=dataset['data_collator'],
|
||||
batch_size=training_args.per_device_train_batch_size
|
||||
) if dataset['eval_dataset'] is not None else None
|
||||
if training_args.do_eval:
|
||||
eval_dataloader = torch.utils.data.DataLoader(
|
||||
dataset['eval_dataset'],
|
||||
shuffle=True,
|
||||
collate_fn=dataset['data_collator'],
|
||||
batch_size=training_args.per_device_train_batch_size
|
||||
)
|
||||
|
||||
dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
|
||||
steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
|
||||
@ -195,17 +196,17 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
|
||||
if global_step > 0:
|
||||
if global_step % training_args.save_steps == 0:
|
||||
save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
|
||||
if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0:
|
||||
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||
if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
|
||||
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||
if training_args.flush_allocator:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if training_args.do_eval and training_args.eval_steps == -1:
|
||||
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||
evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
|
||||
|
||||
save_model(model.model, global_step, training_args.output_dir)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user