File size: 12,084 Bytes
43aec3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25b27b6
43aec3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1173e3
43aec3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a55189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43aec3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a55189
 
 
 
 
 
 
 
 
 
 
 
 
 
43aec3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1173e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
"""
BiRefNet을 사용한 실제 배경 제거 API 서버

설치 필요:
pip install fastapi uvicorn python-multipart pillow
pip install torch torchvision transformers
pip install timm einops

실행:
uvicorn server_birefnet:app --reload --host 0.0.0.0 --port 8000
"""

from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, ImageOps
import io
import torch
import numpy as np
from typing import Tuple
import logging

# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="CleanCut API", version="1.0.0")

# CORS 설정
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 글로벌 모델 변수
model = None
device = None

def load_model():
    """BiRefNet 모델 로드"""
    global model, device
    
    try:
        # GPU 사용 가능 여부 확인
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {device}")
        
        # Hugging Face에서 BiRefNet 모델 로드
        from transformers import AutoModelForImageSegmentation, AutoProcessor
        
        model_name = "ZhengPeng7/BiRefNet_HR"
        logger.info(f"Loading model: {model_name}")
        
        # 모델과 프로세서 로드
        model = AutoModelForImageSegmentation.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        model = model.to(device)
        model.eval()
        
        logger.info("Model loaded successfully")
        return True
        
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        logger.info("Using fallback mode (returning original image)")
        return False

def process_image(image: Image.Image) -> Image.Image:
    """
    BiRefNet을 사용해 배경 제거
    
    Args:
        image: PIL Image 객체
        
    Returns:
        배경이 제거된 RGBA PIL Image
    """
    try:
        if model is None:
            logger.warning("Model not loaded, returning original image with alpha channel")
            return image.convert("RGBA")
        
        # 이미지 전처리
        original_size = image.size
        
        # 모델 입력 크기로 리사이즈 (BiRefNet은 다양한 크기 지원)
        # 일반적으로 1024x1024가 좋은 성능을 보임
        input_size = (1024, 1024)
        image_resized = image.resize(input_size, Image.Resampling.LANCZOS)
        
        # NumPy 배열로 변환
        image_np = np.array(image_resized)
        
        # 정규화 (0-1 범위)
        if image_np.max() > 1:
            image_np = image_np / 255.0
        
        # 텐서로 변환 (batch_size, channels, height, width)
        image_tensor = torch.from_numpy(image_np).float()
        if len(image_tensor.shape) == 3:
            image_tensor = image_tensor.permute(2, 0, 1)  # HWC -> CHW
        image_tensor = image_tensor.unsqueeze(0)  # 배치 차원 추가
        image_tensor = image_tensor.to(device)
        
        # 모델 추론 - BiRefNet의 predict 메서드 사용
        with torch.no_grad():
            # BiRefNet은 PIL Image를 직접 받음
            try:
                # predict 메서드가 있는 경우
                if hasattr(model, 'predict'):
                    mask = model.predict(image)
                    # mask가 PIL Image인 경우 numpy로 변환
                    if isinstance(mask, Image.Image):
                        mask = np.array(mask) / 255.0
                else:
                    # 일반적인 forward 방식
                    output = model(image_tensor)
                    
                    # 출력 형식에 따라 처리
                    if isinstance(output, dict):
                        mask = output.get('logits', output.get('out', output))
                    elif isinstance(output, (list, tuple)):
                        # BiRefNet이 리스트를 반환하는 경우 (multi-scale output)
                        # 마지막 스케일의 출력 사용
                        mask = output[-1] if len(output) > 0 else output[0]
                    else:
                        mask = output
                    
                    # mask가 이미 텐서가 아닌 경우 텐서로 변환
                    if not isinstance(mask, torch.Tensor):
                        if isinstance(mask, list):
                            mask = mask[0] if len(mask) > 0 else mask
                        mask = torch.tensor(mask) if not isinstance(mask, torch.Tensor) else mask
                    
                    # 시그모이드 적용하여 0-1 범위로 변환
                    mask = torch.sigmoid(mask)
                    mask = mask.squeeze().cpu().numpy()
            except Exception as e:
                logger.error(f"Model inference failed: {e}")
                raise
        
        # 마스크를 원본 크기로 리사이즈
        mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
        mask_pil = mask_pil.resize(original_size, Image.Resampling.LANCZOS)
        
        # 원본 이미지를 RGBA로 변환
        image_rgba = image.convert("RGBA")
        
        # 마스크를 알파 채널로 적용
        image_rgba.putalpha(mask_pil)
        
        return image_rgba
        
    except Exception as e:
        logger.error(f"Error processing image: {e}")
        # 에러 발생 시 원본 이미지를 RGBA로 변환하여 반환
        return image.convert("RGBA")

