diff --git a/README.md b/README.md index 8c185eb..5fb61d0 100755 --- a/README.md +++ b/README.md @@ -2,9 +2,11 @@ Official Pytorch implementation of CRAFT text detector | [Paper](https://arxiv.org/abs/1904.01941) | [Pretrained Model](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ) | [Supplementary](https://youtu.be/HI8MzpY8KMI) **[Youngmin Baek](mailto:youngmin.baek@navercorp.com), Bado Lee, Dongyoon Han, Sangdoo Yun, Hwalsuk Lee.** - + Clova AI Research, NAVER Corp. +**Packaged by [Ashish Jha](mailto:arj7192@gmail.com)** + ### Sample Results ### Overview @@ -16,15 +18,39 @@ PyTorch implementation for CRAFT text detector that effectively detect text area **13 Jun, 2019**: Initial update **20 Jul, 2019**: Added post-processing for polygon result **28 Sep, 2019**: Added the trained model on IC15 and the link refiner +**25 Jan, 2020**: Put it together as a PyPI package ## Getting started -### Install dependencies + +### Use it straight from PyPI +#### Installation +``` +pip install craft-text-detection +``` +#### Usage +``` +import craft +import cv2 +img = cv2.imread('/path/to/image/file') + +# run the detector +bboxes, polys, heatmap = craft.detect_text(img) + +# view the image with bounding boxes +img_boxed = craft.show_bounding_boxes(img, bboxes) +cv2.imshow('fig', img_boxed) + +# view detection heatmap +cv2.imshow('fig', heatmap) +``` + +### Use from source - install dependencies #### Requirements - PyTorch>=0.4.1 - torchvision>=0.2.1 - opencv-python>=3.4.2 -- check requiremtns.txt +- check requiremetns.txt ``` pip install -r requirements.txt ``` @@ -33,34 +59,13 @@ pip install -r requirements.txt The code for training is not included in this repository, and we cannot release the full training code for IP reason. -### Test instruction using pretrained model -- Download the trained models - - *Model name* | *Used datasets* | *Languages* | *Purpose* | *Model Link* | - | :--- | :--- | :--- | :--- | :--- | -General | SynthText, IC13, IC17 | Eng + MLT | For general purpose | [Click](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ) -IC15 | SynthText, IC15 | Eng | For IC15 only | [Click](https://drive.google.com/open?id=1i2R7UIUqmkUtF0jv_3MXTqmQ_9wuAnLf) -LinkRefiner | CTW1500 | - | Used with the General Model | [Click](https://drive.google.com/open?id=1XSaFwBkOaFOdtk4Ane3DFyJGPRw6v5bO) - -* Run with pretrained model -``` (with python 3.7) -python test.py --trained_model=[weightfile] --test_folder=[folder path to test images] -``` - -The result image and socre maps will be saved to `./result` by default. - -### Arguments -* `--trained_model`: pretrained model +### Arguments for detect_text * `--text_threshold`: text confidence threshold * `--low_text`: text low-bound score * `--link_threshold`: link confidence threshold -* `--cuda`: use cuda for inference (default:True) * `--canvas_size`: max image size for inference * `--mag_ratio`: image magnification ratio -* `--poly`: enable polygon type result -* `--show_time`: show processing time -* `--test_folder`: folder path to input images -* `--refine`: use link refiner for sentense-level dataset +* `--refine`: use link refiner for sentence-level dataset * `--refiner_model`: pretrained refiner model diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/craft/__init__.py b/craft/__init__.py new file mode 100644 index 0000000..36b144a --- /dev/null +++ b/craft/__init__.py @@ -0,0 +1,29 @@ +from craft_detector.run_craft import * +from craft_detector.downloader import * +import os + +if not os.path.isfile('/tmp/craft_mlt_25k.pth'): + print("downloading model") + download_file_from_google_drive('1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ', '/tmp/craft_mlt_25k.pth') + +net = CRAFT() +net.load_state_dict(copyStateDict(torch.load('/tmp/craft_mlt_25k.pth', map_location='cpu'))) +net.eval() +refine_net = None + + +def detect_text(img, text_threshold=0.7, link_threshold=0.4, low_text=0.4): + bboxes, polys, score_text = test_net(net, img, text_threshold, link_threshold, low_text, False, False, + refine_net) + return bboxes, polys, score_text + + +def show_bounding_boxes(img, bboxes): + img = np.array(img) + for i, box in enumerate(bboxes): + poly = np.array(box).astype(np.int32).reshape((-1)) + + poly = poly.reshape(-1, 2) + cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) + + return img diff --git a/craft_detector/__init__.py b/craft_detector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/craft_detector/basenet/__init__.py b/craft_detector/basenet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/craft_detector/basenet/vgg16_bn.py b/craft_detector/basenet/vgg16_bn.py new file mode 100644 index 0000000..f3f21a7 --- /dev/null +++ b/craft_detector/basenet/vgg16_bn.py @@ -0,0 +1,73 @@ +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.init as init +from torchvision import models +from torchvision.models.vgg import model_urls + +def init_weights(modules): + for m in modules: + if isinstance(m, nn.Conv2d): + init.xavier_uniform_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + +class vgg16_bn(torch.nn.Module): + def __init__(self, pretrained=True, freeze=True): + super(vgg16_bn, self).__init__() + model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') + vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(12): # conv2_2 + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 19): # conv3_3 + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(19, 29): # conv4_3 + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(29, 39): # conv5_3 + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + + # fc6, fc7 without atrous conv + self.slice5 = torch.nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), + nn.Conv2d(1024, 1024, kernel_size=1) + ) + + if not pretrained: + init_weights(self.slice1.modules()) + init_weights(self.slice2.modules()) + init_weights(self.slice3.modules()) + init_weights(self.slice4.modules()) + + init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 + + if freeze: + for param in self.slice1.parameters(): # only first conv + param.requires_grad= False + + def forward(self, X): + h = self.slice1(X) + h_relu2_2 = h + h = self.slice2(h) + h_relu3_2 = h + h = self.slice3(h) + h_relu4_3 = h + h = self.slice4(h) + h_relu5_3 = h + h = self.slice5(h) + h_fc7 = h + vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) + out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) + return out diff --git a/craft.py b/craft_detector/craft.py similarity index 97% rename from craft.py rename to craft_detector/craft.py index 27131df..1ce2997 100755 --- a/craft.py +++ b/craft_detector/craft.py @@ -8,7 +8,8 @@ import torch.nn as nn import torch.nn.functional as F -from basenet.vgg16_bn import vgg16_bn, init_weights +from craft_detector.basenet.vgg16_bn import vgg16_bn, init_weights + class double_conv(nn.Module): def __init__(self, in_ch, mid_ch, out_ch): diff --git a/craft_utils.py b/craft_detector/craft_utils.py similarity index 100% rename from craft_utils.py rename to craft_detector/craft_utils.py diff --git a/craft_detector/downloader.py b/craft_detector/downloader.py new file mode 100644 index 0000000..d84678f --- /dev/null +++ b/craft_detector/downloader.py @@ -0,0 +1,39 @@ +import requests + + +def download_file_from_google_drive(id, destination): + URL = "https://docs.google.com/uc?export=download" + + session = requests.Session() + + response = session.get(URL, params = { 'id' : id }, stream = True) + token = get_confirm_token(response) + + if token: + params = { 'id' : id, 'confirm' : token } + response = session.get(URL, params = params, stream = True) + + save_response_content(response, destination) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + + return None + + +def save_response_content(response, destination): + CHUNK_SIZE = 32768 + + with open(destination, "wb") as f: + for chunk in response.iter_content(CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + +# if __name__ == "__main__": +# file_id = 'TAKE ID FROM SHAREABLE LINK' +# destination = 'DESTINATION FILE ON YOUR DISK' +# download_file_from_google_drive(file_id, destination) diff --git a/file_utils.py b/craft_detector/file_utils.py similarity index 100% rename from file_utils.py rename to craft_detector/file_utils.py diff --git a/imgproc.py b/craft_detector/imgproc.py similarity index 100% rename from imgproc.py rename to craft_detector/imgproc.py diff --git a/refinenet.py b/craft_detector/refinenet.py similarity index 100% rename from refinenet.py rename to craft_detector/refinenet.py diff --git a/craft_detector/run_craft.py b/craft_detector/run_craft.py new file mode 100755 index 0000000..77be7f9 --- /dev/null +++ b/craft_detector/run_craft.py @@ -0,0 +1,95 @@ +""" +Copyright (c) 2019-present NAVER Corp. +MIT License +""" + +# -*- coding: utf-8 -*- +import sys +import os +import time +import argparse + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.autograd import Variable + +from PIL import Image + +import cv2 +from skimage import io +import numpy as np +from craft_detector import craft_utils, imgproc +import json +import zipfile + +from craft_detector.craft import CRAFT + +from collections import OrderedDict + + +def copyStateDict(state_dict): + if list(state_dict.keys())[0].startswith("module"): + start_idx = 1 + else: + start_idx = 0 + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = ".".join(k.split(".")[start_idx:]) + new_state_dict[name] = v + return new_state_dict + + +def str2bool(v): + return v.lower() in ("yes", "y", "true", "t", "1") + + +def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): + t0 = time.time() + + # resize + img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, 1280, interpolation=cv2.INTER_LINEAR, + mag_ratio=1.5) + ratio_h = ratio_w = 1 / target_ratio + + # preprocessing + x = imgproc.normalizeMeanVariance(img_resized) + x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] + x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] + if cuda: + x = x.cuda() + + # forward pass + with torch.no_grad(): + y, feature = net(x) + + # make score and link map + score_text = y[0,:,:,0].cpu().data.numpy() + score_link = y[0,:,:,1].cpu().data.numpy() + + # refine link + if refine_net is not None: + with torch.no_grad(): + y_refiner = refine_net(y, feature) + score_link = y_refiner[0,:,:,0].cpu().data.numpy() + + t0 = time.time() - t0 + t1 = time.time() + + # Post-processing + boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) + + # coordinate adjustment + boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) + polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) + for k in range(len(polys)): + if polys[k] is None: polys[k] = boxes[k] + + t1 = time.time() - t1 + + # render results (optional) + render_img = score_text.copy() + render_img = np.hstack((render_img, score_link)) + ret_score_text = imgproc.cvt2HeatmapImg(render_img) + + return boxes, polys, ret_score_text \ No newline at end of file diff --git a/test.py b/craft_detector/test.py similarity index 94% rename from test.py rename to craft_detector/test.py index 482b503..22cacb6 100755 --- a/test.py +++ b/craft_detector/test.py @@ -47,7 +47,7 @@ def str2bool(v): parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') -parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference') +parser.add_argument('--cuda', default=False, type=str2bool, help='Use cuda for inference') parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') @@ -122,7 +122,6 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r if __name__ == '__main__': # load net net = CRAFT() # initialize - print('Loading weights from checkpoint (' + args.trained_model + ')') if args.cuda: net.load_state_dict(copyStateDict(torch.load(args.trained_model))) @@ -158,6 +157,13 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r for k, image_path in enumerate(image_list): print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') image = imgproc.loadImage(image_path) + # + # (H, W) = image.shape[:2] + # (newW, newH) = (2400, 1800) + # # (newW, newH) = (W, H) + # rW = W / float(newW) + # rH = H / float(newH) + # image = cv2.resize(image, (newW, newH)) bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) diff --git a/figures/test_figure.png b/figures/test_figure.png new file mode 100644 index 0000000..e14b409 Binary files /dev/null and b/figures/test_figure.png differ diff --git a/requirements.txt b/requirements.txt index f4b2412..5b040aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ -torch==0.4.1.post2 +torch==0.4.1 torchvision==0.2.1 opencv-python==3.4.2.17 scikit-image==0.14.2 -scipy==1.1.0 \ No newline at end of file +scipy==1.1.0 +pillow==6.2.0 +requests==2.22.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..d1d3d75 --- /dev/null +++ b/setup.py @@ -0,0 +1,26 @@ +import setuptools + +with open("README.md", "r") as fh: + long_description = fh.read() + +with open('requirements.txt') as f: + required = f.read().splitlines() + +setuptools.setup( + name="craft-text-detection", + version="0.0.1", + author="Clova AI Research, NAVER Corp., Ashish Jha", + author_email="youngmin.baek@navercorp.com, arj7192@gmail.com", + description="Official implementation of Character Region Awareness for Text Detection (CRAFT)", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/clovaai/CRAFT-pytorch", + packages=['craft'], + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + install_requires=required, + python_requires='>=3.6', +) \ No newline at end of file diff --git a/unittests/__init__.py b/unittests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unittests/test_craft.py b/unittests/test_craft.py new file mode 100644 index 0000000..d7558ec --- /dev/null +++ b/unittests/test_craft.py @@ -0,0 +1,18 @@ +import unittest +import craft +import cv2 +import os + + +class TestVFE(unittest.TestCase): + def test_craft(self): + root_dir = os.path.dirname(os.path.abspath(__file__)) + img = cv2.imread(os.path.join(root_dir, '../figures/test_figure.png')) + bboxes, _, _ = craft.detect_text(img) + self.assertEqual(len(bboxes), 2) + self.assertEqual(len(bboxes[0]), 4) + self.assertEqual(len(bboxes[0][0]), 2) + self.assertEqual(int(bboxes[0][0][0]), 82) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file