|
|
@ -1,3 +1,4 @@ |
|
|
|
|
|
|
|
import os |
|
|
|
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast |
|
|
|
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -23,4 +24,7 @@ def get_pretrained_language_model(text_encoder_type): |
|
|
|
return BertModel.from_pretrained(text_encoder_type) |
|
|
|
return BertModel.from_pretrained(text_encoder_type) |
|
|
|
if text_encoder_type == "roberta-base": |
|
|
|
if text_encoder_type == "roberta-base": |
|
|
|
return RobertaModel.from_pretrained(text_encoder_type) |
|
|
|
return RobertaModel.from_pretrained(text_encoder_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.isdir(text_encoder_type): |
|
|
|
|
|
|
|
return BertModel.from_pretrained(text_encoder_type) |
|
|
|
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type)) |
|
|
|
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type)) |
|
|
|