def simple_background_removal(image: Image.Image) -> Image.Image:
    """
    간단한 배경 제거 (폴백 메서드)
    실제 모델이 로드되지 않았을 때 사용
    """
    # 이미지를 RGBA로 변환
    image_rgba = image.convert("RGBA")
    
    # 간단한 임계값 기반 마스크 생성 (데모용)
    # 실제로는 BiRefNet 모델을 사용해야 함
    data = image_rgba.getdata()
    new_data = []
    
    for item in data:
        # 흰색 배경을 투명하게 만들기 (매우 단순한 예제)
        if item[0] > 240 and item[1] > 240 and item[2] > 240:
            new_data.append((item[0], item[1], item[2], 0))
        else:
            new_data.append(item)
    
    image_rgba.putdata(new_data)
    return image_rgba

@app.on_event("startup")
async def startup_event():
    """서버 시작 시 모델 로드"""
    success = load_model()
    if not success:
        logger.warning("Running in demo mode without BiRefNet model")

@app.get("/")
async def root():
    """API 상태 확인"""
    return {
        "service": "CleanCut Background Removal API",
        "status": "running",
        "model_loaded": model is not None,
        "device": str(device) if device else "cpu"
    }

@app.get("/health")
async def health_check():
    """헬스 체크 엔드포인트"""
    return {
        "status": "healthy",
        "model_loaded": model is not None
    }

@app.post("/remove-background")
async def remove_background(
    file: UploadFile = File(...),
    quality: int = 95
):
    """
    이미지 배경 제거 API
    
    Args:
        file: 업로드된 이미지 파일
        quality: PNG 압축 품질 (1-100, 기본값 95)
        
    Returns:
        배경이 제거된 PNG 이미지
    """
    try:
        # 파일 유효성 검사
        if not file.content_type.startswith("image/"):
            raise HTTPException(status_code=400, detail="File must be an image")
        
        # 이미지 읽기
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))
        
        # EXIF 오리엔테이션 처리
        try:
            # EXIF 데이터에 따라 이미지 자동 회전
            image = ImageOps.exif_transpose(image)
        except Exception as e:
            logger.debug(f"EXIF processing skipped: {e}")
        
        # 이미지 크기 체크
        width, height = image.size
        if width < 100 or height < 100:
            raise HTTPException(status_code=400, detail="Image too small (minimum 100x100)")
        if width > 4096 or height > 4096:
            # 큰 이미지는 자동 리사이징
            max_size = 2048
            if width > height:
                new_width = max_size
                new_height = int(height * (max_size / width))
            else:
                new_height = max_size
                new_width = int(width * (max_size / height))
            image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
            logger.info(f"Resized image from {width}x{height} to {new_width}x{new_height}")
        
        # RGB로 변환 (RGBA 이미지 처리를 위해)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        logger.info(f"Processing image: {file.filename}, size: {image.size}")
        
        # 배경 제거 처리
        if model is not None:
            result = process_image(image)
        else:
            # 모델이 없으면 간단한 폴백 메서드 사용
            result = simple_background_removal(image)
        
        # PNG로 저장
        output = io.BytesIO()
        result.save(output, format="PNG", quality=quality, optimize=True)
        output.seek(0)
        
        return Response(
            content=output.getvalue(),
            media_type="image/png",
            headers={
                "Content-Disposition": f"attachment; filename=cleaned_{file.filename}.png"
            }
        )
        
    except Exception as e:
        logger.error(f"Error processing request: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/remove-background-batch")
async def remove_background_batch(files: list[UploadFile] = File(...)):
    """
    여러 이미지 배경 제거 (배치 처리)
    
    Args:
        files: 업로드된 이미지 파일 리스트
        
    Returns:
        처리 결과 정보
    """
    results = []
    
    for file in files:
        try:
            # 각 파일 처리
            contents = await file.read()
            image = Image.open(io.BytesIO(contents))
            
            # EXIF 오리엔테이션 처리
            try:
                image = ImageOps.exif_transpose(image)
            except Exception as e:
                logger.debug(f"EXIF processing skipped: {e}")
            
            # 이미지 크기 체크 및 리사이징
            width, height = image.size
            if width < 100 or height < 100:
                raise ValueError("Image too small")
            if width > 4096 or height > 4096:
                max_size = 2048
                if width > height:
                    new_width = max_size
                    new_height = int(height * (max_size / width))
                else:
                    new_height = max_size
                    new_width = int(width * (max_size / height))
                image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
            
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # 배경 제거
            if model is not None:
                result = process_image(image)
            else:
                result = simple_background_removal(image)
            
            # 결과 저장
            output = io.BytesIO()
            result.save(output, format="PNG", optimize=True)
            
            results.append({
                "filename": file.filename,
                "status": "success",
                "size": len(output.getvalue())
            })
            
        except Exception as e:
            results.append({
                "filename": file.filename,
                "status": "failed",
                "error": str(e)
            })
    
    return {"results": results}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)