add support for huggingfacehub datasets and for specificying a prompt for eval
This commit is contained in:
		
							parent
							
								
									8abea9ef89
								
							
						
					
					
						commit
						a74ef976e4
					
				
					 5 changed files with 183 additions and 43 deletions
				
			
		
							
								
								
									
										11
									
								
								arguments.py
									
										
									
									
									
								
							
							
						
						
									
										11
									
								
								arguments.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -19,6 +19,10 @@ class DataArguments:
 | 
			
		|||
        default=256,
 | 
			
		||||
        metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
 | 
			
		||||
    )
 | 
			
		||||
    data_from_hub: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
 | 
			
		||||
    )
 | 
			
		||||
    dataset: str = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
 | 
			
		||||
| 
						 | 
				
			
			@ -60,10 +64,6 @@ class TrainingArguments():
 | 
			
		|||
        default=False,
 | 
			
		||||
        metadata={"help": "Use 8-bit adam."}
 | 
			
		||||
    )
 | 
			
		||||
    report_to: str = field(
 | 
			
		||||
        default='none',
 | 
			
		||||
        metadata={"help": "To use wandb or something else for reporting."}
 | 
			
		||||
    )
 | 
			
		||||
    resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
 | 
			
		||||
    ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
 | 
			
		||||
    output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
 | 
			
		||||
| 
						 | 
				
			
			@ -85,7 +85,6 @@ class TrainingArguments():
 | 
			
		|||
    logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
 | 
			
		||||
    group_by_length: bool = field(default=False,
 | 
			
		||||
                                  metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
 | 
			
		||||
    storage_fp16: bool = field(default=False, metadata={"help": 'Store untrained layers in 16bit'})
 | 
			
		||||
    save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
 | 
			
		||||
    max_checkpoints: int = field(default=0, metadata={"help": 'the maximum amount of checkpoints to save'})
 | 
			
		||||
    save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
 | 
			
		||||
| 
						 | 
				
			
			@ -94,3 +93,5 @@ class TrainingArguments():
 | 
			
		|||
    max_instant_params: int = field(default=0, metadata={"help": "Maximum amount of paramters to optimize per step in millions"})
 | 
			
		||||
    churn_percent: int = field(default=100, metadata={"help": "The percentage of active parameters to replace when changeing active parameters"})
 | 
			
		||||
    eval_steps: int = field(default=-1, metadata={"help": "Number of optimization steps after wich to compute the evaluation loss"})
 | 
			
		||||
    eval_prompt: str = field(default=None, metadata={"help": "A prompt to used during eval to check if the model is learning"})
 | 
			
		||||
    reshufle_steps: int = field(default=50, metadata={"help": "Number of steps to take before changing the active parameters"})
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,7 +27,44 @@ def group_texts(examples, block_size: int):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataCollatorForCausalLM(object):
 | 
			
		||||
class DataCollatorForCausalLMText(object):
 | 
			
		||||
    tokenizer: transformers.PreTrainedTokenizer
 | 
			
		||||
    max_len: int
 | 
			
		||||
 | 
			
		||||
    def __call__(self, instances: typing.Sequence[typing.Dict]) -> typing.Dict[str, torch.Tensor]:
 | 
			
		||||
        # Extract elements
 | 
			
		||||
        examples = [f"{self.tokenizer.bos_token}{example['text']}{self.tokenizer.eos_token}" for example in instances]
 | 
			
		||||
        # Tokenize
 | 
			
		||||
        tokenized_examples = self.tokenizer(
 | 
			
		||||
            examples,
 | 
			
		||||
            max_length=self.max_len,
 | 
			
		||||
            truncation=True,
 | 
			
		||||
            add_special_tokens=False,
 | 
			
		||||
        )
 | 
			
		||||
        # Build the input and labels for causal LM
 | 
			
		||||
        input_ids = []
 | 
			
		||||
        for tokenized_example in tokenized_examples['input_ids']:
 | 
			
		||||
            input_ids.append(torch.tensor(tokenized_example))
 | 
			
		||||
        # Apply padding
 | 
			
		||||
        padding_value = None
 | 
			
		||||
        if self.tokenizer.pad_token_id is not None:
 | 
			
		||||
            padding_value = self.tokenizer.pad_token_id
 | 
			
		||||
        elif self.tokenizer.eos_token_id is not None:
 | 
			
		||||
            padding_value = self.tokenizer.eos_token_id
 | 
			
		||||
        else:
 | 
			
		||||
            raise RuntimeError("Model dose not have a pad or eos token")
 | 
			
		||||
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=padding_value)
 | 
			
		||||
 | 
			
		||||
        data_dict = {
 | 
			
		||||
            'input_ids': input_ids,
 | 
			
		||||
            'attention_mask': input_ids.ne(padding_value),
 | 
			
		||||
            'labels': input_ids
 | 
			
		||||
        }
 | 
			
		||||
        return data_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataCollatorForCausalLMs2s(object):
 | 
			
		||||
    tokenizer: transformers.PreTrainedTokenizer
 | 
			
		||||
    source_max_len: int
 | 
			
		||||
    target_max_len: int
 | 
			
		||||
