From 44f41c5718817f3a37c4c1295ec8519182184151 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Fri, 4 Jun 2021 15:22:43 -0700 Subject: [PATCH] Add Cog configuration and scripts --- .gitignore | 1 - cog.yaml | 10 ++++++++++ predict.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/.gitignore b/.gitignore index b7c3883..b6df430 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,3 @@ *.caffemodel *.mat *.npy - diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..3282232 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,10 @@ +predict: "predict.py:Predictor" +build: + python_version: "3.8" + python_packages: + - "numpy==1.20.0" + - "opencv-python==4.5.1.48" + - "torch==1.8.0" + system_packages: + - "libgl1-mesa-dev" + - "libglib2.0-0" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..d524041 --- /dev/null +++ b/predict.py @@ -0,0 +1,47 @@ +import tempfile +from pathlib import Path +import cv2 +import numpy as np +import torch +import RRDBNet_arch as arch +import cog + +model_path = ( + "models/RRDB_ESRGAN_x4.pth" +) + + +class Predictor(cog.Predictor): + def setup(self): + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + else: + self.device = torch.device("cpu") + print("Loading model...") + self.model = arch.RRDBNet(3, 3, 64, 23, gc=32) + self.model.load_state_dict(torch.load(model_path), strict=True) + self.model.eval() + self.model = self.model.to(self.device) + + @cog.input("image", type=Path, help="Low-resolution input image") + def predict(self, image): + print("Reading input image...") + img = cv2.imread(str(image), cv2.IMREAD_COLOR) + img = img * 1.0 / 255 + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + img_LR = img.unsqueeze(0) + img_LR = img_LR.to(self.device) + + print("Upscaling...") + with torch.no_grad(): + output = ( + self.model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy() + ) + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round() + out_path = Path(tempfile.mkdtemp()) / "out.png" + + print("Saving result...") + cv2.imwrite(str(out_path), output) + + return out_path