You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
1.1 KiB
27 lines
1.1 KiB
2 years ago
|
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
|
||
|
|
||
|
|
||
|
def get_tokenlizer(text_encoder_type):
|
||
|
if not isinstance(text_encoder_type, str):
|
||
|
# print("text_encoder_type is not a str")
|
||
|
if hasattr(text_encoder_type, "text_encoder_type"):
|
||
|
text_encoder_type = text_encoder_type.text_encoder_type
|
||
|
elif text_encoder_type.get("text_encoder_type", False):
|
||
|
text_encoder_type = text_encoder_type.get("text_encoder_type")
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
|
||
|
)
|
||
|
print("final text_encoder_type: {}".format(text_encoder_type))
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
|
||
|
return tokenizer
|
||
|
|
||
|
|
||
|
def get_pretrained_language_model(text_encoder_type):
|
||
|
if text_encoder_type == "bert-base-uncased":
|
||
|
return BertModel.from_pretrained(text_encoder_type)
|
||
|
if text_encoder_type == "roberta-base":
|
||
|
return RobertaModel.from_pretrained(text_encoder_type)
|
||
|
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
|