fix evaluate() invocation
This commit is contained in:
		
							parent
							
								
									ce2ada2617
								
							
						
					
					
						commit
						2ee4f294af
					
				
					 1 changed files with 12 additions and 11 deletions
				
			
		| 
						 | 
				
			
			@ -59,7 +59,7 @@ def get_optimizer(dyamic_parameters: list[torch.nn.Parameter], static_parameters
 | 
			
		|||
 | 
			
		||||
def evaluate(model: DyntrainModel, tokenizer,
 | 
			
		||||
             dataloader: torch.utils.data.DataLoader, globalstep: int,
 | 
			
		||||
             log_writer: tensorboard.SummaryWriter, eval_prompt: str = None):
 | 
			
		||||
             log_writer: tensorboard.SummaryWriter, eval_prompt: str | None = None):
 | 
			
		||||
    print("*** Eval ***")
 | 
			
		||||
    loss = torch.zeros((1), device="cuda:0")
 | 
			
		||||
    model.model.eval()
 | 
			
		||||
| 
						 | 
				
			
			@ -116,12 +116,13 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
        collate_fn=dataset['data_collator'],
 | 
			
		||||
        batch_size=training_args.per_device_train_batch_size
 | 
			
		||||
    ) if dataset['train_dataset'] is not None else None
 | 
			
		||||
    eval_dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
        dataset['eval_dataset'],
 | 
			
		||||
        shuffle=True,
 | 
			
		||||
        collate_fn=dataset['data_collator'],
 | 
			
		||||
        batch_size=training_args.per_device_train_batch_size
 | 
			
		||||
    ) if dataset['eval_dataset'] is not None else None
 | 
			
		||||
    if training_args.do_eval:
 | 
			
		||||
        eval_dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
            dataset['eval_dataset'],
 | 
			
		||||
            shuffle=True,
 | 
			
		||||
            collate_fn=dataset['data_collator'],
 | 
			
		||||
            batch_size=training_args.per_device_train_batch_size
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    dynamic_param_ratio = (model.staticParameterCount() + model.dynamicParameterCount()) / model.dynamicParameterCount()
 | 
			
		||||
    steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if train_dataloader is not None else 1
 | 
			
		||||
| 
						 | 
				
			
			@ -195,17 +196,17 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
                if global_step > 0:
 | 
			
		||||
                    if global_step % training_args.save_steps == 0:
 | 
			
		||||
                        save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
 | 
			
		||||
                    if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0:
 | 
			
		||||
                        evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
                    if training_args.eval_steps > 0 and global_step % training_args.eval_steps == 0:
 | 
			
		||||
                        evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
                if training_args.flush_allocator:
 | 
			
		||||
                    gc.collect()
 | 
			
		||||
                    torch.cuda.empty_cache()
 | 
			
		||||
            if training_args.do_eval and training_args.eval_steps == -1:
 | 
			
		||||
                evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
                evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
 | 
			
		||||
    # Evaluation
 | 
			
		||||
    if training_args.do_eval:
 | 
			
		||||
        evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
        evaluate(model, tokenizer, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
 | 
			
		||||
    save_model(model.model, global_step, training_args.output_dir)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue