Spaces:
Sleeping
Sleeping
| import time | |
| from io import BytesIO | |
| import os | |
| from dotenv import load_dotenv | |
| from PIL import Image | |
| import logging | |
| from typing import List | |
| from huggingface_hub import login | |
| from fastapi import FastAPI, File, UploadFile | |
| from vllm import LLM, SamplingParams | |
| import torch | |
| import torch._dynamo | |
| torch._dynamo.config.suppress_errors = True | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Set the cache directory to a writable path | |
| os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" | |
| token = os.getenv("huggingface_ankit") | |
| # Login to the Hugging Face Hub | |
| login(token) | |
| app = FastAPI() | |
| llm = None | |
| def load_vllm_model(): | |
| global llm | |
| logger.info(f"Loading vLLM model...") | |
| if llm is None: | |
| llm = LLM( | |
| model="google/paligemma2-3b-mix-448", | |
| trust_remote_code=True, | |
| max_model_len=4096, | |
| dtype="float16", | |
| ) | |
| async def batch_extract_text_vllm(files: List[UploadFile] = File(...)): | |
| try: | |
| start_time = time.time() | |
| load_vllm_model() | |
| results = [] | |
| sampling_params = SamplingParams(temperature=0.0,max_tokens=32) | |
| # Load images | |
| images = [] | |
| for file in files: | |
| image_data = await file.read() | |
| img = Image.open(BytesIO(image_data)).convert("RGB") | |
| images.append(img) | |
| for image in images: | |
| inputs = { | |
| "prompt": "ocr", | |
| "multi_modal_data": { | |
| "image": image | |
| }, | |
| } | |
| outputs = llm.generate(inputs, sampling_params) | |
| for o in outputs: | |
| generated_text = o.outputs[0].text | |
| results.append(generated_text) | |
| logger.info(f"vLLM Batch processing completed in {time.time() - start_time:.2f} seconds") | |
| return {"extracted_texts": results} | |
| except Exception as e: | |
| logger.error(f"Error in batch processing vLLM: {str(e)}") | |
| return {"error": str(e)} | |
| # # main.py | |
| # from fastapi import FastAPI, File, UploadFile | |
| # from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration | |
| # from transformers.image_utils import load_image | |
| # import torch | |
| # from io import BytesIO | |
| # import os | |
| # from dotenv import load_dotenv | |
| # from PIL import Image | |
| # from huggingface_hub import login | |
| # # Load environment variables | |
| # load_dotenv() | |
| # # Set the cache directory to a writable path | |
| # os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" | |
| # token = os.getenv("huggingface_ankit") | |
| # # Login to the Hugging Face Hub | |
| # login(token) | |
| # app = FastAPI() | |
| # model_id = "google/paligemma2-3b-mix-448" | |
| # model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda') | |
| # processor = PaliGemmaProcessor.from_pretrained(model_id) | |
| # def predict(image): | |
| # prompt = "<image> ocr" | |
| # model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda') | |
| # input_len = model_inputs["input_ids"].shape[-1] | |
| # with torch.inference_mode(): | |
| # generation = model.generate(**model_inputs, max_new_tokens=200) | |
| # torch.cuda.empty_cache() | |
| # decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n") | |
| # return decoded | |
| # @app.post("/extract_text") | |
| # async def extract_text(file: UploadFile = File(...)): | |
| # image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image | |
| # text = predict(image) | |
| # return {"extracted_text": text} | |
| # @app.post("/batch_extract_text") | |
| # async def batch_extract_text(files: list[UploadFile] = File(...)): | |
| # # if len(files) > 20: | |
| # # return {"error": "A maximum of 20 images can be processed at a time."} | |
| # images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files] | |
| # prompts = ["OCR"] * len(images) | |
| # model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device) | |
| # input_len = model_inputs["input_ids"].shape[-1] | |
| # with torch.inference_mode(): | |
| # generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) | |
| # torch.cuda.empty_cache() | |
| # extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] | |
| # return {"extracted_texts": extracted_texts} | |
| # if __name__ == "__main__": | |
| # import uvicorn | |
| # uvicorn.run(app, host="0.0.0.0", port=7860) | |
| # Global variables for model and processor | |
| # model = None | |
| # processor = None | |
| # def load_model(): | |
| # """Load model and processor when needed""" | |
| # global model, processor | |
| # if model is None: | |
| # model_id = "google/paligemma2-3b-mix-448" | |
| # logger.info(f"Loading model {model_id}") | |
| # # Load model with memory-efficient settings | |
| # model = PaliGemmaForConditionalGeneration.from_pretrained( | |
| # model_id, | |
| # device_map="auto", | |
| # torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency | |
| # ) | |
| # processor = PaliGemmaProcessor.from_pretrained(model_id) | |
| # logger.info("Model loaded successfully") | |
| # def clean_memory(): | |
| # """Force garbage collection and clear CUDA cache""" | |
| # gc.collect() | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.empty_cache() | |
| # # Clear GPU cache | |
| # torch.cuda.empty_cache() | |
| # logger.info(f"Memory allocated after clearing cache: {torch.cuda.memory_allocated()} bytes") | |
| # logger.info("Memory cleaned") | |
| # def predict(image): | |
| # """Process a single image""" | |
| # load_model() # Ensure model is loaded | |
| # # Process input | |
| # prompt = "<image> ocr" | |
| # model_inputs = processor(text=prompt, images=image, return_tensors="pt") | |
| # # Move to appropriate device | |
| # model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} | |
| # # Generate with memory optimization | |
| # with torch.inference_mode(): | |
| # generation = model.generate(**model_inputs, max_new_tokens=200) | |
| # # Decode output | |
| # decoded = processor.decode(generation[0], skip_special_tokens=True) | |
| # # Clean up intermediates | |
| # del model_inputs, generation | |
| # clean_memory() | |
| # # del model,processor | |
| # return decoded | |
| # @app.post("/extract_text") | |
| # async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = File(...)): | |
| # """Extract text from a single image""" | |
| # try: | |
| # start_time = time.time() | |
| # image = Image.open(BytesIO(await file.read())).convert("RGB") | |
| # text = predict(image) | |
| # # Schedule cleanup after response | |
| # background_tasks.add_task(clean_memory) | |
| # logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds") | |
| # return {"extracted_text": text} | |
| # except Exception as e: | |
| # logger.error(f"Error processing image: {str(e)}") | |
| # return {"error": str(e)} | |
| # @app.post("/batch_extract_text") | |
| # async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)): | |
| # """Extract text from multiple images with batching""" | |
| # try: | |
| # start_time = time.time() | |
| # # Limit batch size for memory management | |
| # max_batch_size = 32 # Adjust based on your GPU memory | |
| # # if len(files) > 32: | |
| # # return {"error": "A maximum of 20 images can be processed at a time."} | |
| # load_model() # Ensure model is loaded | |
| # all_results = [] | |
| # # Process in smaller batches | |
| # for i in range(0, len(files), max_batch_size): | |
| # batch_files = files[i:i+max_batch_size] | |
| # # Load images | |
| # images = [] | |
| # for file in batch_files: | |
| # image_data = await file.read() | |
| # img = Image.open(BytesIO(image_data)).convert("RGB") | |
| # images.append(img) | |
| # # Create batch inputs | |
| # prompts = ["<image> ocr"] * len(images) | |
| # model_inputs = processor(text=prompts, images=images, return_tensors="pt") | |
| # # Move to appropriate device | |
| # model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} | |
| # # Generate with memory optimization | |
| # with torch.inference_mode(): | |
| # generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) | |
| # # Decode outputs | |
| # batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] | |
| # all_results.extend(batch_results) | |
| # # Clean up batch resources | |
| # del model_inputs, generations, images | |
| # clean_memory() | |
| # # Schedule cleanup after response | |
| # background_tasks.add_task(clean_memory) | |
| # logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds") | |
| # return {"extracted_texts": all_results} | |
| # except Exception as e: | |
| # logger.error(f"Error in batch processing: {str(e)}") | |
| # return {"error": str(e)} | |
| # Health check endpoint | |
| # @app.get("/health") | |
| # async def health_check(): | |
| # # Generate a random image (20x40 pixels) with random RGB values | |
| # random_data = np.random.randint(0, 256, (20, 40, 3), dtype=np.uint8) | |
| # # Create an image from the random data | |
| # image = Image.fromarray(random_data) | |
| # predict(image) | |
| # clean_memory() | |
| # return {"status": "healthy"} | |
| # if __name__ == "__main__": | |
| # import uvicorn | |
| # # Start the server with proper worker configuration | |
| # uvicorn.run( | |
| # app, | |
| # host="0.0.0.0", | |
| # port=7860, | |
| # log_level="info", | |
| # workers=1 # Multiple workers can cause GPU memory issues | |
| # ) |