Spaces:
Sleeping
Sleeping
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import copy | |
| import numpy as np | |
| import gradio as gr | |
| import json # Import json module | |
| from src import model | |
| from src import util | |
| from src.body import Body | |
| from src.hand import Hand | |
| # This function will generate and save the pose data as JSON | |
| def save_json(candidate, subset, json_file_path='./pose_data.json'): | |
| pose_data = { | |
| 'candidate': candidate.tolist(), | |
| 'subset': subset.tolist() | |
| } | |
| with open(json_file_path, 'w') as json_file: | |
| json.dump(pose_data, json_file) | |
| return json_file_path | |
| def pose_estimation(test_image): | |
| oriImg = cv2.cvtColor(test_image, cv2.COLOR_RGB2BGR) | |
| # bgr_image_path = './test.png' | |
| # with open(bgr_image_path, 'wb') as bgr_file: | |
| # bgr_file.write(test_image) | |
| # # Load the estimation models | |
| body_estimation = Body('model/body_pose_model.pth') | |
| hand_estimation = Hand('model/hand_pose_model.pth') | |
| # oriImg = cv2.imread(bgr_image_path) # B,G,R order | |
| # Perform pose estimation | |
| candidate, subset = body_estimation(oriImg) | |
| # canvas = copy.deepcopy(oriImg) | |
| canvas = np.zeros_like(oriImg) | |
| canvas = util.draw_bodypose(canvas, candidate, subset) | |
| hands_list = util.handDetect(candidate, subset, oriImg) | |
| all_hand_peaks = [] | |
| for x, y, w, is_left in hands_list: | |
| peaks = hand_estimation(oriImg[y:y+w, x:x+w, :]) | |
| peaks[:, 0] = np.where(peaks[:, 0]==0, peaks[:, 0], peaks[:, 0]+x) | |
| peaks[:, 1] = np.where(peaks[:, 1]==0, peaks[:, 1], peaks[:, 1]+y) | |
| all_hand_peaks.append(peaks) | |
| canvas = util.draw_handpose(canvas, all_hand_peaks) | |
| # plt.imshow(canvas[:, :, [2, 1, 0]]) | |
| # plt.axis('off') | |
| # out_image_path = './out.jpg' | |
| # plt.savefig(out_image_path) | |
| out_image_path = './out.jpg' | |
| cv2.imwrite(out_image_path, canvas) | |
| # Save JSON data and return its path | |
| json_file_path = save_json(candidate, subset) | |
| return out_image_path, json_file_path | |
| # Convert the image path to bytes for Gradio to display | |
| def convert_image_to_bytes(image_path): | |
| with open(image_path, "rb") as image_file: | |
| return image_file.read() | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Pose Estimation") | |
| with gr.Row(): | |
| # image = gr.File(label="Upload Image", type="binary") | |
| image = gr.Image(label="Upload Image", type="numpy") | |
| output_image = gr.Image(label="Estimation Result") | |
| output_json = gr.File(label="Download Pose Data as JSON", type="filepath") # Add JSON output | |
| submit_button = gr.Button("Start Estimation") | |
| # Run pose estimation and display results when the button is clicked | |
| submit_button.click( | |
| pose_estimation, | |
| inputs=[image], | |
| outputs=[output_image, output_json] # Update outputs | |
| ) | |
| # Clear the results | |
| clear_button = gr.Button("Clear") | |
| def clear_outputs(): | |
| output_image.clear() | |
| output_json.clear() # Clear JSON output as well | |
| clear_button.click( | |
| clear_outputs, | |
| inputs=[], | |
| outputs=[output_image, output_json] # Update outputs | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |