import os from typing import List import spaces import gradio as gr import numpy as np import torch import json import tempfile import torch.nn.functional as F from torchvision import transforms from PIL import Image import cv2 from gradio.themes.utils import sizes from classes_and_palettes import ( COCO_KPTS_COLORS, COCO_WHOLEBODY_KPTS_COLORS, GOLIATH_KPTS_COLORS, GOLIATH_SKELETON_INFO, GOLIATH_KEYPOINTS ) import os import sys import subprocess import importlib.util def is_package_installed(package_name): return importlib.util.find_spec(package_name) is not None def find_wheel(package_path): dist_dir = os.path.join(package_path, "dist") if os.path.exists(dist_dir): wheel_files = [f for f in os.listdir(dist_dir) if f.endswith('.whl')] if wheel_files: return os.path.join(dist_dir, wheel_files[0]) return None def install_from_wheel(package_name, package_path): wheel_file = find_wheel(package_path) if wheel_file: print(f"Installing {package_name} from wheel: {wheel_file}") subprocess.check_call([sys.executable, "-m", "pip", "install", wheel_file]) else: print(f"{package_name} wheel not found in {package_path}. Please build it first.") sys.exit(1) def install_local_packages(): packages = [ ("mmengine", "./external/engine"), ("mmcv", "./external/cv"), ("mmdet", "./external/det") ] for package_name, package_path in packages: if not is_package_installed(package_name): print(f"Installing {package_name}...") install_from_wheel(package_name, package_path) else: print(f"{package_name} is already installed.") # Run the installation at the start of your app install_local_packages() from detector_utils import ( adapt_mmdet_pipeline, init_detector, process_images_detector, ) class Config: ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") CHECKPOINTS = { "0.3b": "sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2", "0.6b": "sapiens_0.6b_goliath_best_goliath_AP_600_torchscript.pt2", "1b": "sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2", } DETECTION_CHECKPOINT = os.path.join(CHECKPOINTS_DIR, 'rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth') DETECTION_CONFIG = os.path.join(ASSETS_DIR, 'rtmdet_m_640-8xb32_coco-person_no_nms.py') class ModelManager: @staticmethod def load_model(checkpoint_name: str): if checkpoint_name is None: return None checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name) model = torch.jit.load(checkpoint_path) model.eval() model.to("cuda") return model @staticmethod @torch.inference_mode() def run_model(model, input_tensor): return model(input_tensor) class ImageProcessor: def __init__(self): self.transform = transforms.Compose([ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]) ]) self.detector = init_detector( Config.DETECTION_CONFIG, Config.DETECTION_CHECKPOINT, device='cpu' ) self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) def detect_persons(self, image: Image.Image): # Convert PIL Image to tensor image = np.array(image) image = np.expand_dims(image, axis=0) # Perform person detection bboxes_batch = process_images_detector( image, self.detector ) bboxes = self.get_person_bboxes(bboxes_batch[0]) # Get bboxes for the first (and only) image return bboxes def get_person_bboxes(self, bboxes_batch, score_thr=0.3): person_bboxes = [] for bbox in bboxes_batch: if len(bbox) == 5: # [x1, y1, x2, y2, score] if bbox[4] > score_thr: person_bboxes.append(bbox) elif len(bbox) == 4: # [x1, y1, x2, y2] person_bboxes.append(bbox + [1.0]) # Add a default score of 1.0 return person_bboxes @spaces.GPU @torch.inference_mode() def estimate_pose(self, image: Image.Image, bboxes: List[List[float]], model_name: str, kpt_threshold: float): pose_model = ModelManager.load_model(Config.CHECKPOINTS[model_name]) result_image = image.copy() all_keypoints = [] # List to store keypoints for all persons for bbox in bboxes: cropped_img = self.crop_image(result_image, bbox) input_tensor = self.transform(cropped_img).unsqueeze(0).to("cuda") heatmaps = ModelManager.run_model(pose_model, input_tensor) keypoints = self.heatmaps_to_keypoints(heatmaps[0].cpu().numpy()) all_keypoints.append(keypoints) # Collect keypoints result_image = self.draw_keypoints(result_image, keypoints, bbox, kpt_threshold) return result_image, all_keypoints def process_image(self, image: Image.Image, model_name: str, kpt_threshold: str): bboxes = self.detect_persons(image) result_image, keypoints = self.estimate_pose(image, bboxes, model_name, float(kpt_threshold)) return result_image, keypoints def crop_image(self, image, bbox): if len(bbox) == 4: x1, y1, x2, y2 = map(int, bbox) elif len(bbox) >= 5: x1, y1, x2, y2, _ = map(int, bbox[:5]) else: raise ValueError(f"Unexpected bbox format: {bbox}") crop = image.crop((x1, y1, x2, y2)) return crop @staticmethod def heatmaps_to_keypoints(heatmaps): num_joints = heatmaps.shape[0] # Should be 308 keypoints = {} for i, name in enumerate(GOLIATH_KEYPOINTS): if i < num_joints: heatmap = heatmaps[i] y, x = np.unravel_index(np.argmax(heatmap), heatmap.shape) conf = heatmap[y, x] keypoints[name] = (float(x), float(y), float(conf)) return keypoints @staticmethod def draw_keypoints(image, keypoints, bbox, kpt_threshold): image = np.array(image) # Handle both 4 and 5-element bounding boxes if len(bbox) == 4: x1, y1, x2, y2 = map(int, bbox) elif len(bbox) >= 5: x1, y1, x2, y2, _ = map(int, bbox[:5]) else: raise ValueError(f"Unexpected bbox format: {bbox}") # Calculate adaptive radius and thickness based on bounding box size bbox_width = x2 - x1 bbox_height = y2 - y1 bbox_size = np.sqrt(bbox_width * bbox_height) radius = max(1, int(bbox_size * 0.006)) # minimum 1 pixel thickness = max(1, int(bbox_size * 0.006)) # minimum 1 pixel bbox_thickness = max(1, thickness//4) cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness) # Draw keypoints for i, (name, (x, y, conf)) in enumerate(keypoints.items()): if conf > kpt_threshold and i < len(GOLIATH_KPTS_COLORS): x_coord = int(x * bbox_width / 192) + x1 y_coord = int(y * bbox_height / 256) + y1 color = GOLIATH_KPTS_COLORS[i] cv2.circle(image, (x_coord, y_coord), radius, color, -1) # Draw skeleton for _, link_info in GOLIATH_SKELETON_INFO.items(): pt1_name, pt2_name = link_info['link'] color = link_info['color'] if pt1_name in keypoints and pt2_name in keypoints: pt1 = keypoints[pt1_name] pt2 = keypoints[pt2_name] if pt1[2] > kpt_threshold and pt2[2] > kpt_threshold: x1_coord = int(pt1[0] * bbox_width / 192) + x1 y1_coord = int(pt1[1] * bbox_height / 256) + y1 x2_coord = int(pt2[0] * bbox_width / 192) + x1 y2_coord = int(pt2[1] * bbox_height / 256) + y1 cv2.line(image, (x1_coord, y1_coord), (x2_coord, y2_coord), color, thickness=thickness) return Image.fromarray(image) class GradioInterface: def __init__(self): self.image_processor = ImageProcessor() def create_interface(self): app_styles = """ """ header_html = f""" {app_styles}

