|
|
@ -576,7 +576,7 @@ class WorldModel(DetectionModel): |
|
|
|
text_token = clip.tokenize(text).to(device) |
|
|
|
text_token = clip.tokenize(text).to(device) |
|
|
|
txt_feats = model.encode_text(text_token).to(dtype=torch.float32) |
|
|
|
txt_feats = model.encode_text(text_token).to(dtype=torch.float32) |
|
|
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) |
|
|
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) |
|
|
|
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]) |
|
|
|
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach() |
|
|
|
self.model[-1].nc = len(text) |
|
|
|
self.model[-1].nc = len(text) |
|
|
|
|
|
|
|
|
|
|
|
def init_criterion(self): |
|
|
|
def init_criterion(self): |
|
|
|