From 64ea09a2bc7e06404bd48d064995d87b96566aa5 Mon Sep 17 00:00:00 2001 From: Milly Date: Wed, 14 Sep 2022 06:11:06 +0000 Subject: [PATCH 1/3] fix: change jupyter to an optional dependency --- data/dataset/jnd_dataset.py | 6 +++++- data/dataset/twoafc_dataset.py | 6 +++++- requirements-dev.txt | 2 ++ requirements.txt | 1 - test_dataset_model.py | 6 +++++- test_network.py | 6 +++++- util/visualizer.py | 6 +++++- 7 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 requirements-dev.txt diff --git a/data/dataset/jnd_dataset.py b/data/dataset/jnd_dataset.py index ef0d30f1..ac7efbb2 100644 --- a/data/dataset/jnd_dataset.py +++ b/data/dataset/jnd_dataset.py @@ -5,7 +5,11 @@ from PIL import Image import numpy as np import torch -from IPython import embed + +try: + from IPython import embed +except ModuleNotFoundError: + embed = lambda: None class JNDDataset(BaseDataset): def initialize(self, dataroot, load_size=64): diff --git a/data/dataset/twoafc_dataset.py b/data/dataset/twoafc_dataset.py index d5e7f86b..a6577826 100644 --- a/data/dataset/twoafc_dataset.py +++ b/data/dataset/twoafc_dataset.py @@ -5,7 +5,11 @@ from PIL import Image import numpy as np import torch -# from IPython import embed + +try: + from IPython import embed +except ModuleNotFoundError: + embed = lambda: None class TwoAFCDataset(BaseDataset): def initialize(self, dataroots, load_size=64): diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..9100805e --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,2 @@ +-r requirements.txt +jupyter diff --git a/requirements.txt b/requirements.txt index d219640d..ee7df1c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,3 @@ scikit-image>=0.13.0 opencv-python>=2.4.11 matplotlib>=1.5.1 tqdm>=4.28.1 -jupyter diff --git a/test_dataset_model.py b/test_dataset_model.py index d5339e29..104c8d74 100644 --- a/test_dataset_model.py +++ b/test_dataset_model.py @@ -2,7 +2,11 @@ import lpips from data import data_loader as dl import argparse -from IPython import embed + +try: + from IPython import embed +except ModuleNotFoundError: + embed = lambda: None parser = argparse.ArgumentParser() parser.add_argument('--dataset_mode', type=str, default='2afc', help='[2afc,jnd]') diff --git a/test_network.py b/test_network.py index dfc907a8..ccf9ac3d 100644 --- a/test_network.py +++ b/test_network.py @@ -1,6 +1,10 @@ import torch import lpips -from IPython import embed + +try: + from IPython import embed +except ModuleNotFoundError: + embed = lambda: None use_gpu = False # Whether to use GPU spatial = True # Return a spatial map of perceptual distance. diff --git a/util/visualizer.py b/util/visualizer.py index 499f9976..2d47bec7 100755 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -5,7 +5,11 @@ from . import html import matplotlib.pyplot as plt import math -# from IPython import embed + +try: + from IPython import embed +except ModuleNotFoundError: + embed = lambda: None def zoom_to_res(img,res=256,order=0,axis=0): # img 3xXxX From c5b22fe9ac6b48d608aff5e24d51538b0dd18402 Mon Sep 17 00:00:00 2001 From: Milly Date: Wed, 14 Sep 2022 06:25:20 +0000 Subject: [PATCH 2/3] fix: dynamically import matplotlib --- util/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/util.py b/util/util.py index 6d6367a9..11109c39 100644 --- a/util/util.py +++ b/util/util.py @@ -4,7 +4,6 @@ from PIL import Image import numpy as np import os -import matplotlib.pyplot as plt import torch def load_image(path): @@ -16,6 +15,7 @@ def load_image(path): import cv2 return cv2.imread(path)[:,:,::-1] else: + import matplotlib.pyplot as plt img = (255*plt.imread(path)[:,:,:3]).astype('uint8') return img From e2a32000c0710ec5e167d826ffc9d1322014d18f Mon Sep 17 00:00:00 2001 From: Milly Date: Wed, 14 Sep 2022 06:55:54 +0000 Subject: [PATCH 3/3] fix: fixed dependencies and added extra --- setup.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5f616bc8..016c293c 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,22 @@ packages=['lpips'], package_data={'lpips': ['weights/v0.0/*.pth','weights/v0.1/*.pth']}, include_package_data=True, - install_requires=["torch>=0.4.0", "torchvision>=0.2.1", "numpy>=1.14.3", "scipy>=1.0.1", "tqdm>=4.28.1"], + install_requires=[ + "torch>=0.4.0", + "torchvision>=0.2.1", + "numpy>=1.14.3", + "scipy>=1.0.1", + "scikit-image>=0.13.0", + "tqdm>=4.28.1", + ], + extras_require = { + "dev": ["jupyter"], + "loadimage": [ + "rawpy>=0.17.2", + "opencv-python>=2.4.11", + "matplotlib>=1.5.1", + ], + }, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: BSD License",