diff --git a/datamodules.py b/datamodules.py index e7ac922..d14c75b 100644 --- a/datamodules.py +++ b/datamodules.py @@ -135,7 +135,7 @@ def create_data_module_s2s(tokenizer: transformers.PreTrainedTokenizer, data_arg eval_dataset = dataset['eval'] else: print('Splitting train dataset in train and validation according to `eval_dataset_size`') - dataset = dataset.train_test_split( + dataset = dataset['train'].train_test_split( test_size=data_args.eval_dataset_size, shuffle=True, seed=42 ) eval_dataset = dataset['test'] @@ -175,7 +175,7 @@ def create_data_module_hub(tokenizer: transformers.PreTrainedTokenizer, data_arg eval_dataset = dataset['eval'] else: print('Splitting train dataset in train and validation according to `eval_dataset_size`') - dataset = dataset.train_test_split( + dataset = dataset['train'].train_test_split( test_size=data_args.eval_dataset_size, shuffle=True, seed=42 ) eval_dataset = dataset['test']