| 
						 | 
				
			
			@ -102,7 +139,7 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg
 | 
			
		|||
                test_size=data_args.eval_dataset_size, shuffle=True, seed=42
 | 
			
		||||
            )
 | 
			
		||||
            eval_dataset = dataset['test']
 | 
			
		||||
            eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
 | 
			
		||||
        eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
 | 
			
		||||
 | 
			
		||||
    if 'train' in dataset:
 | 
			
		||||
        train_dataset = dataset['train']
 | 
			
		||||
| 
						 | 
				
			
			@ -111,7 +148,7 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg
 | 
			
		|||
 | 
			
		||||
    train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
 | 
			
		||||
 | 
			
		||||
    data_collator = DataCollatorForCausalLM(
 | 
			
		||||
    data_collator = DataCollatorForCausalLMs2s(
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        source_max_len=data_args.source_max_len,
 | 
			
		||||
        target_max_len=data_args.target_max_len,
 | 
			
		||||
| 
						 | 
				
			
			@ -127,6 +164,40 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train, do_eval, do_predict) -> typing.Dict:
 | 
			
		||||
    try:
 | 
			
		||||
        dataset = datasets.load_dataset(data_args.dataset)
 | 
			
		||||
    except FileNotFoundError as ex:
 | 
			
		||||
        raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
 | 
			
		||||
 | 
			
		||||
    if do_eval or do_predict:
 | 
			
		||||
        if 'eval' in dataset:
 | 
			
		||||
            eval_dataset = dataset['eval']
 | 
			
		||||
        else:
 | 
			
		||||
            print('Splitting train dataset in train and validation according to `eval_dataset_size`')
 | 
			
		||||
            dataset = dataset.train_test_split(
 | 
			
		||||
                test_size=data_args.eval_dataset_size, shuffle=True, seed=42
 | 
			
		||||
            )
 | 
			
		||||
            eval_dataset = dataset['test']
 | 
			
		||||
 | 
			
		||||
    if 'train' in dataset:
 | 
			
		||||
        train_dataset = dataset['train']
 | 
			
		||||
    else:
 | 
			
		||||
        train_dataset = dataset
 | 
			
		||||
 | 
			
		||||
    data_collator = DataCollatorForCausalLMText(
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        max_len=data_args.source_max_len,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return dict(
 | 
			
		||||
        train_dataset=train_dataset if do_train else None,
 | 
			
		||||
        eval_dataset=eval_dataset if do_eval else None,
 | 
			
		||||
        predict_dataset=eval_dataset if do_predict else None,
 | 
			
		||||
        data_collator=data_collator
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
 | 
			
		||||
    try:
 | 
			
		||||
        dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
 | 
			
		||||
| 
						 | 
				
			
			@ -147,7 +218,8 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
 | 
			
		|||
            eval_dataset = dataset['eval']
 | 
			
		||||
        else:
 | 
			
		||||
            print('Splitting train dataset in train and validation according to `eval_dataset_size`')
 | 
			
		||||
            dataset = dataset.train_test_split(
 | 
			
		||||
            breakpoint()
 | 
			
		||||
            dataset = dataset['train'].train_test_split(
 | 
			
		||||
                test_size=data_args.eval_dataset_size, shuffle=True, seed=42
 | 
			
		||||
            )
 | 
			
		||||
            eval_dataset = dataset['test']
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,14 +48,22 @@ class LinearGroup:
 | 
			
		|||
        for module in self.modules:
 | 
			
		||||
            module.decompress()
 | 
			
		||||
 | 
			
		||||
    def checkDistance(self) -> tuple[float, float]:
 | 
			
		||||
        distance_accum = 0.0
 | 
			
		||||
        error_accum = 0.0
 | 
			
		||||
    def getDistanceAndError(self) -> tuple[float, float]:
 | 
			
		||||
        distance_accum = torch.Tensor()
 | 
			
		||||
        error_accum = torch.Tensor()
 | 
			
		||||
        for module in self.modules:
 | 
			
		||||
            distance, error = module.checkDistance()
 | 
			
		||||
            distance_accum += distance**2
 | 
			
		||||
            error_accum += error**2
 | 
			
		||||
        return (math.sqrt(distance_accum) / math.sqrt(len(self.modules)), math.sqrt(error_accum) / math.sqrt(len(self.modules)))
 | 
			
		||||
            distance, error = module.getDistanceAndError()
 | 
			
		||||
            distance = distance.to("cpu")
 | 
			
		||||
            error = error.to("cpu")
 | 
			
		||||
            distance_accum = torch.cat((distance_accum, distance.reshape((distance.numel()))))
 | 
			
		||||
            error_accum = torch.cat((error_accum, error.reshape((error.numel()))))
 | 
			
		||||
        return (distance_accum, error_accum)
 | 
			
		||||
 | 
			
		||||
    def check(self) -> bool:
 | 
			
		||||
        for module in self.modules:
 | 
			
		||||
            if not module.check():
 | 
			
		||||
                return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DyntrainModel:
 | 
			
		||||
| 
						 | 
				
			
			@ -160,15 +168,18 @@ class DyntrainModel:
 | 
			
		|||
        total_params = self.dynamicParameters() + self.staticParameters()
 | 
			
		||||
        return sum(p.numel() for p in total_params if p.requires_grad)
 | 
			
		||||
 | 
			
		||||
    def reshuffleActive(self) -> None:
 | 
			
		||||
    def getDistanceAndErrorSample(self) -> (torch.Tensor, torch.Tensor):
 | 
			
		||||
        index = randint(0, len(self.active_linear_groups) - 1)
 | 
			
		||||
        return self.active_linear_groups[index].getDistanceAndError()
 | 
			
		||||
 | 
			
		||||
    def reshuffleActive(self):
 | 
			
		||||
        active_count = len(self.active_linear_groups)
 | 
			
		||||
        index = 0
 | 
			
		||||
        while len(self.active_linear_groups) > active_count * (1 - self.reshuffle_fraction):
 | 
			
		||||
            distance, error = self.active_linear_groups[index].checkDistance()
 | 
			
		||||
            print(f"linear group has moved {distance} with an error of {error}")
 | 
			
		||||
            group = self.active_linear_groups.pop(index)
 | 
			
		||||
            group.setFrozen(True)
 | 
			
		||||
            self.frozen_linear_groups.append(group)
 | 
			
		||||
            assert group.check()
 | 
			
		||||
 | 
			
		||||
        params = self.activeParameterCount()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -180,6 +191,7 @@ class DyntrainModel:
 | 
			
		|||
            group.setFrozen(False)
 | 
			
		||||
            params += group.paramCount()
 | 
			
		||||
            self.active_linear_groups.append(group)
 | 
			
		||||
            assert group.check()
 | 
			
		||||
        print(math.ceil(params / 1e6))
 | 
			
		||||
 | 
			
		||||
        active_params = self.activeParameterCount()
 | 
			
		||||
| 
						 | 
				
			
			@ -248,4 +260,8 @@ class DyntrainModel:
 | 
			
		|||
            group_index += 1
 | 
			
		||||
 | 
			
		||||
        for group in tqdm(linear_groups, desc="Perpareing layers"):
 | 
			
		||||
            group.compress()
 | 
			
		||||
            if group.isFrozen():
 | 
			
		||||
                group.compress()
 | 
			
		||||
            else:
 | 
			
		||||
                group.decompress()
 | 
			
		||||
            assert group.check()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										61
									
								
								modules.py
									
										
									
									
									
								
							
							
						
						
									
										61
									
								
								modules.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -35,7 +35,6 @@ class Linear(torch.nn.Linear):
 | 
			
		|||
                self.compress()
 | 
			
		||||
            else:
 | 
			
		||||
                self.decompress()
 | 
			
		||||
                self.weightStart = torch.Tensor(self.weight).clone().detach()
 | 
			
		||||
 | 
			
		||||
    def isFrozen(self) -> bool:
 | 
			
		||||
        return not self.weight.requires_grad
 | 
			
		||||
| 
						 | 
				
			
			@ -60,9 +59,15 @@ class Linear(torch.nn.Linear):
 | 
			
		|||
 | 
			
		||||
    @wraps(torch.nn.Module.to)
 | 
			
		||||
    def to(self, *args, **kwargs):
 | 
			
		||||
        breakpoint()
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def check(self) -> bool:
 | 
			
		||||
        if self.isFrozen() and self.weight.dtype != torch.float16:
 | 
			
		||||
            return False
 | 
			
		||||
        elif not self.isFrozen() and self.weight.dtype != torch.float32:
 | 
			
		||||
            return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DynamicConvertingLinear(Linear):
 | 
			
		||||
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -116,6 +121,7 @@ class DynamicQantizedLinear(Linear):
 | 
			
		|||
        self.bias_state = None
 | 
			
		||||
        self.block_size = 128
 | 
			
		||||
        self.quant_type = 'nf4'
 | 
			
		||||
        self.weight_start = self.weight.clone().detach()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def fromLinear(cls, in_module: torch.nn.Linear, active_device: torch.device, cold_device: torch.device,
 | 
			
		||||
| 
						 | 
				
			
			@ -125,6 +131,7 @@ class DynamicQantizedLinear(Linear):
 | 
			
		|||
                         compute_dtype=compute_dtype, output_device=output_device)
 | 
			
		||||
        new_module.weight = torch.nn.Parameter(in_module.weight.to(torch.float32).to(cold_device))
 | 
			
		||||
        new_module.bias = torch.nn.Parameter(in_module.bias.to(torch.float32).to(cold_device)) if new_module.bias is not None else None
 | 
			
		||||
        new_module.weight_start = new_module.weight.clone().detach()
 | 
			
		||||
        return new_module
 | 
			
		||||
 | 
			
		||||
    def compress(self) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -134,26 +141,27 @@ class DynamicQantizedLinear(Linear):
 | 
			
		|||
            bias = self.bias.contiguous().to(torch.float16).cuda(self.active_device)
 | 
			
		||||
            self.bias_quantized, self.bias_state = bnb.functional.quantize_blockwise(bias, blocksize=self.block_size)
 | 
			
		||||
 | 
			
		||||
        weight = torch.nn.Parameter(self.weight.to(self.cold_device))
 | 
			
		||||
        bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
 | 
			
		||||
        frozen = self.isFrozen()
 | 
			
		||||
        self.weight = torch.nn.Parameter(self.weight.to(self.cold_device))
 | 
			
		||||
        self.bias = torch.nn.Parameter(self.bias.to(self.cold_device)) if self.bias is not None else None
 | 
			
		||||
        self.setFrozen(frozen, False)
 | 
			
		||||
 | 
			
		||||
    def decompress(self) -> None:
 | 
			
		||||
        if self.weight_quantized is None:
 | 
			
		||||
            raise RuntimeError("decompress() called in quantized stated before quantized weights are avialable")
 | 
			
		||||
        dtype = self.weight.dtype
 | 
			
		||||
        self.weight = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(dtype).to(self.active_device))
 | 
			
		||||
        self.weight_quantized = None
 | 
			
		||||
        self.weight_state = None
 | 
			
		||||
        self.bias_quantized = None
 | 
			
		||||
        self.bias_state = None
 | 
			
		||||
        self.weight_start = self.weight.clone().detach().to(self.cold_device)
 | 
			
		||||
        self.weight = torch.nn.Parameter(self.weight.to(self.active_device))
 | 
			
		||||
        if self.bias_quantized:
 | 
			
		||||
            self.bias = torch.nn.Parameter(bnb.functional.dequantize_blockwise(self.bias_quantized, self.bias_state).to(dtype).to(self.active_device))
 | 
			
		||||
            self.bias = torch.nn.Parameter(self.bias.to(self.active_device))
 | 
			
		||||
 | 
			
		||||
    def checkDistance(self) -> tuple[float, float]:
 | 
			
		||||
        if self.weight_quantized is None:
 | 
			
		||||
            raise RuntimeError("checkDistance() called without quantized weights avialable")
 | 
			
		||||
    def getDistanceAndError(self) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        original_weight = self.weight.contiguous().to(self.active_device).to(torch.float16)
 | 
			
		||||
        quantized_original_weight, quantized_original_state = bnb.functional.quantize_blockwise(original_weight, blocksize=self.block_size)
 | 
			
		||||
        dequantized_original_weight = bnb.functional.dequantize_blockwise(self.quantized_original_weight, self.quantized_original_state).to(original_weight.dtype)
 | 
			
		||||
        dequantized_weight = bnb.functional.dequantize_blockwise(self.weight_quantized, self.weight_state).to(original_weight.dtype)
 | 
			
		||||
        distance = (torch.linalg.vector_norm(dequantized_original_weight - dequantized_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
 | 
			
		||||
        error = (torch.linalg.vector_norm(dequantized_original_weight - original_weight).to(torch.float32) / dequantized_original_weight.numel()).item()
 | 
			
		||||
        dequantized_original_weight = bnb.functional.dequantize_blockwise(quantized_original_weight, quantized_original_state).to(original_weight.dtype)
 | 
			
		||||
        distance = (self.weight_start - self.weight.to(self.cold_device)).to(torch.float32)
 | 
			
		||||
        error = (dequantized_original_weight - original_weight).to(torch.float32)
 | 
			
		||||
        return (distance, error)
 | 
			
		||||
 | 
			
		||||
    def setOutputDevice(self, output_device: torch.device):
 | 
			
		||||
| 
						 | 
				
			
			@ -200,3 +208,24 @@ class DynamicQantizedLinear(Linear):
 | 
			
		|||
            if not frozen:
 | 
			
		||||
                super().inplaceTo(device=device)
 | 
			
		||||
            self.setFrozen(frozen, False)
 | 
			
		||||
 | 
			
		||||
    def check(self) -> bool:
 | 
			
		||||
        if self.isFrozen():
 | 
			
		||||
            if torch.device(self.weight.device) != torch.device(self.cold_device):
 | 
			
		||||
                breakpoint()
 | 
			
		||||
                print("Frozen but not cold")
 | 
			
		||||
                return False
 | 
			
		||||
            if self.weight_quantized is None:
 | 
			
		||||
                breakpoint()
 | 
			
		||||
                print("Frozen but not quanted")
 | 
			
		||||
                return False
 | 
			
		||||
        else:
 | 
			
		||||
            if torch.device(self.weight.device) != torch.device(self.active_device):
 | 
			
		||||
                breakpoint()
 | 
			
		||||
                print("Active but not warm")
 | 
			
		||||
                return False
 | 
			
		||||
            if self.weight_quantized is not None:
 | 
			
		||||
                breakpoint()
 | 
			
		||||
                print("Active but still quantized")
 | 
			
		||||
                return False
 | 
			
		||||
        return True
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,9 +7,10 @@ import os
 | 
			
		|||
import shutil
 | 
			
		||||
import math
 | 
			
		||||
from tqdm.auto import tqdm
 | 
			
		||||
import gc
 | 
			
		||||
 | 
			
		||||
from arguments import DataArguments, ModelArguments, TrainingArguments
 | 
			
		||||
from datamodules import create_data_module_s2s, create_data_module
 | 
			
		||||
from datamodules import create_data_module_s2s, create_data_module, create_data_module_hub
 | 
			
		||||
from tokenizer import get_tokenizer
 | 
			
		||||
 | 
			
		||||
from dyntrainmodel import DyntrainModel
 | 
			
		||||
| 
						 | 
				
			
			@ -56,7 +57,9 @@ def get_optimizer(dyamic_parameters: list[torch.nn.parameter], static_parameters
 | 
			
		|||
    return optimizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> float:
 | 
			
		||||
def evaluate(model: DyntrainModel, tokenizer,
 | 
			
		||||
             dataloader: torch.utils.data.DataLoader, globalstep: int,
 | 
			
		||||
             log_writer: tensorboard.SummaryWriter, eval_prompt: str = None):
 | 
			
		||||
    print("*** Eval ***")
 | 
			
		||||
    loss = torch.zeros((1), device="cuda:0")
 | 
			
		||||
    model.model.eval()
 | 
			
		||||
| 
						 | 
				
			
			@ -66,8 +69,17 @@ def evaluate(model: DyntrainModel, dataloader: torch.utils.data.DataLoader) -> f
 | 
			
		|||
        outputs = model.model(**batch)
 | 
			
		||||
        loss += outputs.loss
 | 
			
		||||
    loss = loss / len(dataloader)
 | 
			
		||||
    log_writer.add_scalar("Loss/Eval", loss, globalstep)
 | 
			
		||||
    print(f"Eval Loss {loss.item()}")
 | 
			
		||||
 | 
			
		||||
    if eval_prompt is not None:
 | 
			
		||||
        input_ids = tokenizer(eval_prompt, return_tensors="pt").input_ids.to(model.devices[0])
 | 
			
		||||
        attention_mask = torch.ones(input_ids.shape, device=model.devices[0], requires_grad=False)
 | 
			
		||||
        outputs = model.generate(input_ids, attention_mask=attention_mask, do_sample=True, temperature=1, max_new_tokens=100)
 | 
			
		||||
        response_decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 | 
			
		||||
        print(f"Eval generation: response_decoded")
 | 
			
		||||
        log_writer.add_text("Text/Eval", response_decoded, globalstep)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments):
 | 
			
		||||
    log_writer = tensorboard.SummaryWriter()
 | 
			
		||||
| 
						 | 
				
			
			@ -90,6 +102,8 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
    if data_args.dataset.endswith("json"):
 | 
			
		||||
        print("Loading dataset in s2s mode")
 | 
			
		||||
        data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
			
		||||
    elif data_args.data_from_hub:
 | 
			
		||||
        data_module = create_data_module_hub(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading dataset in txt mode")
 | 
			
		||||
        data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
			
		||||
| 
						 | 
				
			
			@ -137,12 +151,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
            for step, batch in enumerate(train_dataloader):
 | 
			
		||||
                for key in batch:
 | 
			
		||||
                    batch[key] = batch[key].to("cuda:0")
 | 
			
		||||
 | 
			
		||||
                outputs = model.model(**batch)
 | 
			
		||||
                loss = outputs.loss / training_args.gradient_accumulation_steps
 | 
			
		||||
                log_writer.add_scalar("Loss/train", loss, global_step)
 | 
			
		||||
                loss.backward()
 | 
			
		||||
 | 
			
		||||
                if (step + 1) % training_args.gradient_accumulation_steps == 0 or step + 1 == len(train_dataloader):
 | 
			
		||||
                    if global_step % training_args.logging_steps == 0:
 | 
			
		||||
                        log_writer.add_scalar("Loss/train", loss, global_step)
 | 
			
		||||
                    optimizer.step()
 | 
			
		||||
                    lr_scheduler.step()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -151,9 +167,14 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
                    if global_step % 5 == 0:
 | 
			
		||||
                        print(f"Train Loss {loss.item()}")
 | 
			
		||||
 | 
			
		||||
                    if global_step % 50 == 0 and training_args.max_instant_params != 0:
 | 
			
		||||
                    if global_step % training_args.reshufle_steps == 0 and training_args.max_instant_params != 0:
 | 
			
		||||
                        print("Reshuffleing")
 | 
			
		||||
                        lr_scheduler.optimizer = None
 | 
			
		||||
                        del optimizer
 | 
			
		||||
                        # distance, error = model.getDistanceAndErrorSample()
 | 
			
		||||
                        # log_writer.add_histogram("Distances/Train", distance, max_bins=50)
 | 
			
		||||
                        # log_writer.add_histogram("Errors/Train", error, max_bins=50)
 | 
			
		||||
 | 
			
		||||
                        model.reshuffleActive()
 | 
			
		||||
                        model.balanceActive()
 | 
			
		||||
                        log_writer.add_scalar("Parameters/train", model.activeParameterCount(), global_step)
 | 
			
		||||
| 
						 | 
				
			
			@ -173,15 +194,16 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
			
		|||
                    if global_step % training_args.save_steps == 0:
 | 
			
		||||
                        save_model(model.model, global_step, training_args.output_dir, training_args.max_checkpoints)
 | 
			
		||||
                    if training_args.eval_steps > 0 and global_step % training_args.save_steps == 0:
 | 
			
		||||
                        evaluate(model, eval_dataloader)
 | 
			
		||||
                        evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
                if training_args.flush_allocator:
 | 
			
		||||
                    gc.collect()
 | 
			
		||||
                    torch.cuda.empty_cache()
 | 
			
		||||
            if training_args.do_eval and training_args.eval_steps == -1:
 | 
			
		||||
                evaluate(model, eval_dataloader)
 | 
			
		||||
                evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
 | 
			
		||||
    # Evaluation
 | 
			
		||||
    if training_args.do_eval:
 | 
			
		||||
        evaluate(model, eval_dataloader)
 | 
			
		||||
        evaluate(model, eval_dataloader, global_step, log_writer, training_args.eval_prompt)
 | 
			
		||||
 | 
			
		||||
    save_model(model.model, global_step, training_args.output_dir)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue