GraphicsAI / app.py
gaur3009's picture
Update app.py
d8988f8 verified
raw
history blame contribute delete
No virus
3.48 kB
import gradio as gr
from gradio_imageslider import ImageSlider
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
torch.set_float32_matmul_precision(["high", "highest"][0])
# Load BiRefNet model for background removal
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def load_img(image, output_type="numpy"):
if output_type == "pil":
return Image.open(image).convert("RGB")
else:
return np.array(Image.open(image).convert("RGB"))
def add_text_to_image(image, text, position, color, font_size):
img = Image.fromarray(image)
draw = ImageDraw.Draw(img)
font = ImageFont.truetype("arial.ttf", font_size)
draw.text(position, text, fill=color, font=font)
return np.array(img)
def inpaint_image(image, mask, inpaint_radius):
img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
mask_cv = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
result = cv2.inpaint(img_cv, mask_cv, inpaint_radius, cv2.INPAINT_TELEA)
return cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
def background_removal(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to("cuda")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
im.putalpha(mask)
return (im, origin)
def update_image(image, text, color, font_size, mask_image, inpaint_radius):
img_with_text = add_text_to_image(image, text, (50, 50), color, font_size)
if mask_image is not None:
mask = np.array(mask_image)
img_with_text = inpaint_image(img_with_text, mask, inpaint_radius)
return img_with_text
def fn(image):
return background_removal(image)
slider1 = ImageSlider(label="Original Image", type="pil")
slider2 = ImageSlider(label="Processed Image", type="pil")
image_input = gr.Image(label="Upload an image for background removal")
text_input = gr.Textbox(label="Enter Text to Add", placeholder="Your text here...")
color_input = gr.ColorPicker(label="Text Color")
font_size_input = gr.Slider(minimum=10, maximum=100, label="Font Size")
mask_input = gr.Image(type="numpy", label="Upload Mask Image (for Inpainting)", optional=True)
inpaint_radius_input = gr.Slider(minimum=1, maximum=50, value=3, label="Inpaint Radius")
bg_removal_interface = gr.Interface(
fn, inputs=image_input, outputs=slider1, examples=["chameleon.jpg"]
)
design_editing_interface = gr.Interface(
fn=lambda image, text, color, font_size, mask_image, inpaint_radius: update_image(image, text, color, font_size, mask_image, inpaint_radius),
inputs=[image_input, text_input, color_input, font_size_input, mask_input, inpaint_radius_input],
outputs=slider2
)
demo = gr.TabbedInterface(
[bg_removal_interface, design_editing_interface],
["Background Removal", "Design Editing"],
title="Advanced Image Editor"
)
if __name__ == "__main__":
demo.launch()