opencv/samples/dnn/download_models.py
Gursimar Singh 073488896e
Merge pull request #25326 from gursimarsingh:improved_text_detection_sample
Improved and refactored text detection sample in dnn module #25326

Clean up samples: #25006

This pull requests merges and simplifies different text detection samples in dnn module of opencv in to one file. An option has been provided to choose the detection model from EAST or DB

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
2024-09-09 17:43:15 +03:00

366 lines
14 KiB
Python

'''
Helper module to download extra data from Internet
'''
from __future__ import print_function
import os
import sys
import yaml
import argparse
import tarfile
import platform
import tempfile
import hashlib
import requests
import shutil
from pathlib import Path
from datetime import datetime
from urllib.request import urlopen
import xml.etree.ElementTree as ET
__all__ = ["downloadFile"]
class HashMismatchException(Exception):
def __init__(self, expected, actual):
Exception.__init__(self)
self.expected = expected
self.actual = actual
def __str__(self):
return 'Hash mismatch: expected {} vs actual of {}'.format(self.expected, self.actual)
def getHashsumFromFile(filepath):
sha = hashlib.sha1()
if os.path.exists(filepath):
print(' there is already a file with the same name')
with open(filepath, 'rb') as f:
while True:
buf = f.read(10*1024*1024)
if not buf:
break
sha.update(buf)
hashsum = sha.hexdigest()
return hashsum
def checkHashsum(expected_sha, filepath, silent=True):
if not os.path.exists(filepath):
print(f"{filepath} does not exist. Skipping hashsum matching")
return False
print(' expected SHA1: {}'.format(expected_sha))
actual_sha = getHashsumFromFile(filepath)
print(' actual SHA1:{}'.format(actual_sha))
hashes_matched = expected_sha == actual_sha
if not hashes_matched and not silent:
raise HashMismatchException(expected_sha, actual_sha)
return hashes_matched
def isArchive(filepath):
return tarfile.is_tarfile(filepath)
class DownloadInstance:
def __init__(self, **kwargs):
self.name = kwargs.pop('name')
self.filename = kwargs.pop('filename')
self.loader = kwargs.pop('loader', None)
self.save_dir = kwargs.pop('save_dir')
self.sha = kwargs.pop('sha', None)
def __str__(self):
return 'DownloadInstance <{}>'.format(self.name)
def get(self):
print(" Working on " + self.name)
print(" Getting file " + self.filename)
if self.sha is None:
print(' No expected hashsum provided, loading file')
else:
filepath = os.path.join(self.save_dir, self.sha, self.filename)
if checkHashsum(self.sha, filepath):
print(' hash match - file already exists, skipping')
return filepath
else:
print(' hash didn\'t match, loading file')
if not os.path.exists(self.save_dir):
print(' creating directory: ' + self.save_dir)
os.makedirs(self.save_dir)
print(' hash check failed - loading')
assert self.loader
try:
self.loader.load(self.filename, self.sha, self.save_dir)
print(' done')
print(' file {}'.format(self.filename))
if self.sha is None:
download_path = os.path.join(self.save_dir, self.filename)
self.sha = getHashsumFromFile(download_path)
new_dir = os.path.join(self.save_dir, self.sha)
if not os.path.exists(new_dir):
os.makedirs(new_dir)
filepath = os.path.join(new_dir, self.filename)
if not (os.path.exists(filepath)):
shutil.move(download_path, new_dir)
print(' No expected hashsum provided, actual SHA is {}'.format(self.sha))
else:
checkHashsum(self.sha, filepath, silent=False)
except Exception as e:
print(" There was some problem with loading file {} for {}".format(self.filename, self.name))
print(" Exception: {}".format(e))
return
print(" Finished " + self.name)
return filepath
class Loader(object):
MB = 1024*1024
BUFSIZE = 10*MB
def __init__(self, download_name, download_sha, archive_member = None):
self.download_name = download_name
self.download_sha = download_sha
self.archive_member = archive_member
def load(self, requested_file, sha, save_dir):
if self.download_sha is None:
download_dir = save_dir
else:
# create a new folder in save_dir to avoid possible name conflicts
download_dir = os.path.join(save_dir, self.download_sha)
if not os.path.exists(download_dir):
os.makedirs(download_dir)
download_path = os.path.join(download_dir, self.download_name)
print(" Preparing to download file " + self.download_name)
if checkHashsum(self.download_sha, download_path):
print(' hash match - file already exists, no need to download')
else:
filesize = self.download(download_path)
print(' Downloaded {} with size {} Mb'.format(self.download_name, filesize/self.MB))
if self.download_sha is not None:
checkHashsum(self.download_sha, download_path, silent=False)
if self.download_name == requested_file:
return
else:
if isArchive(download_path):
if sha is not None:
extract_dir = os.path.join(save_dir, sha)
else:
extract_dir = save_dir
if not os.path.exists(extract_dir):
os.makedirs(extract_dir)
self.extract(requested_file, download_path, extract_dir)
else:
raise Exception("Downloaded file has different name")
def download(self, filepath):
print("Warning: download is not implemented, this is a base class")
return 0
def extract(self, requested_file, archive_path, save_dir):
filepath = os.path.join(save_dir, requested_file)
try:
with tarfile.open(archive_path) as f:
if self.archive_member is None:
pathDict = dict((os.path.split(elem)[1], os.path.split(elem)[0]) for elem in f.getnames())
self.archive_member = pathDict[requested_file]
if self.archive_member == "":
self.archive_member = requested_file
assert self.archive_member in f.getnames()
self.save(filepath, f.extractfile(self.archive_member))
except Exception as e:
print(' catch {}'.format(e))
def save(self, filepath, r):
with open(filepath, 'wb') as f:
print(' progress ', end="")
sys.stdout.flush()
while True:
buf = r.read(self.BUFSIZE)
if not buf:
break
f.write(buf)
print('>', end="")
sys.stdout.flush()
class URLLoader(Loader):
def __init__(self, download_name, download_sha, url, archive_member = None):
super(URLLoader, self).__init__(download_name, download_sha, archive_member)
self.download_name = download_name
self.download_sha = download_sha
self.url = url
def download(self, filepath):
r = urlopen(self.url, timeout=60)
self.printRequest(r)
self.save(filepath, r)
return os.path.getsize(filepath)
def printRequest(self, r):
def getMB(r):
d = dict(r.info())
for c in ['content-length', 'Content-Length']:
if c in d:
return int(d[c]) / self.MB
return '<unknown>'
print(' {} {} [{} Mb]'.format(r.getcode(), r.msg, getMB(r)))
class GDriveLoader(Loader):
BUFSIZE = 1024 * 1024
PROGRESS_SIZE = 10 * 1024 * 1024
def __init__(self, download_name, download_sha, gid, archive_member = None):
super(GDriveLoader, self).__init__(download_name, download_sha, archive_member)
self.download_name = download_name
self.download_sha = download_sha
self.gid = gid
def download(self, filepath):
session = requests.Session() # re-use cookies
URL = "https://docs.google.com/uc?export=download"
response = session.get(URL, params = { 'id' : self.gid }, stream = True)
def get_confirm_token(response): # in case of large files
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
token = get_confirm_token(response)
if token:
params = { 'id' : self.gid, 'confirm' : token }
response = session.get(URL, params = params, stream = True)
sz = 0
progress_sz = self.PROGRESS_SIZE
with open(filepath, "wb") as f:
for chunk in response.iter_content(self.BUFSIZE):
if not chunk:
continue # keep-alive
f.write(chunk)
sz += len(chunk)
if sz >= progress_sz:
progress_sz += self.PROGRESS_SIZE
print('>', end='')
sys.stdout.flush()
print('')
return sz
def produceDownloadInstance(instance_name, filename, sha, url, save_dir, download_name=None, download_sha=None, archive_member=None):
spec_param = url
loader = URLLoader
if download_name is None:
download_name = filename
if download_sha is None:
download_sha = sha
if "drive.google.com" in url:
token = ""
token_part = url.rsplit('/', 1)[-1]
if "&id=" not in token_part:
token_part = url.rsplit('/', 1)[-2]
for param in token_part.split("&"):
if param.startswith("id="):
token = param[3:]
if token:
loader = GDriveLoader
spec_param = token
else:
print("Warning: possibly wrong Google Drive link")
return DownloadInstance(
name=instance_name,
filename=filename,
sha=sha,
save_dir=save_dir,
loader=loader(download_name, download_sha, spec_param, archive_member)
)
def getSaveDir():
env_path = os.environ.get("OPENCV_DOWNLOAD_DATA_PATH", None)
if env_path:
save_dir = env_path
else:
# TODO reuse binding function cv2.utils.fs.getCacheDirectory when issue #19011 is fixed
if platform.system() == "Darwin":
#On Apple devices
temp_env = os.environ.get("TMPDIR", None)
if temp_env is None or not os.path.isdir(temp_env):
temp_dir = Path("/tmp")
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
else:
temp_dir = temp_env
elif platform.system() == "Windows":
temp_dir = tempfile.gettempdir()
else:
xdg_cache_env = os.environ.get("XDG_CACHE_HOME", None)
if (xdg_cache_env and xdg_cache_env[0] and os.path.isdir(xdg_cache_env)):
temp_dir = xdg_cache_env
else:
home_env = os.environ.get("HOME", None)
if (home_env and home_env[0] and os.path.isdir(home_env)):
home_path = os.path.join(home_env, ".cache/")
if os.path.isdir(home_path):
temp_dir = home_path
else:
temp_dir = tempfile.gettempdir()
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
save_dir = os.path.join(temp_dir, "downloads")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
return save_dir
def downloadFile(url, sha=None, save_dir=None, filename=None):
if save_dir is None:
save_dir = getSaveDir()
if filename is None:
filename = "download_" + datetime.now().__str__()
name = filename
return produceDownloadInstance(name, filename, sha, url, save_dir).get()
def parseMetalinkFile(metalink_filepath, save_dir):
NS = {'ml': 'urn:ietf:params:xml:ns:metalink'}
models = []
for file_elem in ET.parse(metalink_filepath).getroot().findall('ml:file', NS):
url = file_elem.find('ml:url', NS).text
fname = file_elem.attrib['name']
name = file_elem.find('ml:identity', NS).text
hash_sum = file_elem.find('ml:hash', NS).text
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir))
return models
def parseYAMLFile(yaml_filepath, save_dir):
models = []
with open(yaml_filepath, 'r') as stream:
data_loaded = yaml.safe_load(stream)
for name, params in data_loaded.items():
load_info = params.get("load_info", None)
if load_info:
fname = os.path.basename(params.get("model"))
hash_sum = load_info.get("sha1")
url = load_info.get("url")
download_sha = load_info.get("download_sha")
download_name = load_info.get("download_name")
archive_member = load_info.get("member")
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
return models
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='This is a utility script for downloading DNN models for samples.')
parser.add_argument('--save_dir', action="store", default=os.getcwd(),
help='Path to the directory to store downloaded files')
parser.add_argument('model_name', type=str, default="", nargs='?', action="store",
help='name of the model to download')
args = parser.parse_args()
models = []
save_dir = args.save_dir
selected_model_name = args.model_name
models.extend(parseMetalinkFile('face_detector/weights.meta4', save_dir))
models.extend(parseYAMLFile('models.yml', save_dir))
for m in models:
print(m)
if selected_model_name and not m.name.startswith(selected_model_name):
continue
print('Model: ' + selected_model_name)
m.get()