Fix group_texts not grouping texts to a single length when the number of samples is less than the number of threads used
This commit is contained in:
		
							parent
							
								
									bc5321cb33
								
							
						
					
					
						commit
						0b39ba0843
					
				
					 2 changed files with 10 additions and 3 deletions
				
			
		| 
						 | 
					@ -4,6 +4,7 @@ import typing
 | 
				
			||||||
import datasets
 | 
					import datasets
 | 
				
			||||||
import itertools
 | 
					import itertools
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from torch.nn.utils.rnn import pad_sequence
 | 
					from torch.nn.utils.rnn import pad_sequence
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -237,7 +238,7 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
 | 
				
			||||||
    train_dataset_tokenized = train_dataset_tokenized.map(
 | 
					    train_dataset_tokenized = train_dataset_tokenized.map(
 | 
				
			||||||
        lambda example: group_texts(example, data_args.block_size),
 | 
					        lambda example: group_texts(example, data_args.block_size),
 | 
				
			||||||
        batched=True,
 | 
					        batched=True,
 | 
				
			||||||
        num_proc=32,
 | 
					        num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
 | 
				
			||||||
        load_from_cache_file=True,
 | 
					        load_from_cache_file=True,
 | 
				
			||||||
        desc=f"Grouping texts in chunks of {data_args.block_size}")
 | 
					        desc=f"Grouping texts in chunks of {data_args.block_size}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -251,10 +252,15 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
 | 
				
			||||||
        eval_dataset_tokenized = eval_dataset_tokenized.map(
 | 
					        eval_dataset_tokenized = eval_dataset_tokenized.map(
 | 
				
			||||||
            lambda example: group_texts(example, data_args.block_size),
 | 
					            lambda example: group_texts(example, data_args.block_size),
 | 
				
			||||||
            batched=True,
 | 
					            batched=True,
 | 
				
			||||||
            num_proc=32,
 | 
					            num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
 | 
				
			||||||
            load_from_cache_file=True,
 | 
					            load_from_cache_file=True,
 | 
				
			||||||
            desc=f"Grouping texts in chunks of {data_args.block_size}")
 | 
					            desc=f"Grouping texts in chunks of {data_args.block_size}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for ids in train_dataset_tokenized['input_ids']:
 | 
				
			||||||
 | 
					        assert len(ids) == data_args.block_size
 | 
				
			||||||
 | 
					    for ids in eval_dataset_tokenized['input_ids']:
 | 
				
			||||||
 | 
					        assert len(ids) == data_args.block_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return dict(
 | 
					    return dict(
 | 
				
			||||||
        train_dataset=train_dataset_tokenized if do_train else None,
 | 
					        train_dataset=train_dataset_tokenized if do_train else None,
 | 
				
			||||||
        eval_dataset=eval_dataset_tokenized if do_eval else None,
 | 
					        eval_dataset=eval_dataset_tokenized if do_eval else None,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -101,7 +101,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
 | 
					    tokenizer = get_tokenizer(model.model, training_args.cache_dir, model_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if data_args.dataset.endswith("json"):
 | 
					    if data_args.dataset.endswith("json") or data_args.dataset.endswith("jsonl"):
 | 
				
			||||||
        print("Loading dataset in s2s mode")
 | 
					        print("Loading dataset in s2s mode")
 | 
				
			||||||
        data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
					        data_module = create_data_module_s2s(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
				
			||||||
    elif data_args.data_from_hub:
 | 
					    elif data_args.data_from_hub:
 | 
				
			||||||
| 
						 | 
					@ -109,6 +109,7 @@ def train(model_args: ModelArguments, data_args: DataArguments, training_args: T
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        print("Loading dataset in txt mode")
 | 
					        print("Loading dataset in txt mode")
 | 
				
			||||||
        data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
					        data_module = create_data_module(tokenizer, data_args, training_args.do_train, training_args.do_eval, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
 | 
					    dataset = {k: v for k, v in data_module.items() if k != 'predict_dataset'}
 | 
				
			||||||
    train_dataloader = torch.utils.data.DataLoader(
 | 
					    train_dataloader = torch.utils.data.DataLoader(
 | 
				
			||||||
        dataset['train_dataset'],
 | 
					        dataset['train_dataset'],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue