File size: 4,298 Bytes
8616c88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40c9160
 
 
8616c88
40c9160
 
 
8616c88
40c9160
 
 
 
 
 
 
 
 
 
8616c88
40c9160
 
 
 
 
 
 
 
 
 
 
 
8616c88
40c9160
 
 
 
 
 
 
 
8616c88
40c9160
8616c88
 
 
 
 
40c9160
8616c88
 
 
40c9160
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import os

# 1. 모델 아키텍처 재정의 (저장된 가중치를 불러오기 위함)
class DynamicConv2D(layers.Layer):
    def __init__(self, k=3, **kwargs):
        super().__init__(**kwargs)
        assert k % 2 == 1
        self.k = k
        self.generator = layers.Dense(k * k)

    def call(self, x):
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        kernels = self.generator(x)
        kernels = tf.nn.softmax(kernels, axis=-1)
        pad = (self.k - 1) // 2
        x_pad = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
        patches = tf.image.extract_patches(
            images=x_pad,
            sizes=[1, self.k, self.k, 1],
            strides=[1, 1, 1, 1],
            rates=[1, 1, 1, 1],
            padding='VALID'
        )
        patches = tf.reshape(patches, [B, H, W, self.k * self.k, C])
        kernels_exp = tf.expand_dims(kernels, axis=-1)
        return tf.reduce_sum(patches * kernels_exp, axis=3)

def build_dynamic_model(input_shape=(28, 28, 1), num_classes=26):
    inputs = layers.Input(shape=input_shape)
    
    # 일반 Conv로 기초 특징 추출
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    
    # Dynamic Convolution 적용
    x = DynamicConv2D(k=3)(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = DynamicConv2D(k=3)(x)
    x = layers.Activation('relu')(x)
    x = layers.GlobalAveragePooling2D()(x)
    
    x = layers.Dense(128, activation='relu')(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs, outputs)

# 2. 모델 로드 (파일 경로 확인 필요)
model = build_dynamic_model()
# 파일 경로를 절대 경로로 설정
weights_path = 'dynamic_conv_alphabet.weights.h5'

if os.path.exists(weights_path):
    model.load_weights(weights_path)
    print(f"가중치 로드 성공: {weights_path}")
else:
    print(f"오류: {weights_path} 파일을 찾을 수 없습니다. 먼저 학습을 수행하세요.")
# 3. 예측 함수 정의

# ... (상단 모델 정의 및 로드 부분은 동일)

def classify_alphabet(sketchpad):
    # 수정: sketchpad가 None이거나 데이터가 없는 경우 처리
    if sketchpad is None or (isinstance(sketchpad, dict) and sketchpad.get("composite") is None):
        return {"상태": "글씨를 기다리는 중..."}
    
    try:
        # 3-1. 이미지 데이터 추출
        # Gradio 버전에 따라 구조가 다를 수 있으므로 안전하게 접근
        img_data = sketchpad["composite"]
        
        # 만약 이미지가 투명도가 없는 3채널이라면 흑백 전환, 4채널이면 Alpha 사용
        if img_data.shape[-1] == 4:
            img = img_data[:, :, 3] # Alpha 채널 (글씨 부분)
        else:
            img = tf.image.rgb_to_grayscale(img_data)[:, :, 0]

        # 3-2. 전처리
        img = tf.cast(img, tf.float32)
        img = tf.image.resize(tf.expand_dims(img, axis=-1), (28, 28))
        
        # EMNIST 데이터 방향에 맞게 전치(Transpose)
        img = tf.image.transpose(img) 
        
        img = img / 255.0
        img = tf.expand_dims(img, axis=0)

        # 3-3. 추론
        preds = model.predict(img, verbose=0)[0]
        
        results = {}
        for i in range(26):
            char = chr(ord('A') + i)
            results[char] = float(preds[i])
            
        return results
    except Exception as e:
        return {"에러": str(e)}

# 4. Gradio 인터페이스 설정 수정
interface = gr.Interface(
    fn=classify_alphabet,
    inputs=gr.Sketchpad(label="알파벳을 그려보세요 (A-Z)", type="numpy"),
    outputs=gr.Label(num_top_classes=3, label="예측 결과"),
    title="Dynamic Conv Alphabet Recognizer",
    live=True
)

if __name__ == "__main__":
    # 수정: SSR 모드와 Hot Reload 관련 에러를 방지하기 위해 설정을 추가합니다.
    interface.launch(
        server_name="0.0.0.0", 
        ssr_mode=False  # 로그에 나온 SSR 관련 이슈 방지
    )