Spaces:
Running
Running
| import datetime | |
| import json | |
| import os | |
| import pickle | |
| from functools import partial, wraps | |
| from pickle import dump, load | |
| from typing import Literal | |
| import click | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from sklearn.mixture import GaussianMixture | |
| from sklearn.model_selection import GridSearchCV | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.preprocessing import StandardScaler | |
| from torch.utils.data import Subset | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| import dnnlib | |
| from dataset import ImageFolderDataset | |
| from flowutils import PatchFlow, sanitize_locals | |
| from networks_edm2 import Precond | |
| DEVICE: Literal["cuda", "cpu"] = "cpu" | |
| model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions" | |
| config_presets = { | |
| "edm2-img64-s-fid": f"{model_root}/edm2-img64-s-1073741-0.075.pkl", # fid = 1.58 | |
| "edm2-img64-m-fid": f"{model_root}/edm2-img64-m-2147483-0.060.pkl", # fid = 1.43 | |
| "edm2-img64-l-fid": f"{model_root}/edm2-img64-l-1073741-0.040.pkl", # fid = 1.33 | |
| } | |
| class StandardRGBEncoder: | |
| def __init__(self): | |
| super().__init__() | |
| def encode(self, x): # raw pixels => final pixels | |
| return x.to(torch.float32) / 127.5 - 1 | |
| def decode(self, x): # final latents => raw pixels | |
| return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8) | |
| class EDMScorer(torch.nn.Module): | |
| def __init__( | |
| self, | |
| net, | |
| stop_ratio=0.8, # Maximum ratio of noise levels to compute | |
| num_steps=10, # Number of noise levels to evaluate. | |
| use_fp16=False, # Execute the underlying model at FP16 precision? | |
| sigma_min=0.002, # Minimum supported noise level. | |
| sigma_max=80, # Maximum supported noise level. | |
| sigma_data=0.5, # Expected standard deviation of the training data. | |
| rho=7, # Time step discretization. | |
| ): | |
| super().__init__() | |
| self.config = sanitize_locals(locals(), ignore_keys="net") | |
| self.config["EDMNet"] = dict(net.init_kwargs) | |
| self.use_fp16 = use_fp16 | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.sigma_data = sigma_data | |
| self.net = net.eval() | |
| self.encoder = StandardRGBEncoder() | |
| # Adjust noise levels based on how far we want to accumulate | |
| self.sigma_min = 1e-1 | |
| self.sigma_max = sigma_max * stop_ratio | |
| step_indices = torch.arange(num_steps, dtype=torch.float64) | |
| t_steps = ( | |
| self.sigma_max ** (1 / rho) | |
| + step_indices | |
| / (num_steps - 1) | |
| * (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho)) | |
| ) ** rho | |
| self.register_buffer("sigma_steps", t_steps.to(torch.float64)) | |
| def forward( | |
| self, | |
| x, | |
| force_fp32=False, | |
| ): | |
| x = self.encoder.encode(x).to(torch.float32) | |
| batch_scores = [] | |
| for sigma in self.sigma_steps: | |
| xhat = self.net(x, sigma, force_fp32=force_fp32) | |
| c_skip = self.net.sigma_data**2 / (sigma**2 + self.net.sigma_data**2) | |
| score = xhat - (c_skip * x) | |
| batch_scores.append(score) | |
| batch_scores = torch.stack(batch_scores, axis=1) | |
| return batch_scores | |
| class ScoreFlow(torch.nn.Module): | |
| def __init__(self, scorenet, device="cpu", **flow_kwargs): | |
| super().__init__() | |
| h = w = scorenet.net.img_resolution | |
| c = scorenet.net.img_channels | |
| num_sigmas = len(scorenet.sigma_steps) | |
| self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs) | |
| self.flow = self.flow.to(device) | |
| self.scorenet = scorenet.to(device).eval().requires_grad_(False) | |
| self.flow.init_weights() | |
| self.config = dict() | |
| self.config.update(**self.scorenet.config) | |
| self.config.update(self.flow.config) | |
| def forward(self, x, **score_kwargs): | |
| x_scores = self.scorenet(x, **score_kwargs) | |
| return self.flow(x_scores) | |
| def build_model_from_config(model_params): | |
| net = Precond(**model_params["EDMNet"]) | |
| scorenet = EDMScorer(net=net, **model_params["EDMScorer"]) | |
| scoreflow = ScoreFlow(scorenet=scorenet, **model_params["PatchFlow"]) | |
| print("Built model from config") | |
| return scoreflow | |
| def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"): | |
| netpath = config_presets[preset] | |
| with dnnlib.util.open_url(netpath, verbose=1) as f: | |
| data = pickle.load(f) | |
| net = data["ema"] | |
| model = EDMScorer(net, num_steps=20).to(device) | |
| return model | |
| def quantile_scorer(gmm, X, y=None): | |
| return np.quantile(gmm.score_samples(X), 0.1) | |
| def compute_gmm_likelihood(x_score, gmmdir): | |
| with open(f"{gmmdir}/gmm.pkl", "rb") as f: | |
| clf = load(f) | |
| nll = -clf.score_samples(x_score) | |
| with np.load(f"{gmmdir}/refscores.npz", "rb") as f: | |
| ref_nll = f["arr_0"] | |
| percentile = (ref_nll < nll).mean() | |
| return nll, percentile | |
| def test_runner(device="cpu"): | |
| # f = "doge.jpg" | |
| f = "goldfish.JPEG" | |
| image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS) | |
| image = np.array(image) | |
| image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) | |
| x = torch.from_numpy(image).unsqueeze(0).to(device) | |
| model = build_model_from_pickle(device=device) | |
| scores = model(x) | |
| return scores | |
| def test_flow_runner(preset, device="cpu", load_weights=None): | |
| # f = "doge.jpg" | |
| f = "goldfish.JPEG" | |
| image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS) | |
| image = np.array(image) | |
| image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) | |
| x = torch.from_numpy(image).unsqueeze(0).to(device) | |
| scorenet = build_model_from_pickle(preset) | |
| score_flow = ScoreFlow(scorenet, device=device) | |
| if load_weights is not None: | |
| score_flow.flow.load_state_dict(torch.load(load_weights)) | |
| heatmap = score_flow(x) | |
| print(heatmap.shape) | |
| heatmap = score_flow(x).detach().cpu().numpy() | |
| heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) * 255 | |
| im = PIL.Image.fromarray(heatmap[0, 0]) | |
| im.convert("RGB").save( | |
| "heatmap.png", | |
| ) | |
| return | |
| def cmdline(): | |
| global DEVICE | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def common_args(func): | |
| def wrapper(*args, **kwargs): | |
| return func(*args, **kwargs) | |
| return wrapper | |
| def train_gmm(preset, outdir, gridsearch=False, **kwargs): | |
| outdir = f"{outdir}/{preset}" | |
| score_path = f"{outdir}/imagenette_score_norms.pt" | |
| X = torch.load(score_path).numpy() | |
| print(f"Loaded score norms from: {score_path} - # Samples: {X.shape[0]}") | |
| gm = GaussianMixture( | |
| n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000 | |
| ) | |
| clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)]) | |
| if gridsearch: | |
| param_grid = dict( | |
| GMM__n_components=range(2, 11, 1), | |
| ) | |
| grid = GridSearchCV( | |
| estimator=clf, | |
| param_grid=param_grid, | |
| cv=5, | |
| n_jobs=2, | |
| verbose=1, | |
| scoring=quantile_scorer, | |
| ) | |
| grid_result = grid.fit(X) | |
| print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_)) | |
| print("-----" * 15) | |
| means = grid_result.cv_results_["mean_test_score"] | |
| stds = grid_result.cv_results_["std_test_score"] | |
| params = grid_result.cv_results_["params"] | |
| for mean, stdev, param in zip(means, stds, params): | |
| print("%f (%f) with: %r" % (mean, stdev, param)) | |
| clf = grid.best_estimator_ | |
| clf.fit(X) | |
| inlier_nll = -clf.score_samples(X) | |
| print("Saving reference inlier scores ... ") | |
| os.makedirs(outdir, exist_ok=True) | |
| with open(f"{outdir}/refscores.npz", "wb") as f: | |
| np.savez_compressed(f, inlier_nll) | |
| with open(f"{outdir}/gmm.pkl", "wb") as f: | |
| dump(clf, f, protocol=5) | |
| print("Saved GMM pickle.") | |
| def cache_score_norms(preset, dataset_path, outdir, batch_size): | |
| device = DEVICE | |
| dsobj = ImageFolderDataset(path=dataset_path, resolution=64) | |
| refimg, reflabel = dsobj[0] | |
| print(f"Loading dataset from {dataset_path}") | |
| print( | |
| f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}" | |
| ) | |
| dsloader = torch.utils.data.DataLoader( | |
| dsobj, batch_size=batch_size, num_workers=4, prefetch_factor=2 | |
| ) | |
| model = build_model_from_pickle(preset=preset, device=device) | |
| score_norms = [] | |
| for x, _ in tqdm(dsloader): | |
| s = model(x.to(device)) | |
| s = s.square().sum(dim=(2, 3, 4)) ** 0.5 | |
| score_norms.append(s.cpu()) | |
| score_norms = torch.cat(score_norms, dim=0) | |
| os.makedirs(f"{outdir}/{preset}/", exist_ok=True) | |
| with open(f"{outdir}/{preset}/imagenette_score_norms.pt", "wb") as f: | |
| torch.save(score_norms, f) | |
| print(f"Computed score norms for {score_norms.shape[0]} samples") | |
| def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs): | |
| print("using device:", DEVICE) | |
| device = DEVICE | |
| dsobj = ImageFolderDataset(path=dataset_path, resolution=64) | |
| print(f"Loaded {len(dsobj)} samples from {dataset_path}") | |
| # Subset of training dataset | |
| val_ratio = 0.1 | |
| train_len = int((1 - val_ratio) * len(dsobj)) | |
| val_len = len(dsobj) - train_len | |
| print( | |
| f"Generating train/test split with ratio={val_ratio} -> {train_len}/{val_len}..." | |
| ) | |
| train_ds = Subset(dsobj, range(train_len)) | |
| val_ds = Subset(dsobj, range(train_len, train_len + val_len)) | |
| trainiter = torch.utils.data.DataLoader( | |
| train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True | |
| ) | |
| testiter = torch.utils.data.DataLoader( | |
| val_ds, batch_size=batch_size * 2, num_workers=4, prefetch_factor=2 | |
| ) | |
| scorenet = build_model_from_pickle(preset) | |
| model = ScoreFlow(scorenet, device=device, **flow_kwargs) | |
| opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5) | |
| train_step = partial( | |
| PatchFlow.stochastic_step, | |
| flow_model=model.flow, | |
| opt=opt, | |
| train=True, | |
| n_patches=128, | |
| device=device, | |
| ) | |
| eval_step = partial( | |
| PatchFlow.stochastic_step, | |
| flow_model=model.flow, | |
| train=False, | |
| n_patches=256, | |
| device=device, | |
| ) | |
| experiment_dir = f"{outdir}/{preset}" | |
| os.makedirs(experiment_dir, exist_ok=True) | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M") | |
| writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}") | |
| with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f: | |
| json.dump(model.config, f, sort_keys=True, indent=4) | |
| with open(f"{experiment_dir}/config.json", "w") as f: | |
| json.dump(model.config, f, sort_keys=True, indent=4) | |
| # totaliters = int(epochs * train_len) | |
| pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?") | |
| step = 0 | |
| for e in pbar: | |
| for x, _ in trainiter: | |
| x = x.to(device) | |
| scores = model.scorenet(x) | |
| if step == 0: | |
| with torch.inference_mode(): | |
| val_loss = eval_step(scores, x) | |
| # Log details about model | |
| writer.add_graph( | |
| model.flow.flows, | |
| ( | |
| torch.zeros(1, scores.shape[1], device=device), | |
| torch.zeros( | |
| 1, | |
| model.flow.position_encoding.cached_penc.shape[-1], | |
| device=device, | |
| ), | |
| ), | |
| ) | |
| train_loss = train_step(scores, x) | |
| if (step + 1) % 10 == 0: | |
| prev_val_loss = val_loss | |
| val_loss = 0.0 | |
| with torch.inference_mode(): | |
| for i, (x, _) in enumerate(testiter): | |
| x = x.to(device) | |
| scores = model.scorenet(x) | |
| val_loss += eval_step(scores, x) | |
| break | |
| val_loss /= i + 1 | |
| writer.add_scalar("loss/val", train_loss, step) | |
| if val_loss < prev_val_loss: | |
| torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt") | |
| writer.add_scalar("loss/train", train_loss, step) | |
| pbar.set_description( | |
| f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}" | |
| ) | |
| step += 1 | |
| # Squeeze the juice | |
| best_ckpt = torch.load(f"{experiment_dir}/flow.pt") | |
| model.flow.load_state_dict(best_ckpt) | |
| pbar = tqdm(range(10), desc="(Tuning) Step:? - Loss: ?") | |
| for e in pbar: | |
| for x, _ in testiter: | |
| x = x.to(device) | |
| scores = model.scorenet(x) | |
| train_loss = train_step(scores, x) | |
| writer.add_scalar("loss/train", train_loss, step) | |
| pbar.set_description(f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}") | |
| step += 1 | |
| # Save final model | |
| torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt") | |
| writer.close() | |
| # cache_score_norms( | |
| # preset=preset, | |
| # dataset_path="/GROND_STOR/amahmood/datasets/img64/", | |
| # device="cuda", | |
| # ) | |
| # train_gmm( | |
| # f"out/msma/{preset}_imagenette_score_norms.pt", outdir=f"out/msma/{preset}" | |
| # ) | |
| # s = test_runner(device=device) | |
| # s = s.square().sum(dim=(2, 3, 4)) ** 0.5 | |
| # s = s.to("cpu").numpy() | |
| # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/") | |
| # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile") | |
| if __name__ == "__main__": | |
| cmdline() | |