Add extra file for storing multimodel data in rag
Browse files
mm_rag/MLM/client.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base interface for client making requests/call to visual language model provider API"""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import List, Optional, Dict, Union, Iterator
|
| 5 |
+
import requests
|
| 6 |
+
import json
|
| 7 |
+
from utility import isBase64, encode_image, encode_image_from_path_or_url, lvlm_inference
|
| 8 |
+
|
| 9 |
+
class BaseClient(ABC):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
hostname: str = "127.0.0.1",
|
| 12 |
+
port: int = 8090,
|
| 13 |
+
timeout: int = 60,
|
| 14 |
+
url: Optional[str] = None):
|
| 15 |
+
self.connection_url = f"http://{hostname}:{port}" if url is None else url
|
| 16 |
+
self.timeout = timeout
|
| 17 |
+
# self.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
| 18 |
+
self.headers = {'Content-Type': 'application/json'}
|
| 19 |
+
|
| 20 |
+
def root(self):
|
| 21 |
+
"""Request for showing welcome message"""
|
| 22 |
+
connection_route = f"{self.connection_url}/"
|
| 23 |
+
return requests.get(connection_route)
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def generate(self,
|
| 27 |
+
prompt: str,
|
| 28 |
+
image: str,
|
| 29 |
+
**kwargs
|
| 30 |
+
) -> str:
|
| 31 |
+
"""Send request to visual language model API
|
| 32 |
+
and return generated text that was returned by the visual language model API
|
| 33 |
+
|
| 34 |
+
Use this method when you want to call visual language model API to generate text without streaming
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
prompt: A prompt.
|
| 38 |
+
image: A string that can be either path to image or base64 of an image.
|
| 39 |
+
**kwargs: Arbitrary additional keyword arguments.
|
| 40 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Text returned from visual language model provider API call
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def generate_stream(
|
| 48 |
+
self,
|
| 49 |
+
prompt: str,
|
| 50 |
+
image: str,
|
| 51 |
+
**kwargs
|
| 52 |
+
) -> Iterator[str]:
|
| 53 |
+
"""Send request to visual language model API
|
| 54 |
+
and return an iterator of streaming text that were returned from the visual language model API call
|
| 55 |
+
|
| 56 |
+
Use this method when you want to call visual language model API to stream generated text.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
prompt: A prompt.
|
| 60 |
+
image: A string that can be either path to image or base64 of an image.
|
| 61 |
+
**kwargs: Arbitrary additional keyword arguments.
|
| 62 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Iterator of text streamed from visual language model provider API call
|
| 66 |
+
"""
|
| 67 |
+
raise NotImplementedError()
|
| 68 |
+
|
| 69 |
+
def generate_batch(
|
| 70 |
+
self,
|
| 71 |
+
prompt: List[str],
|
| 72 |
+
image: List[str],
|
| 73 |
+
**kwargs
|
| 74 |
+
) -> List[str]:
|
| 75 |
+
"""Send a request to visual language model API for multi-batch generation
|
| 76 |
+
and return a list of generated text that was returned by the visual language model API
|
| 77 |
+
|
| 78 |
+
Use this method when you want to call visual language model API to multi-batch generate text.
|
| 79 |
+
Multi-batch generation does not support streaming.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
prompt: List of prompts.
|
| 83 |
+
image: List of strings; each of which can be either path to image or base64 of an image.
|
| 84 |
+
**kwargs: Arbitrary additional keyword arguments.
|
| 85 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
List of texts returned from visual language model provider API call
|
| 89 |
+
"""
|
| 90 |
+
raise NotImplementedError()
|
| 91 |
+
|
| 92 |
+
class PredictionGuardClient(BaseClient):
|
| 93 |
+
|
| 94 |
+
generate_kwargs = ['max_tokens',
|
| 95 |
+
'temperature',
|
| 96 |
+
'top_p',
|
| 97 |
+
'top_k']
|
| 98 |
+
|
| 99 |
+
def filter_accepted_genkwargs(self, kwargs):
|
| 100 |
+
gen_args = {}
|
| 101 |
+
if "generate_kwargs" in kwargs and isinstance(kwargs["generate_kwargs"], dict):
|
| 102 |
+
gen_args = {k:kwargs["generate_kwargs"][k]
|
| 103 |
+
for k in self.generate_kwargs
|
| 104 |
+
if k in kwargs["generate_kwargs"]}
|
| 105 |
+
return gen_args
|
| 106 |
+
|
| 107 |
+
def generate(self,
|
| 108 |
+
prompt: str,
|
| 109 |
+
image: str,
|
| 110 |
+
**kwargs
|
| 111 |
+
) -> str:
|
| 112 |
+
"""Send request to PredictionGuard's API
|
| 113 |
+
and return generated text that was returned by LLAVA model
|
| 114 |
+
|
| 115 |
+
Use this method when you want to call LLAVA model API to generate text without streaming
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
prompt: A prompt.
|
| 119 |
+
image: A string that can be either path/URL to image or base64 of an image.
|
| 120 |
+
**kwargs: Arbitrary additional keyword arguments.
|
| 121 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Text returned from visual language model provider API call
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
assert image is not None and len(image) != "", "the input image cannot be None, it must be either base64-encoded image or path/URL to image"
|
| 128 |
+
if isBase64(image):
|
| 129 |
+
base64_image = image
|
| 130 |
+
else: # this is path to image or URL to image
|
| 131 |
+
base64_image = encode_image_from_path_or_url(image)
|
| 132 |
+
|
| 133 |
+
args = self.filter_accepted_genkwargs(kwargs)
|
| 134 |
+
return lvlm_inference(prompt=prompt, image=base64_image, **args)
|
| 135 |
+
|
mm_rag/MLM/lvlm.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .client import PredictionGuardClient
|
| 2 |
+
from langchain_core.language_models.llms import LLM
|
| 3 |
+
from langchain_core.pydantic_v1 import Extra, root_validator
|
| 4 |
+
from typing import Any, Optional, List, Dict, Iterator, AsyncIterator
|
| 5 |
+
from langchain_core.callbacks import CallbackManagerForLLMRun
|
| 6 |
+
from utility import get_from_dict_or_env, MultimodalModelInput
|
| 7 |
+
|
| 8 |
+
from langchain_core.runnables import RunnableConfig, ensure_config
|
| 9 |
+
from langchain_core.language_models.base import LanguageModelInput
|
| 10 |
+
from langchain_core.prompt_values import StringPromptValue
|
| 11 |
+
# from langchain_core.outputs import GenerationChunk, LLMResult
|
| 12 |
+
from langchain_core.language_models.llms import BaseLLM
|
| 13 |
+
from langchain_core.callbacks import (
|
| 14 |
+
# CallbackManager,
|
| 15 |
+
CallbackManagerForLLMRun,
|
| 16 |
+
)
|
| 17 |
+
# from langchain_core.load import dumpd
|
| 18 |
+
from langchain_core.runnables.config import run_in_executor
|
| 19 |
+
|
| 20 |
+
class LVLM(LLM):
|
| 21 |
+
"""This class extends LLM class for implementing a custom request to LVLM provider API"""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
client: Any = None #: :meta private:
|
| 25 |
+
hostname: Optional[str] = None
|
| 26 |
+
port: Optional[int] = None
|
| 27 |
+
url: Optional[str] = None
|
| 28 |
+
max_new_tokens: Optional[int] = 200
|
| 29 |
+
temperature: Optional[float] = 0.6
|
| 30 |
+
top_k: Optional[float] = 0
|
| 31 |
+
stop: Optional[List[str]] = None
|
| 32 |
+
ignore_eos: Optional[bool] = False
|
| 33 |
+
do_sample: Optional[bool] = True
|
| 34 |
+
lazy_mode: Optional[bool] = True
|
| 35 |
+
hpu_graphs: Optional[bool] = True
|
| 36 |
+
|
| 37 |
+
@root_validator()
|
| 38 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
| 39 |
+
"""Validate that the access token and python package exists in environment if needed"""
|
| 40 |
+
if values['client'] is None:
|
| 41 |
+
# check if url of API is provided
|
| 42 |
+
url = get_from_dict_or_env(values, 'url', "VLM_URL", None)
|
| 43 |
+
if url is None:
|
| 44 |
+
hostname = get_from_dict_or_env(values, 'hostname', 'VLM_HOSTNAME', None)
|
| 45 |
+
port = get_from_dict_or_env(values, 'port', 'VLM_PORT', None)
|
| 46 |
+
if hostname is not None and port is not None:
|
| 47 |
+
values['client'] = PredictionGuardClient(hostname=hostname, port=port)
|
| 48 |
+
else:
|
| 49 |
+
# using default hostname and port to create Client
|
| 50 |
+
values['client'] = PredictionGuardClient()
|
| 51 |
+
else:
|
| 52 |
+
values['client'] = PredictionGuardClient(url=url)
|
| 53 |
+
return values
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def _llm_type(self) -> str:
|
| 57 |
+
"""Return type of llm"""
|
| 58 |
+
return "Large Vision Language Model"
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def _default_params(self) -> Dict[str, Any]:
|
| 62 |
+
"""Get the default parameters for calling the Prediction Guard API."""
|
| 63 |
+
return {
|
| 64 |
+
"max_tokens": self.max_new_tokens,
|
| 65 |
+
"temperature": self.temperature,
|
| 66 |
+
"top_k": self.top_k,
|
| 67 |
+
"ignore_eos": self.ignore_eos,
|
| 68 |
+
"do_sample": self.do_sample,
|
| 69 |
+
"stop" : self.stop,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def get_params(self, **kwargs):
|
| 73 |
+
params = self._default_params
|
| 74 |
+
params.update(kwargs)
|
| 75 |
+
return params
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _call(
|
| 79 |
+
self,
|
| 80 |
+
prompt: str,
|
| 81 |
+
image: str,
|
| 82 |
+
stop: Optional[List[str]] = None,
|
| 83 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 84 |
+
**kwargs: Any,
|
| 85 |
+
) -> str:
|
| 86 |
+
"""Run the VLM on the given input.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
prompt: The prompt to generate from.
|
| 90 |
+
image: This can be either path to image or base64 encode of the image.
|
| 91 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
| 92 |
+
first occurrence of any of the stop substrings.
|
| 93 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
| 94 |
+
Returns:
|
| 95 |
+
The model output as a string. Actual completions DOES NOT include the prompt
|
| 96 |
+
Example: TBD
|
| 97 |
+
"""
|
| 98 |
+
params = {}
|
| 99 |
+
if stop is not None:
|
| 100 |
+
raise ValueError("stop kwargs are not permitted.")
|
| 101 |
+
params['generate_kwargs'] = self.get_params(**kwargs)
|
| 102 |
+
response = self.client.generate(prompt=prompt, image=image, **params)
|
| 103 |
+
return response
|
| 104 |
+
|
| 105 |
+
def _stream(
|
| 106 |
+
self,
|
| 107 |
+
prompt: str,
|
| 108 |
+
image: str,
|
| 109 |
+
stop: Optional[List[str]] = None,
|
| 110 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 111 |
+
**kwargs: Any,
|
| 112 |
+
) -> Iterator[str]:
|
| 113 |
+
"""Stream the VLM on the given prompt and image.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
prompt: The prompt to generate from.
|
| 117 |
+
image: This can be either path to image or base64 encode of the image.
|
| 118 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
| 119 |
+
first occurrence of any of the stop substrings.
|
| 120 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
| 121 |
+
Returns:
|
| 122 |
+
The model outputs an iterator of string. Actual completions DOES NOT include the prompt
|
| 123 |
+
Example: TBD
|
| 124 |
+
"""
|
| 125 |
+
params = {}
|
| 126 |
+
params['generate_kwargs'] = self.get_params(**kwargs)
|
| 127 |
+
for chunk in self.client.generate_stream(prompt=prompt, image=image, **params):
|
| 128 |
+
yield chunk
|
| 129 |
+
|
| 130 |
+
async def _astream(
|
| 131 |
+
self,
|
| 132 |
+
prompt: str,
|
| 133 |
+
image: str,
|
| 134 |
+
stop: Optional[List[str]] = None,
|
| 135 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 136 |
+
**kwargs: Any,
|
| 137 |
+
) -> AsyncIterator[str]:
|
| 138 |
+
"""An async version of _stream method that stream the VLM on the given prompt and image.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
prompt: The prompt to generate from.
|
| 142 |
+
image: This can be either path to image or base64 encode of the image.
|
| 143 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
| 144 |
+
first occurrence of any of the stop substrings.
|
| 145 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
| 146 |
+
Returns:
|
| 147 |
+
The model outputs an async iterator of string. Actual completions DOES NOT include the prompt
|
| 148 |
+
Example: TBD
|
| 149 |
+
"""
|
| 150 |
+
iterator = await run_in_executor(
|
| 151 |
+
None,
|
| 152 |
+
self._stream,
|
| 153 |
+
prompt,
|
| 154 |
+
image,
|
| 155 |
+
stop,
|
| 156 |
+
run_manager.get_sync() if run_manager else None,
|
| 157 |
+
**kwargs,
|
| 158 |
+
)
|
| 159 |
+
done = object()
|
| 160 |
+
while True:
|
| 161 |
+
item = await run_in_executor(
|
| 162 |
+
None,
|
| 163 |
+
next,
|
| 164 |
+
iterator,
|
| 165 |
+
done, # type: ignore[call-arg, arg-type]
|
| 166 |
+
)
|
| 167 |
+
if item is done:
|
| 168 |
+
break
|
| 169 |
+
yield item # type: ignore[misc]
|
| 170 |
+
|
| 171 |
+
def invoke(
|
| 172 |
+
self,
|
| 173 |
+
input: MultimodalModelInput,
|
| 174 |
+
config: Optional[RunnableConfig] = None,
|
| 175 |
+
*,
|
| 176 |
+
stop: Optional[List[str]] = None,
|
| 177 |
+
**kwargs: Any,
|
| 178 |
+
) -> str:
|
| 179 |
+
config = ensure_config(config)
|
| 180 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
| 181 |
+
return (
|
| 182 |
+
self.generate_prompt(
|
| 183 |
+
[self._convert_input(StringPromptValue(text=input['prompt']))],
|
| 184 |
+
stop=stop,
|
| 185 |
+
callbacks=config.get("callbacks"),
|
| 186 |
+
tags=config.get("tags"),
|
| 187 |
+
metadata=config.get("metadata"),
|
| 188 |
+
run_name=config.get("run_name"),
|
| 189 |
+
run_id=config.pop("run_id", None),
|
| 190 |
+
image= input['image'],
|
| 191 |
+
**kwargs,
|
| 192 |
+
)
|
| 193 |
+
.generations[0][0]
|
| 194 |
+
.text
|
| 195 |
+
)
|
| 196 |
+
return (
|
| 197 |
+
self.generate_prompt(
|
| 198 |
+
[self._convert_input(input)],
|
| 199 |
+
stop=stop,
|
| 200 |
+
callbacks=config.get("callbacks"),
|
| 201 |
+
tags=config.get("tags"),
|
| 202 |
+
metadata=config.get("metadata"),
|
| 203 |
+
run_name=config.get("run_name"),
|
| 204 |
+
run_id=config.pop("run_id", None),
|
| 205 |
+
**kwargs,
|
| 206 |
+
)
|
| 207 |
+
.generations[0][0]
|
| 208 |
+
.text
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
async def ainvoke(
|
| 212 |
+
self,
|
| 213 |
+
input: MultimodalModelInput,
|
| 214 |
+
config: Optional[RunnableConfig] = None,
|
| 215 |
+
*,
|
| 216 |
+
stop: Optional[List[str]] = None,
|
| 217 |
+
**kwargs: Any,
|
| 218 |
+
) -> str:
|
| 219 |
+
config = ensure_config(config)
|
| 220 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
| 221 |
+
llm_result = await self.agenerate_prompt(
|
| 222 |
+
[self._convert_input(StringPromptValue(text=input['prompt']))],
|
| 223 |
+
stop=stop,
|
| 224 |
+
callbacks=config.get("callbacks"),
|
| 225 |
+
tags=config.get("tags"),
|
| 226 |
+
metadata=config.get("metadata"),
|
| 227 |
+
run_name=config.get("run_name"),
|
| 228 |
+
run_id=config.pop("run_id", None),
|
| 229 |
+
image=input['image'],
|
| 230 |
+
**kwargs,
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
llm_result = await self.agenerate_prompt(
|
| 234 |
+
[self._convert_input(input)],
|
| 235 |
+
stop=stop,
|
| 236 |
+
callbacks=config.get("callbacks"),
|
| 237 |
+
tags=config.get("tags"),
|
| 238 |
+
metadata=config.get("metadata"),
|
| 239 |
+
run_name=config.get("run_name"),
|
| 240 |
+
run_id=config.pop("run_id", None),
|
| 241 |
+
**kwargs,
|
| 242 |
+
)
|
| 243 |
+
return llm_result.generations[0][0].text
|
| 244 |
+
|
| 245 |
+
def stream(
|
| 246 |
+
self,
|
| 247 |
+
input: MultimodalModelInput,
|
| 248 |
+
config: Optional[RunnableConfig] = None,
|
| 249 |
+
*,
|
| 250 |
+
stop: Optional[List[str]] = None,
|
| 251 |
+
**kwargs: Any,
|
| 252 |
+
) -> Iterator[str]:
|
| 253 |
+
if type(self)._stream == BaseLLM._stream:
|
| 254 |
+
# model doesn't implement streaming, so use default implementation
|
| 255 |
+
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
| 256 |
+
else:
|
| 257 |
+
if stop is not None:
|
| 258 |
+
raise ValueError("stop kwargs are not permitted.")
|
| 259 |
+
image = None
|
| 260 |
+
prompt = None
|
| 261 |
+
if isinstance(input, dict) and 'prompt' in input.keys():
|
| 262 |
+
prompt = self._convert_input(input['prompt']).to_string()
|
| 263 |
+
else:
|
| 264 |
+
raise ValueError("prompt must be provided")
|
| 265 |
+
if isinstance(input, dict) and 'image' in input.keys():
|
| 266 |
+
image = input['image']
|
| 267 |
+
|
| 268 |
+
for chunk in self._stream(
|
| 269 |
+
prompt=prompt, image=image, **kwargs
|
| 270 |
+
):
|
| 271 |
+
yield chunk
|
| 272 |
+
|
| 273 |
+
async def astream(
|
| 274 |
+
self,
|
| 275 |
+
input: LanguageModelInput,
|
| 276 |
+
config: Optional[RunnableConfig] = None,
|
| 277 |
+
*,
|
| 278 |
+
stop: Optional[List[str]] = None,
|
| 279 |
+
**kwargs: Any,
|
| 280 |
+
) -> AsyncIterator[str]:
|
| 281 |
+
if (
|
| 282 |
+
type(self)._astream is BaseLLM._astream
|
| 283 |
+
and type(self)._stream is BaseLLM._stream
|
| 284 |
+
):
|
| 285 |
+
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
| 286 |
+
return
|
| 287 |
+
else:
|
| 288 |
+
if stop is not None:
|
| 289 |
+
raise ValueError("stop kwargs are not permitted.")
|
| 290 |
+
image = None
|
| 291 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
| 292 |
+
prompt = self._convert_input(input['prompt']).to_string()
|
| 293 |
+
image = input['image']
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError("missing image is not permitted")
|
| 296 |
+
prompt = self._convert_input(input).to_string()
|
| 297 |
+
|
| 298 |
+
async for chunk in self._astream(
|
| 299 |
+
prompt=prompt, image=image, **kwargs
|
| 300 |
+
):
|
| 301 |
+
yield chunk
|
mm_rag/embeddings/__pycache__/bridgetower_embeddings.cpython-311.pyc
ADDED
|
Binary file (3.23 kB). View file
|
|
|
mm_rag/embeddings/bridgetower_embeddings.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
from langchain_core.embeddings import Embeddings
|
| 3 |
+
from langchain_core.pydantic_v1 import (
|
| 4 |
+
BaseModel,
|
| 5 |
+
)
|
| 6 |
+
from utility import encode_image, bt_embedding_from_prediction_guard
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
class BridgeTowerEmbeddings(BaseModel, Embeddings):
|
| 10 |
+
""" BridgeTower embedding model """
|
| 11 |
+
|
| 12 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 13 |
+
"""Embed a list of documents using BridgeTower.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
texts: The list of texts to embed.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
List of embeddings, one for each text.
|
| 20 |
+
"""
|
| 21 |
+
embeddings = []
|
| 22 |
+
for text in texts:
|
| 23 |
+
embedding = bt_embedding_from_prediction_guard(text, "")
|
| 24 |
+
embeddings.append(embedding)
|
| 25 |
+
return embeddings
|
| 26 |
+
|
| 27 |
+
def embed_query(self, text: str) -> List[float]:
|
| 28 |
+
"""Embed a query using BridgeTower.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
text: The text to embed.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Embeddings for the text.
|
| 35 |
+
"""
|
| 36 |
+
return self.embed_documents([text])[0]
|
| 37 |
+
|
| 38 |
+
def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
|
| 39 |
+
"""Embed a list of image-text pairs using BridgeTower.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
texts: The list of texts to embed.
|
| 43 |
+
images: The list of path-to-images to embed
|
| 44 |
+
batch_size: the batch size to process, default to 2
|
| 45 |
+
Returns:
|
| 46 |
+
List of embeddings, one for each image-text pairs.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# the length of texts must be equal to the length of images
|
| 50 |
+
assert len(texts)==len(images), "the len of captions should be equal to the len of images"
|
| 51 |
+
|
| 52 |
+
embeddings = []
|
| 53 |
+
for path_to_img, text in tqdm(zip(images, texts), total=len(texts)):
|
| 54 |
+
embedding = bt_embedding_from_prediction_guard(text, encode_image(path_to_img))
|
| 55 |
+
embeddings.append(embedding)
|
| 56 |
+
return embeddings
|
mm_rag/vectorstores/multimodal_lancedb.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Iterable, List, Optional
|
| 2 |
+
from langchain_core.embeddings import Embeddings
|
| 3 |
+
import uuid
|
| 4 |
+
from langchain_community.vectorstores.lancedb import LanceDB
|
| 5 |
+
|
| 6 |
+
class MultimodalLanceDB(LanceDB):
|
| 7 |
+
"""`LanceDB` vector store to process multimodal data
|
| 8 |
+
|
| 9 |
+
To use, you should have ``lancedb`` python package installed.
|
| 10 |
+
You can install it with ``pip install lancedb``.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
connection: LanceDB connection to use. If not provided, a new connection
|
| 14 |
+
will be created.
|
| 15 |
+
embedding: Embedding to use for the vectorstore.
|
| 16 |
+
vector_key: Key to use for the vector in the database. Defaults to ``vector``.
|
| 17 |
+
id_key: Key to use for the id in the database. Defaults to ``id``.
|
| 18 |
+
text_key: Key to use for the text in the database. Defaults to ``text``.
|
| 19 |
+
image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.
|
| 20 |
+
table_name: Name of the table to use. Defaults to ``vectorstore``.
|
| 21 |
+
api_key: API key to use for LanceDB cloud database.
|
| 22 |
+
region: Region to use for LanceDB cloud database.
|
| 23 |
+
mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
Example:
|
| 28 |
+
.. code-block:: python
|
| 29 |
+
vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)
|
| 30 |
+
vectorstore.add_texts(['text1', 'text2'])
|
| 31 |
+
result = vectorstore.similarity_search('text1')
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
connection: Optional[Any] = None,
|
| 37 |
+
embedding: Optional[Embeddings] = None,
|
| 38 |
+
uri: Optional[str] = "/tmp/lancedb",
|
| 39 |
+
vector_key: Optional[str] = "vector",
|
| 40 |
+
id_key: Optional[str] = "id",
|
| 41 |
+
text_key: Optional[str] = "text",
|
| 42 |
+
image_path_key: Optional[str] = "image_path",
|
| 43 |
+
table_name: Optional[str] = "vectorstore",
|
| 44 |
+
api_key: Optional[str] = None,
|
| 45 |
+
region: Optional[str] = None,
|
| 46 |
+
mode: Optional[str] = "append",
|
| 47 |
+
):
|
| 48 |
+
super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
|
| 49 |
+
self._image_path_key = image_path_key
|
| 50 |
+
|
| 51 |
+
def add_text_image_pairs(
|
| 52 |
+
self,
|
| 53 |
+
texts: Iterable[str],
|
| 54 |
+
image_paths: Iterable[str],
|
| 55 |
+
metadatas: Optional[List[dict]] = None,
|
| 56 |
+
ids: Optional[List[str]] = None,
|
| 57 |
+
**kwargs: Any,
|
| 58 |
+
) -> List[str]:
|
| 59 |
+
"""Turn text-image pairs into embedding and add it to the database
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
texts: Iterable of strings to combine with corresponding images to add to the vectorstore.
|
| 63 |
+
images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.
|
| 64 |
+
metadatas: Optional list of metadatas associated with the texts.
|
| 65 |
+
ids: Optional list of ids to associate w ith the texts.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
List of ids of the added text-image pairs.
|
| 69 |
+
"""
|
| 70 |
+
# the length of texts must be equal to the length of images
|
| 71 |
+
assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"
|
| 72 |
+
|
| 73 |
+
# Embed texts and create documents
|
| 74 |
+
docs = []
|
| 75 |
+
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
| 76 |
+
embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) # type: ignore
|
| 77 |
+
for idx, text in enumerate(texts):
|
| 78 |
+
embedding = embeddings[idx]
|
| 79 |
+
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
|
| 80 |
+
docs.append(
|
| 81 |
+
{
|
| 82 |
+
self._vector_key: embedding,
|
| 83 |
+
self._id_key: ids[idx],
|
| 84 |
+
self._text_key: text,
|
| 85 |
+
self._image_path_key : image_paths[idx],
|
| 86 |
+
"metadata": metadata,
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if 'mode' in kwargs:
|
| 91 |
+
mode = kwargs['mode']
|
| 92 |
+
else:
|
| 93 |
+
mode = self.mode
|
| 94 |
+
if self._table_name in self._connection.table_names():
|
| 95 |
+
tbl = self._connection.open_table(self._table_name)
|
| 96 |
+
if self.api_key is None:
|
| 97 |
+
tbl.add(docs, mode=mode)
|
| 98 |
+
else:
|
| 99 |
+
tbl.add(docs)
|
| 100 |
+
else:
|
| 101 |
+
self._connection.create_table(self._table_name, data=docs)
|
| 102 |
+
return ids
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_text_image_pairs(
|
| 106 |
+
cls,
|
| 107 |
+
texts: List[str],
|
| 108 |
+
image_paths: List[str],
|
| 109 |
+
embedding: Embeddings,
|
| 110 |
+
metadatas: Optional[List[dict]] = None,
|
| 111 |
+
connection: Any = None,
|
| 112 |
+
vector_key: Optional[str] = "vector",
|
| 113 |
+
id_key: Optional[str] = "id",
|
| 114 |
+
text_key: Optional[str] = "text",
|
| 115 |
+
image_path_key: Optional[str] = "image_path",
|
| 116 |
+
table_name: Optional[str] = "vectorstore",
|
| 117 |
+
**kwargs: Any,
|
| 118 |
+
):
|
| 119 |
+
|
| 120 |
+
instance = MultimodalLanceDB(
|
| 121 |
+
connection=connection,
|
| 122 |
+
embedding=embedding,
|
| 123 |
+
vector_key=vector_key,
|
| 124 |
+
id_key=id_key,
|
| 125 |
+
text_key=text_key,
|
| 126 |
+
image_path_key=image_path_key,
|
| 127 |
+
table_name=table_name,
|
| 128 |
+
)
|
| 129 |
+
instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)
|
| 130 |
+
|
| 131 |
+
return instance
|