支持本地的huggingface

chiebot_dev
captainfffsama 2 years ago
parent 9dac4c605b
commit 4f04acf4c8
  1. 1
      .gitignore
  2. 1
      demo/inference_on_a_image.py
  3. 4
      groundingdino/util/get_tokenlizer.py
  4. 2
      groundingdino/version.py

1
.gitignore vendored

@ -1,3 +1,4 @@
test_weight/
# IDE # IDE
.idea/ .idea/
.vscode/ .vscode/

@ -27,6 +27,7 @@ def plot_boxes_to_image(image_pil, tgt):
for box, label in zip(boxes, labels): for box, label in zip(boxes, labels):
# from 0..1 to 0..W, 0..H # from 0..1 to 0..W, 0..H
box = box * torch.Tensor([W, H, W, H]) box = box * torch.Tensor([W, H, W, H])
print("label:",label)
# from xywh to xyxy # from xywh to xyxy
box[:2] -= box[2:] / 2 box[:2] -= box[2:] / 2
box[2:] += box[:2] box[2:] += box[:2]

@ -1,3 +1,4 @@
import os
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
@ -23,4 +24,7 @@ def get_pretrained_language_model(text_encoder_type):
return BertModel.from_pretrained(text_encoder_type) return BertModel.from_pretrained(text_encoder_type)
if text_encoder_type == "roberta-base": if text_encoder_type == "roberta-base":
return RobertaModel.from_pretrained(text_encoder_type) return RobertaModel.from_pretrained(text_encoder_type)
if os.path.isdir(text_encoder_type):
return BertModel.from_pretrained(text_encoder_type)
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type)) raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))

@ -1 +1 @@
__version__ = "0.1.0" __version__ = '0.1.0'

Loading…
Cancel
Save