From 7f1394e5141304dfaccc19a53efc93d2125956a2 Mon Sep 17 00:00:00 2001 From: Tiger Ren Date: Thu, 27 Jun 2024 14:47:09 +0800 Subject: [PATCH] add Stable diffusion support --- .gitignore | 4 +++ csv_reader.py | 9 +++++-- image_downloader.py | 6 +++-- sd_gen.py | 65 +++++++++++++++++++++++++++++++++++++++++++++ sd_model.py | 33 +++++++++++++++++++++++ 5 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 sd_gen.py create mode 100644 sd_model.py diff --git a/.gitignore b/.gitignore index 5d381cc..23e8d48 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,7 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# User added ignore files +node_flow +regen_flow +backup \ No newline at end of file diff --git a/csv_reader.py b/csv_reader.py index 1fda049..8ffe270 100644 --- a/csv_reader.py +++ b/csv_reader.py @@ -1,6 +1,7 @@ import pandas as pd from avator import gen_avator from chatbot import ChatBot +from sd_gen import txt2img def log_info(id, title, description, initPrompt): print('----------id----------') @@ -26,13 +27,17 @@ df = pd.read_csv(file_path, sep=",") for index, row in df.iterrows(): retry_count = 0 max_retry_count = 3 - if index != 49: + if index != 51: continue bot = ChatBot() role_id = row['id'] role_title = row['title'] role_description = row['description'] role_initPrompt = row['initPrompt'] + print('start txt2img ') + txt2img(role_description,"", index) + print('end txt2img ') + log_info(role_id, role_title, role_description, role_initPrompt) if len(role_initPrompt) > 3000: role_initPrompt = role_initPrompt[:3000] @@ -55,7 +60,7 @@ for index, row in df.iterrows(): refined_desc = bot.talk_to_zhipu(f'"{role_description}" {safe_template}') refined_init_prompt = bot.talk_to_zhipu(f'"{role_initPrompt}" {safe_template}') success = gen_avator(row['id'], row['title'], refined_desc, refined_init_prompt, './avatar_nsfw', index=index) - + txt2img(role_description,"", index) # print(row['id'],'---', row['title'], row['description'], row['initPrompt']) print('--------------------') diff --git a/image_downloader.py b/image_downloader.py index 7bee9d8..db7694c 100644 --- a/image_downloader.py +++ b/image_downloader.py @@ -11,6 +11,8 @@ def download_image(url, filename): print(f"Error downloading image from {url}, reason: {e}") return False return True - - +def write_image_base64(b64img, filename): + import base64 + with open(filename, 'wb') as file: + file.write(base64.b64decode(b64img)) \ No newline at end of file diff --git a/sd_gen.py b/sd_gen.py new file mode 100644 index 0000000..515cad4 --- /dev/null +++ b/sd_gen.py @@ -0,0 +1,65 @@ +import os +import requests +import base64 + +base_url = 'http://192.168.2.157:7860' + +def txt2img(prompt, neg_prompt, id=None): + preset_prompt = 'realistic, masterpiece, best quality, highres, 4k, extremely detailed, ' + preset_neg_prompt = '(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)),((grayscale)), skin spots, acnes, skin blemishes, age spot, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.5), blurry, (bad anatomy:1.21), (bad proportions:1.331), extra limbs, (disfigured:1.331), (missing arms:1.331), (extra legs:1.331), (fused fingers:1.5), (too many fingers:1.5), (unclear eyes:1.331), lowers, bad hands, missing fingers, extra digit,bad hands, missing fingers, (((extra arms and legs))),' + + # 定义请求URL + url = "http://192.168.2.157:7860/sdapi/v1/txt2img" + + # 定义请求体 + data = { + "prompt": f"{preset_prompt} ${prompt}", + "negative_prompt": f"{preset_neg_prompt}, ${neg_prompt}", + "sampler_index": "DPM++ 2M", + "save_images": "true", + "width": "512", + "height": "720", + "batch_size": 4, + # "upscaler": "latent", # Specify the upscaler type + # "upscale_factor": 2 # Set the upscaling factor to 2 + # "hr_scale": 2, + # "hr_upscaler": "latent" + } + + # 发送POST请求 + response = requests.post(url, json=data) + + # 检查请求是否成功 + if response.status_code == 200: + print("Request was successful.") + print("Response:") + print(response.json()) + # 提取base64编码的图像 + response_data = response.json() + if 'images' in response_data: + for i, image_base64 in enumerate(response_data['images']): + # 解码base64字符串 + image_data = base64.b64decode(image_base64) + + # 将图像保存到本地文件 + if id == None: + id = 0 + image_filename = f"{id}_output_image_{i}.png" + dir_path = os.path.join('.', 'regen_flow') + # if dir_path does not exist, then create the dir + if not os.path.exists(dir_path): + os.mkdir(dir_path) + + with open(os.path.join(dir_path, image_filename), 'wb') as image_file: + image_file.write(image_data) + + print(f"Image {i} saved as {image_filename}") + else: + print("No images found in the response.") + else: + print(f"Request failed with status code {response.status_code}") + print("Response:") + print(response.text) + + +# txt2img("an asian girl, kneel", "") \ No newline at end of file diff --git a/sd_model.py b/sd_model.py new file mode 100644 index 0000000..bb37b08 --- /dev/null +++ b/sd_model.py @@ -0,0 +1,33 @@ +import requests +def get_sd_model(): + url = "http://192.168.2.157:7860/sdapi/v1/sd-models" + response = requests.get(url) + if response.status_code == 200: + print("Request was successful.") + print("Response:") + print(response.json()) + + +def get_current_options(): + response = requests.get(f"http://192.168.2.157:7860/sdapi/v1/options") + if response.status_code == 200: + print(response.json()) + return response.json() + else: + print(f"Failed to get options with status code {response.status_code}") + return None + +def set_new_checkpoint(checkpoint_name): + options = get_current_options() + if options is not None: + options['sd_model_checkpoint'] = checkpoint_name + response = requests.post(f"http://192.168.2.157:7860/sdapi/v1/options", json=options) + if response.status_code == 200: + print(f"Successfully set new checkpoint to {checkpoint_name}") + else: + print(f"Failed to set new checkpoint with status code {response.status_code}") + print("Response:", response.text) + +get_current_options() +set_new_checkpoint("dreamshaper_8.safetensors [879db523c3]") +get_current_options() \ No newline at end of file