File size: 3,166 Bytes
40bae10
 
 
 
 
b79f629
40bae10
 
 
 
 
b79f629
 
 
 
 
 
 
 
 
40bae10
 
446fd2b
cdbf891
 
 
 
 
40bae10
 
 
a4e53a8
40bae10
b79f629
40bae10
c51c462
 
40bae10
 
 
 
 
 
 
 
 
 
 
 
0d66e01
 
 
 
 
b79f629
0d66e01
b79f629
 
 
 
 
40bae10
 
 
 
 
 
 
 
b79f629
40bae10
cdbf891
 
40bae10
e5dfc32
40bae10
b79f629
40bae10
 
 
 
b79f629
40bae10
b79f629
40bae10
 
 
 
b79f629
40bae10
 
 
b79f629
40bae10
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import cv2
import matplotlib.pyplot as plt
import copy
import numpy as np
import gradio as gr
import json  # Import json module
from src import model
from src import util
from src.body import Body
from src.hand import Hand

# This function will generate and save the pose data as JSON
def save_json(candidate, subset, json_file_path='./pose_data.json'):
    pose_data = {
        'candidate': candidate.tolist(),
        'subset': subset.tolist()
    }
    with open(json_file_path, 'w') as json_file:
        json.dump(pose_data, json_file)
    return json_file_path

def pose_estimation(test_image):
    oriImg = cv2.cvtColor(test_image, cv2.COLOR_RGB2BGR)
    
    # bgr_image_path = './test.png'
    # with open(bgr_image_path, 'wb') as bgr_file:
    #     bgr_file.write(test_image)
    # # Load the estimation models
    body_estimation = Body('model/body_pose_model.pth')
    hand_estimation = Hand('model/hand_pose_model.pth')

    # oriImg = cv2.imread(bgr_image_path)  # B,G,R order

    # Perform pose estimation
    candidate, subset = body_estimation(oriImg)
    # canvas = copy.deepcopy(oriImg)
    canvas = np.zeros_like(oriImg)
    canvas = util.draw_bodypose(canvas, candidate, subset)
    hands_list = util.handDetect(candidate, subset, oriImg)

    all_hand_peaks = []
    for x, y, w, is_left in hands_list:
        peaks = hand_estimation(oriImg[y:y+w, x:x+w, :])
        peaks[:, 0] = np.where(peaks[:, 0]==0, peaks[:, 0], peaks[:, 0]+x)
        peaks[:, 1] = np.where(peaks[:, 1]==0, peaks[:, 1], peaks[:, 1]+y)
        all_hand_peaks.append(peaks)

    canvas = util.draw_handpose(canvas, all_hand_peaks)

    # plt.imshow(canvas[:, :, [2, 1, 0]])
    # plt.axis('off')
    # out_image_path = './out.jpg'
    # plt.savefig(out_image_path)

    out_image_path = './out.jpg'
    cv2.imwrite(out_image_path, canvas) 

    # Save JSON data and return its path
    json_file_path = save_json(candidate, subset)

    return out_image_path, json_file_path

# Convert the image path to bytes for Gradio to display
def convert_image_to_bytes(image_path):
    with open(image_path, "rb") as image_file:
        return image_file.read()

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Pose Estimation")
    with gr.Row():
        # image = gr.File(label="Upload Image", type="binary")
        image = gr.Image(label="Upload Image", type="numpy")
        output_image = gr.Image(label="Estimation Result")
        output_json = gr.File(label="Download Pose Data as JSON", type="filepath")  # Add JSON output
    submit_button = gr.Button("Start Estimation")

    # Run pose estimation and display results when the button is clicked
    submit_button.click(
        pose_estimation,
        inputs=[image],
        outputs=[output_image, output_json]  # Update outputs
    )

    # Clear the results
    clear_button = gr.Button("Clear")
    def clear_outputs():
        output_image.clear()
        output_json.clear()  # Clear JSON output as well
    clear_button.click(
        clear_outputs,
        inputs=[],
        outputs=[output_image, output_json]  # Update outputs
    )

if __name__ == "__main__":
    demo.launch(debug=True)