DanbooruTagger: add wd tagger support

This commit is contained in:
uvos 2024-09-05 18:45:34 +02:00
parent 2a6908c849
commit 422debd897
7 changed files with 305 additions and 109 deletions

View File

@ -1,5 +1,5 @@
import warnings
from deepdanbooru_onnx import DeepDanbooru
from wd_onnx import Wd
from PIL import Image
import argparse
import cv2
@ -13,13 +13,20 @@ image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"]
def find_image_files(path: str) -> list[str]:
paths = list()
for root, dirs, files in os.walk(path):
for filename in files:
name, extension = os.path.splitext(filename)
if extension.lower() in image_ext_ocv:
paths.append(os.path.join(root, filename))
return paths
if os.path.isdir(path):
paths = list()
for root, dirs, files in os.walk(path):
for filename in files:
name, extension = os.path.splitext(filename)
if extension.lower() in image_ext_ocv:
paths.append(os.path.join(root, filename))
return paths
else:
name, extension = os.path.splitext(path)
if extension.lower() in image_ext_ocv:
return [path]
else:
return []
def image_loader(paths: list[str]):
@ -35,14 +42,28 @@ def image_loader(paths: list[str]):
yield image_pil, path
def pipeline(queue: Queue, image_paths: list[str], device: int):
danbooru = DeepDanbooru()
def danbooru_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool):
danbooru = DeepDanbooru("cpu" if cpu else "auto")
for path in image_paths:
imageprompt = ""
tags = danbooru(path)
for tag in tags:
imageprompt = imageprompt + ", " + tag
imageprompt = imageprompt[2:]
queue.put({"file_name": path, "text": imageprompt})
def wd_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool):
wd = Wd("cpu" if cpu else "auto", threshold=0.3)
for path in image_paths:
imageprompt = ""
tags = wd(path)
for tag in tags:
imageprompt = imageprompt + ", " + tag
imageprompt = imageprompt[2:]
queue.put({"file_name": path, "text": imageprompt})
@ -57,49 +78,71 @@ def split_list(input_list, count):
def save_meta(meta_file, meta, reldir, common_description):
meta["file_name"] = os.path.relpath(meta["file_name"], reldir)
if common_description is not None:
meta["text"] = common_description + meta["text"]
meta["text"] = common_description + ", " + meta["text"]
meta_file.write(json.dumps(meta) + '\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru")
parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference")
parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one")
parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag")
parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag or a singular image to tag")
parser.add_argument('--wd', '-w', action="store_true", help="use wd tagger instead of DeepDanbooru")
parser.add_argument('--cpu', action="store_true", help="force cpu usge instead of gpu")
args = parser.parse_args()
nparalell = 2
image_paths = find_image_files(args.image_dir)
if len(image_paths) == 0:
print("Unable to find any images at {args.image_dir}")
exit(1)
nparalell = 4 if len(image_paths) > 4 else len(image_paths)
image_path_chunks = list(split_list(image_paths, nparalell))
print(f"Will use {nparalell} processies to create tags")
print(f"Will use {nparalell} processies to create tags for {len(image_paths)} images")
queue = Queue()
pipe = danbooru_pipeline if not args.wd else wd_pipeline
processies = list()
for i in range(0, nparalell):
processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i)))
processies.append(Process(target=pipe, args=(queue, image_path_chunks[i], i, args.cpu)))
processies[-1].start()
progress = tqdm(desc="Generateing tags", total=len(image_paths))
exit = False
with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file:
if len(image_paths) > 1:
with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file:
while not exit:
if not queue.empty():
meta = queue.get()
save_meta(output_file, meta, args.image_dir, args.common_description)
progress.update()
exit = True
for process in processies:
if process.is_alive():
exit = False
break
while not queue.empty():
meta = queue.get()
save_meta(output_file, meta, args.image_dir, args.common_description)
progress.update()
else:
while not exit:
if not queue.empty():
meta = queue.get()
save_meta(output_file, meta, args.image_dir, args.common_description)
print(meta)
progress.update()
exit = True
for process in processies:
if process.is_alive():
exit = False
break
while not queue.empty():
meta = queue.get()
save_meta(output_file, meta, args.image_dir, args.common_description)
print(meta)
progress.update()
for process in processies:
process.join()

