diff --git a/DanbooruTagger/DanbooruTagger.py b/DanbooruTagger/DanbooruTagger.py new file mode 100644 index 0000000..4c138d4 --- /dev/null +++ b/DanbooruTagger/DanbooruTagger.py @@ -0,0 +1,141 @@ +import warnings +from deepdanbooru_onnx import DeepDanbooru +import argparse +import cv2 +import torch +import os +import numpy +from typing import Iterator +from torch.multiprocessing import Process, Queue +import json +from tqdm import tqdm + + +image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"] + + +def find_image_files(path: str) -> list[str]: + paths = list() + for root, dirs, files in os.walk(path): + for filename in files: + name, extension = os.path.splitext(filename) + if extension.lower() in image_ext_ocv: + paths.append(os.path.join(root, filename)) + return paths + + +def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]: + for path in paths: + name, extension = os.path.splitext(path) + extension = extension.lower() + imagebgr = cv2.imread(path) + image = cv2.cvtColor(imagebgr, cv2.COLOR_BGR2RGB) + if image is None: + print(f"Warning: could not load {path}") + else: + yield image, path + + +def pipeline(queue: Queue, image_paths: list[str], prompt: str, device: torch.device, model_name_or_path: str, batch_size: int): + model = LlavaForConditionalGeneration.from_pretrained(model_name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=None, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type='nf4', + ), device_map=device, attn_implementation="flash_attention_2") + processor = AutoProcessor.from_pretrained(model_name_or_path) + image_generator = image_loader(image_paths) + + stop = False + finished_count = 0 + while not stop: + prompts = list() + images = list() + filenames = list() + for i in range(0, batch_size): + image, filename = next(image_generator, (None, None)) + if image is None: + stop = True + break + + filenames.append(filename) + images.append(image) + prompts.append(prompt) + + if len(images) == 0: + break + + inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device) + generate_ids = model.generate(**inputs, max_new_tokens=100, min_new_tokens=3, length_penalty=1.0, do_sample=False, temperature=1.0, top_k=50, top_p=1.0) + decodes = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + finished_count += len(images) + for i, decoded in enumerate(decodes): + trim = len(prompt) - len("") + queue.put({"file_name": filenames[i], "text": decoded[trim:].strip()}) + + +def split_list(input_list, count): + target_length = int(len(input_list) / count) + for i in range(0, count - 1): + yield input_list[i * target_length: (i + 1) * target_length] + yield input_list[(count - 1) * target_length: len(input_list)] + + +def save_meta(meta_file, meta, reldir, common_description): + meta["file_name"] = os.path.relpath(meta["file_name"], reldir) + if common_description is not None: + meta["text"] = common_description + meta["text"] + meta_file.write(json.dumps(meta) + '\n') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("A script to tag images via llava") + parser.add_argument('--model', '-m', default="llava-hf/llava-1.5-13b-hf", help="model to use") + parser.add_argument('--quantize', '-q', action='store_true', help="load quantized") + parser.add_argument('--prompt', '-p', default="Please describe this image in 10 to 20 words.", help="Prompt to use on eatch image") + parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference") + parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one") + parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag") + args = parser.parse_args() + + prompt = "USER: \n" + args.prompt + "\nASSISTANT: " + os.environ["BITSANDBYTES_NOWELCOME"] = "1" + + image_paths = find_image_files(args.image_dir) + image_path_chunks = list(split_list(image_paths, torch.cuda.device_count())) + + print(f"Will use {torch.cuda.device_count()} processies to create tags") + + logging.set_verbosity_error() + warnings.filterwarnings("ignore") + torch.multiprocessing.set_start_method('spawn') + + queue = Queue() + processies = list() + for i in range(0, torch.cuda.device_count()): + processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], prompt, torch.device(i), args.model, args.batch))) + processies[-1].start() + + progress = tqdm(desc="Generateing tags", total=len(image_paths)) + exit = False + with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file: + while not exit: + if not queue.empty(): + meta = queue.get() + save_meta(output_file, meta, args.image_dir, args.common_description) + progress.update() + exit = True + for process in processies: + if process.is_alive(): + exit = False + break + + while not queue.empty(): + meta = queue.get() + save_meta(output_file, meta, args.image_dir, args.common_description) + progress.update() + + for process in processies: + process.join() + diff --git a/DanbooruTagger/deepdanbooru_onnx/__init__.py b/DanbooruTagger/deepdanbooru_onnx/__init__.py new file mode 100644 index 0000000..21d7e94 --- /dev/null +++ b/DanbooruTagger/deepdanbooru_onnx/__init__.py @@ -0,0 +1,3 @@ +from .deepdanbooru_onnx import DeepDanbooru +from .deepdanbooru_onnx import process_image +__version__ = '0.0.8' \ No newline at end of file diff --git a/DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py b/DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py new file mode 100644 index 0000000..c108ceb --- /dev/null +++ b/DanbooruTagger/deepdanbooru_onnx/deepdanbooru_onnx.py @@ -0,0 +1,244 @@ +import onnxruntime as ort +from PIL import Image +import numpy as np +import os +from tqdm import tqdm +import requests +import hashlib +from typing import List, Union +import shutil +from pathlib import Path + + +def process_image(image: Image.Image) -> np.ndarray: + """ + Convert an image to a numpy array. + :param image: the image to convert + :return: the numpy array + """ + + image = image.convert("RGB").resize((512, 512)) + image = np.array(image).astype(np.float32) / 255 + image = image.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1)) + return image + + +def download(url: str, save_path: str, md5: str, length: str) -> bool: + """ + Download a file from url to save_path. + If the file already exists, check its md5. + If the md5 matches, return True,if the md5 doesn't match, return False. + :param url: the url of the file to download + :param save_path: the path to save the file + :param md5: the md5 of the file + :param length: the length of the file + :return: True if the file is downloaded successfully, False otherwise + """ + + try: + response = requests.get(url=url, stream=True) + with open(save_path, "wb") as f: + with tqdm.wrapattr( + response.raw, "read", total=length, desc="Downloading" + ) as r_raw: + shutil.copyfileobj(r_raw, f) + return ( + True + if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5 + else False + ) + except Exception as e: + print(e) + return False + + +def download_model(): + """ + Download the model and tags file from the server. + :return: the path to the model and tags file + """ + + model_url = ( + "https://huggingface.co/chinoll/deepdanbooru/resolve/main/deepdanbooru.onnx" + ) + tags_url = "https://huggingface.co/chinoll/deepdanbooru/resolve/main/tags.txt" + model_md5 = "16be4e40ebcc0b1d1915bbf31f00969f" + tags_md5 = "a3f764de985cdeba89f1d232a4204402" + model_length = 643993025 + tags_length = 133810 + + home = str(Path.home()) + "/.deepdanbooru_onnx/" + if not os.path.exists(home): + os.mkdir(home) + + model_name = "deepdanbooru.onnx" + tags_name = "tags.txt" + + model_path = home + model_name + tags_path = home + tags_name + if os.path.exists(model_path): + if hashlib.md5(open(model_path, "rb").read()).hexdigest() != model_md5: + os.remove(model_path) + if not download(model_url, model_path, model_md5, model_length): + raise ValueError("Model download failed") + + else: + if not download(model_url, model_path, model_md5, model_length): + raise ValueError("Model download failed") + + if os.path.exists(tags_path): + if hashlib.md5(open(tags_path, "rb").read()).hexdigest() != tags_md5: + os.remove(tags_path) + if not download(tags_url, tags_path, tags_md5, tags_length): + raise ValueError("Tags download failed") + else: + if not download(tags_url, tags_path, tags_md5, tags_length): + raise ValueError("Tags download failed") + return model_path, tags_path + + +class DeepDanbooru: + def __init__( + self, + mode: str = "auto", + model_path: Union[str, None] = None, + tags_path: Union[str, None] = None, + threshold: Union[float, int] = 0.6, + pin_memory: bool = False, + batch_size: int = 1, + ): + """ + Initialize the DeepDanbooru class. + :param mode: the mode of the model, "cpu" or "gpu" or "auto" + :param model_path: the path to the model file + :param tags_path: the path to the tags file + :param threshold: the threshold of the model + :param pin_memory: whether to use pin memory + :param batch_size: the batch size of the model + """ + + providers = { + "cpu": "CPUExecutionProvider", + "gpu": "CUDAExecutionProvider", + "tensorrt": "TensorrtExecutionProvider", + "auto": ( + "CUDAExecutionProvider" + if "CUDAExecutionProvider" in ort.get_available_providers() + else "CPUExecutionProvider" + ), + } + + if not (isinstance(threshold, float) or isinstance(threshold, int)): + raise TypeError("threshold must be float or int") + if threshold < 0 or threshold > 1: + raise ValueError("threshold must be between 0 and 1") + if mode not in providers: + raise ValueError( + "Mode not supported. Please choose from: cpu, gpu, tensorrt" + ) + if providers[mode] not in ort.get_available_providers(): + raise ValueError( + f"Your device is not supported {mode}. Please choose from: cpu" + ) + if model_path is not None and not os.path.exists(model_path): + raise FileNotFoundError("Model file not found") + if tags_path is not None and not os.path.exists(tags_path): + raise FileNotFoundError("Tags file not found") + + if model_path is None or tags_path is None: + model_path, tags_path = download_model() + + self.session = ort.InferenceSession(model_path, providers=[providers[mode]]) + self.tags = [i.replace("\n", "") for i in open(tags_path, "r").readlines()] + + self.input_name = self.session.get_inputs()[0].name + self.output_name = [output.name for output in self.session.get_outputs()] + self.threshold = threshold + self.pin_memory = pin_memory + self.batch_size = batch_size + self.mode = mode + self.cache = {} + + def __str__(self) -> str: + return f"DeepDanbooru(mode={self.mode}, threshold={self.threshold}, pin_memory={self.pin_memory}, batch_size={self.batch_size})" + + def __repr__(self) -> str: + return self.__str__() + + def from_image_inference(self, image: Image.Image) -> dict: + image = process_image(image) + return self.predict(image) + + def from_ndarray_inferece(self, image: np.ndarray) -> dict: + if image.shape != (1, 512, 512, 3): + raise ValueError(f"Image must be {(1, 512, 512, 3)}") + return self.predict(image) + + def from_file_inference(self, image: str) -> dict: + return self.from_image_inference(Image.open(image)) + + def from_list_inference(self, image: Union[list, tuple]) -> List[dict]: + if self.pin_memory: + image = [process_image(Image.open(i)) for i in image] + for i in [ + image[i : i + self.batch_size] + for i in range(0, len(image), self.batch_size) + ]: + imagelist = i + bs = len(i) + _imagelist, idx, hashlist = [], [], [] + for j in range(len(i)): + img = Image.open(i[j]) if not self.pin_memory else imagelist[j] + image_hash = hashlib.md5(np.array(img).astype(np.uint8)).hexdigest() + hashlist.append(image_hash) + if image_hash in self.cache: + continue + if not self.pin_memory: + _imagelist.append(process_image(img)) + else: + _imagelist.append(imagelist[j]) + idx.append(j) + + imagelist = _imagelist + if len(imagelist) != 0: + _image = np.vstack(imagelist) + results = self.inference(_image) + results_idx = 0 + else: + results = [] + + for i in range(bs): + image_tag = {} + if i in idx: + hash = hashlist[i] + for tag, score in zip(self.tags, results[results_idx]): + if score >= self.threshold: + image_tag[tag] = score + results_idx += 1 + self.cache[hash] = image_tag + yield image_tag + else: + yield self.cache[hashlist[i]] + + def inference(self, image): + return self.session.run(self.output_name, {self.input_name: image})[0] + + def predict(self, image): + result = self.inference(image) + image_tag = {} + for tag, score in zip(self.tags, result[0]): + if score >= self.threshold: + image_tag[tag] = score + return image_tag + + def __call__(self, image) -> Union[dict, List[dict]]: + if isinstance(image, str): + return self.from_file_inference(image) + elif isinstance(image, np.ndarray): + return self.from_ndarray_inferece(image) + elif isinstance(image, list) or isinstance(image, tuple): + return self.from_list_inference(image) + elif isinstance(image, Image.Image): + return self.from_image_inference(image) + else: + raise ValueError("Image must be a file path or a numpy array or list/tuple") diff --git a/DanbooruTagger/example.py b/DanbooruTagger/example.py new file mode 100644 index 0000000..18fa74d --- /dev/null +++ b/DanbooruTagger/example.py @@ -0,0 +1,3 @@ +from deepdanbooru_onnx import DeepDanbooru +danbooru = DeepDanbooru() +print(danbooru("/run/media/philipp/20404acc-312c-44f2-b2d1-3a0a14257cc6/.Media/porn/00244-3145022840.png"))