From 422debd897093ddc3a06de0e517f74a1e2d95a0f Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 5 Sep 2024 18:45:34 +0200 Subject: [PATCH] DanbooruTagger: add wd tagger support --- DanbooruTagger/DanbooruTagger.py | 87 ++++++-- .../deepdanbooru_onnx.py | 95 +-------- DanbooruTagger/deepdanbooru_onnx/__init__.py | 3 - .../__pycache__/__init__.cpython-312.pyc | Bin 306 -> 0 bytes .../deepdanbooru_onnx.cpython-312.pyc | Bin 13279 -> 0 bytes DanbooruTagger/utils.py | 33 +++ DanbooruTagger/wd_onnx.py | 196 ++++++++++++++++++ 7 files changed, 305 insertions(+), 109 deletions(-) rename DanbooruTagger/{deepdanbooru_onnx => }/deepdanbooru_onnx.py (62%) delete mode 100644 DanbooruTagger/deepdanbooru_onnx/__init__.py delete mode 100644 DanbooruTagger/deepdanbooru_onnx/__pycache__/__init__.cpython-312.pyc delete mode 100644 DanbooruTagger/deepdanbooru_onnx/__pycache__/deepdanbooru_onnx.cpython-312.pyc create mode 100644 DanbooruTagger/utils.py create mode 100644 DanbooruTagger/wd_onnx.py diff --git a/DanbooruTagger/DanbooruTagger.py b/DanbooruTagger/DanbooruTagger.py index 0ebeee1..1c7a217 100644 --- a/DanbooruTagger/DanbooruTagger.py +++ b/DanbooruTagger/DanbooruTagger.py @@ -1,5 +1,5 @@ -import warnings from deepdanbooru_onnx import DeepDanbooru +from wd_onnx import Wd from PIL import Image import argparse import cv2 @@ -13,13 +13,20 @@ image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"] def find_image_files(path: str) -> list[str]: - paths = list() - for root, dirs, files in os.walk(path): - for filename in files: - name, extension = os.path.splitext(filename) - if extension.lower() in image_ext_ocv: - paths.append(os.path.join(root, filename)) - return paths + if os.path.isdir(path): + paths = list() + for root, dirs, files in os.walk(path): + for filename in files: + name, extension = os.path.splitext(filename) + if extension.lower() in image_ext_ocv: + paths.append(os.path.join(root, filename)) + return paths + else: + name, extension = os.path.splitext(path) + if extension.lower() in image_ext_ocv: + return [path] + else: + return [] def image_loader(paths: list[str]): @@ -35,14 +42,28 @@ def image_loader(paths: list[str]): yield image_pil, path -def pipeline(queue: Queue, image_paths: list[str], device: int): - danbooru = DeepDanbooru() +def danbooru_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool): + danbooru = DeepDanbooru("cpu" if cpu else "auto") for path in image_paths: imageprompt = "" tags = danbooru(path) for tag in tags: imageprompt = imageprompt + ", " + tag + imageprompt = imageprompt[2:] + + queue.put({"file_name": path, "text": imageprompt}) + + +def wd_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool): + wd = Wd("cpu" if cpu else "auto", threshold=0.3) + + for path in image_paths: + imageprompt = "" + tags = wd(path) + for tag in tags: + imageprompt = imageprompt + ", " + tag + imageprompt = imageprompt[2:] queue.put({"file_name": path, "text": imageprompt}) @@ -57,49 +78,71 @@ def split_list(input_list, count): def save_meta(meta_file, meta, reldir, common_description): meta["file_name"] = os.path.relpath(meta["file_name"], reldir) if common_description is not None: - meta["text"] = common_description + meta["text"] + meta["text"] = common_description + ", " + meta["text"] meta_file.write(json.dumps(meta) + '\n') if __name__ == "__main__": parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru") - parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference") parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one") - parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag") + parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag or a singular image to tag") + parser.add_argument('--wd', '-w', action="store_true", help="use wd tagger instead of DeepDanbooru") + parser.add_argument('--cpu', action="store_true", help="force cpu usge instead of gpu") args = parser.parse_args() - nparalell = 2 - image_paths = find_image_files(args.image_dir) + + if len(image_paths) == 0: + print("Unable to find any images at {args.image_dir}") + exit(1) + + nparalell = 4 if len(image_paths) > 4 else len(image_paths) image_path_chunks = list(split_list(image_paths, nparalell)) - print(f"Will use {nparalell} processies to create tags") + print(f"Will use {nparalell} processies to create tags for {len(image_paths)} images") queue = Queue() + pipe = danbooru_pipeline if not args.wd else wd_pipeline processies = list() for i in range(0, nparalell): - processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i))) + processies.append(Process(target=pipe, args=(queue, image_path_chunks[i], i, args.cpu))) processies[-1].start() progress = tqdm(desc="Generateing tags", total=len(image_paths)) exit = False - with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file: + + if len(image_paths) > 1: + with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file: + while not exit: + if not queue.empty(): + meta = queue.get() + save_meta(output_file, meta, args.image_dir, args.common_description) + progress.update() + exit = True + for process in processies: + if process.is_alive(): + exit = False + break + + while not queue.empty(): + meta = queue.get() + save_meta(output_file, meta, args.image_dir, args.common_description) + progress.update() + else: while not exit: if not queue.empty(): meta = queue.get() - save_meta(output_file, meta, args.image_dir, args.common_description) + print(meta) progress.update() exit = True for process in processies: if process.is_alive(): exit = False break - while not queue.empty(): meta = queue.get() - save_meta(output_file, meta, args.image_dir, args.common_description) + print(meta) progress.update() for process in processies: process.join() - diff --git a/DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py b/DanbooruTagger/deepdanbooru_onnx.py similarity index 62% rename from DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py rename to DanbooruTagger/deepdanbooru_onnx.py index c108ceb..572d60a 100644 --- a/DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py +++ b/DanbooruTagger/deepdanbooru_onnx.py @@ -2,13 +2,12 @@ import onnxruntime as ort from PIL import Image import numpy as np import os -from tqdm import tqdm -import requests import hashlib from typing import List, Union -import shutil from pathlib import Path +from utils import download + def process_image(image: Image.Image) -> np.ndarray: """ @@ -18,38 +17,9 @@ def process_image(image: Image.Image) -> np.ndarray: """ image = image.convert("RGB").resize((512, 512)) - image = np.array(image).astype(np.float32) / 255 - image = image.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1)) - return image - - -def download(url: str, save_path: str, md5: str, length: str) -> bool: - """ - Download a file from url to save_path. - If the file already exists, check its md5. - If the md5 matches, return True,if the md5 doesn't match, return False. - :param url: the url of the file to download - :param save_path: the path to save the file - :param md5: the md5 of the file - :param length: the length of the file - :return: True if the file is downloaded successfully, False otherwise - """ - - try: - response = requests.get(url=url, stream=True) - with open(save_path, "wb") as f: - with tqdm.wrapattr( - response.raw, "read", total=length, desc="Downloading" - ) as r_raw: - shutil.copyfileobj(r_raw, f) - return ( - True - if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5 - else False - ) - except Exception as e: - print(e) - return False + imagenp = np.array(image).astype(np.float32) / 255 + imagenp = imagenp.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1)) + return imagenp def download_model(): @@ -109,7 +79,7 @@ class DeepDanbooru: ): """ Initialize the DeepDanbooru class. - :param mode: the mode of the model, "cpu" or "gpu" or "auto" + :param mode: the mode of the model, "cpu", "cuda", "hip" or "auto" :param model_path: the path to the model file :param tags_path: the path to the tags file :param threshold: the threshold of the model @@ -119,11 +89,13 @@ class DeepDanbooru: providers = { "cpu": "CPUExecutionProvider", - "gpu": "CUDAExecutionProvider", + "cuda": "CUDAExecutionProvider", + "hip": "ROCMExecutionProvider", "tensorrt": "TensorrtExecutionProvider", "auto": ( "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() + else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider" ), } @@ -166,8 +138,8 @@ class DeepDanbooru: return self.__str__() def from_image_inference(self, image: Image.Image) -> dict: - image = process_image(image) - return self.predict(image) + imagenp = process_image(image) + return self.predict(imagenp) def from_ndarray_inferece(self, image: np.ndarray) -> dict: if image.shape != (1, 512, 512, 3): @@ -177,49 +149,6 @@ class DeepDanbooru: def from_file_inference(self, image: str) -> dict: return self.from_image_inference(Image.open(image)) - def from_list_inference(self, image: Union[list, tuple]) -> List[dict]: - if self.pin_memory: - image = [process_image(Image.open(i)) for i in image] - for i in [ - image[i : i + self.batch_size] - for i in range(0, len(image), self.batch_size) - ]: - imagelist = i - bs = len(i) - _imagelist, idx, hashlist = [], [], [] - for j in range(len(i)): - img = Image.open(i[j]) if not self.pin_memory else imagelist[j] - image_hash = hashlib.md5(np.array(img).astype(np.uint8)).hexdigest() - hashlist.append(image_hash) - if image_hash in self.cache: - continue - if not self.pin_memory: - _imagelist.append(process_image(img)) - else: - _imagelist.append(imagelist[j]) - idx.append(j) - - imagelist = _imagelist - if len(imagelist) != 0: - _image = np.vstack(imagelist) - results = self.inference(_image) - results_idx = 0 - else: - results = [] - - for i in range(bs): - image_tag = {} - if i in idx: - hash = hashlist[i] - for tag, score in zip(self.tags, results[results_idx]): - if score >= self.threshold: - image_tag[tag] = score - results_idx += 1 - self.cache[hash] = image_tag - yield image_tag - else: - yield self.cache[hashlist[i]] - def inference(self, image): return self.session.run(self.output_name, {self.input_name: image})[0] @@ -236,8 +165,6 @@ class DeepDanbooru: return self.from_file_inference(image) elif isinstance(image, np.ndarray): return self.from_ndarray_inferece(image) - elif isinstance(image, list) or isinstance(image, tuple): - return self.from_list_inference(image) elif isinstance(image, Image.Image): return self.from_image_inference(image) else: diff --git a/DanbooruTagger/deepdanbooru_onnx/__init__.py b/DanbooruTagger/deepdanbooru_onnx/__init__.py deleted file mode 100644 index 21d7e94..0000000 --- a/DanbooruTagger/deepdanbooru_onnx/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .deepdanbooru_onnx import DeepDanbooru -from .deepdanbooru_onnx import process_image -__version__ = '0.0.8' \ No newline at end of file diff --git a/DanbooruTagger/deepdanbooru_onnx/__pycache__/__init__.cpython-312.pyc b/DanbooruTagger/deepdanbooru_onnx/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 6edaed3d17922c1ad9e4b14ddada348199ba33c9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 306 zcmX@j%ge<81Ocy;(h7j|V-N=hn4pZ$3P8qmh7^Vr#vF!R#wbQc5St0eW{P40vYDfp zD_JyIUorxfX)@m8aY;=ra7oNd%Fiz<1@d_dit>|Fi;Lqka}(23t5^;64D>AgG+Ay5 zrT`VDKo!R4=jBxtF$0Y#VgVAjxZ~r?Qj3Z+^Yh~4S2BDC3Nic&)6dAyP1P^R$jr$s zD9{fm%11$UnKK;The(N}PRpSVAM_zvvgAYZ5#N@bC{83dw$iMY;*2EH6v@ntEQ_WR zC7VU%#)ajkyQZ52R<`MeTD!2?El?LIP`fD>0rn4Rsu3|!QDd7Qn_?GG>|nD=f9!tW z8_tlD=w9dueDjX)eeZkUd+)p6=-+KNGX>A@{(E5X(*}zAEf$oAv_Rp}N03~jSc;`X zR38m*O-M7s^f9ztr|r{`HhrHS-b~0aV(2sE+Khch(q`&2!CMlEv_sgZm4m5B3B<)iHL6mf- z!+}T`t55mFVM!N2C`r$A;wT^X(+VN*MzR_BJ#soKhfzCIv447PU$~s4SvV0dJ)z%fu`cHG!1c3Yv3SR;$1T?L}I3te!Qn zhMT#OiefM-7;uap%4>>6%^r9TMZ)JfUiA3F9>PwK81eW#;n9)kxW~uyzVRjk))Dpb zz7ZLu!y^uJd1$|kK?al8LWU?nLK-FJxnl>tnolWMblwW{`1JoE@iL#{Cg3so6xV~t zDSnD8LV&zl(!`fwg+pog8YM$+VB7#gAjV0Wa8%L}nkBtY5XYmOWEc!ZeB$<2$t3c= zun>(1PzT_{z9?)j-U#Rf9MkhvfAer;glmot2SR~pwD}Yt8G;Fq1j0kjJ%`D5JjHWS zJ|YOshkfCJNQ58l^$iVid^5{&Q8riA9|?yqtgdN_j!QP^?B@ibpRk#)hEZ{)g%L<5 zsihLiR(9p-OHU`(WvreB>xQ&-L(-hFZl64|Xf3&N{L=C1({DVpV69JE>)&gAzvJDG zjJ0XbnYOk~9?4ouuRL?$%gAvF=;29#313f4C-N{Te`8 z9FtEyf&nRs8RbKgN${QL`lEQ5m=U&J(ucV4kT~pDw=6&!;75+qN5H6Sl!+2^hlZkp zxzmdXl_w68>6S)0*Av?O(TUS6qtv%k6FLzwlP8Glu;PhpiyaqM8)sPEExoeU0gql( z4@?mE$ciVfkL#3CL8iemfrFBxdv#eXRk{bulwAJ3Y%WFxOVFK%P;xhOnVUIY1uK-6 z{9A#EYtC)N{m;@p6l?0C8mJc7lw%s2I!n>;ce&>RbCw#TSJgkS)Qgo0)<%kVPZ-1+ zr5raX6d*R%FkyJvFm96HWAqp$WAK`N-vdI)Aw3)!3&WXVK?n>6LY!xikBoR=ub^B| zcZ)0&x(10LAYi@_&-vJK4|f3swXntGALjh$JONSgz`j}r4rR}XPXqt}lf{LnmmlS} z1Y~sB&PF&Pyg`)VRH&oAkiabyM?jY>YLLK40XaY{o0|xkRZe`V6Oeo1##iR0PNfhY zXpte2r~>?jF!GU+!Q}jENJT2*JUzL261D{dl}?WJ2%~zz~V}JK?H`6NSd($NsFwPEXvA($O4V26CqP!I5Apu!gS)y|cO(42PO**;}lw3SYK z6TKN*{Su`$ZJqHgmb$OiUaieKDz90uS`&lGqjzdQXEbFcOB7?<`k=IY(NVtOs82iU zXWEl>^NuZF7@+R9VH60dilM023MJD}s+u^!v|Ul-2!^(g5v7+m!FJanbJj z{gMvKUkOjc#9nJSWTw8~cyJ$c&s1@+Rd>&0#Qeqs2B`nVHX74b-N8=NFLoNB;TQXK zhYW@f7;OE(Y&x`&`M_y8)WCeupo6?05UP)A4ZX||(s6T~0Yf3T+l4}|Se1FbMgzCagPNYg75FoIdTp_g8D*Rzt%=?liLhJ< zOmNmC`i6u8#fyfv!0}+A=k%VUNw7gtUzjg~feaHnF)WHvp`*EZcywq8RKuXp&o%iY z&Hmv)I1&mivsIfxG)F?`x#kgHAiS)l3GLe0fxj3RPOC}0AWEK=9Rpk&*S3ut@cXw7 zw6HC^TiOQ(2Di5iZriqd$L>MNux z^MyvaBRn7By%xztmaZSoaY7}@y_DnrUfhGfAh^~(i*55X(9hkZbziL{ce zS*|AZ%GH<^P#oh%%ssNsEm2s~Pdf3amyl**3S>R;QcS05Px$xkdLm&GgNm&V_?n6cEP*0!ZBZIg!=Ew(E=FYUat@6x_Rp`hIVNRFUW6?D|o_2U9yA}U?elwIk()cGATaU!YBRBT|tduOJ6SIV(#$w*ntr^izkPfF|gT!(}hXS~TXa~J3BhX^6tGi7ZlYg^Xh zR={{P@0Tff#RGcefMvQyaU?R-r!j%Iru?xkR>LtY!)YO<2Q_G7vD#%N9h7vOKG&w_ zz*B;f6{0#ih#Z+XW>CYLA&O(lmmLxjf?EC*yO9)PhzLl$tmfXGNL;rvuvEY$-&N%GokVE1;j2v#JoJ2hi#SxJrmC zxh3mi@X-~wy<5`a=H^cVHvTC{;IIs1Odt|sC37?o?jM1}%8yIt0d!OPF}C1WeN}+R zJQ_%k79oXL9L;K2hSlWFr(%AwVl}%5M8Ae zk23OM>Wa~?8KV4}7hOPuF>VajtpJ50TgsLx<6YHUF4qD)z*7Bcgabm2=v1wEg6mdP zKv5Z))#F|Wu3rrSc!3uAc?{moN_+mT;Ds5Pax;RxF=2|Eij4uYaFAJ;U^G?Y#I;aT zS0}E6lDZDy7Et|2GI?O9Q~m0U0}ux7SK}n(RzJjBkV9jTlALLb@;ZI5gFrPaIUaf6LL^GJJZzD-`_MyQ*^3wesOeXC!X{#5uIi)6LAj4A7MBsevQu7Sc6i2yIiQNi zpvtHK7HS~uk%j5XPI=KP4@Jrz6t%_2l3a&Qoj!7b^Fvr75{BsP`2dJ;$$jYb;RCCy zt9rSx5aD@oRg+gMF~D|+e?Vdsb6}Dd86R7lpYX`2AbJKM{DtvjTq6i*$Z=?N5o2|$ zAqO~djN`(dZN!dmiFG`UD;|!B5Y~xCBfQA5O`cOB&If$)4@V#pO=4>u9w29n2gupt zQOJqao{fz19+o>F@I$;vSdPULt9mS=?wAXQ4aqJ94n7ze4YM(4ud?I>7pa7M6cfax z2_prG=D(uN66IYZF+i(S0r8=J-+8c$eFGr~5vy?KfDi}^qA%>{Bpu<7!~kz3Q!m8Z ziA~EN1P<|gG1-U7en_O!qp(H0BjQmcM+QUNQgU>K2RWVtY&{V9!r+lXkTX1*(2@yb zTcJRh6EJQc4T1LpmqIWZ$&3^R!qHKPW@%+Zmz4TR-h!=>QIRxTmx5c~UM?m{Wiq*Qaz zNBB^-sW(!tO0I`4^hH(oJ2Nrju#s$D zli2+p2l1*OoA2!YrGI|giN&&-)S5k+vb`zm-mJwn{n9u0Cc2XJcb_KqyGGVHz9S@mBYA@1_0kH7wNcP1oN9DNEh0`r8|3H{5QTZMsuEzv=K|SwpI^BU9FyvUUOyV9?$wyInC`aeM9T+B>Fs z??D2vJ5$z?vUWVMxKozenW|J>>tbVTrm-z;t(!WLFnv-{y-=|^U9mZN;_mJbd++T} z9qgH}=$+DDwl22rm>a*_kZC=-(0VG}dMeZU+()jSg{s}($zZ_syoxwozQI0?&$bQ=!c<&9o^|2-I*PyKB{>( zc{FkK#>ux%-miIfN;_|?%vRL{WJ$KVk#xwtT$PAD>#ms6e#`vj!&1uKOMm%LPpx^5 z7BFV@%RN=c8nyqn&Up+jt?(*v*_|X-5VJ7~88jO~Q9^Fz116IOT5?wug{{!SKueK` z9JGKMEOs{uExLE$Qo*b5j+GT|#f_+h_QtFrbrhk$H)h*{B0FdQl44#?@BNDQ!nE+b ztT+*~G|xk=e32&mk%s<$@R?!zy9&pI?N4g|388x-nWVC%4Ks(|?|QdurZXA3)AsXy zKi!w=IPu}u)S0g@oEb`=8A?4roEjcYosWG%!5zp^&1dp!iK5Qa$29Odra6u;M;E51 zE{*IR!O>D5WF=vIJ!Bqt8LZeMJ=81>SJ3=8!15O$DV&zEzaK9FiyUjY`uhlv<}FOj z?z|Ge6o2F8l(CkOL@pcsh0BUG5xP1cbFIL}j;j~b6SzqbX2p|-=%M(F&r=*hne~NJ#Jx(U_&1ON*3-pyJWG?vT}eLb3=S2z4)3UDKR6;(P<505RqiQd zgO}-s?TOl%TpP=Y1B$T{a((MjRJxaubmRh9R#Y!zk5?cO3fbf)JP5PoIyuKEKNqWC z&8L+xw-8k;B(Mi84h&SJsyE%YYyydzcsbRyFST`ls#X!o@+Nsa|8UNPe%o8vUN44RzF;6X<$&x5?f zTudJ3T}g!-;8xUYsRP<0BG85Q1S6uT4`)ob^l>c-{!Dzqr?pz6RCz~r(M=C~oK z-iGAc;{`VW5cQ{u1*RZKLtnn}7{1lSb;a=JcVPamMxm3Y!mi@I6rScG@+Y5HJO;QV-6%L!yKF zh?QcdAN$I<`@tiFyAcFuFsl>%5rD$8ocD~BoNOV=ss)Mo!>VOD#E$`^Pz`(KKdJvk z{c69m|C{@>olkz`Ixy$^&WrE08%}azW9(b+tyCg*{0Upk$;JNhr1TM z8k{MeX?fQ*U%no^nP>0qnPwJi8*jXD{RME1DrYO_8gAAuc=x8gdw(|Y^WmQkXS~NU zwZ|7~pG((1m#OWYHsiMo_N(?8^Sono)>%8_dcXSJ>a=rn*0m;EQFHB;Z@-dttzB@f zOS{&+H!^o9vwqjFT)Q7ARdC-xzZ;r&ZO=NK*UVSViI#t`ExCcnrAn%*0cK<`o8JA# z%QL0-?e&XQH8-lRSIz9sRBfHsXTg7Jy55v5&s1-lHe}a6anrjsc_H22m05oRK6!9&ziv)2 ziPPB{4_N!%?{;TuTC-bQ;XcOy-EQ~5=efZjmt0xj7l(W?h zZ$5(_-KOiCX3ivi@4kReovY1alkcX2>&fi@)A&=mw@EsP>@#qbp&|} zleZxeaCu~bM8pw~Ucd@xHRnnF2ZHxP0$hbpiq>9TJEO~#uAkCpUFB1z#p>Dwdn0f? zkSw30XR9FC0NdTP;nInzL(}f8-IdVZD7jvev9Hh8t)Dg}8opfu+ugJQw!5P$WvpDy ze6W3Gi+>pO>vE@1%;`w5$#OX!^&>9ViYLD~@U4RC3diaE1yivvbe3UNZO}aG$6{5j zVg-ZXST?G9;iI$@x=ZtzmE2=T$XkVKPU*9P8|w8?zI_>1^{OEM_4jMESJNHawcJb0 zeg236bPWG4`y?7GP%kx$qtOuOwekZn7~hG>-(!NhTr%VwL;mluq$TQBM>_EDAP|Ge zBqj$jS$WbWh&A|qp-_=ngIo=!Jte#W2{`0%H_W}=GjFNOIx7~On?Xe+k7S&!DQoMZ ztsYW&f7o_ z0g5eY=a%H@jB|U+y8RQpEq^{!b|z&#L$Mb)q35CcL1omJC>Bq+C@AXM#;srJW5+Caf}Jm3xLs zfCg_|vS5{!a#t0E-Id*)n-V9%1?t7}cVlj7Xz}TkQdcCe{$zj%=rDxnTVzPIxakekDVQ0E-*XI=0$_zs` zc`GCfei;NxOAL%~{NF;I#GLASN?sft`40ec9OU}|EwuA!&yz1@Bt3ki51%B-YeSYC zwHL&&;LzcBrg+iLlNBX)G&x{5v2G_OPh&z>><5^u!2~tTSH?Z^*4oGaD-==u3FvSy zF*HqoOjZ4wa(_&feN351#{Mx?@*8UH$CT~Y6fC8gE={;^)LgGwq9B_r!>Gk0%q~?| t)9&epB??|M@+*0Wl bool: + """ + Download a file from url to save_path. + If the file already exists, check its md5. + If the md5 matches, return True,if the md5 doesn't match, return False. + :param url: the url of the file to download + :param save_path: the path to save the file + :param md5: the md5 of the file + :param length: the length of the file + :return: True if the file is downloaded successfully, False otherwise + """ + + try: + response = requests.get(url=url, stream=True) + with open(save_path, "wb") as f: + with tqdm.wrapattr( + response.raw, "read", total=length, desc="Downloading" + ) as r_raw: + shutil.copyfileobj(r_raw, f) + return ( + True + if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5 + else False + ) + except Exception as e: + print(e) + return False diff --git a/DanbooruTagger/wd_onnx.py b/DanbooruTagger/wd_onnx.py new file mode 100644 index 0000000..3025dde --- /dev/null +++ b/DanbooruTagger/wd_onnx.py @@ -0,0 +1,196 @@ +import onnxruntime as ort +from PIL import Image +import numpy as np +import os +import hashlib +from typing import List, Union +from pathlib import Path +import csv + +from utils import download + + +def process_image(image: Image.Image, target_size: int) -> np.ndarray: + canvas = Image.new("RGBA", image.size, (255, 255, 255)) + canvas.paste(image, mask=image.split()[3] if image.mode == 'RGBA' else None) + image = canvas.convert("RGB") + + # Pad image to a square + max_dim = max(image.size) + pad_left = (max_dim - image.size[0]) // 2 + pad_top = (max_dim - image.size[1]) // 2 + padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) + padded_image.paste(image, (pad_left, pad_top)) + + # Resize + padded_image = padded_image.resize((target_size, target_size), Image.Resampling.BICUBIC) + + # Convert to numpy array + image_array = np.asarray(padded_image, dtype=np.float32)[..., [2, 1, 0]] + + return np.expand_dims(image_array, axis=0) + + +def download_model(): + """ + Download the model and tags file from the server. + :return: the path to the model and tags file + """ + + model_url = ( + "https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/model.onnx" + ) + tags_url = "https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/selected_tags.csv" + model_md5 = "1fc4f456261c457a08d4b9e3379cac39" + tags_md5 = "55c11e40f95e63ea9ac21a065a73fd0f" + model_length = 378536310 + tags_length = 308468 + + home = str(Path.home()) + "/.wd_onnx/" + if not os.path.exists(home): + os.mkdir(home) + + model_name = "wd.onnx" + tags_name = "selected_tags.csv" + + model_path = home + model_name + tags_path = home + tags_name + if os.path.exists(model_path): + if hashlib.md5(open(model_path, "rb").read()).hexdigest() != model_md5: + os.remove(model_path) + if not download(model_url, model_path, model_md5, model_length): + raise ValueError("Model download failed") + + else: + if not download(model_url, model_path, model_md5, model_length): + raise ValueError("Model download failed") + + if os.path.exists(tags_path): + if hashlib.md5(open(tags_path, "rb").read()).hexdigest() != tags_md5: + os.remove(tags_path) + if not download(tags_url, tags_path, tags_md5, tags_length): + raise ValueError("Tags download failed") + else: + if not download(tags_url, tags_path, tags_md5, tags_length): + raise ValueError("Tags download failed") + return model_path, tags_path + + +class Wd: + def __init__( + self, + mode: str = "auto", + model_path: Union[str, None] = None, + tags_path: Union[str, None] = None, + threshold: Union[float, int] = 0.6, + pin_memory: bool = False, + batch_size: int = 1, + ): + """ + Initialize the DeepDanbooru class. + :param mode: the mode of the model, "cpu", "cuda", "hip" or "auto" + :param model_path: the path to the model file + :param tags_path: the path to the tags file + :param threshold: the threshold of the model + :param pin_memory: whether to use pin memory + :param batch_size: the batch size of the model + """ + + providers = { + "cpu": "CPUExecutionProvider", + "cuda": "CUDAExecutionProvider", + "hip": "ROCMExecutionProvider", + "tensorrt": "TensorrtExecutionProvider", + "auto": ( + "CUDAExecutionProvider" + if "CUDAExecutionProvider" in ort.get_available_providers() + else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers() + else "CPUExecutionProvider" + ), + } + + if not (isinstance(threshold, float) or isinstance(threshold, int)): + raise TypeError("threshold must be float or int") + if threshold < 0 or threshold > 1: + raise ValueError("threshold must be between 0 and 1") + if mode not in providers: + raise ValueError( + "Mode not supported. Please choose from: cpu, gpu, tensorrt" + ) + if providers[mode] not in ort.get_available_providers(): + raise ValueError( + f"Your device is not supported {mode}. Please choose from: cpu" + ) + if model_path is not None and not os.path.exists(model_path): + raise FileNotFoundError("Model file not found") + if tags_path is not None and not os.path.exists(tags_path): + raise FileNotFoundError("Tags file not found") + + if model_path is None or tags_path is None: + model_path, tags_path = download_model() + + self.session = ort.InferenceSession(model_path, providers=[providers[mode]]) + self.tags = list() + with open(tags_path, "r") as tagfile: + reader = csv.DictReader(tagfile) + for row in reader: + self.tags.append(row) + + self.input_name = self.session.get_inputs()[0].name + self.output_name = [output.name for output in self.session.get_outputs()] + self.threshold = threshold + self.pin_memory = pin_memory + self.batch_size = batch_size + self.target_size = self.session.get_inputs()[0].shape[2] + self.mode = mode + self.cache = {} + + def __str__(self) -> str: + return f"Wd(mode={self.mode}, threshold={self.threshold}, pin_memory={self.pin_memory}, batch_size={self.batch_size})" + + def __repr__(self) -> str: + return self.__str__() + + def from_image_inference(self, image: Image.Image) -> dict: + imagenp = process_image(image, self.target_size) + return self.predict(imagenp) + + def from_ndarray_inferece(self, image: np.ndarray) -> dict: + if image.shape != (1, 512, 512, 3): + raise ValueError(f"Image must be {(1, 512, 512, 3)}") + return self.predict(image) + + def from_file_inference(self, path: str) -> dict: + image = Image.open(path) + return self.from_image_inference(Image.open(path)) + + def inference(self, image): + return self.session.run(self.output_name, {self.input_name: image})[0] + + def predict(self, image): + result = self.inference(image) + tags = self.tags + for tag, score in zip(tags, result[0]): + tag['score'] = float(score) + ratings = tags[:4] + tags = tags[4:] + + image_tags = {} + + rating = max(ratings, key=lambda el: el['score']) + image_tags[rating['name']] = rating['score'] + + for tag in tags: + if tag['score'] >= self.threshold: + image_tags[tag['name']] = tag['score'] + return image_tags + + def __call__(self, image) -> Union[dict, List[dict]]: + if isinstance(image, str): + return self.from_file_inference(image) + elif isinstance(image, np.ndarray): + return self.from_ndarray_inferece(image) + elif isinstance(image, Image.Image): + return self.from_image_inference(image) + else: + raise ValueError("Image must be a file path or a numpy array or list/tuple")