mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 13:13:26 +08:00

Extension to PR #26605 ldm inpainting sample #26904 This PR adds and fixes following points in the ldm_inpainting sample on top of original PR #26605 by @Abdurrahheem DONE: 1. Added functionality to load models from a YAML configuration file, allowing for automatic downloading if models are not found locally. 2. Updated the script usage instructions to reflect the correct command format. 3. Improved user interaction by adding instructions to the image window for inpainting controls. 4. Introduced a new models.yml configuration section for inpainting models weights downloading, including placeholders for model SHA1 checksums. 5. Fixed input types and names of the onnx graph generation. 6. Added links to onnx graphs in models.yml 7. Support added for findModels and standarized the sample usage similar to other dnn samples 8. Fixes issue in download_models.py for downloading models from dl.opencv.org 9. Fixes issue in common.py which used to print duplicated positional arguments in case of samples that use multiple models. ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake --------- Co-authored-by: Abdurrahheem <abduragim.shtanchaev@xperience.ai>
384 lines
15 KiB
Python
384 lines
15 KiB
Python
'''
|
|
Helper module to download extra data from Internet
|
|
'''
|
|
from __future__ import print_function
|
|
import os
|
|
import sys
|
|
import yaml
|
|
import argparse
|
|
import tarfile
|
|
import platform
|
|
import tempfile
|
|
import hashlib
|
|
import requests
|
|
import shutil
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from urllib.request import Request, urlopen
|
|
import xml.etree.ElementTree as ET
|
|
|
|
__all__ = ["downloadFile"]
|
|
|
|
class HashMismatchException(Exception):
|
|
def __init__(self, expected, actual):
|
|
Exception.__init__(self)
|
|
self.expected = expected
|
|
self.actual = actual
|
|
def __str__(self):
|
|
return 'Hash mismatch: expected {} vs actual of {}'.format(self.expected, self.actual)
|
|
|
|
def getHashsumFromFile(filepath):
|
|
sha = hashlib.sha1()
|
|
if os.path.exists(filepath):
|
|
print(' there is already a file with the same name')
|
|
with open(filepath, 'rb') as f:
|
|
while True:
|
|
buf = f.read(10*1024*1024)
|
|
if not buf:
|
|
break
|
|
sha.update(buf)
|
|
hashsum = sha.hexdigest()
|
|
return hashsum
|
|
|
|
def checkHashsum(expected_sha, filepath, silent=True):
|
|
if not os.path.exists(filepath):
|
|
print(f"{filepath} does not exist. Skipping hashsum matching")
|
|
return False
|
|
print(' expected SHA1: {}'.format(expected_sha))
|
|
actual_sha = getHashsumFromFile(filepath)
|
|
print(' actual SHA1:{}'.format(actual_sha))
|
|
hashes_matched = expected_sha == actual_sha
|
|
if not hashes_matched and not silent:
|
|
raise HashMismatchException(expected_sha, actual_sha)
|
|
return hashes_matched
|
|
|
|
def isArchive(filepath):
|
|
return tarfile.is_tarfile(filepath)
|
|
|
|
class DownloadInstance:
|
|
def __init__(self, **kwargs):
|
|
self.name = kwargs.pop('name')
|
|
self.filename = kwargs.pop('filename')
|
|
self.loader = kwargs.pop('loader', None)
|
|
self.save_dir = kwargs.pop('save_dir')
|
|
self.sha = kwargs.pop('sha', None)
|
|
|
|
def __str__(self):
|
|
return 'DownloadInstance <{}>'.format(self.name)
|
|
|
|
def get(self):
|
|
print(" Working on " + self.name)
|
|
print(" Getting file " + self.filename)
|
|
if self.sha is None:
|
|
print(' No expected hashsum provided, loading file')
|
|
else:
|
|
filepath = os.path.join(self.save_dir, self.sha, self.filename)
|
|
if checkHashsum(self.sha, filepath):
|
|
print(' hash match - file already exists, skipping')
|
|
return filepath
|
|
else:
|
|
print(' hash didn\'t match, loading file')
|
|
|
|
if not os.path.exists(self.save_dir):
|
|
print(' creating directory: ' + self.save_dir)
|
|
os.makedirs(self.save_dir)
|
|
|
|
|
|
print(' hash check failed - loading')
|
|
assert self.loader
|
|
try:
|
|
self.loader.load(self.filename, self.sha, self.save_dir)
|
|
print(' done')
|
|
print(' file {}'.format(self.filename))
|
|
if self.sha is None:
|
|
download_path = os.path.join(self.save_dir, self.filename)
|
|
self.sha = getHashsumFromFile(download_path)
|
|
new_dir = os.path.join(self.save_dir, self.sha)
|
|
|
|
if not os.path.exists(new_dir):
|
|
os.makedirs(new_dir)
|
|
filepath = os.path.join(new_dir, self.filename)
|
|
if not (os.path.exists(filepath)):
|
|
shutil.move(download_path, new_dir)
|
|
print(' No expected hashsum provided, actual SHA is {}'.format(self.sha))
|
|
else:
|
|
checkHashsum(self.sha, filepath, silent=False)
|
|
except Exception as e:
|
|
print(" There was some problem with loading file {} for {}".format(self.filename, self.name))
|
|
print(" Exception: {}".format(e))
|
|
return
|
|
|
|
print(" Finished " + self.name)
|
|
return filepath
|
|
|
|
class Loader(object):
|
|
MB = 1024*1024
|
|
BUFSIZE = 10*MB
|
|
def __init__(self, download_name, download_sha, archive_member = None):
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.archive_member = archive_member
|
|
|
|
def load(self, requested_file, sha, save_dir):
|
|
if self.download_sha is None:
|
|
download_dir = save_dir
|
|
else:
|
|
# create a new folder in save_dir to avoid possible name conflicts
|
|
download_dir = os.path.join(save_dir, self.download_sha)
|
|
if not os.path.exists(download_dir):
|
|
os.makedirs(download_dir)
|
|
download_path = os.path.join(download_dir, self.download_name)
|
|
print(" Preparing to download file " + self.download_name)
|
|
if checkHashsum(self.download_sha, download_path):
|
|
print(' hash match - file already exists, no need to download')
|
|
else:
|
|
filesize = self.download(download_path)
|
|
print(' Downloaded {} with size {} Mb'.format(self.download_name, filesize/self.MB))
|
|
if self.download_sha is not None:
|
|
checkHashsum(self.download_sha, download_path, silent=False)
|
|
if self.download_name == requested_file:
|
|
return
|
|
else:
|
|
if isArchive(download_path):
|
|
if sha is not None:
|
|
extract_dir = os.path.join(save_dir, sha)
|
|
else:
|
|
extract_dir = save_dir
|
|
if not os.path.exists(extract_dir):
|
|
os.makedirs(extract_dir)
|
|
self.extract(requested_file, download_path, extract_dir)
|
|
else:
|
|
raise Exception("Downloaded file has different name")
|
|
|
|
def download(self, filepath):
|
|
print("Warning: download is not implemented, this is a base class")
|
|
return 0
|
|
|
|
def extract(self, requested_file, archive_path, save_dir):
|
|
filepath = os.path.join(save_dir, requested_file)
|
|
try:
|
|
with tarfile.open(archive_path) as f:
|
|
if self.archive_member is None:
|
|
pathDict = dict((os.path.split(elem)[1], os.path.split(elem)[0]) for elem in f.getnames())
|
|
self.archive_member = pathDict[requested_file]
|
|
if self.archive_member == "":
|
|
self.archive_member = requested_file
|
|
assert self.archive_member in f.getnames()
|
|
self.save(filepath, f.extractfile(self.archive_member))
|
|
except Exception as e:
|
|
print(' catch {}'.format(e))
|
|
|
|
def save(self, filepath, r):
|
|
with open(filepath, 'wb') as f:
|
|
print(' progress ', end="")
|
|
sys.stdout.flush()
|
|
while True:
|
|
buf = r.read(self.BUFSIZE)
|
|
if not buf:
|
|
break
|
|
f.write(buf)
|
|
print('>', end="")
|
|
sys.stdout.flush()
|
|
|
|
class URLLoader(Loader):
|
|
def __init__(self, download_name, download_sha, url, archive_member = None):
|
|
super(URLLoader, self).__init__(download_name, download_sha, archive_member)
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.url = url
|
|
|
|
def download(self, filepath):
|
|
headers = {'User-Agent': 'Wget/1.20.3'}
|
|
req = Request(self.url, headers=headers)
|
|
with urlopen(req, timeout=60) as r:
|
|
self.printRequest(r)
|
|
self.save(filepath, r)
|
|
return os.path.getsize(filepath)
|
|
|
|
def printRequest(self, r):
|
|
def getMB(r):
|
|
d = dict(r.info())
|
|
for c in ['content-length', 'Content-Length']:
|
|
if c in d:
|
|
return int(d[c]) / self.MB
|
|
return '<unknown>'
|
|
print(' {} {} [{} Mb]'.format(r.getcode(), r.msg, getMB(r)))
|
|
|
|
class GDriveLoader(Loader):
|
|
BUFSIZE = 1024 * 1024
|
|
PROGRESS_SIZE = 10 * 1024 * 1024
|
|
def __init__(self, download_name, download_sha, gid, archive_member = None):
|
|
super(GDriveLoader, self).__init__(download_name, download_sha, archive_member)
|
|
self.download_name = download_name
|
|
self.download_sha = download_sha
|
|
self.gid = gid
|
|
|
|
def download(self, filepath):
|
|
session = requests.Session() # re-use cookies
|
|
|
|
URL = "https://docs.google.com/uc?export=download"
|
|
response = session.get(URL, params = { 'id' : self.gid }, stream = True)
|
|
|
|
def get_confirm_token(response): # in case of large files
|
|
for key, value in response.cookies.items():
|
|
if key.startswith('download_warning'):
|
|
return value
|
|
return None
|
|
token = get_confirm_token(response)
|
|
|
|
if token:
|
|
params = { 'id' : self.gid, 'confirm' : token }
|
|
response = session.get(URL, params = params, stream = True)
|
|
|
|
sz = 0
|
|
progress_sz = self.PROGRESS_SIZE
|
|
with open(filepath, "wb") as f:
|
|
for chunk in response.iter_content(self.BUFSIZE):
|
|
if not chunk:
|
|
continue # keep-alive
|
|
|
|
f.write(chunk)
|
|
sz += len(chunk)
|
|
if sz >= progress_sz:
|
|
progress_sz += self.PROGRESS_SIZE
|
|
print('>', end='')
|
|
sys.stdout.flush()
|
|
print('')
|
|
return sz
|
|
|
|
def produceDownloadInstance(instance_name, filename, sha, url, save_dir, download_name=None, download_sha=None, archive_member=None):
|
|
spec_param = url
|
|
loader = URLLoader
|
|
if download_name is None:
|
|
download_name = filename
|
|
if download_sha is None:
|
|
download_sha = sha
|
|
if "drive.google.com" in url:
|
|
token = ""
|
|
token_part = url.rsplit('/', 1)[-1]
|
|
if "&id=" not in token_part:
|
|
token_part = url.rsplit('/', 1)[-2]
|
|
for param in token_part.split("&"):
|
|
if param.startswith("id="):
|
|
token = param[3:]
|
|
if token:
|
|
loader = GDriveLoader
|
|
spec_param = token
|
|
else:
|
|
print("Warning: possibly wrong Google Drive link")
|
|
return DownloadInstance(
|
|
name=instance_name,
|
|
filename=filename,
|
|
sha=sha,
|
|
save_dir=save_dir,
|
|
loader=loader(download_name, download_sha, spec_param, archive_member)
|
|
)
|
|
|
|
def getSaveDir():
|
|
env_path = os.environ.get("OPENCV_DOWNLOAD_DATA_PATH", None)
|
|
if env_path:
|
|
save_dir = env_path
|
|
else:
|
|
# TODO reuse binding function cv2.utils.fs.getCacheDirectory when issue #19011 is fixed
|
|
if platform.system() == "Darwin":
|
|
#On Apple devices
|
|
temp_env = os.environ.get("TMPDIR", None)
|
|
if temp_env is None or not os.path.isdir(temp_env):
|
|
temp_dir = Path("/tmp")
|
|
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
|
|
else:
|
|
temp_dir = temp_env
|
|
elif platform.system() == "Windows":
|
|
temp_dir = tempfile.gettempdir()
|
|
else:
|
|
xdg_cache_env = os.environ.get("XDG_CACHE_HOME", None)
|
|
if (xdg_cache_env and xdg_cache_env[0] and os.path.isdir(xdg_cache_env)):
|
|
temp_dir = xdg_cache_env
|
|
else:
|
|
home_env = os.environ.get("HOME", None)
|
|
if (home_env and home_env[0] and os.path.isdir(home_env)):
|
|
home_path = os.path.join(home_env, ".cache/")
|
|
if os.path.isdir(home_path):
|
|
temp_dir = home_path
|
|
else:
|
|
temp_dir = tempfile.gettempdir()
|
|
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
|
|
|
|
save_dir = os.path.join(temp_dir, "downloads")
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
return save_dir
|
|
|
|
def downloadFile(url, sha=None, save_dir=None, filename=None):
|
|
if save_dir is None:
|
|
save_dir = getSaveDir()
|
|
if filename is None:
|
|
filename = "download_" + datetime.now().__str__()
|
|
name = filename
|
|
return produceDownloadInstance(name, filename, sha, url, save_dir).get()
|
|
|
|
def parseMetalinkFile(metalink_filepath, save_dir):
|
|
NS = {'ml': 'urn:ietf:params:xml:ns:metalink'}
|
|
models = []
|
|
for file_elem in ET.parse(metalink_filepath).getroot().findall('ml:file', NS):
|
|
url = file_elem.find('ml:url', NS).text
|
|
fname = file_elem.attrib['name']
|
|
name = file_elem.find('ml:identity', NS).text
|
|
hash_sum = file_elem.find('ml:hash', NS).text
|
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir))
|
|
return models
|
|
|
|
def parseYAMLFile(yaml_filepath, save_dir, model_name):
|
|
models = []
|
|
with open(yaml_filepath, 'r') as stream:
|
|
data_loaded = yaml.safe_load(stream)
|
|
for name, params in data_loaded.items():
|
|
if model_name != "" and name != model_name:
|
|
continue
|
|
for key in params.keys():
|
|
if key.endswith("load_info"):
|
|
prefix = key[:-len('load_info')]
|
|
load_info = params.get(prefix+"load_info", None)
|
|
if load_info:
|
|
print(prefix)
|
|
if prefix == "config_":
|
|
fname = os.path.basename(params.get("config"))
|
|
hash_sum = load_info.get("sha1")
|
|
url = load_info.get("url")
|
|
download_sha = load_info.get("download_sha")
|
|
download_name = load_info.get("download_name")
|
|
archive_member = load_info.get("member")
|
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
|
|
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
|
|
else:
|
|
fname = os.path.basename(params.get(prefix+"model"))
|
|
hash_sum = load_info.get(prefix+"sha1")
|
|
url = load_info.get(prefix+"url")
|
|
download_sha = load_info.get(prefix+"download_sha")
|
|
download_name = load_info.get(prefix+"download_name")
|
|
archive_member = load_info.get(prefix+"member")
|
|
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
|
|
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
|
|
|
|
return models
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='This is a utility script for downloading DNN models for samples.')
|
|
|
|
parser.add_argument('--save_dir', action="store", default=os.getcwd(),
|
|
help='Path to the directory to store downloaded files')
|
|
parser.add_argument('model_name', type=str, default="", nargs='?', action="store",
|
|
help='name of the model to download')
|
|
args = parser.parse_args()
|
|
models = []
|
|
save_dir = args.save_dir
|
|
selected_model_name = args.model_name
|
|
models.extend(parseMetalinkFile('face_detector/weights.meta4', save_dir))
|
|
models.extend(parseYAMLFile('models.yml', save_dir, selected_model_name))
|
|
for m in models:
|
|
print(m)
|
|
if selected_model_name and not m.name.startswith(selected_model_name):
|
|
continue
|
|
print('Model: ' + selected_model_name)
|
|
m.get()
|