mirror of
https://github.com/opencv/opencv.git
synced 2024-12-02 16:00:17 +08:00
1f3255d76b
Scripts for downloading models in DNN samples * Initial commit. Utility classes and functions for downloading files * updated download script * Support YAML parsing, update download script and configs * Fix problem with archived files * fix models.yml * Move download utilities to more appropriate place * Fix script description * Update README * update utilities for broader range of files * fix loading with no hashsum provided * remove unnecessary import * fix for Python2 * Add usage examples for downloadFile function * Add more secure cache folder selection * Remove trailing whitespaces * Fix indentation * Update function interface * Change function for temp dir, change entry name in models.yml * Update getCacheDirectory function call * Return python implementation for cache directory selection, use more specific env variable * Fix whitespace
364 lines
14 KiB
Python
364 lines
14 KiB
Python
'''
|
|
Helper module to download extra data from Internet
|
|
'''
|
|
from __future__ import print_function
|
|
import os
|
|
import cv2
|
|
import sys
|
|
import yaml
|
|
import argparse
|
|
import tarfile
|
|
import platform
|
|
import tempfile
|
|
import hashlib
|
|
import requests
|
|
import shutil
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
if sys.version_info[0] < 3:
|
|
from urllib2 import urlopen
|
|
else:
|
|
from urllib.request import urlopen
|
|
import xml.etree.ElementTree as ET
|
|
|
|
__all__ = ["downloadFile"]
|
|
|
|
class HashMismatchException(Exception):
|
|
def __init__(self, expected, actual):
|
|
Exception.__init__(self)
|
|
self.expected = expected
|
|
self.actual = actual
|
|
def __str__(self):
|
|
return 'Hash mismatch: expected {} vs actual of {}'.format(self.expected, self.actual)
|
|
|
|
def getHashsumFromFile(filepath):
|
|
sha = hashlib.sha1()
|
|
if os.path.exists(filepath):
|
|
print(' there is already a file with the same name')
|
|
with open(filepath, 'rb') as f:
|
|
while True:
|
|
buf = f.read(10*1024*1024)
|
|
if not buf:
|
|
break
|
|
sha.update(buf)
|
|
hashsum = sha.hexdigest()
|
|
return hashsum
|
|
|
|
def checkHashsum(expected_sha, filepath, silent=True):
|
|
print(' expected SHA1: {}'.format(expected_sha))
|
|
actual_sha = getHashsumFromFile(filepath)
|
|
print(' actual SHA1:{}'.format(actual_sha))
|
|
hashes_matched = expected_sha == actual_sha
|
|
if not hashes_matched and not silent:
|
|
raise HashMismatchException(expected_sha, actual_sha)
|
|
return hashes_matched
|
|
|
|
def isArchive(filepath):
|
|
return tarfile.is_tarfile(filepath)
|
|
|
|
class DownloadInstance:
|
|
def __init__(self, **kwargs):
|
|
self.name = kwargs.pop('name')
|
|
self.filename = kwargs.pop('filename')
|
|
self.loader = kwargs.pop('loader', None)
|
|
self.save_dir = kwargs.pop('save_dir')
|
|
self.sha = kwargs.pop('sha', None)
|
|
|
|
def __str__(self):
|
|
return 'DownloadInstance <{}>'.format(self.name)
|
|
|
|
def get(self):
|
|
print(" Working on " + self.name)
|
|
print(" Getting file " + self.filename)
|
|
if self.sha is None:
|
|
print(' No expected hashsum provided, loading file')
|
|
else:
|
|
filepath = os.path.join(self.save_dir, self.sha, self.filename)
|
|
if checkHashsum(self.sha, filepath):
|
|
print(' hash match - file already exists, skipping')
|
|
return filepath
|
|
else:
|
|
print(' hash didn\'t match, loading file')
|
|
|
|
if not os.path.exists(self.save_dir):
|
|
print(' creating directory: ' + self.save_dir)
|
|
os.makedirs(self.save_dir)
|
|
|
|
|
|
print(' hash check failed - loading')
|
|
assert self.loader
|
|
try:
|
|
self.loader.load(self.filename, self.sha, self.save_dir)
|
|
print(' done')
|
|
print(' file {}'.format(self.filename))
|
|
if self.sha is None:
|
|
download_path = os.path.join(self.save_dir, self.filename)
|
|
self.sha = getHashsumFromFile(download_path)
|
|
new_dir = os.path.join(self.save_dir, self.sha)
|
|
|
|
if not os.path.exists(new_dir):
|
|
os.makedirs(new_dir)
|
|
filepath = os.path.join(new_dir, self.filename)
|
|
if not (os.path.exists(filepath)):
|
|
shutil.move(download_path, new_dir)
|
|
print(' No expected hashsum provided, actual SHA is {}'.format(self.sha))
|
|
else:
|
|
checkHashsum(self.sha, filepath, silent=False)
|
|
except Exception as e:
|
|
print(" There was some problem with loading file {} for {}".format(self.filename, self.name))
|
|
print(" Exception: {}".format(e))
|
|
return
|
|
|
|
print(" Finished " + self.name)
|
|
return filepath
|
|
|
|
class Loader(object):
|
|
MB = 1024*1024
|
|
BUFSIZE = 10*MB
|
|
def __init__(self, download_name, download_sha, archive_member = None):
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.archive_member = archive_member
|
|
|
|
def load(self, requested_file, sha, save_dir):
|
|
if self.download_sha is None:
|
|
download_dir = save_dir
|
|
else:
|
|
# create a new folder in save_dir to avoid possible name conflicts
|
|
download_dir = os.path.join(save_dir, self.download_sha)
|
|
if not os.path.exists(download_dir):
|
|
os.makedirs(download_dir)
|
|
download_path = os.path.join(download_dir, self.download_name)
|
|
print(" Preparing to download file " + self.download_name)
|
|
if checkHashsum(self.download_sha, download_path):
|
|
print(' hash match - file already exists, no need to download')
|
|
else:
|
|
filesize = self.download(download_path)
|
|
print(' Downloaded {} with size {} Mb'.format(self.download_name, filesize/self.MB))
|
|
if self.download_sha is not None:
|
|
checkHashsum(self.download_sha, download_path, silent=False)
|
|
if self.download_name == requested_file:
|
|
return
|
|
else:
|
|
if isArchive(download_path):
|
|
if sha is not None:
|
|
extract_dir = os.path.join(save_dir, sha)
|
|
else:
|
|
extract_dir = save_dir
|
|
if not os.path.exists(extract_dir):
|
|
os.makedirs(extract_dir)
|
|
self.extract(requested_file, download_path, extract_dir)
|
|
else:
|
|
raise Exception("Downloaded file has different name")
|
|
|
|
def download(self, filepath):
|
|
print("Warning: download is not implemented, this is a base class")
|
|
return 0
|
|
|
|
def extract(self, requested_file, archive_path, save_dir):
|
|
filepath = os.path.join(save_dir, requested_file)
|
|
try:
|
|
with tarfile.open(archive_path) as f:
|
|
if self.archive_member is None:
|
|
pathDict = dict((os.path.split(elem)[1], os.path.split(elem)[0]) for elem in f.getnames())
|
|
self.archive_member = pathDict[requested_file]
|
|
assert self.archive_member in f.getnames()
|
|
self.save(filepath, f.extractfile(self.archive_member))
|
|
except Exception as e:
|
|
print(' catch {}'.format(e))
|
|
|
|
def save(self, filepath, r):
|
|
with open(filepath, 'wb') as f:
|
|
print(' progress ', end="")
|
|
sys.stdout.flush()
|
|
while True:
|
|
buf = r.read(self.BUFSIZE)
|
|
if not buf:
|
|
break
|
|
f.write(buf)
|
|
print('>', end="")
|
|
sys.stdout.flush()
|
|
|
|
class URLLoader(Loader):
|
|
def __init__(self, download_name, download_sha, url, archive_member = None):
|
|
super(URLLoader, self).__init__(download_name, download_sha, archive_member)
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.url = url
|
|
|
|
def download(self, filepath):
|
|
r = urlopen(self.url, timeout=60)
|
|
self.printRequest(r)
|
|
self.save(filepath, r)
|
|
return os.path.getsize(filepath)
|
|
|
|
def printRequest(self, r):
|
|
def getMB(r):
|
|
d = dict(r.info())
|
|
for c in ['content-length', 'Content-Length']:
|
|
if c in d:
|
|
return int(d[c]) / self.MB
|
|
return '<unknown>'
|
|
print(' {} {} [{} Mb]'.format(r.getcode(), r.msg, getMB(r)))
|
|
|
|
class GDriveLoader(Loader):
|
|
BUFSIZE = 1024 * 1024
|
|
PROGRESS_SIZE = 10 * 1024 * 1024
|
|
def __init__(self, download_name, download_sha, gid, archive_member = None):
|
|
super(GDriveLoader, self).__init__(download_name, download_sha, archive_member)
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.gid = gid
|
|
|
|
def download(self, filepath):
|
|
session = requests.Session() # re-use cookies
|
|
|
|
URL = "https://docs.google.com/uc?export=download"
|
|
response = session.get(URL, params = { 'id' : self.gid }, stream = True)
|
|
|
|
def get_confirm_token(response): # in case of large files
|
|
for key, value in response.cookies.items():
|
|
if key.startswith('download_warning'):
|
|
return value
|
|
return None
|
|
token = get_confirm_token(response)
|
|
|
|
if token:
|
|
params = { 'id' : self.gid, 'confirm' : token }
|
|
response = session.get(URL, params = params, stream = True)
|
|
|
|
sz = 0
|
|
progress_sz = self.PROGRESS_SIZE
|
|
with open(filepath, "wb") as f:
|
|
for chunk in response.iter_content(self.BUFSIZE):
|
|
if not chunk:
|
|
continue # keep-alive
|
|
|
|
f.write(chunk)
|
|
sz += len(chunk)
|
|
if sz >= progress_sz:
|
|
progress_sz += self.PROGRESS_SIZE
|
|
print('>', end='')
|
|
sys.stdout.flush()
|
|
print('')
|
|
return sz
|
|
|
|
def produceDownloadInstance(instance_name, filename, sha, url, save_dir, download_name=None, download_sha=None, archive_member=None):
|
|
spec_param = url
|
|
loader = URLLoader
|
|
if download_name is None:
|
|
download_name = filename
|
|
if download_sha is None:
|
|
download_sha = sha
|
|
if "drive.google.com" in url:
|
|
token = ""
|
|
token_part = url.rsplit('/', 1)[-1]
|
|
if "&id=" not in token_part:
|
|
token_part = url.rsplit('/', 1)[-2]
|
|
for param in token_part.split("&"):
|
|
if param.startswith("id="):
|
|
token = param[3:]
|
|
if token:
|
|
loader = GDriveLoader
|
|
spec_param = token
|
|
else:
|
|
print("Warning: possibly wrong Google Drive link")
|
|
return DownloadInstance(
|
|
name=instance_name,
|
|
filename=filename,
|
|
sha=sha,
|
|
save_dir=save_dir,
|
|
loader=loader(download_name, download_sha, spec_param, archive_member)
|
|
)
|
|
|
|
def getSaveDir():
|
|
env_path = os.environ.get("OPENCV_DOWNLOAD_DATA_PATH", None)
|
|
if env_path:
|
|
save_dir = env_path
|
|
else:
|
|
# TODO reuse binding function cv2.utils.fs.getCacheDirectory when issue #19011 is fixed
|
|
if platform.system() == "Darwin":
|
|
#On Apple devices
|
|
temp_env = os.environ.get("TMPDIR", None)
|
|
if temp_env is None or not os.path.isdir(temp_env):
|
|
temp_dir = Path("/tmp")
|
|
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
|
|
else:
|
|
temp_dir = temp_env
|
|
elif platform.system() == "Windows":
|
|
temp_dir = tempfile.gettempdir()
|
|
else:
|
|
xdg_cache_env = os.environ.get("XDG_CACHE_HOME", None)
|
|
if (xdg_cache_env and xdg_cache_env[0] and os.path.isdir(xdg_cache_env)):
|
|
temp_dir = xdg_cache_env
|
|
else:
|
|
home_env = os.environ.get("HOME", None)
|
|
if (home_env and home_env[0] and os.path.isdir(home_env)):
|
|
home_path = os.path.join(home_env, ".cache/")
|
|
if os.path.isdir(home_path):
|
|
temp_dir = home_path
|
|
else:
|
|
temp_dir = tempfile.gettempdir()
|
|
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
|
|
|
|
save_dir = os.path.join(temp_dir, "downloads")
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
return save_dir
|
|
|
|
def downloadFile(url, sha=None, save_dir=None, filename=None):
|
|
if save_dir is None:
|
|
save_dir = getSaveDir()
|
|
if filename is None:
|
|
filename = "download_" + datetime.now().__str__()
|
|
name = filename
|
|
return produceDownloadInstance(name, filename, sha, url, save_dir).get()
|
|
|
|
def parseMetalinkFile(metalink_filepath, save_dir):
|
|
NS = {'ml': 'urn:ietf:params:xml:ns:metalink'}
|
|
models = []
|
|
for file_elem in ET.parse(metalink_filepath).getroot().findall('ml:file', NS):
|
|
url = file_elem.find('ml:url', NS).text
|
|
fname = file_elem.attrib['name']
|
|
name = file_elem.find('ml:identity', NS).text
|
|
hash_sum = file_elem.find('ml:hash', NS).text
|
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir))
|
|
return models
|
|
|
|
def parseYAMLFile(yaml_filepath, save_dir):
|
|
models = []
|
|
with open(yaml_filepath, 'r') as stream:
|
|
data_loaded = yaml.safe_load(stream)
|
|
for name, params in data_loaded.items():
|
|
load_info = params.get("load_info", None)
|
|
if load_info:
|
|
fname = os.path.basename(params.get("model"))
|
|
hash_sum = load_info.get("sha1")
|
|
url = load_info.get("url")
|
|
download_sha = load_info.get("download_sha")
|
|
download_name = load_info.get("download_name")
|
|
archive_member = load_info.get("member")
|
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
|
|
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
|
|
|
|
return models
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='This is a utility script for downloading DNN models for samples.')
|
|
|
|
parser.add_argument('--save_dir', action="store", default=os.getcwd(),
|
|
help='Path to the directory to store downloaded files')
|
|
parser.add_argument('model_name', type=str, default="", nargs='?', action="store",
|
|
help='name of the model to download')
|
|
args = parser.parse_args()
|
|
models = []
|
|
save_dir = args.save_dir
|
|
selected_model_name = args.model_name
|
|
models.extend(parseMetalinkFile('face_detector/weights.meta4', save_dir))
|
|
models.extend(parseYAMLFile('models.yml', save_dir))
|
|
for m in models:
|
|
print(m)
|
|
if selected_model_name and not m.name.startswith(selected_model_name):
|
|
continue
|
|
print('Model: ' + selected_model_name)
|
|
m.get() |