Sapiens: Pose Estimation

ECCV 2024 (Oral)

Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. This demo showcases the finetuned pose estimation model.

""" js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'dark') { url.searchParams.set('__theme', 'dark'); window.location.href = url.href; } } """ def process_image(image, model_name, kpt_threshold): result_image, keypoints = self.image_processor.process_image(image, model_name, kpt_threshold) with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w') as json_file: json.dump(keypoints, json_file) json_file_path = json_file.name return result_image, json_file_path with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: gr.HTML(header_html) with gr.Row(elem_classes="content-container"): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") with gr.Row(): model_name = gr.Dropdown( label="Model Size", choices=list(Config.CHECKPOINTS.keys()), value="1b", ) kpt_threshold = gr.Dropdown( label="Min Keypoint Confidence", choices=["0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "0.9"], value="0.3", ) example_model = gr.Examples( inputs=input_image, examples_per_page=14, examples=[ os.path.join(Config.ASSETS_DIR, "images", img) for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images")) ], ) with gr.Column(): result_image = gr.Image(label="Pose-308 Result", type="pil", elem_classes="image-preview") json_output = gr.File(label="Pose-308 Output (.json)") run_button = gr.Button("Run") run_button.click( fn=process_image, inputs=[input_image, model_name, kpt_threshold], outputs=[result_image, json_output], ) return demo def main(): if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True interface = GradioInterface() demo = interface.create_interface() demo.launch(share=False) if __name__ == "__main__": main()