yeonjin98 commited on
Commit
43aec3d
·
verified ·
1 Parent(s): da9d55b

Update server_birefnet.py

Browse files
Files changed (1) hide show
  1. server_birefnet.py +327 -314
server_birefnet.py CHANGED
@@ -1,315 +1,328 @@
1
- """
2
- BiRefNet을 사용한 실제 배경 제거 API 서버
3
-
4
- 설치 필요:
5
- pip install fastapi uvicorn python-multipart pillow
6
- pip install torch torchvision transformers
7
- pip install timm einops
8
-
9
- 실행:
10
- uvicorn server_birefnet:app --reload --host 0.0.0.0 --port 8000
11
- """
12
-
13
- from fastapi import FastAPI, File, UploadFile, HTTPException
14
- from fastapi.responses import Response
15
- from fastapi.middleware.cors import CORSMiddleware
16
- from PIL import Image
17
- import io
18
- import torch
19
- import numpy as np
20
- from typing import Tuple
21
- import logging
22
-
23
- # 로깅 설정
24
- logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
-
27
- app = FastAPI(title="CleanCut API", version="1.0.0")
28
-
29
- # CORS 설정
30
- app.add_middleware(
31
- CORSMiddleware,
32
- allow_origins=["*"],
33
- allow_credentials=True,
34
- allow_methods=["*"],
35
- allow_headers=["*"],
36
- )
37
-
38
- # 글로벌 모델 변수
39
- model = None
40
- device = None
41
-
42
- def load_model():
43
- """BiRefNet 모델 로드"""
44
- global model, device
45
-
46
- try:
47
- # GPU 사용 가능 여부 확인
48
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
- logger.info(f"Using device: {device}")
50
-
51
- # Hugging Face에서 BiRefNet 모델 로드
52
- from transformers import AutoModelForImageSegmentation, AutoProcessor
53
-
54
- model_name = "ZhengPeng7/BiRefNet"
55
- logger.info(f"Loading model: {model_name}")
56
-
57
- # 모델과 프로세서 로드
58
- model = AutoModelForImageSegmentation.from_pretrained(
59
- model_name,
60
- trust_remote_code=True
61
- )
62
- model = model.to(device)
63
- model.eval()
64
-
65
- logger.info("Model loaded successfully")
66
- return True
67
-
68
- except Exception as e:
69
- logger.error(f"Failed to load model: {e}")
70
- logger.info("Using fallback mode (returning original image)")
71
- return False
72
-
73
- def process_image(image: Image.Image) -> Image.Image:
74
- """
75
- BiRefNet을 사용해 배경 제거
76
-
77
- Args:
78
- image: PIL Image 객체
79
-
80
- Returns:
81
- 배경이 제거된 RGBA PIL Image
82
- """
83
- try:
84
- if model is None:
85
- logger.warning("Model not loaded, returning original image with alpha channel")
86
- return image.convert("RGBA")
87
-
88
- # 이미지 전처리
89
- original_size = image.size
90
-
91
- # 모델 입력 크기로 리사이즈 (BiRefNet은 다양한 크기 지원)
92
- # 일반적으로 1024x1024가 좋은 성능을 보임
93
- input_size = (1024, 1024)
94
- image_resized = image.resize(input_size, Image.Resampling.LANCZOS)
95
-
96
- # NumPy 배열로 변환
97
- image_np = np.array(image_resized)
98
-
99
- # 정규화 (0-1 범위)
100
- if image_np.max() > 1:
101
- image_np = image_np / 255.0
102
-
103
- # 텐서로 변환 (batch_size, channels, height, width)
104
- image_tensor = torch.from_numpy(image_np).float()
105
- if len(image_tensor.shape) == 3:
106
- image_tensor = image_tensor.permute(2, 0, 1) # HWC -> CHW
107
- image_tensor = image_tensor.unsqueeze(0) # 배치 차원 추가
108
- image_tensor = image_tensor.to(device)
109
-
110
- # 모델 추론 - BiRefNet의 predict 메서드 사용
111
- with torch.no_grad():
112
- # BiRefNet은 PIL Image를 직접 받음
113
- try:
114
- # predict 메서드가 있는 경우
115
- if hasattr(model, 'predict'):
116
- mask = model.predict(image)
117
- # mask가 PIL Image인 경우 numpy로 변환
118
- if isinstance(mask, Image.Image):
119
- mask = np.array(mask) / 255.0
120
- else:
121
- # 일반적인 forward 방식
122
- output = model(image_tensor)
123
-
124
- # 출력 형식에 따라 처리
125
- if isinstance(output, dict):
126
- mask = output.get('logits', output.get('out', output))
127
- elif isinstance(output, (list, tuple)):
128
- # BiRefNet이 리스트를 반환하는 경우 (multi-scale output)
129
- # 마지막 스케일의 출력 사용
130
- mask = output[-1] if len(output) > 0 else output[0]
131
- else:
132
- mask = output
133
-
134
- # mask가 이미 텐서가 아닌 경우 텐서로 변환
135
- if not isinstance(mask, torch.Tensor):
136
- if isinstance(mask, list):
137
- mask = mask[0] if len(mask) > 0 else mask
138
- mask = torch.tensor(mask) if not isinstance(mask, torch.Tensor) else mask
139
-
140
- # 시그모이드 적용하여 0-1 범위로 변환
141
- mask = torch.sigmoid(mask)
142
- mask = mask.squeeze().cpu().numpy()
143
- except Exception as e:
144
- logger.error(f"Model inference failed: {e}")
145
- raise
146
-
147
- # 마스크를 원본 크기로 리사이즈
148
- mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
149
- mask_pil = mask_pil.resize(original_size, Image.Resampling.LANCZOS)
150
-
151
- # 원본 이미지를 RGBA로 변환
152
- image_rgba = image.convert("RGBA")
153
-
154
- # 마스크를 알파 채널로 적용
155
- image_rgba.putalpha(mask_pil)
156
-
157
- return image_rgba
158
-
159
- except Exception as e:
160
- logger.error(f"Error processing image: {e}")
161
- # 에러 발생 시 원본 이미지를 RGBA로 변환하여 반환
162
- return image.convert("RGBA")
163
-
164
- def simple_background_removal(image: Image.Image) -> Image.Image:
165
- """
166
- 간단한 배경 제거 (폴백 메서드)
167
- 실제 모델이 로드되지 않았을 때 사용
168
- """
169
- # 이미지를 RGBA로 변환
170
- image_rgba = image.convert("RGBA")
171
-
172
- # 간단한 임계값 기반 마스크 생성 (데모용)
173
- # 실제로는 BiRefNet 모델을 사용해야 함
174
- data = image_rgba.getdata()
175
- new_data = []
176
-
177
- for item in data:
178
- # 흰색 배경을 투명하게 만들기 (매우 단순한 예제)
179
- if item[0] > 240 and item[1] > 240 and item[2] > 240:
180
- new_data.append((item[0], item[1], item[2], 0))
181
- else:
182
- new_data.append(item)
183
-
184
- image_rgba.putdata(new_data)
185
- return image_rgba
186
-
187
- @app.on_event("startup")
188
- async def startup_event():
189
- """서버 시작 시 모델 로드"""
190
- success = load_model()
191
- if not success:
192
- logger.warning("Running in demo mode without BiRefNet model")
193
-
194
- @app.get("/")
195
- async def root():
196
- """API 상태 확인"""
197
- return {
198
- "service": "CleanCut Background Removal API",
199
- "status": "running",
200
- "model_loaded": model is not None,
201
- "device": str(device) if device else "cpu"
202
- }
203
-
204
- @app.get("/health")
205
- async def health_check():
206
- """헬스 체크 엔드포인트"""
207
- return {
208
- "status": "healthy",
209
- "model_loaded": model is not None
210
- }
211
-
212
- @app.post("/remove-background")
213
- async def remove_background(
214
- file: UploadFile = File(...),
215
- quality: int = 95
216
- ):
217
- """
218
- 이미지 배경 제거 API
219
-
220
- Args:
221
- file: 업로드된 이미지 파일
222
- quality: PNG 압축 품질 (1-100, 기본값 95)
223
-
224
- Returns:
225
- 배경이 제거된 PNG 이미지
226
- """
227
- try:
228
- # 파일 유효성 검사
229
- if not file.content_type.startswith("image/"):
230
- raise HTTPException(status_code=400, detail="File must be an image")
231
-
232
- # 이미지 읽기
233
- contents = await file.read()
234
- image = Image.open(io.BytesIO(contents))
235
-
236
- # RGB로 변환 (RGBA 이미지 처리를 위해)
237
- if image.mode != 'RGB':
238
- image = image.convert('RGB')
239
-
240
- logger.info(f"Processing image: {file.filename}, size: {image.size}")
241
-
242
- # 배경 제거 처리
243
- if model is not None:
244
- result = process_image(image)
245
- else:
246
- # 모델이 없으면 간단한 폴백 메서드 사용
247
- result = simple_background_removal(image)
248
-
249
- # PNG로 저장
250
- output = io.BytesIO()
251
- result.save(output, format="PNG", quality=quality, optimize=True)
252
- output.seek(0)
253
-
254
- return Response(
255
- content=output.getvalue(),
256
- media_type="image/png",
257
- headers={
258
- "Content-Disposition": f"attachment; filename=cleaned_{file.filename}.png"
259
- }
260
- )
261
-
262
- except Exception as e:
263
- logger.error(f"Error processing request: {e}")
264
- raise HTTPException(status_code=500, detail=str(e))
265
-
266
- @app.post("/remove-background-batch")
267
- async def remove_background_batch(files: list[UploadFile] = File(...)):
268
- """
269
- 여러 이미지 배경 제거 (배치 처리)
270
-
271
- Args:
272
- files: 업로드된 이미지 파일 리스트
273
-
274
- Returns:
275
- 처리 결과 정보
276
- """
277
- results = []
278
-
279
- for file in files:
280
- try:
281
- # 각 파일 처리
282
- contents = await file.read()
283
- image = Image.open(io.BytesIO(contents))
284
-
285
- if image.mode != 'RGB':
286
- image = image.convert('RGB')
287
-
288
- # 배경 제거
289
- if model is not None:
290
- result = process_image(image)
291
- else:
292
- result = simple_background_removal(image)
293
-
294
- # 결과 저장
295
- output = io.BytesIO()
296
- result.save(output, format="PNG", optimize=True)
297
-
298
- results.append({
299
- "filename": file.filename,
300
- "status": "success",
301
- "size": len(output.getvalue())
302
- })
303
-
304
- except Exception as e:
305
- results.append({
306
- "filename": file.filename,
307
- "status": "failed",
308
- "error": str(e)
309
- })
310
-
311
- return {"results": results}
312
-
313
- if __name__ == "__main__":
314
- import uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ """
2
+ BiRefNet을 사용한 실제 배경 제거 API 서버
3
+
4
+ 설치 필요:
5
+ pip install fastapi uvicorn python-multipart pillow
6
+ pip install torch torchvision transformers
7
+ pip install timm einops
8
+
9
+ 실행:
10
+ uvicorn server_birefnet:app --reload --host 0.0.0.0 --port 8000
11
+ """
12
+
13
+ from fastapi import FastAPI, File, UploadFile, HTTPException
14
+ from fastapi.responses import Response
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from PIL import Image, ImageOps
17
+ import io
18
+ import torch
19
+ import numpy as np
20
+ from typing import Tuple
21
+ import logging
22
+
23
+ # 로깅 설정
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ app = FastAPI(title="CleanCut API", version="1.0.0")
28
+
29
+ # CORS 설정
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"],
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # 글로벌 모델 변수
39
+ model = None
40
+ device = None
41
+
42
+ def load_model():
43
+ """BiRefNet 모델 로드"""
44
+ global model, device
45
+
46
+ try:
47
+ # GPU 사용 가능 여부 확인
48
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+ logger.info(f"Using device: {device}")
50
+
51
+ # Hugging Face에서 BiRefNet 모델 로드
52
+ from transformers import AutoModelForImageSegmentation, AutoProcessor
53
+
54
+ model_name = "ZhengPeng7/BiRefNet"
55
+ logger.info(f"Loading model: {model_name}")
56
+
57
+ # 모델과 프로세서 로드
58
+ model = AutoModelForImageSegmentation.from_pretrained(
59
+ model_name,
60
+ trust_remote_code=True
61
+ )
62
+ model = model.to(device)
63
+ model.eval()
64
+
65
+ logger.info("Model loaded successfully")
66
+ return True
67
+
68
+ except Exception as e:
69
+ logger.error(f"Failed to load model: {e}")
70
+ logger.info("Using fallback mode (returning original image)")
71
+ return False
72
+
73
+ def process_image(image: Image.Image) -> Image.Image:
74
+ """
75
+ BiRefNet을 사용해 배경 제거
76
+
77
+ Args:
78
+ image: PIL Image 객체
79
+
80
+ Returns:
81
+ 배경이 제거된 RGBA PIL Image
82
+ """
83
+ try:
84
+ if model is None:
85
+ logger.warning("Model not loaded, returning original image with alpha channel")
86
+ return image.convert("RGBA")
87
+
88
+ # 이미지 전처리
89
+ original_size = image.size
90
+
91
+ # 모델 입력 크기로 리사이즈 (BiRefNet은 다양한 크기 지원)
92
+ # 일반적으로 1024x1024가 좋은 성능을 보임
93
+ input_size = (1024, 1024)
94
+ image_resized = image.resize(input_size, Image.Resampling.LANCZOS)
95
+
96
+ # NumPy 배열로 변환
97
+ image_np = np.array(image_resized)
98
+
99
+ # 정규화 (0-1 범위)
100
+ if image_np.max() > 1:
101
+ image_np = image_np / 255.0
102
+
103
+ # 텐서로 변환 (batch_size, channels, height, width)
104
+ image_tensor = torch.from_numpy(image_np).float()
105
+ if len(image_tensor.shape) == 3:
106
+ image_tensor = image_tensor.permute(2, 0, 1) # HWC -> CHW
107
+ image_tensor = image_tensor.unsqueeze(0) # 배치 차원 추가
108
+ image_tensor = image_tensor.to(device)
109
+
110
+ # 모델 추론 - BiRefNet의 predict 메서드 사용
111
+ with torch.no_grad():
112
+ # BiRefNet은 PIL Image를 직접 받음
113
+ try:
114
+ # predict 메서드가 있는 경우
115
+ if hasattr(model, 'predict'):
116
+ mask = model.predict(image)
117
+ # mask가 PIL Image인 경우 numpy로 변환
118
+ if isinstance(mask, Image.Image):
119
+ mask = np.array(mask) / 255.0
120
+ else:
121
+ # 일반적인 forward 방식
122
+ output = model(image_tensor)
123
+
124
+ # 출력 형식에 따라 처리
125
+ if isinstance(output, dict):
126
+ mask = output.get('logits', output.get('out', output))
127
+ elif isinstance(output, (list, tuple)):
128
+ # BiRefNet이 리스트를 반환하는 경우 (multi-scale output)
129
+ # 마지막 스케일의 출력 사용
130
+ mask = output[-1] if len(output) > 0 else output[0]
131
+ else:
132
+ mask = output
133
+
134
+ # mask가 이미 텐서가 아닌 경우 텐서��� 변환
135
+ if not isinstance(mask, torch.Tensor):
136
+ if isinstance(mask, list):
137
+ mask = mask[0] if len(mask) > 0 else mask
138
+ mask = torch.tensor(mask) if not isinstance(mask, torch.Tensor) else mask
139
+
140
+ # 시그모이드 적용하여 0-1 범위로 변환
141
+ mask = torch.sigmoid(mask)
142
+ mask = mask.squeeze().cpu().numpy()
143
+ except Exception as e:
144
+ logger.error(f"Model inference failed: {e}")
145
+ raise
146
+
147
+ # 마스크를 원본 크기로 리사이즈
148
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
149
+ mask_pil = mask_pil.resize(original_size, Image.Resampling.LANCZOS)
150
+
151
+ # 원본 이미지를 RGBA로 변환
152
+ image_rgba = image.convert("RGBA")
153
+
154
+ # 마스크를 알파 채널로 적용
155
+ image_rgba.putalpha(mask_pil)
156
+
157
+ return image_rgba
158
+
159
+ except Exception as e:
160
+ logger.error(f"Error processing image: {e}")
161
+ # 에러 발생 시 원본 이미지를 RGBA로 변환하여 반환
162
+ return image.convert("RGBA")
163
+
164
+ def simple_background_removal(image: Image.Image) -> Image.Image:
165
+ """
166
+ 간단한 배경 제거 (폴백 메서드)
167
+ 실제 모델이 로드되지 않았을 때 사용
168
+ """
169
+ # 이미지를 RGBA로 변환
170
+ image_rgba = image.convert("RGBA")
171
+
172
+ # 간단한 임계값 기반 마스크 생성 (데모용)
173
+ # 실제로는 BiRefNet 모델을 사용해야 함
174
+ data = image_rgba.getdata()
175
+ new_data = []
176
+
177
+ for item in data:
178
+ # 흰색 배경을 투명하게 만들기 (매우 단순한 예제)
179
+ if item[0] > 240 and item[1] > 240 and item[2] > 240:
180
+ new_data.append((item[0], item[1], item[2], 0))
181
+ else:
182
+ new_data.append(item)
183
+
184
+ image_rgba.putdata(new_data)
185
+ return image_rgba
186
+
187
+ @app.on_event("startup")
188
+ async def startup_event():
189
+ """서버 시작 시 모델 로드"""
190
+ success = load_model()
191
+ if not success:
192
+ logger.warning("Running in demo mode without BiRefNet model")
193
+
194
+ @app.get("/")
195
+ async def root():
196
+ """API 상태 확인"""
197
+ return {
198
+ "service": "CleanCut Background Removal API",
199
+ "status": "running",
200
+ "model_loaded": model is not None,
201
+ "device": str(device) if device else "cpu"
202
+ }
203
+
204
+ @app.get("/health")
205
+ async def health_check():
206
+ """헬스 체크 엔드포인트"""
207
+ return {
208
+ "status": "healthy",
209
+ "model_loaded": model is not None
210
+ }
211
+
212
+ @app.post("/remove-background")
213
+ async def remove_background(
214
+ file: UploadFile = File(...),
215
+ quality: int = 95
216
+ ):
217
+ """
218
+ 이미지 배경 제거 API
219
+
220
+ Args:
221
+ file: 업로드된 이미지 파일
222
+ quality: PNG 압축 품질 (1-100, 기본값 95)
223
+
224
+ Returns:
225
+ 배경이 제거된 PNG 이미지
226
+ """
227
+ try:
228
+ # 파일 유효성 검사
229
+ if not file.content_type.startswith("image/"):
230
+ raise HTTPException(status_code=400, detail="File must be an image")
231
+
232
+ # 이미지 읽기
233
+ contents = await file.read()
234
+ image = Image.open(io.BytesIO(contents))
235
+
236
+ # EXIF 오리엔테이션 처리
237
+ try:
238
+ # EXIF 데이터에 따라 이미지 자동 회전
239
+ image = ImageOps.exif_transpose(image)
240
+ except Exception as e:
241
+ logger.debug(f"EXIF processing skipped: {e}")
242
+
243
+ # RGB로 변환 (RGBA 이미지 처리를 위해)
244
+ if image.mode != 'RGB':
245
+ image = image.convert('RGB')
246
+
247
+ logger.info(f"Processing image: {file.filename}, size: {image.size}")
248
+
249
+ # 배경 제거 처리
250
+ if model is not None:
251
+ result = process_image(image)
252
+ else:
253
+ # 모델이 없으면 간단한 폴백 메서드 사용
254
+ result = simple_background_removal(image)
255
+
256
+ # PNG로 저장
257
+ output = io.BytesIO()
258
+ result.save(output, format="PNG", quality=quality, optimize=True)
259
+ output.seek(0)
260
+
261
+ return Response(
262
+ content=output.getvalue(),
263
+ media_type="image/png",
264
+ headers={
265
+ "Content-Disposition": f"attachment; filename=cleaned_{file.filename}.png"
266
+ }
267
+ )
268
+
269
+ except Exception as e:
270
+ logger.error(f"Error processing request: {e}")
271
+ raise HTTPException(status_code=500, detail=str(e))
272
+
273
+ @app.post("/remove-background-batch")
274
+ async def remove_background_batch(files: list[UploadFile] = File(...)):
275
+ """
276
+ 여러 이미지 배경 제거 (배치 처리)
277
+
278
+ Args:
279
+ files: 업로드된 이미지 파일 리스트
280
+
281
+ Returns:
282
+ 처리 결과 정보
283
+ """
284
+ results = []
285
+
286
+ for file in files:
287
+ try:
288
+ # 파일 처리
289
+ contents = await file.read()
290
+ image = Image.open(io.BytesIO(contents))
291
+
292
+ # EXIF 오리엔테이션 처리
293
+ try:
294
+ image = ImageOps.exif_transpose(image)
295
+ except Exception as e:
296
+ logger.debug(f"EXIF processing skipped: {e}")
297
+
298
+ if image.mode != 'RGB':
299
+ image = image.convert('RGB')
300
+
301
+ # 배경 제거
302
+ if model is not None:
303
+ result = process_image(image)
304
+ else:
305
+ result = simple_background_removal(image)
306
+
307
+ # 결과 저장
308
+ output = io.BytesIO()
309
+ result.save(output, format="PNG", optimize=True)
310
+
311
+ results.append({
312
+ "filename": file.filename,
313
+ "status": "success",
314
+ "size": len(output.getvalue())
315
+ })
316
+
317
+ except Exception as e:
318
+ results.append({
319
+ "filename": file.filename,
320
+ "status": "failed",
321
+ "error": str(e)
322
+ })
323
+
324
+ return {"results": results}
325
+
326
+ if __name__ == "__main__":
327
+ import uvicorn
328
  uvicorn.run(app, host="0.0.0.0", port=8000)