| | |
| | |
| |
|
| | import os |
| | import shutil |
| | from omegaconf import OmegaConf |
| | from cog import BasePredictor, Input, Path |
| |
|
| | from sampler import ResShiftSampler |
| |
|
| | class Predictor(BasePredictor): |
| | def setup(self) -> None: |
| | """Load the model into memory to make running multiple predictions efficient""" |
| | self.configs = { |
| | "realsr": OmegaConf.load('./configs/realsr_swinunet_realesrgan256_journal.yaml'), |
| | "bicsr": configs = OmegaConf.load('./configs/bicx4_swinunet_lpips.yaml'), |
| | } |
| |
|
| | def predict( |
| | self, |
| | image: Path = Input(description="Grayscale input image"), |
| | scale: int = Input(description="Factor to scale image by.", default=4), |
| | chop_size: int = Input( |
| | choices=[512, 256], description="Chopping forward.", default=512 |
| | ), |
| | task: str = Input( |
| | choices=["realsr", "bicsr"], |
| | description="Choose a task", |
| | default="realsr", |
| | ), |
| | seed: int = Input( |
| | description="Random seed. Leave blank to randomize the seed.", default=12345 |
| | ), |
| | ) -> Path: |
| | """Run a single prediction on the model""" |
| | if seed is None: |
| | seed = int.from_bytes(os.urandom(2), "big") |
| | print(f"Using seed: {seed}") |
| |
|
| | configs = self.configs[task] |
| |
|
| | if task == 'realsr': |
| | ckpt_path = f"weights/resshift_realsrx4_s4_v3.pth" |
| | configs.model.ckpt_path = ckpt_path |
| | else: |
| | ckpt_path = f"weights/resshift_bicsrx4_s4.pth" |
| | configs.model.ckpt_path = ckpt_path |
| | configs.diffusion.params.steps = 4 |
| | configs.diffusion.params.sf = scale |
| | configs.autoencoder.ckpt_path = f"weights/autoencoder_vq_f4.pth" |
| |
|
| | chop_stride = 448 if chop_size == 512 else 224 |
| |
|
| | resshift_sampler = ResShiftSampler( |
| | configs, |
| | sf=scale, |
| | chop_size=chop_size, |
| | chop_stride=chop_stride, |
| | chop_bs=1, |
| | use_amp=True, |
| | seed=seed, |
| | padding_offset=configs.model.params.get('lq_size', 64), |
| | ) |
| |
|
| | out_path = "out_dir" |
| | if os.path.exists(out_path): |
| | shutil.rmtree(out_path) |
| | resshift_sampler.inference( |
| | str(image), |
| | out_path, |
| | mask_path=None, |
| | bs=1, |
| | noise_repeat=False |
| | ) |
| |
|
| | out = "/tmp/out.png" |
| | shutil.copy(os.path.join(out_path, os.listdir(out_path)[0]), out) |
| |
|
| | return Path(out) |
| |
|