View File

@ -2,13 +2,12 @@ import onnxruntime as ort
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import requests
import hashlib
from typing import List, Union
import shutil
from pathlib import Path
from utils import download
def process_image(image: Image.Image) -> np.ndarray:
"""
@ -18,38 +17,9 @@ def process_image(image: Image.Image) -> np.ndarray:
"""
image = image.convert("RGB").resize((512, 512))
image = np.array(image).astype(np.float32) / 255
image = image.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1))
return image
def download(url: str, save_path: str, md5: str, length: str) -> bool:
"""
Download a file from url to save_path.
If the file already exists, check its md5.
If the md5 matches, return True,if the md5 doesn't match, return False.
:param url: the url of the file to download
:param save_path: the path to save the file
:param md5: the md5 of the file
:param length: the length of the file
:return: True if the file is downloaded successfully, False otherwise
"""
try:
response = requests.get(url=url, stream=True)
with open(save_path, "wb") as f:
with tqdm.wrapattr(
response.raw, "read", total=length, desc="Downloading"
) as r_raw:
shutil.copyfileobj(r_raw, f)
return (
True
if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5
else False
)
except Exception as e:
print(e)
return False
imagenp = np.array(image).astype(np.float32) / 255
imagenp = imagenp.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1))
return imagenp
def download_model():
@ -109,7 +79,7 @@ class DeepDanbooru:
):
"""
Initialize the DeepDanbooru class.
:param mode: the mode of the model, "cpu" or "gpu" or "auto"
:param mode: the mode of the model, "cpu", "cuda", "hip" or "auto"
:param model_path: the path to the model file
:param tags_path: the path to the tags file
:param threshold: the threshold of the model
@ -119,11 +89,13 @@ class DeepDanbooru:
providers = {
"cpu": "CPUExecutionProvider",
"gpu": "CUDAExecutionProvider",
"cuda": "CUDAExecutionProvider",
"hip": "ROCMExecutionProvider",
"tensorrt": "TensorrtExecutionProvider",
"auto": (
"CUDAExecutionProvider"
if "CUDAExecutionProvider" in ort.get_available_providers()
else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers()
else "CPUExecutionProvider"
),
}
@ -166,8 +138,8 @@ class DeepDanbooru:
return self.__str__()
def from_image_inference(self, image: Image.Image) -> dict:
image = process_image(image)
return self.predict(image)
imagenp = process_image(image)
return self.predict(imagenp)
def from_ndarray_inferece(self, image: np.ndarray) -> dict:
if image.shape != (1, 512, 512, 3):
@ -177,49 +149,6 @@ class DeepDanbooru:
def from_file_inference(self, image: str) -> dict:
return self.from_image_inference(Image.open(image))
def from_list_inference(self, image: Union[list, tuple]) -> List[dict]:
if self.pin_memory:
image = [process_image(Image.open(i)) for i in image]
for i in [
image[i : i + self.batch_size]
for i in range(0, len(image), self.batch_size)
]:
imagelist = i
bs = len(i)
_imagelist, idx, hashlist = [], [], []
for j in range(len(i)):
img = Image.open(i[j]) if not self.pin_memory else imagelist[j]
image_hash = hashlib.md5(np.array(img).astype(np.uint8)).hexdigest()
hashlist.append(image_hash)
if image_hash in self.cache:
continue
if not self.pin_memory:
_imagelist.append(process_image(img))
else:
_imagelist.append(imagelist[j])
idx.append(j)
imagelist = _imagelist
if len(imagelist) != 0:
_image = np.vstack(imagelist)
results = self.inference(_image)
results_idx = 0
else:
results = []
for i in range(bs):
image_tag = {}
if i in idx:
hash = hashlist[i]
for tag, score in zip(self.tags, results[results_idx]):
if score >= self.threshold:
image_tag[tag] = score
results_idx += 1
self.cache[hash] = image_tag
yield image_tag
else:
yield self.cache[hashlist[i]]
def inference(self, image):
return self.session.run(self.output_name, {self.input_name: image})[0]
@ -236,8 +165,6 @@ class DeepDanbooru:
return self.from_file_inference(image)
elif isinstance(image, np.ndarray):
return self.from_ndarray_inferece(image)
elif isinstance(image, list) or isinstance(image, tuple):
return self.from_list_inference(image)
elif isinstance(image, Image.Image):
return self.from_image_inference(image)
else:

View File

@ -1,3 +0,0 @@
from .deepdanbooru_onnx import DeepDanbooru
from .deepdanbooru_onnx import process_image
__version__ = '0.0.8'

33
DanbooruTagger/utils.py Normal file
View File

@ -0,0 +1,33 @@
import requests
import shutil
import hashlib
from tqdm import tqdm
def download(url: str, save_path: str, md5: str, length: str) -> bool:
"""
Download a file from url to save_path.
If the file already exists, check its md5.
If the md5 matches, return True,if the md5 doesn't match, return False.
:param url: the url of the file to download
:param save_path: the path to save the file
:param md5: the md5 of the file
:param length: the length of the file
:return: True if the file is downloaded successfully, False otherwise
"""
try:
response = requests.get(url=url, stream=True)
with open(save_path, "wb") as f:
with tqdm.wrapattr(
response.raw, "read", total=length, desc="Downloading"
) as r_raw:
shutil.copyfileobj(r_raw, f)
return (
True
if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5
else False
)
except Exception as e:
print(e)
return False

196
DanbooruTagger/wd_onnx.py Normal file
View File

@ -0,0 +1,196 @@
import onnxruntime as ort
from PIL import Image
import numpy as np
import os
import hashlib
from typing import List, Union
from pathlib import Path
import csv
from utils import download
def process_image(image: Image.Image, target_size: int) -> np.ndarray:
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.paste(image, mask=image.split()[3] if image.mode == 'RGBA' else None)
image = canvas.convert("RGB")
# Pad image to a square
max_dim = max(image.size)
pad_left = (max_dim - image.size[0]) // 2
pad_top = (max_dim - image.size[1]) // 2
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
# Resize
padded_image = padded_image.resize((target_size, target_size), Image.Resampling.BICUBIC)
# Convert to numpy array
image_array = np.asarray(padded_image, dtype=np.float32)[..., [2, 1, 0]]
return np.expand_dims(image_array, axis=0)
def download_model():
"""
Download the model and tags file from the server.
:return: the path to the model and tags file
"""
model_url = (
"https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/model.onnx"
)
tags_url = "https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/selected_tags.csv"
model_md5 = "1fc4f456261c457a08d4b9e3379cac39"
tags_md5 = "55c11e40f95e63ea9ac21a065a73fd0f"
model_length = 378536310
tags_length = 308468
home = str(Path.home()) + "/.wd_onnx/"
if not os.path.exists(home):
os.mkdir(home)
model_name = "wd.onnx"
tags_name = "selected_tags.csv"
model_path = home + model_name
tags_path = home + tags_name
if os.path.exists(model_path):
if hashlib.md5(open(model_path, "rb").read()).hexdigest() != model_md5:
os.remove(model_path)
if not download(model_url, model_path, model_md5, model_length):
raise ValueError("Model download failed")
else:
if not download(model_url, model_path, model_md5, model_length):
raise ValueError("Model download failed")
if os.path.exists(tags_path):
if hashlib.md5(open(tags_path, "rb").read()).hexdigest() != tags_md5:
os.remove(tags_path)
if not download(tags_url, tags_path, tags_md5, tags_length):
raise ValueError("Tags download failed")
else:
if not download(tags_url, tags_path, tags_md5, tags_length):
raise ValueError("Tags download failed")
return model_path, tags_path
class Wd:
def __init__(
self,
mode: str = "auto",
model_path: Union[str, None] = None,
tags_path: Union[str, None] = None,
threshold: Union[float, int] = 0.6,
pin_memory: bool = False,
batch_size: int = 1,
):
"""
Initialize the DeepDanbooru class.
:param mode: the mode of the model, "cpu", "cuda", "hip" or "auto"
:param model_path: the path to the model file
:param tags_path: the path to the tags file
:param threshold: the threshold of the model
:param pin_memory: whether to use pin memory
:param batch_size: the batch size of the model
"""
providers = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"hip": "ROCMExecutionProvider",
"tensorrt": "TensorrtExecutionProvider",
"auto": (
"CUDAExecutionProvider"
if "CUDAExecutionProvider" in ort.get_available_providers()
else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers()
else "CPUExecutionProvider"
),
}
if not (isinstance(threshold, float) or isinstance(threshold, int)):
raise TypeError("threshold must be float or int")
if threshold < 0 or threshold > 1:
raise ValueError("threshold must be between 0 and 1")
if mode not in providers:
raise ValueError(
"Mode not supported. Please choose from: cpu, gpu, tensorrt"
)
if providers[mode] not in ort.get_available_providers():
raise ValueError(
f"Your device is not supported {mode}. Please choose from: cpu"
)
if model_path is not None and not os.path.exists(model_path):
raise FileNotFoundError("Model file not found")
if tags_path is not None and not os.path.exists(tags_path):
raise FileNotFoundError("Tags file not found")
if model_path is None or tags_path is None:
model_path, tags_path = download_model()
self.session = ort.InferenceSession(model_path, providers=[providers[mode]])
self.tags = list()
with open(tags_path, "r") as tagfile:
reader = csv.DictReader(tagfile)
for row in reader:
self.tags.append(row)
self.input_name = self.session.get_inputs()[0].name
self.output_name = [output.name for output in self.session.get_outputs()]
self.threshold = threshold
self.pin_memory = pin_memory
self.batch_size = batch_size
self.target_size = self.session.get_inputs()[0].shape[2]
self.mode = mode
self.cache = {}
def __str__(self) -> str:
return f"Wd(mode={self.mode}, threshold={self.threshold}, pin_memory={self.pin_memory}, batch_size={self.batch_size})"
def __repr__(self) -> str:
return self.__str__()
def from_image_inference(self, image: Image.Image) -> dict:
imagenp = process_image(image, self.target_size)
return self.predict(imagenp)
def from_ndarray_inferece(self, image: np.ndarray) -> dict:
if image.shape != (1, 512, 512, 3):
raise ValueError(f"Image must be {(1, 512, 512, 3)}")
return self.predict(image)
def from_file_inference(self, path: str) -> dict:
image = Image.open(path)
return self.from_image_inference(Image.open(path))
def inference(self, image):
return self.session.run(self.output_name, {self.input_name: image})[0]
def predict(self, image):
result = self.inference(image)
tags = self.tags
for tag, score in zip(tags, result[0]):
tag['score'] = float(score)
ratings = tags[:4]
tags = tags[4:]
image_tags = {}
rating = max(ratings, key=lambda el: el['score'])
image_tags[rating['name']] = rating['score']
for tag in tags:
if tag['score'] >= self.threshold:
image_tags[tag['name']] = tag['score']
return image_tags
def __call__(self, image) -> Union[dict, List[dict]]:
if isinstance(image, str):
return self.from_file_inference(image)
elif isinstance(image, np.ndarray):
return self.from_ndarray_inferece(image)
elif isinstance(image, Image.Image):
return self.from_image_inference(image)
else:
raise ValueError("Image must be a file path or a numpy array or list/tuple")