mirror of
https://github.com/opencv/opencv.git
synced 2024-12-04 00:39:11 +08:00
073488896e
Improved and refactored text detection sample in dnn module #25326 Clean up samples: #25006 This pull requests merges and simplifies different text detection samples in dnn module of opencv in to one file. An option has been provided to choose the detection model from EAST or DB ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
366 lines
14 KiB
Python
366 lines
14 KiB
Python
'''
|
|
Helper module to download extra data from Internet
|
|
'''
|
|
from __future__ import print_function
|
|
import os
|
|
import sys
|
|
import yaml
|
|
import argparse
|
|
import tarfile
|
|
import platform
|
|
import tempfile
|
|
import hashlib
|
|
import requests
|
|
import shutil
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from urllib.request import urlopen
|
|
import xml.etree.ElementTree as ET
|
|
|
|
__all__ = ["downloadFile"]
|
|
|
|
class HashMismatchException(Exception):
|
|
def __init__(self, expected, actual):
|
|
Exception.__init__(self)
|
|
self.expected = expected
|
|
self.actual = actual
|
|
def __str__(self):
|
|
return 'Hash mismatch: expected {} vs actual of {}'.format(self.expected, self.actual)
|
|
|
|
def getHashsumFromFile(filepath):
|
|
sha = hashlib.sha1()
|
|
if os.path.exists(filepath):
|
|
print(' there is already a file with the same name')
|
|
with open(filepath, 'rb') as f:
|
|
while True:
|
|
buf = f.read(10*1024*1024)
|
|
if not buf:
|
|
break
|
|
sha.update(buf)
|
|
hashsum = sha.hexdigest()
|
|
return hashsum
|
|
|
|
def checkHashsum(expected_sha, filepath, silent=True):
|
|
if not os.path.exists(filepath):
|
|
print(f"{filepath} does not exist. Skipping hashsum matching")
|
|
return False
|
|
print(' expected SHA1: {}'.format(expected_sha))
|
|
actual_sha = getHashsumFromFile(filepath)
|
|
print(' actual SHA1:{}'.format(actual_sha))
|
|
hashes_matched = expected_sha == actual_sha
|
|
if not hashes_matched and not silent:
|
|
raise HashMismatchException(expected_sha, actual_sha)
|
|
return hashes_matched
|
|
|
|
def isArchive(filepath):
|
|
return tarfile.is_tarfile(filepath)
|
|
|
|
class DownloadInstance:
|
|
def __init__(self, **kwargs):
|
|
self.name = kwargs.pop('name')
|
|
self.filename = kwargs.pop('filename')
|
|
self.loader = kwargs.pop('loader', None)
|
|
self.save_dir = kwargs.pop('save_dir')
|
|
self.sha = kwargs.pop('sha', None)
|
|
|
|
def __str__(self):
|
|
return 'DownloadInstance <{}>'.format(self.name)
|
|
|
|
def get(self):
|
|
print(" Working on " + self.name)
|
|
print(" Getting file " + self.filename)
|
|
if self.sha is None:
|
|
print(' No expected hashsum provided, loading file')
|
|
else:
|
|
filepath = os.path.join(self.save_dir, self.sha, self.filename)
|
|
if checkHashsum(self.sha, filepath):
|
|
print(' hash match - file already exists, skipping')
|
|
return filepath
|
|
else:
|
|
print(' hash didn\'t match, loading file')
|
|
|
|
if not os.path.exists(self.save_dir):
|
|
print(' creating directory: ' + self.save_dir)
|
|
os.makedirs(self.save_dir)
|
|
|
|
|
|
print(' hash check failed - loading')
|
|
assert self.loader
|
|
try:
|
|
self.loader.load(self.filename, self.sha, self.save_dir)
|
|
print(' done')
|
|
print(' file {}'.format(self.filename))
|
|
if self.sha is None:
|
|
download_path = os.path.join(self.save_dir, self.filename)
|
|
self.sha = getHashsumFromFile(download_path)
|
|
new_dir = os.path.join(self.save_dir, self.sha)
|
|
|
|
if not os.path.exists(new_dir):
|
|
os.makedirs(new_dir)
|
|
filepath = os.path.join(new_dir, self.filename)
|
|
if not (os.path.exists(filepath)):
|
|
shutil.move(download_path, new_dir)
|
|
print(' No expected hashsum provided, actual SHA is {}'.format(self.sha))
|
|
else:
|
|
checkHashsum(self.sha, filepath, silent=False)
|
|
except Exception as e:
|
|
print(" There was some problem with loading file {} for {}".format(self.filename, self.name))
|
|
print(" Exception: {}".format(e))
|
|
return
|
|
|
|
print(" Finished " + self.name)
|
|
return filepath
|
|
|
|
class Loader(object):
|
|
MB = 1024*1024
|
|
BUFSIZE = 10*MB
|
|
def __init__(self, download_name, download_sha, archive_member = None):
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.archive_member = archive_member
|
|
|
|
def load(self, requested_file, sha, save_dir):
|
|
if self.download_sha is None:
|
|
download_dir = save_dir
|
|
else:
|
|
# create a new folder in save_dir to avoid possible name conflicts
|
|
download_dir = os.path.join(save_dir, self.download_sha)
|
|
if not os.path.exists(download_dir):
|
|
os.makedirs(download_dir)
|
|
download_path = os.path.join(download_dir, self.download_name)
|
|
print(" Preparing to download file " + self.download_name)
|
|
if checkHashsum(self.download_sha, download_path):
|
|
print(' hash match - file already exists, no need to download')
|
|
else:
|
|
filesize = self.download(download_path)
|
|
print(' Downloaded {} with size {} Mb'.format(self.download_name, filesize/self.MB))
|
|
if self.download_sha is not None:
|
|
checkHashsum(self.download_sha, download_path, silent=False)
|
|
if self.download_name == requested_file:
|
|
return
|
|
else:
|
|
if isArchive(download_path):
|
|
if sha is not None:
|
|
extract_dir = os.path.join(save_dir, sha)
|
|
else:
|
|
extract_dir = save_dir
|
|
if not os.path.exists(extract_dir):
|
|
os.makedirs(extract_dir)
|
|
self.extract(requested_file, download_path, extract_dir)
|
|
else:
|
|
raise Exception("Downloaded file has different name")
|
|
|
|
def download(self, filepath):
|
|
print("Warning: download is not implemented, this is a base class")
|
|
return 0
|
|
|
|
def extract(self, requested_file, archive_path, save_dir):
|
|
filepath = os.path.join(save_dir, requested_file)
|
|
try:
|
|
with tarfile.open(archive_path) as f:
|
|
if self.archive_member is None:
|
|
pathDict = dict((os.path.split(elem)[1], os.path.split(elem)[0]) for elem in f.getnames())
|
|
self.archive_member = pathDict[requested_file]
|
|
if self.archive_member == "":
|
|
self.archive_member = requested_file
|
|
assert self.archive_member in f.getnames()
|
|
self.save(filepath, f.extractfile(self.archive_member))
|
|
except Exception as e:
|
|
print(' catch {}'.format(e))
|
|
|
|
def save(self, filepath, r):
|
|
with open(filepath, 'wb') as f:
|
|
print(' progress ', end="")
|
|
sys.stdout.flush()
|
|
while True:
|
|
buf = r.read(self.BUFSIZE)
|
|
if not buf:
|
|
break
|
|
f.write(buf)
|
|
print('>', end="")
|
|
sys.stdout.flush()
|
|
|
|
class URLLoader(Loader):
|
|
def __init__(self, download_name, download_sha, url, archive_member = None):
|
|
super(URLLoader, self).__init__(download_name, download_sha, archive_member)
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.url = url
|
|
|
|
def download(self, filepath):
|
|
r = urlopen(self.url, timeout=60)
|
|
self.printRequest(r)
|
|
self.save(filepath, r)
|
|
return os.path.getsize(filepath)
|
|
|
|
def printRequest(self, r):
|
|
def getMB(r):
|
|
d = dict(r.info())
|
|
for c in ['content-length', 'Content-Length']:
|
|
if c in d:
|
|
return int(d[c]) / self.MB
|
|
return '<unknown>'
|
|
print(' {} {} [{} Mb]'.format(r.getcode(), r.msg, getMB(r)))
|
|
|
|
class GDriveLoader(Loader):
|
|
BUFSIZE = 1024 * 1024
|
|
PROGRESS_SIZE = 10 * 1024 * 1024
|
|
def __init__(self, download_name, download_sha, gid, archive_member = None):
|
|
super(GDriveLoader, self).__init__(download_name, download_sha, archive_member)
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.gid = gid
|
|
|
|
def download(self, filepath):
|
|
session = requests.Session() # re-use cookies
|
|
|
|
URL = "https://docs.google.com/uc?export=download"
|
|
response = session.get(URL, params = { 'id' : self.gid }, stream = True)
|
|
|
|
def get_confirm_token(response): # in case of large files
|
|
for key, value in response.cookies.items():
|
|
if key.startswith('download_warning'):
|
|
return value
|
|
return None
|
|
token = get_confirm_token(response)
|
|
|
|
if token:
|
|
params = { 'id' : self.gid, 'confirm' : token }
|
|
response = session.get(URL, params = params, stream = True)
|
|
|
|
sz = 0
|
|
progress_sz = self.PROGRESS_SIZE
|
|
with open(filepath, "wb") as f:
|
|
for chunk in response.iter_content(self.BUFSIZE):
|
|
if not chunk:
|
|
continue # keep-alive
|
|
|
|
f.write(chunk)
|
|
sz += len(chunk)
|
|
if sz >= progress_sz:
|
|
progress_sz += self.PROGRESS_SIZE
|
|
print('>', end='')
|
|
sys.stdout.flush()
|
|
print('')
|
|
return sz
|
|
|
|
def produceDownloadInstance(instance_name, filename, sha, url, save_dir, download_name=None, download_sha=None, archive_member=None):
|
|
spec_param = url
|
|
loader = URLLoader
|
|
if download_name is None:
|
|
download_name = filename
|
|
if download_sha is None:
|
|
download_sha = sha
|
|
if "drive.google.com" in url:
|
|
token = ""
|
|
token_part = url.rsplit('/', 1)[-1]
|
|
if "&id=" not in token_part:
|
|
token_part = url.rsplit('/', 1)[-2]
|
|
for param in token_part.split("&"):
|
|
if param.startswith("id="):
|
|
token = param[3:]
|
|
if token:
|
|
loader = GDriveLoader
|
|
spec_param = token
|
|
else:
|
|
print("Warning: possibly wrong Google Drive link")
|
|
return DownloadInstance(
|
|
name=instance_name,
|
|
filename=filename,
|
|
sha=sha,
|
|
save_dir=save_dir,
|
|
loader=loader(download_name, download_sha, spec_param, archive_member)
|
|
)
|
|
|
|
def getSaveDir():
|
|
env_path = os.environ.get("OPENCV_DOWNLOAD_DATA_PATH", None)
|
|
if env_path:
|
|
save_dir = env_path
|
|
else:
|
|
# TODO reuse binding function cv2.utils.fs.getCacheDirectory when issue #19011 is fixed
|
|
if platform.system() == "Darwin":
|
|
#On Apple devices
|
|
temp_env = os.environ.get("TMPDIR", None)
|
|
if temp_env is None or not os.path.isdir(temp_env):
|
|
temp_dir = Path("/tmp")
|
|
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
|
|
else:
|
|
temp_dir = temp_env
|
|
elif platform.system() == "Windows":
|
|
temp_dir = tempfile.gettempdir()
|
|
else:
|
|
xdg_cache_env = os.environ.get("XDG_CACHE_HOME", None)
|
|
if (xdg_cache_env and xdg_cache_env[0] and os.path.isdir(xdg_cache_env)):
|
|
temp_dir = xdg_cache_env
|
|
else:
|
|
home_env = os.environ.get("HOME", None)
|
|
if (home_env and home_env[0] and os.path.isdir(home_env)):
|
|
home_path = os.path.join(home_env, ".cache/")
|
|
if os.path.isdir(home_path):
|
|
temp_dir = home_path
|
|
else:
|
|
temp_dir = tempfile.gettempdir()
|
|
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
|
|
|
|
save_dir = os.path.join(temp_dir, "downloads")
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
return save_dir
|
|
|
|
def downloadFile(url, sha=None, save_dir=None, filename=None):
|
|
if save_dir is None:
|
|
save_dir = getSaveDir()
|
|
if filename is None:
|
|
filename = "download_" + datetime.now().__str__()
|
|
name = filename
|
|
return produceDownloadInstance(name, filename, sha, url, save_dir).get()
|
|
|
|
def parseMetalinkFile(metalink_filepath, save_dir):
|
|
NS = {'ml': 'urn:ietf:params:xml:ns:metalink'}
|
|
models = []
|
|
for file_elem in ET.parse(metalink_filepath).getroot().findall('ml:file', NS):
|
|
url = file_elem.find('ml:url', NS).text
|
|
fname = file_elem.attrib['name']
|
|
name = file_elem.find('ml:identity', NS).text
|
|
hash_sum = file_elem.find('ml:hash', NS).text
|
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir))
|
|
return models
|
|
|
|
def parseYAMLFile(yaml_filepath, save_dir):
|
|
models = []
|
|
with open(yaml_filepath, 'r') as stream:
|
|
data_loaded = yaml.safe_load(stream)
|
|
for name, params in data_loaded.items():
|
|
load_info = params.get("load_info", None)
|
|
if load_info:
|
|
fname = os.path.basename(params.get("model"))
|
|
hash_sum = load_info.get("sha1")
|
|
url = load_info.get("url")
|
|
download_sha = load_info.get("download_sha")
|
|
download_name = load_info.get("download_name")
|
|
archive_member = load_info.get("member")
|
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
|
|
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
|
|
|
|
return models
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='This is a utility script for downloading DNN models for samples.')
|
|
|
|
parser.add_argument('--save_dir', action="store", default=os.getcwd(),
|
|
help='Path to the directory to store downloaded files')
|
|
parser.add_argument('model_name', type=str, default="", nargs='?', action="store",
|
|
help='name of the model to download')
|
|
args = parser.parse_args()
|
|
models = []
|
|
save_dir = args.save_dir
|
|
selected_model_name = args.model_name
|
|
models.extend(parseMetalinkFile('face_detector/weights.meta4', save_dir))
|
|
models.extend(parseYAMLFile('models.yml', save_dir))
|
|
for m in models:
|
|
print(m)
|
|
if selected_model_name and not m.name.startswith(selected_model_name):
|
|
continue
|
|
print('Model: ' + selected_model_name)
|
|
m.get()
|