fix automatic train test spliting

This commit is contained in:
2024-05-08 22:07:11 +02:00
parent 65482b55a6
commit bc5321cb33

View File

@ -135,7 +135,7 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg
eval_dataset = dataset['eval'] eval_dataset = dataset['eval']
else: else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`') print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset.train_test_split( dataset = dataset['train'].train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42 test_size=data_args.eval_dataset_size, shuffle=True, seed=42
) )
eval_dataset = dataset['test'] eval_dataset = dataset['test']
@ -175,7 +175,7 @@ def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_arg
eval_dataset = dataset['eval'] eval_dataset = dataset['eval']
else: else:
print('Splitting train dataset in train and validation according to `eval_dataset_size`') print('Splitting train dataset in train and validation according to `eval_dataset_size`')
dataset = dataset.train_test_split( dataset = dataset['train'].train_test_split(
test_size=data_args.eval_dataset_size, shuffle=True, seed=42 test_size=data_args.eval_dataset_size, shuffle=True, seed=42
) )
eval_dataset = dataset['test'] eval_dataset = dataset['test']