This repository has been archived by the owner on Dec 26, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
match.py
48 lines (38 loc) · 1.65 KB
/
match.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from PIL import Image
import numpy as np
import torch
import cn_clip.clip as clip
from cn_clip.clip import load_from_name
def openFromFile(paths: list[str]):
return list(map(lambda path: Image.open(path), paths))
def arrayToImage(array: np.ndarray):
return Image.fromarray(array)
# 假设 array 是一个形状为(3, H, W)的numpy数组
# return Image.fromarray((array * 255).astype(np.uint8).transpose(1, 2, 0))
def imageTextMatch(image: Image.Image, text: list[str]):
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using", device)
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
model.eval()
text = clip.tokenize(text).to(device)
image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
logits_per_image, logits_per_text = model.get_similarity(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
return probs[0]
def TextMatchImages(text: str, images: list[Image.Image]):
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using", device)
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
model.eval()
text = clip.tokenize([text]).to(device)
images = torch.stack(list(map(preprocess, images))).to(device)
with torch.no_grad():
logits_per_image, logits_per_text = model.get_similarity(images, text)
probs = logits_per_text.softmax(dim=-1).cpu().numpy()
return probs[0]
def classify(prompt: str, images: list[Image.Image]):
probs = TextMatchImages(prompt, images)
print("Label probs: ", probs)
ans = probs.argmax(axis=0)
return ans