Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| def edm_sampler( | |
| net, | |
| x_N, | |
| conditioning=None, | |
| latents=None, | |
| randn_like=torch.randn_like, | |
| num_steps=18, | |
| sigma_min=0.002, | |
| sigma_max=80, | |
| rho=7, | |
| S_churn=0, | |
| S_min=0, | |
| S_max=float("inf"), | |
| S_noise=1, | |
| ): | |
| # Adjust noise levels based on what's supported by the network. | |
| sigma_min = max(sigma_min, net.sigma_min) | |
| sigma_max = min(sigma_max, net.sigma_max) | |
| # Time step discretization. | |
| step_indices = torch.arange(num_steps, dtype=torch.float64, device=x_N.device) | |
| t_steps = ( | |
| sigma_max ** (1 / rho) | |
| + step_indices | |
| / (num_steps - 1) | |
| * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) | |
| ) ** rho | |
| t_steps = torch.cat( | |
| [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] | |
| ) # t_N = 0 | |
| # Main sampling loop. | |
| x_next = x_N.to(torch.float64) * t_steps[0] | |
| for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 | |
| x_cur = x_next | |
| # Increase noise temporarily. | |
| gamma = ( | |
| min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 | |
| ) | |
| t_hat = net.round_sigma(t_cur + gamma * t_cur) | |
| x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) | |
| # Euler step. | |
| denoised, latents = net( | |
| x_hat, t_hat.expand(x_cur.shape[0]), conditioning, previous_latents=latents | |
| ) | |
| denoised = denoised.to(torch.float64) | |
| d_cur = (x_hat - denoised) / t_hat | |
| x_next = x_hat + (t_next - t_hat) * d_cur | |
| # Apply 2nd order correction. | |
| if i < num_steps - 1: | |
| denoised, latents = net( | |
| x_next, | |
| t_next.expand(x_cur.shape[0]), | |
| conditioning, | |
| previous_latents=latents, | |
| ) | |
| denoised = denoised.to(torch.float64) | |
| d_prime = (x_next - denoised) / t_next | |
| x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) | |
| return x_next | |