Spaces:
Sleeping
Sleeping
File size: 12,084 Bytes
43aec3d 25b27b6 43aec3d e1173e3 43aec3d 0a55189 43aec3d 0a55189 43aec3d e1173e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
"""
BiRefNet을 사용한 실제 배경 제거 API 서버
설치 필요:
pip install fastapi uvicorn python-multipart pillow
pip install torch torchvision transformers
pip install timm einops
실행:
uvicorn server_birefnet:app --reload --host 0.0.0.0 --port 8000
"""
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, ImageOps
import io
import torch
import numpy as np
from typing import Tuple
import logging
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="CleanCut API", version="1.0.0")
# CORS 설정
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 글로벌 모델 변수
model = None
device = None
def load_model():
"""BiRefNet 모델 로드"""
global model, device
try:
# GPU 사용 가능 여부 확인
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Hugging Face에서 BiRefNet 모델 로드
from transformers import AutoModelForImageSegmentation, AutoProcessor
model_name = "ZhengPeng7/BiRefNet_HR"
logger.info(f"Loading model: {model_name}")
# 모델과 프로세서 로드
model = AutoModelForImageSegmentation.from_pretrained(
model_name,
trust_remote_code=True
)
model = model.to(device)
model.eval()
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.info("Using fallback mode (returning original image)")
return False
def process_image(image: Image.Image) -> Image.Image:
"""
BiRefNet을 사용해 배경 제거
Args:
image: PIL Image 객체
Returns:
배경이 제거된 RGBA PIL Image
"""
try:
if model is None:
logger.warning("Model not loaded, returning original image with alpha channel")
return image.convert("RGBA")
# 이미지 전처리
original_size = image.size
# 모델 입력 크기로 리사이즈 (BiRefNet은 다양한 크기 지원)
# 일반적으로 1024x1024가 좋은 성능을 보임
input_size = (1024, 1024)
image_resized = image.resize(input_size, Image.Resampling.LANCZOS)
# NumPy 배열로 변환
image_np = np.array(image_resized)
# 정규화 (0-1 범위)
if image_np.max() > 1:
image_np = image_np / 255.0
# 텐서로 변환 (batch_size, channels, height, width)
image_tensor = torch.from_numpy(image_np).float()
if len(image_tensor.shape) == 3:
image_tensor = image_tensor.permute(2, 0, 1) # HWC -> CHW
image_tensor = image_tensor.unsqueeze(0) # 배치 차원 추가
image_tensor = image_tensor.to(device)
# 모델 추론 - BiRefNet의 predict 메서드 사용
with torch.no_grad():
# BiRefNet은 PIL Image를 직접 받음
try:
# predict 메서드가 있는 경우
if hasattr(model, 'predict'):
mask = model.predict(image)
# mask가 PIL Image인 경우 numpy로 변환
if isinstance(mask, Image.Image):
mask = np.array(mask) / 255.0
else:
# 일반적인 forward 방식
output = model(image_tensor)
# 출력 형식에 따라 처리
if isinstance(output, dict):
mask = output.get('logits', output.get('out', output))
elif isinstance(output, (list, tuple)):
# BiRefNet이 리스트를 반환하는 경우 (multi-scale output)
# 마지막 스케일의 출력 사용
mask = output[-1] if len(output) > 0 else output[0]
else:
mask = output
# mask가 이미 텐서가 아닌 경우 텐서로 변환
if not isinstance(mask, torch.Tensor):
if isinstance(mask, list):
mask = mask[0] if len(mask) > 0 else mask
mask = torch.tensor(mask) if not isinstance(mask, torch.Tensor) else mask
# 시그모이드 적용하여 0-1 범위로 변환
mask = torch.sigmoid(mask)
mask = mask.squeeze().cpu().numpy()
except Exception as e:
logger.error(f"Model inference failed: {e}")
raise
# 마스크를 원본 크기로 리사이즈
mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
mask_pil = mask_pil.resize(original_size, Image.Resampling.LANCZOS)
# 원본 이미지를 RGBA로 변환
image_rgba = image.convert("RGBA")
# 마스크를 알파 채널로 적용
image_rgba.putalpha(mask_pil)
return image_rgba
except Exception as e:
logger.error(f"Error processing image: {e}")
# 에러 발생 시 원본 이미지를 RGBA로 변환하여 반환
return image.convert("RGBA")
def simple_background_removal(image: Image.Image) -> Image.Image:
"""
간단한 배경 제거 (폴백 메서드)
실제 모델이 로드되지 않았을 때 사용
"""
# 이미지를 RGBA로 변환
image_rgba = image.convert("RGBA")
# 간단한 임계값 기반 마스크 생성 (데모용)
# 실제로는 BiRefNet 모델을 사용해야 함
data = image_rgba.getdata()
new_data = []
for item in data:
# 흰색 배경을 투명하게 만들기 (매우 단순한 예제)
if item[0] > 240 and item[1] > 240 and item[2] > 240:
new_data.append((item[0], item[1], item[2], 0))
else:
new_data.append(item)
image_rgba.putdata(new_data)
return image_rgba
@app.on_event("startup")
async def startup_event():
"""서버 시작 시 모델 로드"""
success = load_model()
if not success:
logger.warning("Running in demo mode without BiRefNet model")
@app.get("/")
async def root():
"""API 상태 확인"""
return {
"service": "CleanCut Background Removal API",
"status": "running",
"model_loaded": model is not None,
"device": str(device) if device else "cpu"
}
@app.get("/health")
async def health_check():
"""헬스 체크 엔드포인트"""
return {
"status": "healthy",
"model_loaded": model is not None
}
@app.post("/remove-background")
async def remove_background(
file: UploadFile = File(...),
quality: int = 95
):
"""
이미지 배경 제거 API
Args:
file: 업로드된 이미지 파일
quality: PNG 압축 품질 (1-100, 기본값 95)
Returns:
배경이 제거된 PNG 이미지
"""
try:
# 파일 유효성 검사
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
# 이미지 읽기
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# EXIF 오리엔테이션 처리
try:
# EXIF 데이터에 따라 이미지 자동 회전
image = ImageOps.exif_transpose(image)
except Exception as e:
logger.debug(f"EXIF processing skipped: {e}")
# 이미지 크기 체크
width, height = image.size
if width < 100 or height < 100:
raise HTTPException(status_code=400, detail="Image too small (minimum 100x100)")
if width > 4096 or height > 4096:
# 큰 이미지는 자동 리사이징
max_size = 2048
if width > height:
new_width = max_size
new_height = int(height * (max_size / width))
else:
new_height = max_size
new_width = int(width * (max_size / height))
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
logger.info(f"Resized image from {width}x{height} to {new_width}x{new_height}")
# RGB로 변환 (RGBA 이미지 처리를 위해)
if image.mode != 'RGB':
image = image.convert('RGB')
logger.info(f"Processing image: {file.filename}, size: {image.size}")
# 배경 제거 처리
if model is not None:
result = process_image(image)
else:
# 모델이 없으면 간단한 폴백 메서드 사용
result = simple_background_removal(image)
# PNG로 저장
output = io.BytesIO()
result.save(output, format="PNG", quality=quality, optimize=True)
output.seek(0)
return Response(
content=output.getvalue(),
media_type="image/png",
headers={
"Content-Disposition": f"attachment; filename=cleaned_{file.filename}.png"
}
)
except Exception as e:
logger.error(f"Error processing request: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/remove-background-batch")
async def remove_background_batch(files: list[UploadFile] = File(...)):
"""
여러 이미지 배경 제거 (배치 처리)
Args:
files: 업로드된 이미지 파일 리스트
Returns:
처리 결과 정보
"""
results = []
for file in files:
try:
# 각 파일 처리
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# EXIF 오리엔테이션 처리
try:
image = ImageOps.exif_transpose(image)
except Exception as e:
logger.debug(f"EXIF processing skipped: {e}")
# 이미지 크기 체크 및 리사이징
width, height = image.size
if width < 100 or height < 100:
raise ValueError("Image too small")
if width > 4096 or height > 4096:
max_size = 2048
if width > height:
new_width = max_size
new_height = int(height * (max_size / width))
else:
new_height = max_size
new_width = int(width * (max_size / height))
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
if image.mode != 'RGB':
image = image.convert('RGB')
# 배경 제거
if model is not None:
result = process_image(image)
else:
result = simple_background_removal(image)
# 결과 저장
output = io.BytesIO()
result.save(output, format="PNG", optimize=True)
results.append({
"filename": file.filename,
"status": "success",
"size": len(output.getvalue())
})
except Exception as e:
results.append({
"filename": file.filename,
"status": "failed",
"error": str(e)
})
return {"results": results}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|