DanbooruTagger: add wd tagger support
This commit is contained in:
parent
2a6908c849
commit
422debd897
@ -1,5 +1,5 @@
|
||||
import warnings
|
||||
from deepdanbooru_onnx import DeepDanbooru
|
||||
from wd_onnx import Wd
|
||||
from PIL import Image
|
||||
import argparse
|
||||
import cv2
|
||||
@ -13,13 +13,20 @@ image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"]
|
||||
|
||||
|
||||
def find_image_files(path: str) -> list[str]:
|
||||
paths = list()
|
||||
for root, dirs, files in os.walk(path):
|
||||
for filename in files:
|
||||
name, extension = os.path.splitext(filename)
|
||||
if extension.lower() in image_ext_ocv:
|
||||
paths.append(os.path.join(root, filename))
|
||||
return paths
|
||||
if os.path.isdir(path):
|
||||
paths = list()
|
||||
for root, dirs, files in os.walk(path):
|
||||
for filename in files:
|
||||
name, extension = os.path.splitext(filename)
|
||||
if extension.lower() in image_ext_ocv:
|
||||
paths.append(os.path.join(root, filename))
|
||||
return paths
|
||||
else:
|
||||
name, extension = os.path.splitext(path)
|
||||
if extension.lower() in image_ext_ocv:
|
||||
return [path]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def image_loader(paths: list[str]):
|
||||
@ -35,14 +42,28 @@ def image_loader(paths: list[str]):
|
||||
yield image_pil, path
|
||||
|
||||
|
||||
def pipeline(queue: Queue, image_paths: list[str], device: int):
|
||||
danbooru = DeepDanbooru()
|
||||
def danbooru_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool):
|
||||
danbooru = DeepDanbooru("cpu" if cpu else "auto")
|
||||
|
||||
for path in image_paths:
|
||||
imageprompt = ""
|
||||
tags = danbooru(path)
|
||||
for tag in tags:
|
||||
imageprompt = imageprompt + ", " + tag
|
||||
imageprompt = imageprompt[2:]
|
||||
|
||||
queue.put({"file_name": path, "text": imageprompt})
|
||||
|
||||
|
||||
def wd_pipeline(queue: Queue, image_paths: list[str], device: int, cpu: bool):
|
||||
wd = Wd("cpu" if cpu else "auto", threshold=0.3)
|
||||
|
||||
for path in image_paths:
|
||||
imageprompt = ""
|
||||
tags = wd(path)
|
||||
for tag in tags:
|
||||
imageprompt = imageprompt + ", " + tag
|
||||
imageprompt = imageprompt[2:]
|
||||
|
||||
queue.put({"file_name": path, "text": imageprompt})
|
||||
|
||||
@ -57,49 +78,71 @@ def split_list(input_list, count):
|
||||
def save_meta(meta_file, meta, reldir, common_description):
|
||||
meta["file_name"] = os.path.relpath(meta["file_name"], reldir)
|
||||
if common_description is not None:
|
||||
meta["text"] = common_description + meta["text"]
|
||||
meta["text"] = common_description + ", " + meta["text"]
|
||||
meta_file.write(json.dumps(meta) + '\n')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("A script to tag images via DeepDanbooru")
|
||||
parser.add_argument('--batch', '-b', default=4, type=int, help="Batch size to use for inference")
|
||||
parser.add_argument('--common_description', '-c', help="An optional description that will be preended to the ai generated one")
|
||||
parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag")
|
||||
parser.add_argument('--image_dir', '-i', help="A directory containg the images to tag or a singular image to tag")
|
||||
parser.add_argument('--wd', '-w', action="store_true", help="use wd tagger instead of DeepDanbooru")
|
||||
parser.add_argument('--cpu', action="store_true", help="force cpu usge instead of gpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
nparalell = 2
|
||||
|
||||
image_paths = find_image_files(args.image_dir)
|
||||
|
||||
if len(image_paths) == 0:
|
||||
print("Unable to find any images at {args.image_dir}")
|
||||
exit(1)
|
||||
|
||||
nparalell = 4 if len(image_paths) > 4 else len(image_paths)
|
||||
image_path_chunks = list(split_list(image_paths, nparalell))
|
||||
|
||||
print(f"Will use {nparalell} processies to create tags")
|
||||
print(f"Will use {nparalell} processies to create tags for {len(image_paths)} images")
|
||||
|
||||
queue = Queue()
|
||||
pipe = danbooru_pipeline if not args.wd else wd_pipeline
|
||||
processies = list()
|
||||
for i in range(0, nparalell):
|
||||
processies.append(Process(target=pipeline, args=(queue, image_path_chunks[i], i)))
|
||||
processies.append(Process(target=pipe, args=(queue, image_path_chunks[i], i, args.cpu)))
|
||||
processies[-1].start()
|
||||
|
||||
progress = tqdm(desc="Generateing tags", total=len(image_paths))
|
||||
exit = False
|
||||
with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file:
|
||||
|
||||
if len(image_paths) > 1:
|
||||
with open(os.path.join(args.image_dir, "metadata.jsonl"), mode='w') as output_file:
|
||||
while not exit:
|
||||
if not queue.empty():
|
||||
meta = queue.get()
|
||||
save_meta(output_file, meta, args.image_dir, args.common_description)
|
||||
progress.update()
|
||||
exit = True
|
||||
for process in processies:
|
||||
if process.is_alive():
|
||||
exit = False
|
||||
break
|
||||
|
||||
while not queue.empty():
|
||||
meta = queue.get()
|
||||
save_meta(output_file, meta, args.image_dir, args.common_description)
|
||||
progress.update()
|
||||
else:
|
||||
while not exit:
|
||||
if not queue.empty():
|
||||
meta = queue.get()
|
||||
save_meta(output_file, meta, args.image_dir, args.common_description)
|
||||
print(meta)
|
||||
progress.update()
|
||||
exit = True
|
||||
for process in processies:
|
||||
if process.is_alive():
|
||||
exit = False
|
||||
break
|
||||
|
||||
while not queue.empty():
|
||||
meta = queue.get()
|
||||
save_meta(output_file, meta, args.image_dir, args.common_description)
|
||||
print(meta)
|
||||
progress.update()
|
||||
|
||||
for process in processies:
|
||||
process.join()
|
||||
|
||||
|
@ -2,13 +2,12 @@ import onnxruntime as ort
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import hashlib
|
||||
from typing import List, Union
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from utils import download
|
||||
|
||||
|
||||
def process_image(image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
@ -18,38 +17,9 @@ def process_image(image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
|
||||
image = image.convert("RGB").resize((512, 512))
|
||||
image = np.array(image).astype(np.float32) / 255
|
||||
image = image.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
|
||||
def download(url: str, save_path: str, md5: str, length: str) -> bool:
|
||||
"""
|
||||
Download a file from url to save_path.
|
||||
If the file already exists, check its md5.
|
||||
If the md5 matches, return True,if the md5 doesn't match, return False.
|
||||
:param url: the url of the file to download
|
||||
:param save_path: the path to save the file
|
||||
:param md5: the md5 of the file
|
||||
:param length: the length of the file
|
||||
:return: True if the file is downloaded successfully, False otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
response = requests.get(url=url, stream=True)
|
||||
with open(save_path, "wb") as f:
|
||||
with tqdm.wrapattr(
|
||||
response.raw, "read", total=length, desc="Downloading"
|
||||
) as r_raw:
|
||||
shutil.copyfileobj(r_raw, f)
|
||||
return (
|
||||
True
|
||||
if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5
|
||||
else False
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return False
|
||||
imagenp = np.array(image).astype(np.float32) / 255
|
||||
imagenp = imagenp.transpose((2, 0, 1)).reshape(1, 3, 512, 512).transpose((0, 2, 3, 1))
|
||||
return imagenp
|
||||
|
||||
|
||||
def download_model():
|
||||
@ -109,7 +79,7 @@ class DeepDanbooru:
|
||||
):
|
||||
"""
|
||||
Initialize the DeepDanbooru class.
|
||||
:param mode: the mode of the model, "cpu" or "gpu" or "auto"
|
||||
:param mode: the mode of the model, "cpu", "cuda", "hip" or "auto"
|
||||
:param model_path: the path to the model file
|
||||
:param tags_path: the path to the tags file
|
||||
:param threshold: the threshold of the model
|
||||
@ -119,11 +89,13 @@ class DeepDanbooru:
|
||||
|
||||
providers = {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"gpu": "CUDAExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"hip": "ROCMExecutionProvider",
|
||||
"tensorrt": "TensorrtExecutionProvider",
|
||||
"auto": (
|
||||
"CUDAExecutionProvider"
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers()
|
||||
else "CPUExecutionProvider"
|
||||
),
|
||||
}
|
||||
@ -166,8 +138,8 @@ class DeepDanbooru:
|
||||
return self.__str__()
|
||||
|
||||
def from_image_inference(self, image: Image.Image) -> dict:
|
||||
image = process_image(image)
|
||||
return self.predict(image)
|
||||
imagenp = process_image(image)
|
||||
return self.predict(imagenp)
|
||||
|
||||
def from_ndarray_inferece(self, image: np.ndarray) -> dict:
|
||||
if image.shape != (1, 512, 512, 3):
|
||||
@ -177,49 +149,6 @@ class DeepDanbooru:
|
||||
def from_file_inference(self, image: str) -> dict:
|
||||
return self.from_image_inference(Image.open(image))
|
||||
|
||||
def from_list_inference(self, image: Union[list, tuple]) -> List[dict]:
|
||||
if self.pin_memory:
|
||||
image = [process_image(Image.open(i)) for i in image]
|
||||
for i in [
|
||||
image[i : i + self.batch_size]
|
||||
for i in range(0, len(image), self.batch_size)
|
||||
]:
|
||||
imagelist = i
|
||||
bs = len(i)
|
||||
_imagelist, idx, hashlist = [], [], []
|
||||
for j in range(len(i)):
|
||||
img = Image.open(i[j]) if not self.pin_memory else imagelist[j]
|
||||
image_hash = hashlib.md5(np.array(img).astype(np.uint8)).hexdigest()
|
||||
hashlist.append(image_hash)
|
||||
if image_hash in self.cache:
|
||||
continue
|
||||
if not self.pin_memory:
|
||||
_imagelist.append(process_image(img))
|
||||
else:
|
||||
_imagelist.append(imagelist[j])
|
||||
idx.append(j)
|
||||
|
||||
imagelist = _imagelist
|
||||
if len(imagelist) != 0:
|
||||
_image = np.vstack(imagelist)
|
||||
results = self.inference(_image)
|
||||
results_idx = 0
|
||||
else:
|
||||
results = []
|
||||
|
||||
for i in range(bs):
|
||||
image_tag = {}
|
||||
if i in idx:
|
||||
hash = hashlist[i]
|
||||
for tag, score in zip(self.tags, results[results_idx]):
|
||||
if score >= self.threshold:
|
||||
image_tag[tag] = score
|
||||
results_idx += 1
|
||||
self.cache[hash] = image_tag
|
||||
yield image_tag
|
||||
else:
|
||||
yield self.cache[hashlist[i]]
|
||||
|
||||
def inference(self, image):
|
||||
return self.session.run(self.output_name, {self.input_name: image})[0]
|
||||
|
||||
@ -236,8 +165,6 @@ class DeepDanbooru:
|
||||
return self.from_file_inference(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
return self.from_ndarray_inferece(image)
|
||||
elif isinstance(image, list) or isinstance(image, tuple):
|
||||
return self.from_list_inference(image)
|
||||
elif isinstance(image, Image.Image):
|
||||
return self.from_image_inference(image)
|
||||
else:
|
@ -1,3 +0,0 @@
|
||||
from .deepdanbooru_onnx import DeepDanbooru
|
||||
from .deepdanbooru_onnx import process_image
|
||||
__version__ = '0.0.8'
|
Binary file not shown.
Binary file not shown.
33
DanbooruTagger/utils.py
Normal file
33
DanbooruTagger/utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
import requests
|
||||
import shutil
|
||||
import hashlib
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def download(url: str, save_path: str, md5: str, length: str) -> bool:
|
||||
"""
|
||||
Download a file from url to save_path.
|
||||
If the file already exists, check its md5.
|
||||
If the md5 matches, return True,if the md5 doesn't match, return False.
|
||||
:param url: the url of the file to download
|
||||
:param save_path: the path to save the file
|
||||
:param md5: the md5 of the file
|
||||
:param length: the length of the file
|
||||
:return: True if the file is downloaded successfully, False otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
response = requests.get(url=url, stream=True)
|
||||
with open(save_path, "wb") as f:
|
||||
with tqdm.wrapattr(
|
||||
response.raw, "read", total=length, desc="Downloading"
|
||||
) as r_raw:
|
||||
shutil.copyfileobj(r_raw, f)
|
||||
return (
|
||||
True
|
||||
if hashlib.md5(open(save_path, "rb").read()).hexdigest() == md5
|
||||
else False
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return False
|
196
DanbooruTagger/wd_onnx.py
Normal file
196
DanbooruTagger/wd_onnx.py
Normal file
@ -0,0 +1,196 @@
|
||||
import onnxruntime as ort
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import os
|
||||
import hashlib
|
||||
from typing import List, Union
|
||||
from pathlib import Path
|
||||
import csv
|
||||
|
||||
from utils import download
|
||||
|
||||
|
||||
def process_image(image: Image.Image, target_size: int) -> np.ndarray:
|
||||
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
||||
canvas.paste(image, mask=image.split()[3] if image.mode == 'RGBA' else None)
|
||||
image = canvas.convert("RGB")
|
||||
|
||||
# Pad image to a square
|
||||
max_dim = max(image.size)
|
||||
pad_left = (max_dim - image.size[0]) // 2
|
||||
pad_top = (max_dim - image.size[1]) // 2
|
||||
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
||||
padded_image.paste(image, (pad_left, pad_top))
|
||||
|
||||
# Resize
|
||||
padded_image = padded_image.resize((target_size, target_size), Image.Resampling.BICUBIC)
|
||||
|
||||
# Convert to numpy array
|
||||
image_array = np.asarray(padded_image, dtype=np.float32)[..., [2, 1, 0]]
|
||||
|
||||
return np.expand_dims(image_array, axis=0)
|
||||
|
||||
|
||||
def download_model():
|
||||
"""
|
||||
Download the model and tags file from the server.
|
||||
:return: the path to the model and tags file
|
||||
"""
|
||||
|
||||
model_url = (
|
||||
"https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/model.onnx"
|
||||
)
|
||||
tags_url = "https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/selected_tags.csv"
|
||||
model_md5 = "1fc4f456261c457a08d4b9e3379cac39"
|
||||
tags_md5 = "55c11e40f95e63ea9ac21a065a73fd0f"
|
||||
model_length = 378536310
|
||||
tags_length = 308468
|
||||
|
||||
home = str(Path.home()) + "/.wd_onnx/"
|
||||
if not os.path.exists(home):
|
||||
os.mkdir(home)
|
||||
|
||||
model_name = "wd.onnx"
|
||||
tags_name = "selected_tags.csv"
|
||||
|
||||
model_path = home + model_name
|
||||
tags_path = home + tags_name
|
||||
if os.path.exists(model_path):
|
||||
if hashlib.md5(open(model_path, "rb").read()).hexdigest() != model_md5:
|
||||
os.remove(model_path)
|
||||
if not download(model_url, model_path, model_md5, model_length):
|
||||
raise ValueError("Model download failed")
|
||||
|
||||
else:
|
||||
if not download(model_url, model_path, model_md5, model_length):
|
||||
raise ValueError("Model download failed")
|
||||
|
||||
if os.path.exists(tags_path):
|
||||
if hashlib.md5(open(tags_path, "rb").read()).hexdigest() != tags_md5:
|
||||
os.remove(tags_path)
|
||||
if not download(tags_url, tags_path, tags_md5, tags_length):
|
||||
raise ValueError("Tags download failed")
|
||||
else:
|
||||
if not download(tags_url, tags_path, tags_md5, tags_length):
|
||||
raise ValueError("Tags download failed")
|
||||
return model_path, tags_path
|
||||
|
||||
|
||||
class Wd:
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = "auto",
|
||||
model_path: Union[str, None] = None,
|
||||
tags_path: Union[str, None] = None,
|
||||
threshold: Union[float, int] = 0.6,
|
||||
pin_memory: bool = False,
|
||||
batch_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Initialize the DeepDanbooru class.
|
||||
:param mode: the mode of the model, "cpu", "cuda", "hip" or "auto"
|
||||
:param model_path: the path to the model file
|
||||
:param tags_path: the path to the tags file
|
||||
:param threshold: the threshold of the model
|
||||
:param pin_memory: whether to use pin memory
|
||||
:param batch_size: the batch size of the model
|
||||
"""
|
||||
|
||||
providers = {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"hip": "ROCMExecutionProvider",
|
||||
"tensorrt": "TensorrtExecutionProvider",
|
||||
"auto": (
|
||||
"CUDAExecutionProvider"
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else "ROCMExecutionProvider" if "ROCMExecutionProvider" in ort.get_available_providers()
|
||||
else "CPUExecutionProvider"
|
||||
),
|
||||
}
|
||||
|
||||
if not (isinstance(threshold, float) or isinstance(threshold, int)):
|
||||
raise TypeError("threshold must be float or int")
|
||||
if threshold < 0 or threshold > 1:
|
||||
raise ValueError("threshold must be between 0 and 1")
|
||||
if mode not in providers:
|
||||
raise ValueError(
|
||||
"Mode not supported. Please choose from: cpu, gpu, tensorrt"
|
||||
)
|
||||
if providers[mode] not in ort.get_available_providers():
|
||||
raise ValueError(
|
||||
f"Your device is not supported {mode}. Please choose from: cpu"
|
||||
)
|
||||
if model_path is not None and not os.path.exists(model_path):
|
||||
raise FileNotFoundError("Model file not found")
|
||||
if tags_path is not None and not os.path.exists(tags_path):
|
||||
raise FileNotFoundError("Tags file not found")
|
||||
|
||||
if model_path is None or tags_path is None:
|
||||
model_path, tags_path = download_model()
|
||||
|
||||
self.session = ort.InferenceSession(model_path, providers=[providers[mode]])
|
||||
self.tags = list()
|
||||
with open(tags_path, "r") as tagfile:
|
||||
reader = csv.DictReader(tagfile)
|
||||
for row in reader:
|
||||
self.tags.append(row)
|
||||
|
||||
self.input_name = self.session.get_inputs()[0].name
|
||||
self.output_name = [output.name for output in self.session.get_outputs()]
|
||||
self.threshold = threshold
|
||||
self.pin_memory = pin_memory
|
||||
self.batch_size = batch_size
|
||||
self.target_size = self.session.get_inputs()[0].shape[2]
|
||||
self.mode = mode
|
||||
self.cache = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Wd(mode={self.mode}, threshold={self.threshold}, pin_memory={self.pin_memory}, batch_size={self.batch_size})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def from_image_inference(self, image: Image.Image) -> dict:
|
||||
imagenp = process_image(image, self.target_size)
|
||||
return self.predict(imagenp)
|
||||
|
||||
def from_ndarray_inferece(self, image: np.ndarray) -> dict:
|
||||
if image.shape != (1, 512, 512, 3):
|
||||
raise ValueError(f"Image must be {(1, 512, 512, 3)}")
|
||||
return self.predict(image)
|
||||
|
||||
def from_file_inference(self, path: str) -> dict:
|
||||
image = Image.open(path)
|
||||
return self.from_image_inference(Image.open(path))
|
||||
|
||||
def inference(self, image):
|
||||
return self.session.run(self.output_name, {self.input_name: image})[0]
|
||||
|
||||
def predict(self, image):
|
||||
result = self.inference(image)
|
||||
tags = self.tags
|
||||
for tag, score in zip(tags, result[0]):
|
||||
tag['score'] = float(score)
|
||||
ratings = tags[:4]
|
||||
tags = tags[4:]
|
||||
|
||||
image_tags = {}
|
||||
|
||||
rating = max(ratings, key=lambda el: el['score'])
|
||||
image_tags[rating['name']] = rating['score']
|
||||
|
||||
for tag in tags:
|
||||
if tag['score'] >= self.threshold:
|
||||
image_tags[tag['name']] = tag['score']
|
||||
return image_tags
|
||||
|
||||
def __call__(self, image) -> Union[dict, List[dict]]:
|
||||
if isinstance(image, str):
|
||||
return self.from_file_inference(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
return self.from_ndarray_inferece(image)
|
||||
elif isinstance(image, Image.Image):
|
||||
return self.from_image_inference(image)
|
||||
else:
|
||||
raise ValueError("Image must be a file path or a numpy array or list/tuple")
|
Loading…
x
Reference in New Issue
Block a user