Spaces:
Running
Running
plotting a blended version of the heatmap
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ from pickle import load
|
|
| 4 |
import gradio as gr
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
import torch
|
| 8 |
|
| 9 |
from msma import ScoreFlow, config_presets
|
|
@@ -39,39 +40,50 @@ def plot_against_reference(nll, ref_nll):
|
|
| 39 |
fig.tight_layout()
|
| 40 |
return fig
|
| 41 |
|
| 42 |
-
|
| 43 |
-
def plot_heatmap(heatmap):
|
| 44 |
fig, ax = plt.subplots()
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
fig.tight_layout()
|
| 48 |
return fig
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
|
|
|
| 54 |
|
| 55 |
with torch.inference_mode():
|
|
|
|
| 56 |
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
|
| 57 |
-
img =
|
| 58 |
-
img = img.to(device)
|
| 59 |
model = load_model(modeldir='models', preset=preset, device=device)
|
|
|
|
|
|
|
|
|
|
| 60 |
x = model.scorenet(img)
|
| 61 |
x = x.square().sum(dim=(2, 3, 4)) ** 0.5
|
| 62 |
-
img_likelihood = model(img).cpu().numpy()
|
| 63 |
nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
|
| 64 |
-
|
| 65 |
outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
|
| 66 |
histplot = plot_against_reference(nll, ref_nll)
|
| 67 |
-
heatmapplot = plot_heatmap(img_likelihood)
|
| 68 |
|
| 69 |
return outstr, heatmapplot, histplot
|
| 70 |
|
| 71 |
|
| 72 |
demo = gr.Interface(
|
| 73 |
fn=run_inference,
|
| 74 |
-
inputs=["
|
| 75 |
outputs=["text",
|
| 76 |
gr.Plot(label="Anomaly Heatmap"),
|
| 77 |
gr.Plot(label="Comparing to Imagenette"),
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
| 7 |
+
import PIL.Image as Image
|
| 8 |
import torch
|
| 9 |
|
| 10 |
from msma import ScoreFlow, config_presets
|
|
|
|
| 40 |
fig.tight_layout()
|
| 41 |
return fig
|
| 42 |
|
| 43 |
+
def plot_heatmap(img: Image, heatmap: np.array):
|
|
|
|
| 44 |
fig, ax = plt.subplots()
|
| 45 |
+
cmap = plt.get_cmap("gist_heat")
|
| 46 |
+
h = heatmap[0,0].copy()
|
| 47 |
+
qmin, qmax = np.quantile(h, 0.5), np.quantile(h, 0.999)
|
| 48 |
+
h = np.clip(h, a_min=qmin, a_max=qmax)
|
| 49 |
+
h = (h-h.min()) / (h.max() - h.min())
|
| 50 |
+
h = cmap(h, bytes=True)[:,:,:3]
|
| 51 |
+
h = Image.fromarray(h).resize(img.size, resample=Image.Resampling.BILINEAR)
|
| 52 |
+
im = Image.blend(img, h, alpha=0.6)
|
| 53 |
+
im = ax.imshow(np.array(im))
|
| 54 |
+
# fig.colorbar(im)
|
| 55 |
+
# plt.grid(False)
|
| 56 |
+
# plt.axis("off")
|
| 57 |
fig.tight_layout()
|
| 58 |
return fig
|
| 59 |
|
| 60 |
+
def run_inference(input_img, preset="edm2-img64-s-fid", device="cuda"):
|
|
|
|
| 61 |
|
| 62 |
+
# img = center_crop_imagenet(64, img)
|
| 63 |
+
input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
|
| 64 |
|
| 65 |
with torch.inference_mode():
|
| 66 |
+
img = np.array(input_img)
|
| 67 |
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
|
| 68 |
+
img = img.float().to(device)
|
|
|
|
| 69 |
model = load_model(modeldir='models', preset=preset, device=device)
|
| 70 |
+
img_likelihood = model(img).cpu().numpy()
|
| 71 |
+
|
| 72 |
+
img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
|
| 73 |
x = model.scorenet(img)
|
| 74 |
x = x.square().sum(dim=(2, 3, 4)) ** 0.5
|
|
|
|
| 75 |
nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
|
| 76 |
+
|
| 77 |
outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
|
| 78 |
histplot = plot_against_reference(nll, ref_nll)
|
| 79 |
+
heatmapplot = plot_heatmap(input_img, img_likelihood)
|
| 80 |
|
| 81 |
return outstr, heatmapplot, histplot
|
| 82 |
|
| 83 |
|
| 84 |
demo = gr.Interface(
|
| 85 |
fn=run_inference,
|
| 86 |
+
inputs=[gr.Image(type='pil', label="Input Image")],
|
| 87 |
outputs=["text",
|
| 88 |
gr.Plot(label="Anomaly Heatmap"),
|
| 89 |
gr.Plot(label="Comparing to Imagenette"),
|