Add chat datamodules
This commit is contained in:
		
							parent
							
								
									0b39ba0843
								
							
						
					
					
						commit
						2f35689355
					
				
					 2 changed files with 148 additions and 28 deletions
				
			
		
							
								
								
									
										53
									
								
								arguments.py
									
										
									
									
									
								
							
							
						
						
									
										53
									
								
								arguments.py
									
										
									
									
									
								
							| 
						 | 
					@ -1,11 +1,52 @@
 | 
				
			||||||
from dataclasses import dataclass, field
 | 
					from dataclasses import dataclass, field
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional, Self
 | 
				
			||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DatasetType(Enum):
 | 
				
			||||||
 | 
					    TEXT = 1
 | 
				
			||||||
 | 
					    S2S = 2
 | 
				
			||||||
 | 
					    HUB = 3
 | 
				
			||||||
 | 
					    CHAT = 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def to_string(dtype: Self) -> str:
 | 
				
			||||||
 | 
					        if dtype == DatasetType.TEXT:
 | 
				
			||||||
 | 
					            return "text"
 | 
				
			||||||
 | 
					        elif dtype == DatasetType.S2S:
 | 
				
			||||||
 | 
					            return "s2s"
 | 
				
			||||||
 | 
					        elif dtype == DatasetType.HUB:
 | 
				
			||||||
 | 
					            return "hub"
 | 
				
			||||||
 | 
					        elif dtype == DatasetType.CHAT:
 | 
				
			||||||
 | 
					            return "chat"
 | 
				
			||||||
 | 
					        return "invalid"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def from_string(string: str):
 | 
				
			||||||
 | 
					        if string == str(DatasetType.TEXT):
 | 
				
			||||||
 | 
					            return DatasetType.TEXT
 | 
				
			||||||
 | 
					        elif string == str(DatasetType.S2S):
 | 
				
			||||||
 | 
					            return DatasetType.S2S
 | 
				
			||||||
 | 
					        elif string == str(DatasetType.HUB):
 | 
				
			||||||
 | 
					            return DatasetType.HUB
 | 
				
			||||||
 | 
					        elif string == str(DatasetType.CHAT):
 | 
				
			||||||
 | 
					            return DatasetType.CHAT
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __str__(self):
 | 
				
			||||||
 | 
					        return DatasetType.to_string(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class DataArguments:
 | 
					class DataArguments:
 | 
				
			||||||
    dataset: str = field(
 | 
					    dataset: str = field(
 | 
				
			||||||
        metadata={"help": "A json file (s2s) or text file with the dataset to train on"}
 | 
					        metadata={"help": "The dataset to train on"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    dataset_type: str = field(
 | 
				
			||||||
 | 
					        default="text", metadata={"help": f"The type of dataset, set to one of {[e for e in DatasetType]}"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    dataset_chat_template: str | None = field(
 | 
				
			||||||
 | 
					        default=None, metadata={"help": "overrides the chat template to be the one set here"}
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    eval_dataset_size: int = field(
 | 
					    eval_dataset_size: int = field(
 | 
				
			||||||
        default=512, metadata={"help": "Size of validation dataset."}
 | 
					        default=512, metadata={"help": "Size of validation dataset."}
 | 
				
			||||||
| 
						 | 
					@ -26,10 +67,6 @@ class DataArguments:
 | 
				
			||||||
        default=False,
 | 
					        default=False,
 | 
				
			||||||
        metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
 | 
					        metadata={"help": "If this is set the dataset is assumed to be a name of a hf-hub dataset"}
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    block_size: int = field(
 | 
					 | 
				
			||||||
        default=512,
 | 
					 | 
				
			||||||
        metadata={"help": "size of the blocks the text is split into for training"},
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
| 
						 | 
					@ -65,8 +102,9 @@ class TrainingArguments():
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
 | 
					    resume: bool = field(default=False, metadata={"help": 'Resume from previous checkpoint'})
 | 
				
			||||||
    ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
 | 
					    ddp_find_unused_parameters: bool = field(default=True, metadata={"help": 'set if trainer should try to find unused parameters'})
 | 
				
			||||||
    output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
 | 
					    output_dir: str = field(default='./output', metadata={"help": 'The output dir for checkpoints'})
 | 
				
			||||||
    per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
 | 
					    per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
 | 
				
			||||||
 | 
					    per_device_eval_batch_size: int = field(default=1, metadata={"help": 'The eval batch size per GPU. Increase for better speed.'})
 | 
				
			||||||
    gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
 | 
					    gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
 | 
				
			||||||
    epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'})
 | 
					    epochs: int = field(default=3, metadata={"help": 'How many epochs to train for'})
 | 
				
			||||||
    weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'})
 | 
					    weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'})
 | 
				
			||||||
| 
						 | 
					@ -82,6 +120,7 @@ class TrainingArguments():
 | 
				
			||||||
                                   metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
 | 
					                                   metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
 | 
				
			||||||
    warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'})
 | 
					    warmup_steps: float = field(default=0, metadata={"help": 'number of steps to do a warmup for'})
 | 
				
			||||||
    logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
 | 
					    logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
 | 
				
			||||||
 | 
					    logging_dir: str = field(default='./log', metadata={"help": 'The output dir for logs'})
 | 
				
			||||||
    group_by_length: bool = field(default=False,
 | 
					    group_by_length: bool = field(default=False,
 | 
				
			||||||
                                  metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
 | 
					                                  metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
 | 
				
			||||||
    save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
 | 
					    save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										123
									
								
								datamodules.py
									
										
									
									
									
								
							
							
						
						
									
										123
									
								
								datamodules.py
									
										
									
									
									
								
							| 
						 | 
					@ -7,22 +7,23 @@ import transformers
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from torch.nn.utils.rnn import pad_sequence
 | 
					from torch.nn.utils.rnn import pad_sequence
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from arguments import DataArguments
 | 
					from arguments import DataArguments, DatasetType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
IGNORE_INDEX = -100
 | 
					IGNORE_INDEX = -100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def group_texts(examples, block_size: int):
 | 
					def group_texts(examples, source_max_len: int):
 | 
				
			||||||
    # Concatenate all texts.
 | 
					    # Concatenate all texts.
 | 
				
			||||||
    concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
 | 
					    concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
 | 
				
			||||||
    total_length = len(concatenated_examples[list(examples.keys())[0]])
 | 
					    total_length = len(concatenated_examples[list(examples.keys())[0]])
 | 
				
			||||||
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
 | 
					    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
 | 
				
			||||||
    # customize this part to your needs.
 | 
					    # customize this part to your needs.
 | 
				
			||||||
    if total_length >= block_size:
 | 
					    if total_length >= source_max_len:
 | 
				
			||||||
        total_length = (total_length // block_size) * block_size
 | 
					        total_length = (total_length // source_max_len) * source_max_len
 | 
				
			||||||
    # Split by chunks of max_len.
 | 
					    # Split by chunks of max_len.
 | 
				
			||||||
    result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
 | 
					    result = {k: [t[i: i + source_max_len] for i in range(0, total_length, source_max_len)] for k, t in concatenated_examples.items()}
 | 
				
			||||||
    result["labels"] = result["input_ids"].copy()
 | 
					    result["labels"] = result["input_ids"].copy()
 | 
				
			||||||
    return result
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -199,14 +200,15 @@ def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_arg
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
 | 
					def create_data_module_txt(tokenizer: transformers.PreTrainedTokenizer,
 | 
				
			||||||
 | 
					                           data_args: DataArguments, do_train: bool, do_eval: bool, do_predict: bool) -> typing.Dict:
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
 | 
					        dataset = datasets.load_dataset('text', data_files={'train': [data_args.dataset]})
 | 
				
			||||||
    except FileNotFoundError as ex:
 | 
					    except FileNotFoundError as ex:
 | 
				
			||||||
        raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
 | 
					        raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if data_args.block_size > tokenizer.model_max_length:
 | 
					    if data_args.source_max_len > tokenizer.model_max_length:
 | 
				
			||||||
        raise ValueError(f"Block size of {data_args.block_size} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
 | 
					        raise ValueError(f"Max source length of {data_args.source_max_len} is larger than the maximum size supported by the model: {tokenizer.model_max_length}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_newline_fn(example):
 | 
					    def add_newline_fn(example):
 | 
				
			||||||
        example['text'] = example['text'] + '\n'
 | 
					        example['text'] = example['text'] + '\n'
 | 
				
			||||||
| 
						 | 
					@ -219,9 +221,7 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
 | 
				
			||||||
            eval_dataset = dataset['eval']
 | 
					            eval_dataset = dataset['eval']
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            print('Splitting train dataset in train and validation according to `eval_dataset_size`')
 | 
					            print('Splitting train dataset in train and validation according to `eval_dataset_size`')
 | 
				
			||||||
            dataset = dataset['train'].train_test_split(
 | 
					            dataset = dataset['train'].train_test_split(test_size=data_args.eval_dataset_size, shuffle=False)
 | 
				
			||||||
                test_size=data_args.eval_dataset_size, shuffle=True, seed=42
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            eval_dataset = dataset['test']
 | 
					            eval_dataset = dataset['test']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if 'train' in dataset:
 | 
					    if 'train' in dataset:
 | 
				
			||||||
| 
						 | 
					@ -233,14 +233,14 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
 | 
				
			||||||
        lambda example: tokenizer(example['text']),
 | 
					        lambda example: tokenizer(example['text']),
 | 
				
			||||||
        batched=True,
 | 
					        batched=True,
 | 
				
			||||||
        remove_columns='text',
 | 
					        remove_columns='text',
 | 
				
			||||||
        num_proc=32,
 | 
					        num_proc=os.cpu_count(),
 | 
				
			||||||
        load_from_cache_file=True)
 | 
					        load_from_cache_file=True)
 | 
				
			||||||
    train_dataset_tokenized = train_dataset_tokenized.map(
 | 
					    train_dataset_tokenized = train_dataset_tokenized.map(
 | 
				
			||||||
        lambda example: group_texts(example, data_args.block_size),
 | 
					        lambda example: group_texts(example, data_args.source_max_len),
 | 
				
			||||||
        batched=True,
 | 
					        batched=True,
 | 
				
			||||||
        num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
 | 
					        num_proc=max(1, min(os.cpu_count(), int(len(train_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))),
 | 
				
			||||||
        load_from_cache_file=True,
 | 
					        load_from_cache_file=True,
 | 
				
			||||||
        desc=f"Grouping texts in chunks of {data_args.block_size}")
 | 
					        desc=f"Grouping texts in chunks of {data_args.source_max_len}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    eval_dataset_tokenized = None
 | 
					    eval_dataset_tokenized = None
 | 
				
			||||||
    if eval_dataset is not None:
 | 
					    if eval_dataset is not None:
 | 
				
			||||||
| 
						 | 
					@ -248,18 +248,18 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
 | 
				
			||||||
            lambda example: tokenizer(example['text']),
 | 
					            lambda example: tokenizer(example['text']),
 | 
				
			||||||
            batched=True,
 | 
					            batched=True,
 | 
				
			||||||
            remove_columns='text',
 | 
					            remove_columns='text',
 | 
				
			||||||
            num_proc=32)
 | 
					            num_proc=os.cpu_count())
 | 
				
			||||||
        eval_dataset_tokenized = eval_dataset_tokenized.map(
 | 
					        eval_dataset_tokenized = eval_dataset_tokenized.map(
 | 
				
			||||||
            lambda example: group_texts(example, data_args.block_size),
 | 
					            lambda example: group_texts(example, data_args.source_max_len),
 | 
				
			||||||
            batched=True,
 | 
					            batched=True,
 | 
				
			||||||
            num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.block_size * 10)))),
 | 
					            num_proc=max(1, min(os.cpu_count(), int(len(eval_dataset_tokenized['input_ids']) / (data_args.source_max_len * 10)))),
 | 
				
			||||||
            load_from_cache_file=True,
 | 
					            load_from_cache_file=True,
 | 
				
			||||||
            desc=f"Grouping texts in chunks of {data_args.block_size}")
 | 
					            desc=f"Grouping texts in chunks of {data_args.source_max_len}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for ids in train_dataset_tokenized['input_ids']:
 | 
					    for ids in train_dataset_tokenized['input_ids']:
 | 
				
			||||||
        assert len(ids) == data_args.block_size
 | 
					        assert len(ids) == data_args.source_max_len
 | 
				
			||||||
    for ids in eval_dataset_tokenized['input_ids']:
 | 
					    for ids in eval_dataset_tokenized['input_ids']:
 | 
				
			||||||
        assert len(ids) == data_args.block_size
 | 
					        assert len(ids) == data_args.source_max_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return dict(
 | 
					    return dict(
 | 
				
			||||||
        train_dataset=train_dataset_tokenized if do_train else None,
 | 
					        train_dataset=train_dataset_tokenized if do_train else None,
 | 
				
			||||||
| 
						 | 
					@ -267,3 +267,84 @@ def create_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: D
 | 
				
			||||||
        predict_dataset=eval_dataset_tokenized if do_predict else None,
 | 
					        predict_dataset=eval_dataset_tokenized if do_predict else None,
 | 
				
			||||||
        data_collator=transformers.default_data_collator
 | 
					        data_collator=transformers.default_data_collator
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict):
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        dataset = datasets.Dataset.from_json(path_or_paths=data_args.dataset)
 | 
				
			||||||
 | 
					    except FileNotFoundError as ex:
 | 
				
			||||||
 | 
					        raise ValueError(f"Error loading dataset from {data_args.dataset}, {ex}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if data_args.dataset_chat_template is not None:
 | 
				
			||||||
 | 
					        tokenizer.chat_template = data_args.dataset_chat_template
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    target_len = data_args.source_max_len * 0.5
 | 
				
			||||||
 | 
					    grouped_chats = list()
 | 
				
			||||||
 | 
					    last_len = 0
 | 
				
			||||||
 | 
					    for row in tqdm(dataset, desc="Grouping chat messages"):
 | 
				
			||||||
 | 
					        content_length = len(tokenizer(row['content'])['input_ids'])
 | 
				
			||||||
 | 
					        if last_len + content_length <= target_len and len(grouped_chats) > 0:
 | 
				
			||||||
 | 
					            grouped_chats[-1]['chat'].append(row)
 | 
				
			||||||
 | 
					            last_len += content_length
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            last_len = 0
 | 
				
			||||||
 | 
					            grouped_chats.append({'chat': [row]})
 | 
				
			||||||
 | 
					    dataset = datasets.Dataset.from_list(grouped_chats)
 | 
				
			||||||
 | 
					    dataset = dataset.map(lambda x: {"text": tokenizer.apply_chat_template(x["chat"], tokenize=False, add_generation_prompt=False)})
 | 
				
			||||||
 | 
					    dataset.remove_columns('chat')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    eval_dataset = None
 | 
				
			||||||
 | 
					    if do_eval or do_predict:
 | 
				
			||||||
 | 
					        print('Splitting train dataset in train and validation according to `eval_dataset_size`')
 | 
				
			||||||
 | 
					        dataset_split = dataset.train_test_split(test_size=data_args.eval_dataset_size, shuffle=True)
 | 
				
			||||||
 | 
					        train_dataset = dataset_split["train"]
 | 
				
			||||||
 | 
					        eval_dataset = dataset_split["test"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    data_collator = DataCollatorForCausalLMText(
 | 
				
			||||||
 | 
					        tokenizer=tokenizer,
 | 
				
			||||||
 | 
					        max_len=data_args.source_max_len,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return dict(
 | 
				
			||||||
 | 
					        train_dataset=train_dataset if do_train else None,
 | 
				
			||||||
 | 
					        eval_dataset=eval_dataset,
 | 
				
			||||||
 | 
					        predict_dataset=eval_dataset,
 | 
				
			||||||
 | 
					        data_collator=data_collator
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_data_loaders(tokenizer, data_args: DataArguments, batch_size: int, eval_batch_size: int,
 | 
				
			||||||
 | 
					                     do_train: bool, do_eval: bool, do_predict: bool = False):
 | 
				
			||||||
 | 
					    data_type = DatasetType.from_string(data_args.dataset_type)
 | 
				
			||||||
 | 
					    if data_type == DatasetType.S2S:
 | 
				
			||||||
 | 
					        print("Loading dataset in s2s mode")
 | 
				
			||||||
 | 
					        data_module = create_data_module_s2s(tokenizer, data_args, do_train, do_eval, do_predict)
 | 
				
			||||||
 | 
					    elif data_type == DatasetType.HUB:
 | 
				
			||||||
 | 
					        print("Loading dataset from hub, expecting alpaca style")
 | 
				
			||||||
 | 
					        data_module = create_data_module_hub(tokenizer, data_args, do_train, do_eval, do_predict)
 | 
				
			||||||
 | 
					    elif data_type == DatasetType.TEXT:
 | 
				
			||||||
 | 
					        print("Loading dataset in txt mode")
 | 
				
			||||||
 | 
					        data_module = create_data_module_txt(tokenizer, data_args, do_train, do_eval, do_predict)
 | 
				
			||||||
 | 
					    elif data_type == DatasetType.CHAT:
 | 
				
			||||||
 | 
					        print("Loading dataset in chat mode")
 | 
				
			||||||
 | 
					        data_module = create_data_module_chat(tokenizer, data_args, do_train, do_eval, do_predict)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise RuntimeError("Unkown dataset type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_dataloader = None
 | 
				
			||||||
 | 
					    eval_dataloader = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if do_train:
 | 
				
			||||||
 | 
					        train_dataloader = torch.utils.data.DataLoader(
 | 
				
			||||||
 | 
					            data_module['train_dataset'],
 | 
				
			||||||
 | 
					            shuffle=True,
 | 
				
			||||||
 | 
					            collate_fn=data_module['data_collator'],
 | 
				
			||||||
 | 
					            batch_size=batch_size
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if do_eval:
 | 
				
			||||||
 | 
					        eval_dataloader = torch.utils.data.DataLoader(
 | 
				
			||||||
 | 
					            data_module['eval_dataset'],
 | 
				
			||||||
 | 
					            shuffle=True,
 | 
				
			||||||
 | 
					            collate_fn=data_module['data_collator'],
 | 
				
			||||||
 | 
					            batch_size=eval_batch_size
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    return train_dataloader, eval_dataloader
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue