Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Packaging CRAFT text detection under PyPI #90

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
Official Pytorch implementation of CRAFT text detector | [Paper](https://arxiv.org/abs/1904.01941) | [Pretrained Model](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ) | [Supplementary](https://youtu.be/HI8MzpY8KMI)

**[Youngmin Baek](mailto:[email protected]), Bado Lee, Dongyoon Han, Sangdoo Yun, Hwalsuk Lee.**

Clova AI Research, NAVER Corp.

**Packaged by [Ashish Jha](mailto:[email protected])**

### Sample Results

### Overview
Expand All @@ -16,15 +18,39 @@ PyTorch implementation for CRAFT text detector that effectively detect text area
**13 Jun, 2019**: Initial update
**20 Jul, 2019**: Added post-processing for polygon result
**28 Sep, 2019**: Added the trained model on IC15 and the link refiner
**25 Jan, 2020**: Put it together as a PyPI package


## Getting started
### Install dependencies

### Use it straight from PyPI
#### Installation
```
pip install craft-text-detection
```
#### Usage
```
import craft
import cv2
img = cv2.imread('/path/to/image/file')

# run the detector
bboxes, polys, heatmap = craft.detect_text(img)

# view the image with bounding boxes
img_boxed = craft.show_bounding_boxes(img, bboxes)
cv2.imshow('fig', img_boxed)

# view detection heatmap
cv2.imshow('fig', heatmap)
```

### Use from source - install dependencies
#### Requirements
- PyTorch>=0.4.1
- torchvision>=0.2.1
- opencv-python>=3.4.2
- check requiremtns.txt
- check requiremetns.txt
```
pip install -r requirements.txt
```
Expand All @@ -33,34 +59,13 @@ pip install -r requirements.txt
The code for training is not included in this repository, and we cannot release the full training code for IP reason.


### Test instruction using pretrained model
- Download the trained models

*Model name* | *Used datasets* | *Languages* | *Purpose* | *Model Link* |
| :--- | :--- | :--- | :--- | :--- |
General | SynthText, IC13, IC17 | Eng + MLT | For general purpose | [Click](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ)
IC15 | SynthText, IC15 | Eng | For IC15 only | [Click](https://drive.google.com/open?id=1i2R7UIUqmkUtF0jv_3MXTqmQ_9wuAnLf)
LinkRefiner | CTW1500 | - | Used with the General Model | [Click](https://drive.google.com/open?id=1XSaFwBkOaFOdtk4Ane3DFyJGPRw6v5bO)

* Run with pretrained model
``` (with python 3.7)
python test.py --trained_model=[weightfile] --test_folder=[folder path to test images]
```

The result image and socre maps will be saved to `./result` by default.

### Arguments
* `--trained_model`: pretrained model
### Arguments for detect_text
* `--text_threshold`: text confidence threshold
* `--low_text`: text low-bound score
* `--link_threshold`: link confidence threshold
* `--cuda`: use cuda for inference (default:True)
* `--canvas_size`: max image size for inference
* `--mag_ratio`: image magnification ratio
* `--poly`: enable polygon type result
* `--show_time`: show processing time
* `--test_folder`: folder path to input images
* `--refine`: use link refiner for sentense-level dataset
* `--refine`: use link refiner for sentence-level dataset
* `--refiner_model`: pretrained refiner model


Expand Down
Empty file added __init__.py
Empty file.
29 changes: 29 additions & 0 deletions craft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from craft_detector.run_craft import *
from craft_detector.downloader import *
import os

if not os.path.isfile('/tmp/craft_mlt_25k.pth'):
print("downloading model")
download_file_from_google_drive('1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ', '/tmp/craft_mlt_25k.pth')

net = CRAFT()
net.load_state_dict(copyStateDict(torch.load('/tmp/craft_mlt_25k.pth', map_location='cpu')))
net.eval()
refine_net = None


def detect_text(img, text_threshold=0.7, link_threshold=0.4, low_text=0.4):
bboxes, polys, score_text = test_net(net, img, text_threshold, link_threshold, low_text, False, False,
refine_net)
return bboxes, polys, score_text


def show_bounding_boxes(img, bboxes):
img = np.array(img)
for i, box in enumerate(bboxes):
poly = np.array(box).astype(np.int32).reshape((-1))

poly = poly.reshape(-1, 2)
cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2)

return img
Empty file added craft_detector/__init__.py
Empty file.
Empty file.
73 changes: 73 additions & 0 deletions craft_detector/basenet/vgg16_bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.init as init
from torchvision import models
from torchvision.models.vgg import model_urls

def init_weights(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()

class vgg16_bn(torch.nn.Module):
def __init__(self, pretrained=True, freeze=True):
super(vgg16_bn, self).__init__()
model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(12): # conv2_2
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 19): # conv3_3
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(19, 29): # conv4_3
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(29, 39): # conv5_3
self.slice4.add_module(str(x), vgg_pretrained_features[x])

# fc6, fc7 without atrous conv
self.slice5 = torch.nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
nn.Conv2d(1024, 1024, kernel_size=1)
)

if not pretrained:
init_weights(self.slice1.modules())
init_weights(self.slice2.modules())
init_weights(self.slice3.modules())
init_weights(self.slice4.modules())

init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7

if freeze:
for param in self.slice1.parameters(): # only first conv
param.requires_grad= False

def forward(self, X):
h = self.slice1(X)
h_relu2_2 = h
h = self.slice2(h)
h_relu3_2 = h
h = self.slice3(h)
h_relu4_3 = h
h = self.slice4(h)
h_relu5_3 = h
h = self.slice5(h)
h_fc7 = h
vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
return out
3 changes: 2 additions & 1 deletion craft.py → craft_detector/craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import torch.nn as nn
import torch.nn.functional as F

from basenet.vgg16_bn import vgg16_bn, init_weights
from craft_detector.basenet.vgg16_bn import vgg16_bn, init_weights


class double_conv(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch):
Expand Down
File renamed without changes.
39 changes: 39 additions & 0 deletions craft_detector/downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import requests


def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"

session = requests.Session()

response = session.get(URL, params = { 'id' : id }, stream = True)
token = get_confirm_token(response)

if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params = params, stream = True)

