diff --git a/DanbooruTagger/DanbooruTagger.py b/DanbooruTagger/DanbooruTagger.py index 0ebeee1..1c7a217 100644 --- a/DanbooruTagger/DanbooruTagger.py +++ b/DanbooruTagger/DanbooruTagger.py @@ -1,5 +1,5 @@ -import warnings from deepdanbooru_onnx import DeepDanbooru +from wd_onnx import Wd from PIL import Image import argparse import cv2 @@ -13,13 +13,20 @@ image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"] def find_image_files(path: str) -> list[str]: - paths = list() - for root, dirs, files in os.walk(path): - for filename in files: - name, extension = os.path.splitext(filename) - if extension.lower() in image_ext_ocv: - paths.append(os.path.join(root, filename)) - return paths + if os.path.isdir(path): + paths = list() + for root, dirs, files in os.walk(path): + for filename in files: + name, extension = os.path.splitext(filename) + if extension.lower() in image_ext_ocv: + paths.append(os.path.join(root, filename)) + return paths + else: + name, extension = os.path.splitext(path) + if extension.lower() in image_ext_ocv: + return [path] + else: + return [] def image_loader(paths: list[str]): @@ -35,14 +42,28 @@ def image_loader(paths: list[str]): yield image_pil, path -def pipeline(queue: Queue, image_paths: list[str], device: int): - danbooru = DeepDanbooru() +def danbooru_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool): + danbooru = DeepDanbooru("cpu" if cpu else "auto") for path in image_paths: imageprompt = "" tags = danbooru(path) for tag in tags: imageprompt = imageprompt + ", " + tag + imageprompt = imageprompt[2:] + + queue.put({"file_name": path, "text": imageprompt}) + + +def wd_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool): + wd = Wd("cpu" if cpu else "auto", threshold=0.3) + + for path in image_paths: + imageprompt = "" + tags = wd(path) + for tag in tags: + imageprompt = imageprompt + ", " + tag + imageprompt = imageprompt[2:] queue.put({"file_name": path, "text": imageprompt}) @@ -57,49 +78,71 @@ def split_list(input_list, count): def save_meta(meta_file, meta, reldir, common_description): meta["file_name"] = os.path.relpath(meta["file_name"], reldir) if common_description is not None: - meta["text"] = common_description + meta["text"] + meta["text"] = common_description + ", " + meta["text"] meta_file.write(json.dumps(meta) + '\n') if __name__ == "__main__": parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru") - parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference") parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one") - parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag") + parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag or a singular image to tag") + parser.add_argument('--wd', '-w', action="store_true", help="use wd tagger instead of DeepDanbooru") + parser.add_argument('--cpu', action="store_true", help="force cpu usge instead of gpu") args = parser.parse_args() - nparalell = 2 - image_paths = find_image_files(args.image_dir) + + if len(image_paths) == 0: + print("Unable to find any images at {args.image_dir}") + exit(1) + + nparalell = 4 if len(image_paths) > 4 else len(image_paths) image_path_chunks = list(split_list(image_paths, nparalell)) - print(f"Will use {nparalell} processies to create tags") + print(f"Will use {nparalell} processies to create tags for {len(image_paths)} images") queue = Queue() + pipe = danbooru_pipeline if not args.wd else wd_pipeline processies = list() for i in range(0, nparalell): - processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i))) + processies.append(Process(target=pipe, args=(queue, image_path_chunks[i], i, args.cpu))) processies[-1].start() progress = tqdm(desc="Generateing tags", total=len(image_paths)) exit = False - with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file: + + if len(image_paths) > 1: + with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file: + while not exit: + if not queue.empty(): + meta = queue.get() + save_meta(output_file, meta, args.image_dir, args.common_description) + progress.update() + exit = True + for process in processies: + if process.is_alive(): + exit = False + break + + while not queue.empty(): + meta = queue.get() + save_meta(output_file, meta, args.image_dir, args.common_description) + progress.update() + else: while not exit: if not queue.empty(): meta = queue.get() - save_meta(output_file, meta, args.image_dir, args.common_description) + print(meta) progress.update() exit = True for process in processies: if process.is_alive(): exit = False break - while not queue.empty(): meta = queue.get() - save_meta(output_file, meta, args.image_dir, args.common_description) + print(meta) progress.update() for process in processies: process.join() - diff --git a/DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py b/DanbooruTagger/deepdanbooru_onnx.py similarity index 62% rename from DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py rename to DanbooruTagger/deepdanbooru_onnx.py index c108ceb..572d60a 100644 --- a/DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py +++ b/DanbooruTagger/deepdanbooru_onnx.py @@ -2,13 +2,12 @@ import onnxruntime as ort from PIL import Image import numpy as np import os -from tqdm import tqdm -import requests import hashlib from typing import List, Union -import shutil from pathlib import Path +from utils import download + def process_image(image: Image.Image) -> np.ndarray: """ @@ -18,38 +17,9 @@ def process_image(image: Image.Image) -> np.ndarray: """ image = image.convert("RGB").resize((512, 512)) - image = np.array(image).astype(np.float32) / 255 - image = image.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1)) - return image - - -def download(url: str, save_path: str, md5: str, length: str) -> bool: - """ - Download a file from url to save_path. - If the file already exists, check its md5. - If the md5 matches, return True,if the md5 doesn't match, return False. - :param url: the url of the file to download - :param save_path: the path to save the file - :param md5: the md5 of the file - :param length: the length of the file - :return: True if the file is downloaded successfully, False otherwise - """ - - try: - response = requests.get(url=url, stream=True) - with open(save_path, "wb") as f: - with tqdm.wrapattr( - response.raw, "read", total=length, desc="Downloading" - ) as r_raw: - shutil.copyfileobj(r_raw, f) - return ( - True - if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5 - else False - ) - except Exception as e: - print(e) - return False + imagenp = np.array(image).astype(np.float32) / 255 + imagenp = imagenp.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1)) + return imagenp def download_model(): @@ -109,7 +79,7 @@ class DeepDanbooru: ): """ Initialize the DeepDanbooru class. - :param mode: the mode of the model, "cpu" or "gpu" or "auto" + :param mode: the mode of the model, "cpu", "cuda", "hip" or "auto" :param model_path: the path to the model file :param tags_path: the path to the tags file :param threshold: the threshold of the model @@ -119,11 +89,13 @@ class DeepDanbooru: providers = { "cpu": "CPUExecutionProvider", - "gpu": "CUDAExecutionProvider", + "cuda": "CUDAExecutionProvider", + "hip": "ROCMExecutionProvider", "tensorrt": "TensorrtExecutionProvider", "auto": ( "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() + else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider" ), } @@ -166,8 +138,8 @@ class DeepDanbooru: return self.__str__() def from_image_inference(self, image: Image.Image) -> dict: - image = process_image(image) - return self.predict(image) + imagenp = process_image(image) + return self.predict(imagenp) def from_ndarray_inferece(self, image: np.ndarray) -> dict: if image.shape != (1, 512, 512, 3): @@ -177,49 +149,6 @@ class DeepDanbooru: def from_file_inference(self, image: str) -> dict: return self.from_image_inference(Image.open(image)) - def from_list_inference(self, image: Union[list, tuple]) -> List[dict]: - if self.pin_memory: - image = [process_image(Image.open(i)) for i in image] - for i in [ - image[i : i + self.batch_size] - for i in range(0, len(image), self.batch_size) - ]: - imagelist = i - bs = len(i) - _imagelist, idx, hashlist = [], [], [] - for j in range(len(i)): - img = Image.open(i[j]) if not self.pin_memory else imagelist[j] - image_hash = hashlib.md5(np.array(img).astype(np.uint8)).hexdigest() - hashlist.append(image_hash) - if image_hash in self.cache: - continue - if not self.pin_memory: - _imagelist.append(process_image(img)) - else: - _imagelist.append(imagelist[j]) - idx.append(j) - - imagelist = _imagelist - if len(imagelist) != 0: - _image = np.vstack(imagelist) - results = self.inference(_image) - results_idx = 0 - else: - results = [] - - for i in range(bs): - image_tag = {} - if i in idx: - hash = hashlist[i] - for tag, score in zip(self.tags, results[results_idx]): - if score >= self.threshold: - image_tag[tag] = score - results_idx += 1 - self.cache[hash] = image_tag - yield image_tag - else: - yield self.cache[hashlist[i]] - def inference(self, image): return self.session.run(self.output_name, {self.input_name: image})[0] @@ -236,8 +165,6 @@ class DeepDanbooru: return self.from_file_inference(image) elif isinstance(image, np.ndarray): return self.from_ndarray_inferece(image) - elif isinstance(image, list) or isinstance(image, tuple): - return self.from_list_inference(image) elif isinstance(image, Image.Image): return self.from_image_inference(image) else: diff --git a/DanbooruTagger/deepdanbooru_onnx/__init__.py b/DanbooruTagger/deepdanbooru_onnx/__init__.py deleted file mode 100644 index 21d7e94..0000000 --- a/DanbooruTagger/deepdanbooru_onnx/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .deepdanbooru_onnx import DeepDanbooru -from .deepdanbooru_onnx import process_image -__version__ = '0.0.8' \ No newline at end of file diff --git a/DanbooruTagger/deepdanbooru_onnx/__pycache__/__init__.cpython-312.pyc b/DanbooruTagger/deepdanbooru_onnx/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 6edaed3..0000000 Binary files a/DanbooruTagger/deepdanbooru_onnx/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/DanbooruTagger/deepdanbooru_onnx/__pycache__/deepdanbooru_onnx.cpython-312.pyc b/DanbooruTagger/deepdanbooru_onnx/__pycache__/deepdanbooru_onnx.cpython-312.pyc deleted file mode 100644 index e5960a9..0000000 Binary files a/DanbooruTagger/deepdanbooru_onnx/__pycache__/deepdanbooru_onnx.cpython-312.pyc and /dev/null differ diff --git a/DanbooruTagger/utils.py b/DanbooruTagger/utils.py new file mode 100644 index 0000000..d0946f2 --- /dev/null +++ b/DanbooruTagger/utils.py @@ -0,0 +1,33 @@ +import requests +import shutil +import hashlib +from tqdm import tqdm + + +def download(url: str, save_path: str, md5: str, length: str) -> bool: + """ + Download a file from url to save_path. + If the file already exists, check its md5. + If the md5 matches, return True,if the md5 doesn't match, return False. + :param url: the url of the file to download + :param save_path: the path to save the file + :param md5: the md5 of the file + :param length: the length of the file + :return: True if the file is downloaded successfully, False otherwise + """ + + try: + response = requests.get(url=url, stream=True) + with open(save_path, "wb") as f: + with tqdm.wrapattr( + response.raw, "read", total=length, desc="Downloading" + ) as r_raw: + shutil.copyfileobj(r_raw, f) + return ( + True + if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5 + else False + ) + except Exception as e: + print(e) + return False diff --git a/DanbooruTagger/wd_onnx.py b/DanbooruTagger/wd_onnx.py new file mode 100644 index 0000000..3025dde --- /dev/null +++ b/DanbooruTagger/wd_onnx.py @@ -0,0 +1,196 @@ +import onnxruntime as ort +from PIL import Image +import numpy as np +import os +import hashlib +from typing import List, Union +from pathlib import Path +import csv + +from utils import download + + +def process_image(image: Image.Image, target_size: int) -> np.ndarray: + canvas = Image.new("RGBA", image.size, (255, 255, 255)) + canvas.paste(image, mask=image.split()[3] if image.mode == 'RGBA' else None) + image = canvas.convert("RGB") + + # Pad image to a square + max_dim = max(image.size) + pad_left = (max_dim - image.size[0]) // 2 + pad_top = (max_dim - image.size[1]) // 2 + padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) + padded_image.paste(image, (pad_left, pad_top)) + + # Resize + padded_image = padded_image.resize((target_size, target_size), Image.Resampling.BICUBIC) + + # Convert to numpy array + image_array = np.asarray(padded_image, dtype=np.float32)[..., [2, 1, 0]] + + return np.expand_dims(image_array, axis=0) + + +def download_model(): + """ + Download the model and tags file from the server. + :return: the path to the model and tags file + """ + + model_url = ( + "https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/model.onnx" + ) + tags_url = "https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/selected_tags.csv" + model_md5 = "1fc4f456261c457a08d4b9e3379cac39" + tags_md5 = "55c11e40f95e63ea9ac21a065a73fd0f" + model_length = 378536310 + tags_length = 308468 + + home = str(Path.home()) + "/.wd_onnx/" + if not os.path.exists(home): + os.mkdir(home) + + model_name = "wd.onnx" + tags_name = "selected_tags.csv" + + model_path = home + model_name + tags_path = home + tags_name + if os.path.exists(model_path): + if hashlib.md5(open(model_path, "rb").read()).hexdigest() != model_md5: + os.remove(model_path) + if not download(model_url, model_path, model_md5, model_length): + raise ValueError("Model download failed") + + else: + if not download(model_url, model_path, model_md5, model_length): + raise ValueError("Model download failed") + + if os.path.exists(tags_path): + if hashlib.md5(open(tags_path, "rb").read()).hexdigest() != tags_md5: + os.remove(tags_path) + if not download(tags_url, tags_path, tags_md5, tags_length): + raise ValueError("Tags download failed") + else: + if not download(tags_url, tags_path, tags_md5, tags_length): + raise ValueError("Tags download failed") + return model_path, tags_path + + +class Wd: + def __init__( + self, + mode: str = "auto", + model_path: Union[str, None] = None, + tags_path: Union[str, None] = None, + threshold: Union[float, int] = 0.6, + pin_memory: bool = False, + batch_size: int = 1, + ): + """ + Initialize the DeepDanbooru class. + :param mode: the mode of the model, "cpu", "cuda", "hip" or "auto" + :param model_path: the path to the model file + :param tags_path: the path to the tags file + :param threshold: the threshold of the model + :param pin_memory: whether to use pin memory + :param batch_size: the batch size of the model + """ + + providers = { + "cpu": "CPUExecutionProvider", + "cuda": "CUDAExecutionProvider", + "hip": "ROCMExecutionProvider", + "tensorrt": "TensorrtExecutionProvider", + "auto": ( + "CUDAExecutionProvider" + if "CUDAExecutionProvider" in ort.get_available_providers() + else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers() + else "CPUExecutionProvider" + ), + } + + if not (isinstance(threshold, float) or isinstance(threshold, int)): + raise TypeError("threshold must be float or int") + if threshold < 0 or threshold > 1: + raise ValueError("threshold must be between 0 and 1") + if mode not in providers: + raise ValueError( + "Mode not supported. Please choose from: cpu, gpu, tensorrt" + ) + if providers[mode] not in ort.get_available_providers(): + raise ValueError( + f"Your device is not supported {mode}. Please choose from: cpu" + ) + if model_path is not None and not os.path.exists(model_path): + raise FileNotFoundError("Model file not found") + if tags_path is not None and not os.path.exists(tags_path): + raise FileNotFoundError("Tags file not found") + + if model_path is None or tags_path is None: + model_path, tags_path = download_model() + + self.session = ort.InferenceSession(model_path, providers=[providers[mode]]) + self.tags = list() + with open(tags_path, "r") as tagfile: + reader = csv.DictReader(tagfile) + for row in reader: + self.tags.append(row) + + self.input_name = self.session.get_inputs()[0].name + self.output_name = [output.name for output in self.session.get_outputs()] + self.threshold = threshold + self.pin_memory = pin_memory + self.batch_size = batch_size + self.target_size = self.session.get_inputs()[0].shape[2] + self.mode = mode + self.cache = {} + + def __str__(self) -> str: + return f"Wd(mode={self.mode}, threshold={self.threshold}, pin_memory={self.pin_memory}, batch_size={self.batch_size})" + + def __repr__(self) -> str: + return self.__str__() + + def from_image_inference(self, image: Image.Image) -> dict: + imagenp = process_image(image, self.target_size) + return self.predict(imagenp) + + def from_ndarray_inferece(self, image: np.ndarray) -> dict: + if image.shape != (1, 512, 512, 3): + raise ValueError(f"Image must be {(1, 512, 512, 3)}") + return self.predict(image) + + def from_file_inference(self, path: str) -> dict: + image = Image.open(path) + return self.from_image_inference(Image.open(path)) + + def inference(self, image): + return self.session.run(self.output_name, {self.input_name: image})[0] + + def predict(self, image): + result = self.inference(image) + tags = self.tags + for tag, score in zip(tags, result[0]): + tag['score'] = float(score) + ratings = tags[:4] + tags = tags[4:] + + image_tags = {} + + rating = max(ratings, key=lambda el: el['score']) + image_tags[rating['name']] = rating['score'] + + for tag in tags: + if tag['score'] >= self.threshold: + image_tags[tag['name']] = tag['score'] + return image_tags + + def __call__(self, image) -> Union[dict, List[dict]]: + if isinstance(image, str): + return self.from_file_inference(image) + elif isinstance(image, np.ndarray): + return self.from_ndarray_inferece(image) + elif isinstance(image, Image.Image): + return self.from_image_inference(image) + else: + raise ValueError("Image must be a file path or a numpy array or list/tuple")