Spaces:
Build error
Build error
| """ | |
| Thanks to Freddy Boulton (https://github.com/freddyaboulton) for helping with this. | |
| """ | |
| import pickle | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from transformers import AutoModel | |
| from similarity_utils import BuildLSHTable | |
| seed = 42 | |
| # Only runs once when the script is first run. | |
| with open("lsh.pickle", "rb") as handle: | |
| loaded_lsh = pickle.load(handle) | |
| # Load model for computing embeddings. | |
| model_ckpt = "nateraw/vit-base-beans" | |
| model = AutoModel.from_pretrained(model_ckpt) | |
| lsh_builder = BuildLSHTable(model) | |
| lsh_builder.lsh = loaded_lsh | |
| # Candidate images. | |
| dataset = load_dataset("beans") | |
| candidate_dataset = dataset["train"].shuffle(seed=seed) | |
| def query(image, top_k): | |
| results = lsh_builder.query(image) | |
| # Should be a list of string file paths for gr.Gallery to work | |
| images = [] | |
| # List of labels for each image in the gallery | |
| labels = [] | |
| candidates = [] | |
| overlaps = [] | |
| for idx, r in enumerate(sorted(results, key=results.get, reverse=True)): | |
| if idx == top_k: | |
| break | |
| image_id, label = r.split("_")[0], r.split("_")[1] | |
| candidates.append(candidate_dataset[int(image_id)]["image"]) | |
| labels.append(label) | |
| overlaps.append(results[r]) | |
| candidates.insert(0, image) | |
| labels.insert(0, label) | |
| for i, candidate in enumerate(candidates): | |
| filename = f"{i}.png" | |
| candidate.save(filename) | |
| images.append(filename) | |
| # The gallery component can be a list of tuples, where the first element is a path to a file | |
| # and the second element is an optional caption for that image | |
| return list(zip(images, labels)) | |
| # You can set the type of gr.Image to be PIL, numpy or str (filepath) | |
| # Not sure what the best for this demo is. | |
| gr.Interface( | |
| query, | |
| inputs=[gr.Image(), gr.Slider(value=5, minimum=1, maximum=10, step=1)], | |
| outputs=gr.Gallery(), | |
| ).launch() | |