save_response_content(response, destination)


def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value

return None


def save_response_content(response, destination):
CHUNK_SIZE = 32768

with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)


# if __name__ == "__main__":
# file_id = 'TAKE ID FROM SHAREABLE LINK'
# destination = 'DESTINATION FILE ON YOUR DISK'
# download_file_from_google_drive(file_id, destination)
File renamed without changes.
File renamed without changes.
File renamed without changes.
95 changes: 95 additions & 0 deletions craft_detector/run_craft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""

# -*- coding: utf-8 -*-
import sys
import os
import time
import argparse

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

from PIL import Image

import cv2
from skimage import io
import numpy as np
from craft_detector import craft_utils, imgproc
import json
import zipfile

from craft_detector.craft import CRAFT

from collections import OrderedDict


def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict


def str2bool(v):
return v.lower() in ("yes", "y", "true", "t", "1")


def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
t0 = time.time()

# resize
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, 1280, interpolation=cv2.INTER_LINEAR,
mag_ratio=1.5)
ratio_h = ratio_w = 1 / target_ratio

# preprocessing
x = imgproc.normalizeMeanVariance(img_resized)
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
if cuda:
x = x.cuda()

# forward pass
with torch.no_grad():
y, feature = net(x)

# make score and link map
score_text = y[0,:,:,0].cpu().data.numpy()
score_link = y[0,:,:,1].cpu().data.numpy()

# refine link
if refine_net is not None:
with torch.no_grad():
y_refiner = refine_net(y, feature)
score_link = y_refiner[0,:,:,0].cpu().data.numpy()

t0 = time.time() - t0
t1 = time.time()

# Post-processing
boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)

# coordinate adjustment
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
for k in range(len(polys)):
if polys[k] is None: polys[k] = boxes[k]

t1 = time.time() - t1

# render results (optional)
render_img = score_text.copy()
render_img = np.hstack((render_img, score_link))
ret_score_text = imgproc.cvt2HeatmapImg(render_img)

return boxes, polys, ret_score_text
10 changes: 8 additions & 2 deletions test.py → craft_detector/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def str2bool(v):
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
parser.add_argument('--cuda', default=False, type=str2bool, help='Use cuda for inference')
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
Expand Down Expand Up @@ -122,7 +122,6 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r
if __name__ == '__main__':
# load net
net = CRAFT() # initialize

print('Loading weights from checkpoint (' + args.trained_model + ')')
if args.cuda:
net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
Expand Down Expand Up @@ -158,6 +157,13 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, r
for k, image_path in enumerate(image_list):
print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
image = imgproc.loadImage(image_path)
#
# (H, W) = image.shape[:2]
# (newW, newH) = (2400, 1800)
# # (newW, newH) = (W, H)
# rW = W / float(newW)
# rH = H / float(newH)
# image = cv2.resize(image, (newW, newH))

bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)

Expand Down
Binary file added figures/test_figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
torch==0.4.1.post2
torch==0.4.1
torchvision==0.2.1
opencv-python==3.4.2.17
scikit-image==0.14.2
scipy==1.1.0
scipy==1.1.0
pillow==6.2.0
requests==2.22.0
Loading