Inital commit
This commit is contained in:
62
tokenizer.py
Normal file
62
tokenizer.py
Normal file
@ -0,0 +1,62 @@
|
||||
import transformers
|
||||
|
||||
from arguments import ModelArguments
|
||||
|
||||
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
|
||||
|
||||
def smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict: dict,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
):
|
||||
"""Resize tokenizer and embedding.
|
||||
|
||||
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
||||
"""
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings_data = model.get_input_embeddings().weight.data
|
||||
output_embeddings_data = model.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
|
||||
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings_data[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
|
||||
def get_tokenizer(model, cache_dir, model_args: ModelArguments):
|
||||
print(f'Tokenizer: {model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path}')
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer if model_args.tokenizer is not None else model_args.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
eos_token="[EOS]",
|
||||
tokenizer_type='llama' if 'llama' in model_args.model_name_or_path else None,
|
||||
trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
if tokenizer._pad_token is None and not model_args.noresize:
|
||||
smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
if 'llama' in model_args.model_name_or_path or isinstance(tokenizer, transformers.LlamaTokenizer):
|
||||
# LLaMA tokenizer may not have correct special tokens set.
|
||||
# Check and add them if missing to prevent them from being parsed into different tokens.
|
||||
# Note that these are present in the vocabulary.
|
||||
# Note also that `model.config.pad_token_id` is 0 which corresponds to `<unk>` token.
|
||||
print('Adding special tokens.')
|
||||
tokenizer.add_special_tokens({
|
||||
"eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
|
||||
"bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
|
||||
"unk_token": tokenizer.convert_ids_to_tokens(
|
||||
model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id
|
||||
),
|
||||
})
|
||||
return tokenizer
|
Reference in New Issue
Block a user