Upload files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- inference/infer.py +317 -0
- inference/infer_ola_internvl.py +448 -0
- inference/infer_ola_internvl_audio.py +244 -0
- inference/infer_ola_internvl_audio_ckpt.py +245 -0
- inference/infer_ola_internvl_copy.py +318 -0
- inference/infer_ola_internvl_text_visual.py +409 -0
- inference/log.txt +480 -0
- inference/log1.txt +339 -0
- ola.egg-info/PKG-INFO +265 -0
- ola.egg-info/SOURCES.txt +44 -0
- ola.egg-info/dependency_links.txt +1 -0
- ola.egg-info/requires.txt +40 -0
- ola.egg-info/top_level.txt +4 -0
- ola/__pycache__/arguments.cpython-312.pyc +0 -0
- ola/__pycache__/constants.cpython-312.pyc +0 -0
- ola/__pycache__/conversation.cpython-312.pyc +0 -0
- ola/__pycache__/mm_utils.cpython-312.pyc +0 -0
- ola/__pycache__/utils.cpython-312.pyc +0 -0
- ola/arguments.py +65 -0
- ola/constants.py +14 -0
- ola/conversation.py +266 -0
- ola/datasets/__init__.py +0 -0
- ola/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
- ola/datasets/__pycache__/preprocess.cpython-312.pyc +0 -0
- ola/datasets/preprocess.py +234 -0
- ola/mm_utils.py +271 -0
- ola/model/__init__.py +2 -0
- ola/model/__pycache__/__init__.cpython-312.pyc +0 -0
- ola/model/__pycache__/builder.cpython-312.pyc +0 -0
- ola/model/__pycache__/ola_arch.cpython-312.pyc +0 -0
- ola/model/builder.py +97 -0
- ola/model/builder_back.py +294 -0
- ola/model/language_model/__pycache__/conversation.cpython-312.pyc +0 -0
- ola/model/language_model/__pycache__/ola_qwen.cpython-312.pyc +0 -0
- ola/model/language_model/__pycache__/ola_qwen3.cpython-312.pyc +0 -0
- ola/model/language_model/conversation.py +403 -0
- ola/model/language_model/ola_qwen.py +237 -0
- ola/model/language_model/ola_qwen3.py +466 -0
- ola/model/multimodal_encoder/__pycache__/builder.cpython-312.pyc +0 -0
- ola/model/multimodal_encoder/__pycache__/configuration_intern_vit.cpython-312.pyc +0 -0
- ola/model/multimodal_encoder/__pycache__/internvl_vit.cpython-312.pyc +0 -0
- ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-312.pyc +0 -0
- ola/model/multimodal_encoder/builder.py +16 -0
- ola/model/multimodal_encoder/configuration_intern_vit.py +119 -0
- ola/model/multimodal_encoder/internvl_vit.py +435 -0
- ola/model/multimodal_encoder/oryx_vit.py +1075 -0
- ola/model/multimodal_projector/__pycache__/builder.cpython-312.pyc +0 -0
- ola/model/multimodal_projector/__pycache__/internvl_projector.cpython-312.pyc +0 -0
- ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-312.pyc +0 -0
- ola/model/multimodal_projector/builder.py +177 -0
inference/infer.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['LOWRES_RESIZE'] = '384x32'
|
| 4 |
+
os.environ['HIGHRES_BASE'] = '0x32'
|
| 5 |
+
os.environ['VIDEO_RESIZE'] = "0x64"
|
| 6 |
+
os.environ['VIDEO_MAXRES'] = "480"
|
| 7 |
+
os.environ['VIDEO_MINRES'] = "288"
|
| 8 |
+
os.environ['MAXRES'] = '1536'
|
| 9 |
+
os.environ['MINRES'] = '0'
|
| 10 |
+
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
| 11 |
+
os.environ['LOAD_VISION_EARLY'] = '1'
|
| 12 |
+
os.environ['PAD2STRIDE'] = '1'
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import torch
|
| 16 |
+
import re
|
| 17 |
+
from decord import VideoReader, cpu
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import numpy as np
|
| 20 |
+
import transformers
|
| 21 |
+
import moviepy as mp
|
| 22 |
+
from typing import Dict, Optional, Sequence, List
|
| 23 |
+
import librosa
|
| 24 |
+
import whisper
|
| 25 |
+
from ola.conversation import conv_templates, SeparatorStyle
|
| 26 |
+
from ola.model.builder import load_pretrained_model
|
| 27 |
+
from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
| 28 |
+
from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
| 29 |
+
from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
| 30 |
+
import argparse
|
| 31 |
+
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
parser.add_argument('--model_path', type=str, default='/data1/cxy/model/THUdyh/Ola-7b')
|
| 34 |
+
parser.add_argument('--text', type=str, default="What does the speech say?")
|
| 35 |
+
parser.add_argument('--audio_path', type=str, default="/data1/cxy/dataset/english.mp3")
|
| 36 |
+
parser.add_argument('--image_path', type=str, default=None)
|
| 37 |
+
parser.add_argument('--video_path', type=str, default=None)
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
model_path = args.model_path
|
| 41 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None)
|
| 42 |
+
model = model.to('cuda').eval()
|
| 43 |
+
model = model.bfloat16()
|
| 44 |
+
# breakpoint()
|
| 45 |
+
USE_SPEECH=False
|
| 46 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 47 |
+
|
| 48 |
+
def load_audio(audio_file_name):
|
| 49 |
+
speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
|
| 50 |
+
if len(speech_wav.shape) > 1:
|
| 51 |
+
speech_wav = speech_wav[:, 0]
|
| 52 |
+
speech_wav = speech_wav.astype(np.float32)
|
| 53 |
+
CHUNK_LIM = 480000
|
| 54 |
+
SAMPLE_RATE = 16000
|
| 55 |
+
speechs = []
|
| 56 |
+
speech_wavs = []
|
| 57 |
+
|
| 58 |
+
if len(speech_wav) <= CHUNK_LIM:
|
| 59 |
+
speech = whisper.pad_or_trim(speech_wav)
|
| 60 |
+
speech_wav = whisper.pad_or_trim(speech_wav)
|
| 61 |
+
speechs.append(speech)
|
| 62 |
+
speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0))
|
| 63 |
+
else:
|
| 64 |
+
for i in range(0, len(speech_wav), CHUNK_LIM):
|
| 65 |
+
chunk = speech_wav[i : i + CHUNK_LIM]
|
| 66 |
+
if len(chunk) < CHUNK_LIM:
|
| 67 |
+
chunk = whisper.pad_or_trim(chunk)
|
| 68 |
+
speechs.append(chunk)
|
| 69 |
+
speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
|
| 70 |
+
mels = []
|
| 71 |
+
for chunk in speechs:
|
| 72 |
+
chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
|
| 73 |
+
mels.append(chunk)
|
| 74 |
+
|
| 75 |
+
mels = torch.cat(mels, dim=0)
|
| 76 |
+
speech_wavs = torch.cat(speech_wavs, dim=0)
|
| 77 |
+
if mels.shape[0] > 25:
|
| 78 |
+
mels = mels[:25]
|
| 79 |
+
speech_wavs = speech_wavs[:25]
|
| 80 |
+
|
| 81 |
+
speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
|
| 82 |
+
speech_chunks = torch.LongTensor([mels.shape[0]])
|
| 83 |
+
return mels, speech_length, speech_chunks, speech_wavs
|
| 84 |
+
|
| 85 |
+
def extract_audio(videos_file_path):
|
| 86 |
+
my_clip = mp.VideoFileClip(videos_file_path)
|
| 87 |
+
return my_clip.audio
|
| 88 |
+
|
| 89 |
+
image_path = args.image_path
|
| 90 |
+
audio_path = args.audio_path
|
| 91 |
+
video_path = args.video_path
|
| 92 |
+
text = args.text
|
| 93 |
+
|
| 94 |
+
if video_path is not None:
|
| 95 |
+
modality = "video"
|
| 96 |
+
visual = video_path
|
| 97 |
+
assert image_path is None
|
| 98 |
+
|
| 99 |
+
elif image_path is not None:
|
| 100 |
+
visual = image_path
|
| 101 |
+
modality = "image"
|
| 102 |
+
assert video_path is None
|
| 103 |
+
|
| 104 |
+
elif audio_path is not None:
|
| 105 |
+
modality = "text"
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# input audio and video, do not parse audio in the video, else parse audio in the video
|
| 109 |
+
if audio_path:
|
| 110 |
+
USE_SPEECH = True
|
| 111 |
+
elif modality == "video":
|
| 112 |
+
USE_SPEECH = True
|
| 113 |
+
else:
|
| 114 |
+
USE_SPEECH = False
|
| 115 |
+
|
| 116 |
+
speechs = []
|
| 117 |
+
speech_lengths = []
|
| 118 |
+
speech_wavs = []
|
| 119 |
+
speech_chunks = []
|
| 120 |
+
if modality == "video":
|
| 121 |
+
vr = VideoReader(visual, ctx=cpu(0))
|
| 122 |
+
total_frame_num = len(vr)
|
| 123 |
+
fps = round(vr.get_avg_fps())
|
| 124 |
+
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
|
| 125 |
+
frame_idx = uniform_sampled_frames.tolist()
|
| 126 |
+
spare_frames = vr.get_batch(frame_idx).asnumpy()
|
| 127 |
+
video = [Image.fromarray(frame) for frame in spare_frames]
|
| 128 |
+
elif modality == "image":
|
| 129 |
+
image = [Image.open(visual)]
|
| 130 |
+
image_sizes = [image[0].size]
|
| 131 |
+
else:
|
| 132 |
+
images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
| 133 |
+
images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
| 134 |
+
image_sizes = [(224, 224)]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if USE_SPEECH and audio_path:
|
| 138 |
+
audio_path = audio_path
|
| 139 |
+
speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path)
|
| 140 |
+
speechs.append(speech.bfloat16().to('cuda'))
|
| 141 |
+
speech_lengths.append(speech_length.to('cuda'))
|
| 142 |
+
speech_chunks.append(speech_chunk.to('cuda'))
|
| 143 |
+
speech_wavs.append(speech_wav.to('cuda'))
|
| 144 |
+
print('load audio')
|
| 145 |
+
elif USE_SPEECH and not audio_path:
|
| 146 |
+
# parse audio in the video
|
| 147 |
+
audio = extract_audio(visual)
|
| 148 |
+
audio.write_audiofile("./video_audio.wav")
|
| 149 |
+
video_audio_path = './video_audio.wav'
|
| 150 |
+
speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path)
|
| 151 |
+
speechs.append(speech.bfloat16().to('cuda'))
|
| 152 |
+
speech_lengths.append(speech_length.to('cuda'))
|
| 153 |
+
speech_chunks.append(speech_chunk.to('cuda'))
|
| 154 |
+
speech_wavs.append(speech_wav.to('cuda'))
|
| 155 |
+
else:
|
| 156 |
+
speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
|
| 157 |
+
speech_lengths = [torch.LongTensor([3000]).to('cuda')]
|
| 158 |
+
speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
|
| 159 |
+
speech_chunks = [torch.LongTensor([1]).to('cuda')]
|
| 160 |
+
|
| 161 |
+
conv_mode = "qwen_1_5"
|
| 162 |
+
if text:
|
| 163 |
+
qs = text
|
| 164 |
+
else:
|
| 165 |
+
qs = ''
|
| 166 |
+
|
| 167 |
+
if USE_SPEECH and audio_path and image_path: # image + speech instruction
|
| 168 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + "User's question in speech: " + DEFAULT_SPEECH_TOKEN + '\n'
|
| 169 |
+
elif USE_SPEECH and video_path: # video + audio
|
| 170 |
+
qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs
|
| 171 |
+
elif USE_SPEECH and audio_path: # audio + text
|
| 172 |
+
qs = DEFAULT_SPEECH_TOKEN + "\n" + qs
|
| 173 |
+
elif image_path or video_path: # image / video
|
| 174 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
| 175 |
+
elif text: # text
|
| 176 |
+
qs = qs
|
| 177 |
+
|
| 178 |
+
conv = conv_templates[conv_mode].copy()
|
| 179 |
+
conv.append_message(conv.roles[0], qs)
|
| 180 |
+
conv.append_message(conv.roles[1], None)
|
| 181 |
+
prompt = conv.get_prompt()
|
| 182 |
+
if USE_SPEECH and audio_path and image_path: # image + speech instruction
|
| 183 |
+
input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
| 184 |
+
elif USE_SPEECH and video_path: # video + audio
|
| 185 |
+
input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
| 186 |
+
elif USE_SPEECH and audio_path: # audio + text
|
| 187 |
+
# breakpoint()
|
| 188 |
+
input_ids = tokenizer_speech_token(prompt, tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
| 189 |
+
else:
|
| 190 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
| 191 |
+
|
| 192 |
+
if modality == "video":
|
| 193 |
+
video_processed = []
|
| 194 |
+
for idx, frame in enumerate(video):
|
| 195 |
+
image_processor.do_resize = False
|
| 196 |
+
image_processor.do_center_crop = False
|
| 197 |
+
frame = process_anyres_video(frame, image_processor)
|
| 198 |
+
|
| 199 |
+
if frame_idx is not None and idx in frame_idx:
|
| 200 |
+
video_processed.append(frame.unsqueeze(0))
|
| 201 |
+
elif frame_idx is None:
|
| 202 |
+
video_processed.append(frame.unsqueeze(0))
|
| 203 |
+
|
| 204 |
+
if frame_idx is None:
|
| 205 |
+
frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
|
| 206 |
+
|
| 207 |
+
video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda")
|
| 208 |
+
video_processed = (video_processed, video_processed)
|
| 209 |
+
|
| 210 |
+
video_data = (video_processed, (384, 384), "video")
|
| 211 |
+
elif modality == "image":
|
| 212 |
+
image_processor.do_resize = False
|
| 213 |
+
image_processor.do_center_crop = False
|
| 214 |
+
image_tensor, image_highres_tensor = [], []
|
| 215 |
+
for visual in image:
|
| 216 |
+
image_tensor_, image_highres_tensor_ = process_anyres_highres_image(visual, image_processor)
|
| 217 |
+
image_tensor.append(image_tensor_)
|
| 218 |
+
image_highres_tensor.append(image_highres_tensor_)
|
| 219 |
+
if all(x.shape == image_tensor[0].shape for x in image_tensor):
|
| 220 |
+
image_tensor = torch.stack(image_tensor, dim=0)
|
| 221 |
+
if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor):
|
| 222 |
+
image_highres_tensor = torch.stack(image_highres_tensor, dim=0)
|
| 223 |
+
if type(image_tensor) is list:
|
| 224 |
+
image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor]
|
| 225 |
+
else:
|
| 226 |
+
image_tensor = image_tensor.bfloat16().to("cuda")
|
| 227 |
+
if type(image_highres_tensor) is list:
|
| 228 |
+
image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor]
|
| 229 |
+
else:
|
| 230 |
+
image_highres_tensor = image_highres_tensor.bfloat16().to("cuda")
|
| 231 |
+
|
| 232 |
+
pad_token_ids = 151643
|
| 233 |
+
|
| 234 |
+
attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
|
| 235 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 236 |
+
keywords = [stop_str]
|
| 237 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 238 |
+
|
| 239 |
+
gen_kwargs = {}
|
| 240 |
+
|
| 241 |
+
if "max_new_tokens" not in gen_kwargs:
|
| 242 |
+
gen_kwargs["max_new_tokens"] = 1024
|
| 243 |
+
if "temperature" not in gen_kwargs:
|
| 244 |
+
gen_kwargs["temperature"] = 0.2
|
| 245 |
+
if "top_p" not in gen_kwargs:
|
| 246 |
+
gen_kwargs["top_p"] = None
|
| 247 |
+
if "num_beams" not in gen_kwargs:
|
| 248 |
+
gen_kwargs["num_beams"] = 1
|
| 249 |
+
# breakpoint()
|
| 250 |
+
with torch.inference_mode():
|
| 251 |
+
if modality == "video":
|
| 252 |
+
output_ids = model.generate(
|
| 253 |
+
inputs=input_ids,
|
| 254 |
+
images=video_data[0][0],
|
| 255 |
+
images_highres=video_data[0][1],
|
| 256 |
+
modalities=video_data[2],
|
| 257 |
+
speech=speechs,
|
| 258 |
+
speech_lengths=speech_lengths,
|
| 259 |
+
speech_chunks=speech_chunks,
|
| 260 |
+
speech_wav=speech_wavs,
|
| 261 |
+
attention_mask=attention_masks,
|
| 262 |
+
use_cache=True,
|
| 263 |
+
stopping_criteria=[stopping_criteria],
|
| 264 |
+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
|
| 265 |
+
temperature=gen_kwargs["temperature"],
|
| 266 |
+
top_p=gen_kwargs["top_p"],
|
| 267 |
+
num_beams=gen_kwargs["num_beams"],
|
| 268 |
+
max_new_tokens=gen_kwargs["max_new_tokens"],
|
| 269 |
+
)
|
| 270 |
+
elif modality == "image":
|
| 271 |
+
output_ids = model.generate(
|
| 272 |
+
inputs=input_ids,
|
| 273 |
+
images=image_tensor,
|
| 274 |
+
images_highres=image_highres_tensor,
|
| 275 |
+
image_sizes=image_sizes,
|
| 276 |
+
modalities=['image'],
|
| 277 |
+
speech=speechs,
|
| 278 |
+
speech_lengths=speech_lengths,
|
| 279 |
+
speech_chunks=speech_chunks,
|
| 280 |
+
speech_wav=speech_wavs,
|
| 281 |
+
attention_mask=attention_masks,
|
| 282 |
+
use_cache=True,
|
| 283 |
+
stopping_criteria=[stopping_criteria],
|
| 284 |
+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
|
| 285 |
+
temperature=gen_kwargs["temperature"],
|
| 286 |
+
top_p=gen_kwargs["top_p"],
|
| 287 |
+
num_beams=gen_kwargs["num_beams"],
|
| 288 |
+
max_new_tokens=gen_kwargs["max_new_tokens"],
|
| 289 |
+
)
|
| 290 |
+
elif modality == "text":
|
| 291 |
+
output_ids = model.generate(
|
| 292 |
+
input_ids,
|
| 293 |
+
images=images,
|
| 294 |
+
images_highres=images_highres,
|
| 295 |
+
image_sizes=image_sizes,
|
| 296 |
+
modalities=['text'],
|
| 297 |
+
speech=speechs,
|
| 298 |
+
speech_lengths=speech_lengths,
|
| 299 |
+
speech_chunks=speech_chunks,
|
| 300 |
+
speech_wav=speech_wavs,
|
| 301 |
+
attention_mask=attention_masks,
|
| 302 |
+
use_cache=True,
|
| 303 |
+
stopping_criteria=[stopping_criteria],
|
| 304 |
+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
|
| 305 |
+
temperature=gen_kwargs["temperature"],
|
| 306 |
+
top_p=gen_kwargs["top_p"],
|
| 307 |
+
num_beams=gen_kwargs["num_beams"],
|
| 308 |
+
max_new_tokens=gen_kwargs["max_new_tokens"],
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 312 |
+
outputs = outputs.strip()
|
| 313 |
+
if outputs.endswith(stop_str):
|
| 314 |
+
outputs = outputs[:-len(stop_str)]
|
| 315 |
+
outputs = outputs.strip()
|
| 316 |
+
|
| 317 |
+
print(outputs)
|
inference/infer_ola_internvl.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['LOWRES_RESIZE'] = '384x32'
|
| 4 |
+
os.environ['HIGHRES_BASE'] = '0x32'
|
| 5 |
+
os.environ['VIDEO_RESIZE'] = "0x64"
|
| 6 |
+
os.environ['VIDEO_MAXRES'] = "480"
|
| 7 |
+
os.environ['VIDEO_MINRES'] = "288"
|
| 8 |
+
os.environ['MAXRES'] = '1536'
|
| 9 |
+
os.environ['MINRES'] = '0'
|
| 10 |
+
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
| 11 |
+
os.environ['LOAD_VISION_EARLY'] = '1'
|
| 12 |
+
os.environ['PAD2STRIDE'] = '1'
|
| 13 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import math
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torchvision.transforms as T
|
| 21 |
+
from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 24 |
+
from transformers import AutoModel, AutoTokenizer
|
| 25 |
+
from contextlib import redirect_stdout
|
| 26 |
+
import io
|
| 27 |
+
import librosa
|
| 28 |
+
import whisper
|
| 29 |
+
import moviepy as mp
|
| 30 |
+
import torch
|
| 31 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModel
|
| 32 |
+
|
| 33 |
+
# pure text
|
| 34 |
+
# image + text
|
| 35 |
+
# video + text
|
| 36 |
+
# audio + text
|
| 37 |
+
# video + audio + text
|
| 38 |
+
|
| 39 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 40 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 41 |
+
import gradio as gr
|
| 42 |
+
import torch
|
| 43 |
+
import re
|
| 44 |
+
from decord import VideoReader, cpu
|
| 45 |
+
from PIL import Image
|
| 46 |
+
import numpy as np
|
| 47 |
+
import transformers
|
| 48 |
+
import moviepy as mp
|
| 49 |
+
from typing import Dict, Optional, Sequence, List
|
| 50 |
+
import librosa
|
| 51 |
+
import whisper
|
| 52 |
+
from ola.conversation import conv_templates, SeparatorStyle
|
| 53 |
+
from ola.model.builder import load_pretrained_model
|
| 54 |
+
from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
| 55 |
+
from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
| 56 |
+
from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
| 57 |
+
import argparse
|
| 58 |
+
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B')
|
| 61 |
+
parser.add_argument('--text', type=str, default="What does the speech say?")
|
| 62 |
+
parser.add_argument('--audio_path', type=str, default=None)
|
| 63 |
+
parser.add_argument('--image_path', type=str, default=None)
|
| 64 |
+
parser.add_argument('--video_path', type=str, default=None)
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
model_path = args.model_path
|
| 68 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None)
|
| 69 |
+
model = model.to('cuda').eval()
|
| 70 |
+
model = model.bfloat16()
|
| 71 |
+
|
| 72 |
+
resource_path = "/data1/cxy/plm-v/modeling/example/"
|
| 73 |
+
# set the max number of tiles in `max_num`
|
| 74 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 75 |
+
best_ratio_diff = float('inf')
|
| 76 |
+
best_ratio = (1, 1)
|
| 77 |
+
area = width * height
|
| 78 |
+
for ratio in target_ratios:
|
| 79 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 80 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 81 |
+
if ratio_diff < best_ratio_diff:
|
| 82 |
+
best_ratio_diff = ratio_diff
|
| 83 |
+
best_ratio = ratio
|
| 84 |
+
elif ratio_diff == best_ratio_diff:
|
| 85 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 86 |
+
best_ratio = ratio
|
| 87 |
+
return best_ratio
|
| 88 |
+
def build_transform(input_size):
|
| 89 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
| 90 |
+
transform = T.Compose([
|
| 91 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 92 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 93 |
+
T.ToTensor(),
|
| 94 |
+
T.Normalize(mean=MEAN, std=STD)
|
| 95 |
+
])
|
| 96 |
+
return transform
|
| 97 |
+
|
| 98 |
+
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
| 99 |
+
orig_width, orig_height = image.size
|
| 100 |
+
aspect_ratio = orig_width / orig_height
|
| 101 |
+
|
| 102 |
+
# calculate the existing image aspect ratio
|
| 103 |
+
target_ratios = set(
|
| 104 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
| 105 |
+
i * j <= max_num and i * j >= min_num)
|
| 106 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 107 |
+
|
| 108 |
+
# find the closest aspect ratio to the target
|
| 109 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
| 110 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 111 |
+
|
| 112 |
+
# calculate the target width and height
|
| 113 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 114 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 115 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 116 |
+
|
| 117 |
+
# resize the image
|
| 118 |
+
resized_img = image.resize((target_width, target_height))
|
| 119 |
+
processed_images = []
|
| 120 |
+
for i in range(blocks):
|
| 121 |
+
box = (
|
| 122 |
+
(i % (target_width // image_size)) * image_size,
|
| 123 |
+
(i // (target_width // image_size)) * image_size,
|
| 124 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 125 |
+
((i // (target_width // image_size)) + 1) * image_size
|
| 126 |
+
)
|
| 127 |
+
# split the image
|
| 128 |
+
split_img = resized_img.crop(box)
|
| 129 |
+
processed_images.append(split_img)
|
| 130 |
+
assert len(processed_images) == blocks
|
| 131 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 132 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 133 |
+
processed_images.append(thumbnail_img)
|
| 134 |
+
return processed_images
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def load_image(image_file, input_size=448, max_num=12):
|
| 138 |
+
image = Image.open(image_file).convert('RGB')
|
| 139 |
+
transform = build_transform(input_size=input_size)
|
| 140 |
+
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
| 141 |
+
pixel_values = [transform(image) for image in images]
|
| 142 |
+
pixel_values = torch.stack(pixel_values)
|
| 143 |
+
return pixel_values
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
pixel_values = load_image(f'{resource_path}image1.jpg', max_num=12).to(torch.bfloat16).cuda()
|
| 148 |
+
# breakpoint()
|
| 149 |
+
generation_config = dict(max_new_tokens=1024, do_sample=True)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# breakpoint()
|
| 153 |
+
question = 'Hello, who are you?'
|
| 154 |
+
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
|
| 155 |
+
print(f'User: {question}\nAssistant: {response}')
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# 多模态推理测试
|
| 160 |
+
print("\n" + "="*80)
|
| 161 |
+
print("🧪 开始多模态推理测试")
|
| 162 |
+
print("="*80)
|
| 163 |
+
|
| 164 |
+
def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, num_patches_list=None):
|
| 165 |
+
"""统一的推理测试函数"""
|
| 166 |
+
print(f"\n{'='*60}")
|
| 167 |
+
print(f"🧪 测试: {test_name}")
|
| 168 |
+
print(f"📝 问题: {question}")
|
| 169 |
+
print(f"{'='*60}")
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
# 准备参数
|
| 173 |
+
chat_kwargs = {
|
| 174 |
+
'tokenizer': tokenizer,
|
| 175 |
+
'pixel_values': pixel_values_input,
|
| 176 |
+
'question': question,
|
| 177 |
+
'generation_config': generation_config,
|
| 178 |
+
'verbose': True
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# 如果有视频数据,添加num_patches_list参数
|
| 182 |
+
if num_patches_list is not None:
|
| 183 |
+
chat_kwargs['num_patches_list'] = num_patches_list
|
| 184 |
+
|
| 185 |
+
# 如果有speech数据,添加speech参数
|
| 186 |
+
if speech_input is not None:
|
| 187 |
+
chat_kwargs.update({
|
| 188 |
+
'speech': speech_input, # mel 谱图,用于 Whisper
|
| 189 |
+
'speech_lengths': speech_lengths_input,
|
| 190 |
+
'speech_wav': speech_wavs, # 原始音频波形,用于 BEATs
|
| 191 |
+
})
|
| 192 |
+
|
| 193 |
+
# 执行推理
|
| 194 |
+
# breakpoint()
|
| 195 |
+
response = model.chat(**chat_kwargs)
|
| 196 |
+
|
| 197 |
+
print(f"✅ 推理成功!")
|
| 198 |
+
print(f"🤖 回复: {response}")
|
| 199 |
+
|
| 200 |
+
return True, response
|
| 201 |
+
|
| 202 |
+
except Exception as e:
|
| 203 |
+
print(f"❌ 推理失败: {str(e)}")
|
| 204 |
+
import traceback
|
| 205 |
+
traceback.print_exc()
|
| 206 |
+
return False, str(e)
|
| 207 |
+
|
| 208 |
+
# 测试1: Pure Text (应该正常,使用训练好的InternVL)
|
| 209 |
+
success1, response1 = test_inference(
|
| 210 |
+
test_name="Pure Text",
|
| 211 |
+
question="Hello, who are you? Please introduce yourself briefly.",
|
| 212 |
+
pixel_values_input=None,
|
| 213 |
+
speech_input=None,
|
| 214 |
+
speech_lengths_input=None
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# 测试2: Text & Image - Visual only (应该正常,使用训练好的InternVL)
|
| 218 |
+
# success2, response2 = test_inference(
|
| 219 |
+
# test_name="Text & Image (Visual only)",
|
| 220 |
+
# question="<image>\nPlease describe this image in detail.",
|
| 221 |
+
# pixel_values_input=pixel_values,
|
| 222 |
+
# speech_input=None,
|
| 223 |
+
# speech_lengths_input=None
|
| 224 |
+
# )
|
| 225 |
+
|
| 226 |
+
print("\n" + "="*60)
|
| 227 |
+
print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)")
|
| 228 |
+
print("="*60)
|
| 229 |
+
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
| 230 |
+
if bound:
|
| 231 |
+
start, end = bound[0], bound[1]
|
| 232 |
+
else:
|
| 233 |
+
start, end = -100000, 100000
|
| 234 |
+
start_idx = max(first_idx, round(start * fps))
|
| 235 |
+
end_idx = min(round(end * fps), max_frame)
|
| 236 |
+
seg_size = float(end_idx - start_idx) / num_segments
|
| 237 |
+
frame_indices = np.array([
|
| 238 |
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
| 239 |
+
for idx in range(num_segments)
|
| 240 |
+
])
|
| 241 |
+
return frame_indices
|
| 242 |
+
|
| 243 |
+
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
| 244 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
| 245 |
+
max_frame = len(vr) - 1
|
| 246 |
+
fps = float(vr.get_avg_fps())
|
| 247 |
+
|
| 248 |
+
pixel_values_list, num_patches_list = [], []
|
| 249 |
+
transform = build_transform(input_size=input_size)
|
| 250 |
+
frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
|
| 251 |
+
for frame_index in frame_indices:
|
| 252 |
+
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
|
| 253 |
+
img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
| 254 |
+
pixel_values = [transform(tile) for tile in img]
|
| 255 |
+
pixel_values = torch.stack(pixel_values)
|
| 256 |
+
num_patches_list.append(pixel_values.shape[0])
|
| 257 |
+
pixel_values_list.append(pixel_values)
|
| 258 |
+
pixel_values = torch.cat(pixel_values_list)
|
| 259 |
+
return pixel_values, num_patches_list
|
| 260 |
+
|
| 261 |
+
def load_audio(audio_file_name):
|
| 262 |
+
"""
|
| 263 |
+
加载音频文件,使用Ola风格的mel谱图预处理
|
| 264 |
+
这与原始的Ola load_audio函数保持一致
|
| 265 |
+
"""
|
| 266 |
+
speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
|
| 267 |
+
if len(speech_wav.shape) > 1:
|
| 268 |
+
speech_wav = speech_wav[:, 0]
|
| 269 |
+
speech_wav = speech_wav.astype(np.float32)
|
| 270 |
+
CHUNK_LIM = 480000
|
| 271 |
+
SAMPLE_RATE = 16000
|
| 272 |
+
speechs = []
|
| 273 |
+
speech_wavs = []
|
| 274 |
+
|
| 275 |
+
if len(speech_wav) <= CHUNK_LIM:
|
| 276 |
+
speech = whisper.pad_or_trim(speech_wav)
|
| 277 |
+
speech_wav_chunk = whisper.pad_or_trim(speech_wav)
|
| 278 |
+
speechs.append(speech)
|
| 279 |
+
speech_wavs.append(torch.from_numpy(speech_wav_chunk).unsqueeze(0))
|
| 280 |
+
else:
|
| 281 |
+
for i in range(0, len(speech_wav), CHUNK_LIM):
|
| 282 |
+
chunk = speech_wav[i : i + CHUNK_LIM]
|
| 283 |
+
if len(chunk) < CHUNK_LIM:
|
| 284 |
+
chunk = whisper.pad_or_trim(chunk)
|
| 285 |
+
speechs.append(chunk)
|
| 286 |
+
speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
|
| 287 |
+
|
| 288 |
+
# 生成mel谱图
|
| 289 |
+
mels = []
|
| 290 |
+
for chunk in speechs:
|
| 291 |
+
chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
|
| 292 |
+
mels.append(chunk)
|
| 293 |
+
|
| 294 |
+
mels = torch.cat(mels, dim=0)
|
| 295 |
+
speech_wavs = torch.cat(speech_wavs, dim=0)
|
| 296 |
+
if mels.shape[0] > 25:
|
| 297 |
+
mels = mels[:25]
|
| 298 |
+
speech_wavs = speech_wavs[:25]
|
| 299 |
+
|
| 300 |
+
speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
|
| 301 |
+
speech_chunks = torch.LongTensor([mels.shape[0]])
|
| 302 |
+
|
| 303 |
+
return mels, speech_length, speech_chunks, speech_wavs
|
| 304 |
+
|
| 305 |
+
def extract_audio(videos_file_path):
|
| 306 |
+
my_clip = mp.VideoFileClip(videos_file_path)
|
| 307 |
+
return my_clip.audio
|
| 308 |
+
|
| 309 |
+
# 加载视频数据用于视频测试
|
| 310 |
+
print("\n📥 加载视频数据...")
|
| 311 |
+
try:
|
| 312 |
+
video_path = f'{resource_path}red-panda.mp4'
|
| 313 |
+
if os.path.exists(video_path):
|
| 314 |
+
video_pixel_values, video_num_patches_list = load_video(video_path, num_segments=8, max_num=1)
|
| 315 |
+
video_pixel_values = video_pixel_values.to(torch.bfloat16).cuda()
|
| 316 |
+
video_loaded = True
|
| 317 |
+
print(f"✅ 视频加载成功:")
|
| 318 |
+
print(f" - 视频帧数: {len(video_num_patches_list)}")
|
| 319 |
+
print(f" - 视频像素值形状: {video_pixel_values.shape}")
|
| 320 |
+
print(f" - 每帧patch数: {video_num_patches_list}")
|
| 321 |
+
else:
|
| 322 |
+
print(f"⚠️ 视频文件不存在: {video_path}")
|
| 323 |
+
video_loaded = False
|
| 324 |
+
video_pixel_values = None
|
| 325 |
+
video_num_patches_list = None
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"❌ 视频加载失败: {e}")
|
| 328 |
+
video_loaded = False
|
| 329 |
+
video_pixel_values = None
|
| 330 |
+
video_num_patches_list = None
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
audio_path = f'/data1/cxy/dataset/english.mp3'
|
| 335 |
+
|
| 336 |
+
# 加载音频数据用于后续测试
|
| 337 |
+
print("\n📥 加载音频数据...")
|
| 338 |
+
try:
|
| 339 |
+
# 加载音频文件 - 使用Ola风格的mel谱图预处理
|
| 340 |
+
mels, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path)
|
| 341 |
+
print(f"✅ 音频加载成功:")
|
| 342 |
+
print(f" - mel谱图形状: {mels.shape}")
|
| 343 |
+
print(f" - 音频长度: {speech_lengths}")
|
| 344 |
+
print(f" - 音频块数: {speech_chunks}")
|
| 345 |
+
print(f" - 原始音频波形形状: {speech_wavs.shape}")
|
| 346 |
+
|
| 347 |
+
# 将音频数据转换为适当的格式并移到GPU
|
| 348 |
+
mels = mels.to(torch.bfloat16).cuda()
|
| 349 |
+
speech_lengths = speech_lengths.cuda()
|
| 350 |
+
speech_chunks = speech_chunks.cuda()
|
| 351 |
+
speech_wavs = speech_wavs.cuda()
|
| 352 |
+
|
| 353 |
+
audio_loaded = True
|
| 354 |
+
|
| 355 |
+
except Exception as e:
|
| 356 |
+
print(f"❌ 音频加载失败: {e}")
|
| 357 |
+
audio_loaded = False
|
| 358 |
+
mels = None
|
| 359 |
+
speech_lengths = None
|
| 360 |
+
|
| 361 |
+
# 测试3: Audio only (可能乱码,speech部分未训练)
|
| 362 |
+
if audio_loaded:
|
| 363 |
+
success3, response3 = test_inference(
|
| 364 |
+
test_name="Audio only (预期乱码)",
|
| 365 |
+
question="<speech>\nPlease transcribe and summarize what you heard in the audio.",
|
| 366 |
+
pixel_values_input=None,
|
| 367 |
+
speech_input=mels,
|
| 368 |
+
speech_lengths_input=speech_lengths
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
print("⚠️ 跳过Audio only测试 (音频加载失败)")
|
| 372 |
+
success3 = False
|
| 373 |
+
|
| 374 |
+
# # 测试4: Audio + Image (可能乱码,speech部分未训练)
|
| 375 |
+
# if audio_loaded:
|
| 376 |
+
# success4, response4 = test_inference(
|
| 377 |
+
# test_name="Audio + Image (预期乱码)",
|
| 378 |
+
# question="<image>\nUser's question in speech: <speech>\n",
|
| 379 |
+
# pixel_values_input=pixel_values,
|
| 380 |
+
# speech_input=mels,
|
| 381 |
+
# speech_lengths_input=speech_lengths
|
| 382 |
+
# )
|
| 383 |
+
# else:
|
| 384 |
+
# print("⚠️ 跳过Audio + Image测试 (音频加载失败)")
|
| 385 |
+
# success4 = False
|
| 386 |
+
|
| 387 |
+
# 测试5: Video + Text (应该正常,使用训练好的InternVL)
|
| 388 |
+
# if video_loaded:
|
| 389 |
+
# # 构建视频帧前缀
|
| 390 |
+
# video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(video_num_patches_list))])
|
| 391 |
+
# video_question = video_prefix + 'What is the red panda doing in this video? Please describe the actions and movements you observe.'
|
| 392 |
+
|
| 393 |
+
# success5, response5 = test_inference(
|
| 394 |
+
# test_name="Video + Text",
|
| 395 |
+
# question=video_question,
|
| 396 |
+
# pixel_values_input=video_pixel_values,
|
| 397 |
+
# speech_input=None,
|
| 398 |
+
# speech_lengths_input=None,
|
| 399 |
+
# num_patches_list=video_num_patches_list
|
| 400 |
+
# )
|
| 401 |
+
# else:
|
| 402 |
+
# print("⚠️ 跳过Video + Text测试 (视频加载失败)")
|
| 403 |
+
# success5 = False
|
| 404 |
+
|
| 405 |
+
# 测试5: Video + Audio (可能乱码,speech部分未训练)
|
| 406 |
+
# if audio_loaded:
|
| 407 |
+
# success5, response5 = test_inference(
|
| 408 |
+
# test_name="Video + Audio (预期乱码)",
|
| 409 |
+
# question="<speech><image>\nDescribe what you hear and see in this content.",
|
| 410 |
+
# pixel_values_input=pixel_values,
|
| 411 |
+
# speech_input=mels,
|
| 412 |
+
# speech_lengths_input=speech_lengths
|
| 413 |
+
# )
|
| 414 |
+
# else:
|
| 415 |
+
# print("⚠️ 跳过Video + Audio测试 (音频加载失败)")
|
| 416 |
+
# success5 = False
|
| 417 |
+
|
| 418 |
+
# 测试总结
|
| 419 |
+
print("\n" + "="*80)
|
| 420 |
+
print("📊 多模态推理测试总结")
|
| 421 |
+
print("="*80)
|
| 422 |
+
|
| 423 |
+
test_results = [
|
| 424 |
+
("Pure Text", success1, "PASS", "应该正常 (训练好的InternVL)"),
|
| 425 |
+
# ("Text & Image", success2, "PASS", "应该正常 (训练好的InternVL)"),
|
| 426 |
+
# ("Video + Text", success5 if video_loaded else False, "PASS", "应该正常 (训练好的InternVL)"),
|
| 427 |
+
("Audio only", success3 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"),
|
| 428 |
+
# ("Audio + Image", success4 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"),
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
for test_name, success, expected, note in test_results:
|
| 432 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 433 |
+
print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}")
|
| 434 |
+
|
| 435 |
+
passed = sum(1 for _, success, _, _ in test_results if success)
|
| 436 |
+
total = len(test_results)
|
| 437 |
+
print(f"\n📈 测试统计: {passed}/{total} 通过")
|
| 438 |
+
|
| 439 |
+
if passed >= 2: # 至少pure text、text&image、video+text中的2个应该通过
|
| 440 |
+
print("🎉 基础功能正常,Speech集成架构成功!")
|
| 441 |
+
print("💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练")
|
| 442 |
+
if passed >= 3:
|
| 443 |
+
print("🌟 所有基础模态测试都通过了!")
|
| 444 |
+
else:
|
| 445 |
+
print("⚠️ 基础功能可能存在问题,需要进一步检查")
|
| 446 |
+
|
| 447 |
+
print("\n=== 多模态推理测试完成 ===")
|
| 448 |
+
|
inference/infer_ola_internvl_audio.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['LOWRES_RESIZE'] = '384x32'
|
| 4 |
+
os.environ['HIGHRES_BASE'] = '0x32'
|
| 5 |
+
os.environ['VIDEO_RESIZE'] = "0x64"
|
| 6 |
+
os.environ['VIDEO_MAXRES'] = "480"
|
| 7 |
+
os.environ['VIDEO_MINRES'] = "288"
|
| 8 |
+
os.environ['MAXRES'] = '1536'
|
| 9 |
+
os.environ['MINRES'] = '0'
|
| 10 |
+
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
| 11 |
+
os.environ['LOAD_VISION_EARLY'] = '1'
|
| 12 |
+
os.environ['PAD2STRIDE'] = '1'
|
| 13 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import math
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torchvision.transforms as T
|
| 21 |
+
from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 24 |
+
from transformers import AutoModel, AutoTokenizer
|
| 25 |
+
from contextlib import redirect_stdout
|
| 26 |
+
import io
|
| 27 |
+
import librosa
|
| 28 |
+
import whisper
|
| 29 |
+
import moviepy as mp
|
| 30 |
+
import torch
|
| 31 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModel
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 35 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 36 |
+
import gradio as gr
|
| 37 |
+
import torch
|
| 38 |
+
import re
|
| 39 |
+
from decord import VideoReader, cpu
|
| 40 |
+
from PIL import Image
|
| 41 |
+
import numpy as np
|
| 42 |
+
import transformers
|
| 43 |
+
import moviepy as mp
|
| 44 |
+
from typing import Dict, Optional, Sequence, List
|
| 45 |
+
import librosa
|
| 46 |
+
import whisper
|
| 47 |
+
from ola.conversation import conv_templates, SeparatorStyle
|
| 48 |
+
from ola.model.builder import load_pretrained_model
|
| 49 |
+
from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
| 50 |
+
from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
| 51 |
+
from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
| 52 |
+
import argparse
|
| 53 |
+
|
| 54 |
+
parser = argparse.ArgumentParser()
|
| 55 |
+
parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B')
|
| 56 |
+
parser.add_argument('--text', type=str, default="What does the speech say?")
|
| 57 |
+
parser.add_argument('--audio_path', type=str, default=None)
|
| 58 |
+
parser.add_argument('--image_path', type=str, default=None)
|
| 59 |
+
parser.add_argument('--video_path', type=str, default=None)
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
|
| 62 |
+
model_path = args.model_path
|
| 63 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None)
|
| 64 |
+
model = model.to('cuda').eval()
|
| 65 |
+
model = model.bfloat16()
|
| 66 |
+
|
| 67 |
+
resource_path = "/data1/cxy/plm-v/modeling/example/"
|
| 68 |
+
|
| 69 |
+
generation_config = dict(
|
| 70 |
+
max_new_tokens=256,
|
| 71 |
+
do_sample=False, # Use greedy decoding to avoid sampling issues
|
| 72 |
+
temperature=0.5,
|
| 73 |
+
top_p=0.8,
|
| 74 |
+
top_k=10,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# 多模态推理测试
|
| 78 |
+
print("\n" + "="*80)
|
| 79 |
+
print("🧪 开始多模态推理测试")
|
| 80 |
+
print("="*80)
|
| 81 |
+
|
| 82 |
+
def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, speech_wavs_input=None, speech_chunks_input=None, num_patches_list=None):
|
| 83 |
+
"""统一的推理测试函数"""
|
| 84 |
+
print(f"\n{'='*60}")
|
| 85 |
+
print(f"🧪 测试: {test_name}")
|
| 86 |
+
print(f"📝 问题: {question}")
|
| 87 |
+
print(f"{'='*60}")
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
# 准备参数
|
| 91 |
+
chat_kwargs = {
|
| 92 |
+
'tokenizer': tokenizer,
|
| 93 |
+
'pixel_values': pixel_values_input,
|
| 94 |
+
'question': question,
|
| 95 |
+
'generation_config': generation_config,
|
| 96 |
+
'verbose': True
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# 如果有视频数据,添加num_patches_list参数
|
| 100 |
+
if num_patches_list is not None:
|
| 101 |
+
chat_kwargs['num_patches_list'] = num_patches_list
|
| 102 |
+
|
| 103 |
+
# 如果有speech数据,添加speech参数
|
| 104 |
+
if speech_input is not None:
|
| 105 |
+
chat_kwargs.update({
|
| 106 |
+
'speech': speech_input, # mel 谱图,用于 Whisper
|
| 107 |
+
'speech_lengths': speech_lengths_input,
|
| 108 |
+
'speech_wav': speech_wavs_input, # 原始音频波形,用于 BEATs
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
# 如果有speech_chunks数据,添加speech_chunks参数
|
| 112 |
+
if speech_chunks_input is not None:
|
| 113 |
+
chat_kwargs['speech_chunks'] = speech_chunks_input
|
| 114 |
+
|
| 115 |
+
# 执行推理
|
| 116 |
+
# breakpoint()
|
| 117 |
+
response = model.chat(**chat_kwargs)
|
| 118 |
+
|
| 119 |
+
print(f"✅ 推理成功!")
|
| 120 |
+
print(f"🤖 回复: {response}")
|
| 121 |
+
|
| 122 |
+
return True, response
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"❌ 推理失败: {str(e)}")
|
| 126 |
+
import traceback
|
| 127 |
+
traceback.print_exc()
|
| 128 |
+
return False, str(e)
|
| 129 |
+
|
| 130 |
+
# success1, response1 = test_inference(
|
| 131 |
+
# test_name="Pure Text",
|
| 132 |
+
# question="What is China's capital? Please introduce the city in detail.",
|
| 133 |
+
# pixel_values_input=None,
|
| 134 |
+
# speech_input=None,
|
| 135 |
+
# speech_lengths_input=None
|
| 136 |
+
# )
|
| 137 |
+
|
| 138 |
+
print("\n" + "="*60)
|
| 139 |
+
print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)")
|
| 140 |
+
print("="*60)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_audio(audio_file_name):
|
| 144 |
+
speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
|
| 145 |
+
if len(speech_wav.shape) > 1:
|
| 146 |
+
speech_wav = speech_wav[:, 0]
|
| 147 |
+
speech_wav = speech_wav.astype(np.float32)
|
| 148 |
+
CHUNK_LIM = 480000
|
| 149 |
+
SAMPLE_RATE = 16000
|
| 150 |
+
speechs = []
|
| 151 |
+
speech_wavs = []
|
| 152 |
+
|
| 153 |
+
if len(speech_wav) <= CHUNK_LIM:
|
| 154 |
+
speech = whisper.pad_or_trim(speech_wav)
|
| 155 |
+
speech_wav = whisper.pad_or_trim(speech_wav)
|
| 156 |
+
speechs.append(speech)
|
| 157 |
+
speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0))
|
| 158 |
+
else:
|
| 159 |
+
for i in range(0, len(speech_wav), CHUNK_LIM):
|
| 160 |
+
chunk = speech_wav[i : i + CHUNK_LIM]
|
| 161 |
+
if len(chunk) < CHUNK_LIM:
|
| 162 |
+
chunk = whisper.pad_or_trim(chunk)
|
| 163 |
+
speechs.append(chunk)
|
| 164 |
+
speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
|
| 165 |
+
mels = []
|
| 166 |
+
for chunk in speechs:
|
| 167 |
+
chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
|
| 168 |
+
mels.append(chunk)
|
| 169 |
+
|
| 170 |
+
mels = torch.cat(mels, dim=0)
|
| 171 |
+
speech_wavs = torch.cat(speech_wavs, dim=0)
|
| 172 |
+
if mels.shape[0] > 25:
|
| 173 |
+
mels = mels[:25]
|
| 174 |
+
speech_wavs = speech_wavs[:25]
|
| 175 |
+
|
| 176 |
+
speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
|
| 177 |
+
speech_chunks = torch.LongTensor([mels.shape[0]])
|
| 178 |
+
return mels, speech_length, speech_chunks, speech_wavs
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
audio_path = f'/data1/cxy/dataset/english.mp3'
|
| 184 |
+
|
| 185 |
+
# 加载音频数据用于后续测试
|
| 186 |
+
print("\n📥 加载音频数据...")
|
| 187 |
+
try:
|
| 188 |
+
# 加载音频文件 - 使用Ola风格的mel谱图预处理
|
| 189 |
+
speech, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path)
|
| 190 |
+
print(f"✅ 音频加载成功:")
|
| 191 |
+
print(f" - mel谱图形状: {speech.shape}")
|
| 192 |
+
print(f" - 音频长度: {speech_lengths}")
|
| 193 |
+
print(f" - 音频块数: {speech_chunks}")
|
| 194 |
+
print(f" - 原始音频波形形状: {speech_wavs.shape}")
|
| 195 |
+
|
| 196 |
+
# 将音频数据转换为适当的格式并移到GPU
|
| 197 |
+
speech = speech.to(torch.bfloat16).cuda()
|
| 198 |
+
speech_lengths = speech_lengths.cuda()
|
| 199 |
+
speech_chunks = speech_chunks.cuda()
|
| 200 |
+
speech_wavs = speech_wavs.cuda()
|
| 201 |
+
|
| 202 |
+
audio_loaded = True
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"❌ 音频加载失败: {e}")
|
| 206 |
+
audio_loaded = False
|
| 207 |
+
mels = None
|
| 208 |
+
speech_lengths = None
|
| 209 |
+
|
| 210 |
+
# 测试3: Audio only (可能乱码,speech部分未训练)
|
| 211 |
+
if audio_loaded:
|
| 212 |
+
success3, response3 = test_inference(
|
| 213 |
+
test_name="Audio only (预期乱码)",
|
| 214 |
+
question="<speech>\nPlease transcribe and summarize what you heard in the audio.",
|
| 215 |
+
pixel_values_input=None,
|
| 216 |
+
speech_input=speech,
|
| 217 |
+
speech_lengths_input=speech_lengths,
|
| 218 |
+
speech_wavs_input=speech_wavs,
|
| 219 |
+
speech_chunks_input=speech_chunks
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
print("⚠️ 跳过Audio only测试 (音频加载失败)")
|
| 223 |
+
success3 = False
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# 测试总结
|
| 227 |
+
print("\n" + "="*80)
|
| 228 |
+
print("📊 多模态推理测试总结")
|
| 229 |
+
print("="*80)
|
| 230 |
+
|
| 231 |
+
test_results = [
|
| 232 |
+
("Audio only", success3 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"),
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
for test_name, success, expected, note in test_results:
|
| 236 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 237 |
+
print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}")
|
| 238 |
+
|
| 239 |
+
passed = sum(1 for _, success, _, _ in test_results if success)
|
| 240 |
+
total = len(test_results)
|
| 241 |
+
print(f"\n📈 测试统计: {passed}/{total} 通过")
|
| 242 |
+
|
| 243 |
+
print("\n=== 多模态推理测试完成 ===")
|
| 244 |
+
|
inference/infer_ola_internvl_audio_ckpt.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['LOWRES_RESIZE'] = '384x32'
|
| 4 |
+
os.environ['HIGHRES_BASE'] = '0x32'
|
| 5 |
+
os.environ['VIDEO_RESIZE'] = "0x64"
|
| 6 |
+
os.environ['VIDEO_MAXRES'] = "480"
|
| 7 |
+
os.environ['VIDEO_MINRES'] = "288"
|
| 8 |
+
os.environ['MAXRES'] = '1536'
|
| 9 |
+
os.environ['MINRES'] = '0'
|
| 10 |
+
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
| 11 |
+
os.environ['LOAD_VISION_EARLY'] = '1'
|
| 12 |
+
os.environ['PAD2STRIDE'] = '1'
|
| 13 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import math
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torchvision.transforms as T
|
| 21 |
+
from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 24 |
+
from transformers import AutoModel, AutoTokenizer
|
| 25 |
+
from contextlib import redirect_stdout
|
| 26 |
+
import io
|
| 27 |
+
import librosa
|
| 28 |
+
import whisper
|
| 29 |
+
import moviepy as mp
|
| 30 |
+
import torch
|
| 31 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModel
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 35 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 36 |
+
import gradio as gr
|
| 37 |
+
import torch
|
| 38 |
+
import re
|
| 39 |
+
from decord import VideoReader, cpu
|
| 40 |
+
from PIL import Image
|
| 41 |
+
import numpy as np
|
| 42 |
+
import transformers
|
| 43 |
+
import moviepy as mp
|
| 44 |
+
from typing import Dict, Optional, Sequence, List
|
| 45 |
+
import librosa
|
| 46 |
+
import whisper
|
| 47 |
+
from ola.conversation import conv_templates, SeparatorStyle
|
| 48 |
+
from ola.model.builder import load_pretrained_model
|
| 49 |
+
from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
| 50 |
+
from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
| 51 |
+
from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
| 52 |
+
import argparse
|
| 53 |
+
|
| 54 |
+
parser = argparse.ArgumentParser()
|
| 55 |
+
parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B')
|
| 56 |
+
parser.add_argument('--text', type=str, default="Give the caption of the given audio or speech.")
|
| 57 |
+
parser.add_argument('--audio_path', type=str, default=None)
|
| 58 |
+
parser.add_argument('--image_path', type=str, default=None)
|
| 59 |
+
parser.add_argument('--video_path', type=str, default=None)
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
|
| 62 |
+
model_path = args.model_path
|
| 63 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None)
|
| 64 |
+
model = model.to('cuda').eval()
|
| 65 |
+
model = model.bfloat16()
|
| 66 |
+
|
| 67 |
+
resource_path = "/data1/cxy/plm-v/modeling/example/"
|
| 68 |
+
|
| 69 |
+
generation_config = dict(
|
| 70 |
+
max_new_tokens=256,
|
| 71 |
+
do_sample=False, # Use greedy decoding to avoid sampling issues
|
| 72 |
+
temperature=0.5,
|
| 73 |
+
top_p=0.8,
|
| 74 |
+
top_k=10,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# 多模态推理测试
|
| 78 |
+
print("\n" + "="*80)
|
| 79 |
+
print("🧪 开始多模态推理测试")
|
| 80 |
+
print("="*80)
|
| 81 |
+
|
| 82 |
+
def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, speech_wavs_input=None, speech_chunks_input=None, num_patches_list=None):
|
| 83 |
+
"""统一的推理测试函数"""
|
| 84 |
+
print(f"\n{'='*60}")
|
| 85 |
+
print(f"🧪 测试: {test_name}")
|
| 86 |
+
print(f"📝 问题: {question}")
|
| 87 |
+
print(f"{'='*60}")
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
# 准备参数
|
| 91 |
+
chat_kwargs = {
|
| 92 |
+
'tokenizer': tokenizer,
|
| 93 |
+
'pixel_values': pixel_values_input,
|
| 94 |
+
'question': question,
|
| 95 |
+
'generation_config': generation_config,
|
| 96 |
+
'verbose': True
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# 如果有视频数据,添加num_patches_list参数
|
| 100 |
+
if num_patches_list is not None:
|
| 101 |
+
chat_kwargs['num_patches_list'] = num_patches_list
|
| 102 |
+
|
| 103 |
+
# 如果有speech数据,添加speech参数
|
| 104 |
+
if speech_input is not None:
|
| 105 |
+
chat_kwargs.update({
|
| 106 |
+
'speech': speech_input, # mel 谱图,用于 Whisper
|
| 107 |
+
'speech_lengths': speech_lengths_input,
|
| 108 |
+
'speech_wav': speech_wavs_input, # 原始音频波形,用于 BEATs
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
# 如果有speech_chunks数据,添加speech_chunks参数
|
| 112 |
+
if speech_chunks_input is not None:
|
| 113 |
+
chat_kwargs['speech_chunks'] = speech_chunks_input
|
| 114 |
+
|
| 115 |
+
# 执行推理
|
| 116 |
+
# breakpoint()
|
| 117 |
+
response = model.chat(**chat_kwargs)
|
| 118 |
+
|
| 119 |
+
print(f"✅ 推理成功!")
|
| 120 |
+
print(f"🤖 回复: {response}")
|
| 121 |
+
|
| 122 |
+
return True, response
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"❌ 推理失败: {str(e)}")
|
| 126 |
+
import traceback
|
| 127 |
+
traceback.print_exc()
|
| 128 |
+
return False, str(e)
|
| 129 |
+
|
| 130 |
+
# success1, response1 = test_inference(
|
| 131 |
+
# test_name="Pure Text",
|
| 132 |
+
# question="What is China's capital? Please introduce the city in detail.",
|
| 133 |
+
# pixel_values_input=None,
|
| 134 |
+
# speech_input=None,
|
| 135 |
+
# speech_lengths_input=None
|
| 136 |
+
# )
|
| 137 |
+
|
| 138 |
+
print("\n" + "="*60)
|
| 139 |
+
print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)")
|
| 140 |
+
print("="*60)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_audio(audio_file_name):
|
| 144 |
+
speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
|
| 145 |
+
if len(speech_wav.shape) > 1:
|
| 146 |
+
speech_wav = speech_wav[:, 0]
|
| 147 |
+
speech_wav = speech_wav.astype(np.float32)
|
| 148 |
+
CHUNK_LIM = 480000
|
| 149 |
+
SAMPLE_RATE = 16000
|
| 150 |
+
speechs = []
|
| 151 |
+
speech_wavs = []
|
| 152 |
+
|
| 153 |
+
if len(speech_wav) <= CHUNK_LIM:
|
| 154 |
+
speech = whisper.pad_or_trim(speech_wav)
|
| 155 |
+
speech_wav = whisper.pad_or_trim(speech_wav)
|
| 156 |
+
speechs.append(speech)
|
| 157 |
+
speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0))
|
| 158 |
+
else:
|
| 159 |
+
for i in range(0, len(speech_wav), CHUNK_LIM):
|
| 160 |
+
chunk = speech_wav[i : i + CHUNK_LIM]
|
| 161 |
+
if len(chunk) < CHUNK_LIM:
|
| 162 |
+
chunk = whisper.pad_or_trim(chunk)
|
| 163 |
+
speechs.append(chunk)
|
| 164 |
+
speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
|
| 165 |
+
mels = []
|
| 166 |
+
for chunk in speechs:
|
| 167 |
+
chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
|
| 168 |
+
mels.append(chunk)
|
| 169 |
+
|
| 170 |
+
mels = torch.cat(mels, dim=0)
|
| 171 |
+
speech_wavs = torch.cat(speech_wavs, dim=0)
|
| 172 |
+
if mels.shape[0] > 25:
|
| 173 |
+
mels = mels[:25]
|
| 174 |
+
speech_wavs = speech_wavs[:25]
|
| 175 |
+
|
| 176 |
+
speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
|
| 177 |
+
speech_chunks = torch.LongTensor([mels.shape[0]])
|
| 178 |
+
return mels, speech_length, speech_chunks, speech_wavs
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# audio_path = f'/data1/cxy/dataset/english.mp3'
|
| 184 |
+
audio_path = "/data1/cxy/plm-v/modeling/data/Clotho/train/Leaves rustling.wav"
|
| 185 |
+
|
| 186 |
+
# 加载音频数据用于后续测试
|
| 187 |
+
print("\n📥 加载音频数据...")
|
| 188 |
+
try:
|
| 189 |
+
# 加载音频文件 - 使用Ola风格的mel谱图预处理
|
| 190 |
+
speech, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path)
|
| 191 |
+
print(f"✅ 音频加载成功:")
|
| 192 |
+
print(f" - mel谱图形状: {speech.shape}")
|
| 193 |
+
print(f" - 音频长度: {speech_lengths}")
|
| 194 |
+
print(f" - 音频块数: {speech_chunks}")
|
| 195 |
+
print(f" - 原始音频波形形状: {speech_wavs.shape}")
|
| 196 |
+
|
| 197 |
+
# 将音频数据转换为适当的格式并移到GPU
|
| 198 |
+
speech = speech.to(torch.bfloat16).cuda()
|
| 199 |
+
speech_lengths = speech_lengths.cuda()
|
| 200 |
+
speech_chunks = speech_chunks.cuda()
|
| 201 |
+
speech_wavs = speech_wavs.cuda()
|
| 202 |
+
|
| 203 |
+
audio_loaded = True
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"❌ 音频加载失败: {e}")
|
| 207 |
+
audio_loaded = False
|
| 208 |
+
mels = None
|
| 209 |
+
speech_lengths = None
|
| 210 |
+
|
| 211 |
+
# 测试3: Audio only (可能乱码,speech部分未训练)
|
| 212 |
+
if audio_loaded:
|
| 213 |
+
success3, response3 = test_inference(
|
| 214 |
+
test_name="Audio only (预期乱码)",
|
| 215 |
+
question="<speech>\nGive the caption of the given audio or speech.",
|
| 216 |
+
pixel_values_input=None,
|
| 217 |
+
speech_input=speech,
|
| 218 |
+
speech_lengths_input=speech_lengths,
|
| 219 |
+
speech_wavs_input=speech_wavs,
|
| 220 |
+
speech_chunks_input=speech_chunks
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
print("⚠️ 跳过Audio only测试 (音频加载失败)")
|
| 224 |
+
success3 = False
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# 测试总结
|
| 228 |
+
print("\n" + "="*80)
|
| 229 |
+
print("📊 多模态推理测试总结")
|
| 230 |
+
print("="*80)
|
| 231 |
+
|
| 232 |
+
test_results = [
|
| 233 |
+
("Audio only", success3 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"),
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
for test_name, success, expected, note in test_results:
|
| 237 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 238 |
+
print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}")
|
| 239 |
+
|
| 240 |
+
passed = sum(1 for _, success, _, _ in test_results if success)
|
| 241 |
+
total = len(test_results)
|
| 242 |
+
print(f"\n📈 测试统计: {passed}/{total} 通过")
|
| 243 |
+
|
| 244 |
+
print("\n=== 多模态推理测试完成 ===")
|
| 245 |
+
|
inference/infer_ola_internvl_copy.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['LOWRES_RESIZE'] = '384x32'
|
| 4 |
+
os.environ['HIGHRES_BASE'] = '0x32'
|
| 5 |
+
os.environ['VIDEO_RESIZE'] = "0x64"
|
| 6 |
+
os.environ['VIDEO_MAXRES'] = "480"
|
| 7 |
+
os.environ['VIDEO_MINRES'] = "288"
|
| 8 |
+
os.environ['MAXRES'] = '1536'
|
| 9 |
+
os.environ['MINRES'] = '0'
|
| 10 |
+
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
| 11 |
+
os.environ['LOAD_VISION_EARLY'] = '1'
|
| 12 |
+
os.environ['PAD2STRIDE'] = '1'
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import torch
|
| 16 |
+
import re
|
| 17 |
+
from decord import VideoReader, cpu
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import numpy as np
|
| 20 |
+
import transformers
|
| 21 |
+
import moviepy as mp
|
| 22 |
+
from typing import Dict, Optional, Sequence, List
|
| 23 |
+
import librosa
|
| 24 |
+
import whisper
|
| 25 |
+
from ola.conversation import conv_templates, SeparatorStyle
|
| 26 |
+
from ola.model.builder import load_pretrained_model
|
| 27 |
+
from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
| 28 |
+
from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
| 29 |
+
from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
| 30 |
+
import argparse
|
| 31 |
+
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B')
|
| 34 |
+
parser.add_argument('--text', type=str, default="What does the speech say?")
|
| 35 |
+
parser.add_argument('--audio_path', type=str, default=None)
|
| 36 |
+
parser.add_argument('--image_path', type=str, default=None)
|
| 37 |
+
parser.add_argument('--video_path', type=str, default=None)
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
model_path = args.model_path
|
| 41 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None)
|
| 42 |
+
breakpoint()
|
| 43 |
+
model = model.to('cuda').eval()
|
| 44 |
+
model = model.bfloat16()
|
| 45 |
+
# breakpoint()
|
| 46 |
+
USE_SPEECH=False
|
| 47 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 48 |
+
|
| 49 |
+
def load_audio(audio_file_name):
|
| 50 |
+
speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
|
| 51 |
+
if len(speech_wav.shape) > 1:
|
| 52 |
+
speech_wav = speech_wav[:, 0]
|
| 53 |
+
speech_wav = speech_wav.astype(np.float32)
|
| 54 |
+
CHUNK_LIM = 480000
|
| 55 |
+
SAMPLE_RATE = 16000
|
| 56 |
+
speechs = []
|
| 57 |
+
speech_wavs = []
|
| 58 |
+
|
| 59 |
+
if len(speech_wav) <= CHUNK_LIM:
|
| 60 |
+
speech = whisper.pad_or_trim(speech_wav)
|
| 61 |
+
speech_wav = whisper.pad_or_trim(speech_wav)
|
| 62 |
+
speechs.append(speech)
|
| 63 |
+
speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0))
|
| 64 |
+
else:
|
| 65 |
+
for i in range(0, len(speech_wav), CHUNK_LIM):
|
| 66 |
+
chunk = speech_wav[i : i + CHUNK_LIM]
|
| 67 |
+
if len(chunk) < CHUNK_LIM:
|
| 68 |
+
chunk = whisper.pad_or_trim(chunk)
|
| 69 |
+
speechs.append(chunk)
|
| 70 |
+
speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
|
| 71 |
+
mels = []
|
| 72 |
+
for chunk in speechs:
|
| 73 |
+
chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
|
| 74 |
+
mels.append(chunk)
|
| 75 |
+
|
| 76 |
+
mels = torch.cat(mels, dim=0)
|
| 77 |
+
speech_wavs = torch.cat(speech_wavs, dim=0)
|
| 78 |
+
if mels.shape[0] > 25:
|
| 79 |
+
mels = mels[:25]
|
| 80 |
+
speech_wavs = speech_wavs[:25]
|
| 81 |
+
|
| 82 |
+
speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
|
| 83 |
+
speech_chunks = torch.LongTensor([mels.shape[0]])
|
| 84 |
+
return mels, speech_length, speech_chunks, speech_wavs
|
| 85 |
+
|
| 86 |
+
def extract_audio(videos_file_path):
|
| 87 |
+
my_clip = mp.VideoFileClip(videos_file_path)
|
| 88 |
+
return my_clip.audio
|
| 89 |
+
|
| 90 |
+
image_path = args.image_path
|
| 91 |
+
audio_path = args.audio_path
|
| 92 |
+
video_path = args.video_path
|
| 93 |
+
text = args.text
|
| 94 |
+
modality = "text"
|
| 95 |
+
if video_path is not None:
|
| 96 |
+
modality = "video"
|
| 97 |
+
visual = video_path
|
| 98 |
+
assert image_path is None
|
| 99 |
+
|
| 100 |
+
elif image_path is not None:
|
| 101 |
+
visual = image_path
|
| 102 |
+
modality = "image"
|
| 103 |
+
assert video_path is None
|
| 104 |
+
|
| 105 |
+
elif audio_path is not None:
|
| 106 |
+
modality = "text"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# input audio and video, do not parse audio in the video, else parse audio in the video
|
| 110 |
+
if audio_path:
|
| 111 |
+
USE_SPEECH = True
|
| 112 |
+
elif modality == "video":
|
| 113 |
+
USE_SPEECH = True
|
| 114 |
+
else:
|
| 115 |
+
USE_SPEECH = False
|
| 116 |
+
|
| 117 |
+
speechs = []
|
| 118 |
+
speech_lengths = []
|
| 119 |
+
speech_wavs = []
|
| 120 |
+
speech_chunks = []
|
| 121 |
+
if modality == "video":
|
| 122 |
+
vr = VideoReader(visual, ctx=cpu(0))
|
| 123 |
+
total_frame_num = len(vr)
|
| 124 |
+
fps = round(vr.get_avg_fps())
|
| 125 |
+
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
|
| 126 |
+
frame_idx = uniform_sampled_frames.tolist()
|
| 127 |
+
spare_frames = vr.get_batch(frame_idx).asnumpy()
|
| 128 |
+
video = [Image.fromarray(frame) for frame in spare_frames]
|
| 129 |
+
elif modality == "image":
|
| 130 |
+
image = [Image.open(visual)]
|
| 131 |
+
image_sizes = [image[0].size]
|
| 132 |
+
else:
|
| 133 |
+
images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
| 134 |
+
images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)]
|
| 135 |
+
image_sizes = [(224, 224)]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if USE_SPEECH and audio_path:
|
| 139 |
+
audio_path = audio_path
|
| 140 |
+
speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path)
|
| 141 |
+
speechs.append(speech.bfloat16().to('cuda'))
|
| 142 |
+
speech_lengths.append(speech_length.to('cuda'))
|
| 143 |
+
speech_chunks.append(speech_chunk.to('cuda'))
|
| 144 |
+
speech_wavs.append(speech_wav.to('cuda'))
|
| 145 |
+
print('load audio')
|
| 146 |
+
elif USE_SPEECH and not audio_path:
|
| 147 |
+
# parse audio in the video
|
| 148 |
+
audio = extract_audio(visual)
|
| 149 |
+
audio.write_audiofile("./video_audio.wav")
|
| 150 |
+
video_audio_path = './video_audio.wav'
|
| 151 |
+
speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path)
|
| 152 |
+
speechs.append(speech.bfloat16().to('cuda'))
|
| 153 |
+
speech_lengths.append(speech_length.to('cuda'))
|
| 154 |
+
speech_chunks.append(speech_chunk.to('cuda'))
|
| 155 |
+
speech_wavs.append(speech_wav.to('cuda'))
|
| 156 |
+
else:
|
| 157 |
+
speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
|
| 158 |
+
speech_lengths = [torch.LongTensor([3000]).to('cuda')]
|
| 159 |
+
speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
|
| 160 |
+
speech_chunks = [torch.LongTensor([1]).to('cuda')]
|
| 161 |
+
|
| 162 |
+
conv_mode = "qwen_1_5"
|
| 163 |
+
if text:
|
| 164 |
+
qs = text
|
| 165 |
+
else:
|
| 166 |
+
qs = ''
|
| 167 |
+
|
| 168 |
+
if USE_SPEECH and audio_path and image_path: # image + speech instruction
|
| 169 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + "User's question in speech: " + DEFAULT_SPEECH_TOKEN + '\n'
|
| 170 |
+
elif USE_SPEECH and video_path: # video + audio
|
| 171 |
+
qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs
|
| 172 |
+
elif USE_SPEECH and audio_path: # audio + text
|
| 173 |
+
qs = DEFAULT_SPEECH_TOKEN + "\n" + qs
|
| 174 |
+
elif image_path or video_path: # image / video
|
| 175 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
| 176 |
+
elif text: # text
|
| 177 |
+
qs = qs
|
| 178 |
+
|
| 179 |
+
conv = conv_templates[conv_mode].copy()
|
| 180 |
+
conv.append_message(conv.roles[0], qs)
|
| 181 |
+
conv.append_message(conv.roles[1], None)
|
| 182 |
+
prompt = conv.get_prompt()
|
| 183 |
+
if USE_SPEECH and audio_path and image_path: # image + speech instruction
|
| 184 |
+
input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
| 185 |
+
elif USE_SPEECH and video_path: # video + audio
|
| 186 |
+
input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
| 187 |
+
elif USE_SPEECH and audio_path: # audio + text
|
| 188 |
+
# breakpoint()
|
| 189 |
+
input_ids = tokenizer_speech_token(prompt, tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
| 190 |
+
else:
|
| 191 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
|
| 192 |
+
|
| 193 |
+
if modality == "video":
|
| 194 |
+
video_processed = []
|
| 195 |
+
for idx, frame in enumerate(video):
|
| 196 |
+
image_processor.do_resize = False
|
| 197 |
+
image_processor.do_center_crop = False
|
| 198 |
+
frame = process_anyres_video(frame, image_processor)
|
| 199 |
+
|
| 200 |
+
if frame_idx is not None and idx in frame_idx:
|
| 201 |
+
video_processed.append(frame.unsqueeze(0))
|
| 202 |
+
elif frame_idx is None:
|
| 203 |
+
video_processed.append(frame.unsqueeze(0))
|
| 204 |
+
|
| 205 |
+
if frame_idx is None:
|
| 206 |
+
frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
|
| 207 |
+
|
| 208 |
+
video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda")
|
| 209 |
+
video_processed = (video_processed, video_processed)
|
| 210 |
+
|
| 211 |
+
video_data = (video_processed, (384, 384), "video")
|
| 212 |
+
elif modality == "image":
|
| 213 |
+
image_processor.do_resize = False
|
| 214 |
+
image_processor.do_center_crop = False
|
| 215 |
+
image_tensor, image_highres_tensor = [], []
|
| 216 |
+
for visual in image:
|
| 217 |
+
image_tensor_, image_highres_tensor_ = process_anyres_highres_image(visual, image_processor)
|
| 218 |
+
image_tensor.append(image_tensor_)
|
| 219 |
+
image_highres_tensor.append(image_highres_tensor_)
|
| 220 |
+
if all(x.shape == image_tensor[0].shape for x in image_tensor):
|
| 221 |
+
image_tensor = torch.stack(image_tensor, dim=0)
|
| 222 |
+
if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor):
|
| 223 |
+
image_highres_tensor = torch.stack(image_highres_tensor, dim=0)
|
| 224 |
+
if type(image_tensor) is list:
|
| 225 |
+
image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor]
|
| 226 |
+
else:
|
| 227 |
+
image_tensor = image_tensor.bfloat16().to("cuda")
|
| 228 |
+
if type(image_highres_tensor) is list:
|
| 229 |
+
image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor]
|
| 230 |
+
else:
|
| 231 |
+
image_highres_tensor = image_highres_tensor.bfloat16().to("cuda")
|
| 232 |
+
|
| 233 |
+
pad_token_ids = 151643
|
| 234 |
+
|
| 235 |
+
attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
|
| 236 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 237 |
+
keywords = [stop_str]
|
| 238 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 239 |
+
|
| 240 |
+
gen_kwargs = {}
|
| 241 |
+
|
| 242 |
+
if "max_new_tokens" not in gen_kwargs:
|
| 243 |
+
gen_kwargs["max_new_tokens"] = 1024
|
| 244 |
+
if "temperature" not in gen_kwargs:
|
| 245 |
+
gen_kwargs["temperature"] = 0.2
|
| 246 |
+
if "top_p" not in gen_kwargs:
|
| 247 |
+
gen_kwargs["top_p"] = None
|
| 248 |
+
if "num_beams" not in gen_kwargs:
|
| 249 |
+
gen_kwargs["num_beams"] = 1
|
| 250 |
+
# breakpoint()
|
| 251 |
+
with torch.inference_mode():
|
| 252 |
+
if modality == "video":
|
| 253 |
+
output_ids = model.generate(
|
| 254 |
+
inputs=input_ids,
|
| 255 |
+
images=video_data[0][0],
|
| 256 |
+
images_highres=video_data[0][1],
|
| 257 |
+
modalities=video_data[2],
|
| 258 |
+
speech=speechs,
|
| 259 |
+
speech_lengths=speech_lengths,
|
| 260 |
+
speech_chunks=speech_chunks,
|
| 261 |
+
speech_wav=speech_wavs,
|
| 262 |
+
attention_mask=attention_masks,
|
| 263 |
+
use_cache=True,
|
| 264 |
+
stopping_criteria=[stopping_criteria],
|
| 265 |
+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
|
| 266 |
+
temperature=gen_kwargs["temperature"],
|
| 267 |
+
top_p=gen_kwargs["top_p"],
|
| 268 |
+
num_beams=gen_kwargs["num_beams"],
|
| 269 |
+
max_new_tokens=gen_kwargs["max_new_tokens"],
|
| 270 |
+
)
|
| 271 |
+
elif modality == "image":
|
| 272 |
+
output_ids = model.generate(
|
| 273 |
+
inputs=input_ids,
|
| 274 |
+
images=image_tensor,
|
| 275 |
+
images_highres=image_highres_tensor,
|
| 276 |
+
image_sizes=image_sizes,
|
| 277 |
+
modalities=['image'],
|
| 278 |
+
speech=speechs,
|
| 279 |
+
speech_lengths=speech_lengths,
|
| 280 |
+
speech_chunks=speech_chunks,
|
| 281 |
+
speech_wav=speech_wavs,
|
| 282 |
+
attention_mask=attention_masks,
|
| 283 |
+
use_cache=True,
|
| 284 |
+
stopping_criteria=[stopping_criteria],
|
| 285 |
+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
|
| 286 |
+
temperature=gen_kwargs["temperature"],
|
| 287 |
+
top_p=gen_kwargs["top_p"],
|
| 288 |
+
num_beams=gen_kwargs["num_beams"],
|
| 289 |
+
max_new_tokens=gen_kwargs["max_new_tokens"],
|
| 290 |
+
)
|
| 291 |
+
elif modality == "text":
|
| 292 |
+
output_ids = model.generate(
|
| 293 |
+
input_ids,
|
| 294 |
+
images=images,
|
| 295 |
+
images_highres=images_highres,
|
| 296 |
+
image_sizes=image_sizes,
|
| 297 |
+
modalities=['text'],
|
| 298 |
+
speech=speechs,
|
| 299 |
+
speech_lengths=speech_lengths,
|
| 300 |
+
speech_chunks=speech_chunks,
|
| 301 |
+
speech_wav=speech_wavs,
|
| 302 |
+
attention_mask=attention_masks,
|
| 303 |
+
use_cache=True,
|
| 304 |
+
stopping_criteria=[stopping_criteria],
|
| 305 |
+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
|
| 306 |
+
temperature=gen_kwargs["temperature"],
|
| 307 |
+
top_p=gen_kwargs["top_p"],
|
| 308 |
+
num_beams=gen_kwargs["num_beams"],
|
| 309 |
+
max_new_tokens=gen_kwargs["max_new_tokens"],
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 313 |
+
outputs = outputs.strip()
|
| 314 |
+
if outputs.endswith(stop_str):
|
| 315 |
+
outputs = outputs[:-len(stop_str)]
|
| 316 |
+
outputs = outputs.strip()
|
| 317 |
+
|
| 318 |
+
print(outputs)
|
inference/infer_ola_internvl_text_visual.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['LOWRES_RESIZE'] = '384x32'
|
| 4 |
+
os.environ['HIGHRES_BASE'] = '0x32'
|
| 5 |
+
os.environ['VIDEO_RESIZE'] = "0x64"
|
| 6 |
+
os.environ['VIDEO_MAXRES'] = "480"
|
| 7 |
+
os.environ['VIDEO_MINRES'] = "288"
|
| 8 |
+
os.environ['MAXRES'] = '1536'
|
| 9 |
+
os.environ['MINRES'] = '0'
|
| 10 |
+
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
|
| 11 |
+
os.environ['LOAD_VISION_EARLY'] = '1'
|
| 12 |
+
os.environ['PAD2STRIDE'] = '1'
|
| 13 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import math
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torchvision.transforms as T
|
| 21 |
+
from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 24 |
+
from transformers import AutoModel, AutoTokenizer
|
| 25 |
+
from contextlib import redirect_stdout
|
| 26 |
+
import io
|
| 27 |
+
import librosa
|
| 28 |
+
import whisper
|
| 29 |
+
import moviepy as mp
|
| 30 |
+
import torch
|
| 31 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModel
|
| 32 |
+
|
| 33 |
+
# pure text
|
| 34 |
+
# image + text
|
| 35 |
+
# video + text
|
| 36 |
+
# audio + text
|
| 37 |
+
# video + audio + text
|
| 38 |
+
|
| 39 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 40 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 41 |
+
import gradio as gr
|
| 42 |
+
import torch
|
| 43 |
+
import re
|
| 44 |
+
from decord import VideoReader, cpu
|
| 45 |
+
from PIL import Image
|
| 46 |
+
import numpy as np
|
| 47 |
+
import transformers
|
| 48 |
+
import moviepy as mp
|
| 49 |
+
from typing import Dict, Optional, Sequence, List
|
| 50 |
+
import librosa
|
| 51 |
+
import whisper
|
| 52 |
+
from ola.conversation import conv_templates, SeparatorStyle
|
| 53 |
+
from ola.model.builder import load_pretrained_model
|
| 54 |
+
from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
|
| 55 |
+
from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
|
| 56 |
+
from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
|
| 57 |
+
import argparse
|
| 58 |
+
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B')
|
| 61 |
+
parser.add_argument('--text', type=str, default="What does the speech say?")
|
| 62 |
+
parser.add_argument('--audio_path', type=str, default=None)
|
| 63 |
+
parser.add_argument('--image_path', type=str, default=None)
|
| 64 |
+
parser.add_argument('--video_path', type=str, default=None)
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
model_path = args.model_path
|
| 68 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None)
|
| 69 |
+
model = model.to('cuda').eval()
|
| 70 |
+
model = model.bfloat16()
|
| 71 |
+
|
| 72 |
+
resource_path = "/data1/cxy/plm-v/modeling/example/"
|
| 73 |
+
# set the max number of tiles in `max_num`
|
| 74 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 75 |
+
best_ratio_diff = float('inf')
|
| 76 |
+
best_ratio = (1, 1)
|
| 77 |
+
area = width * height
|
| 78 |
+
for ratio in target_ratios:
|
| 79 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 80 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 81 |
+
if ratio_diff < best_ratio_diff:
|
| 82 |
+
best_ratio_diff = ratio_diff
|
| 83 |
+
best_ratio = ratio
|
| 84 |
+
elif ratio_diff == best_ratio_diff:
|
| 85 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 86 |
+
best_ratio = ratio
|
| 87 |
+
return best_ratio
|
| 88 |
+
def build_transform(input_size):
|
| 89 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
| 90 |
+
transform = T.Compose([
|
| 91 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 92 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 93 |
+
T.ToTensor(),
|
| 94 |
+
T.Normalize(mean=MEAN, std=STD)
|
| 95 |
+
])
|
| 96 |
+
return transform
|
| 97 |
+
|
| 98 |
+
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
| 99 |
+
orig_width, orig_height = image.size
|
| 100 |
+
aspect_ratio = orig_width / orig_height
|
| 101 |
+
|
| 102 |
+
# calculate the existing image aspect ratio
|
| 103 |
+
target_ratios = set(
|
| 104 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
| 105 |
+
i * j <= max_num and i * j >= min_num)
|
| 106 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 107 |
+
|
| 108 |
+
# find the closest aspect ratio to the target
|
| 109 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
| 110 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 111 |
+
|
| 112 |
+
# calculate the target width and height
|
| 113 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 114 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 115 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 116 |
+
|
| 117 |
+
# resize the image
|
| 118 |
+
resized_img = image.resize((target_width, target_height))
|
| 119 |
+
processed_images = []
|
| 120 |
+
for i in range(blocks):
|
| 121 |
+
box = (
|
| 122 |
+
(i % (target_width // image_size)) * image_size,
|
| 123 |
+
(i // (target_width // image_size)) * image_size,
|
| 124 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 125 |
+
((i // (target_width // image_size)) + 1) * image_size
|
| 126 |
+
)
|
| 127 |
+
# split the image
|
| 128 |
+
split_img = resized_img.crop(box)
|
| 129 |
+
processed_images.append(split_img)
|
| 130 |
+
assert len(processed_images) == blocks
|
| 131 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 132 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 133 |
+
processed_images.append(thumbnail_img)
|
| 134 |
+
return processed_images
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def load_image(image_file, input_size=448, max_num=12):
|
| 138 |
+
image = Image.open(image_file).convert('RGB')
|
| 139 |
+
transform = build_transform(input_size=input_size)
|
| 140 |
+
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
| 141 |
+
pixel_values = [transform(image) for image in images]
|
| 142 |
+
pixel_values = torch.stack(pixel_values)
|
| 143 |
+
return pixel_values
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
pixel_values = load_image(f'{resource_path}image1.jpg', max_num=12).to(torch.bfloat16).cuda()
|
| 148 |
+
# breakpoint()
|
| 149 |
+
generation_config = dict(max_new_tokens=1024, do_sample=True)
|
| 150 |
+
generation_config = dict(
|
| 151 |
+
max_new_tokens=256,
|
| 152 |
+
do_sample=False, # Use greedy decoding to avoid sampling issues
|
| 153 |
+
temperature=0.5,
|
| 154 |
+
top_p=0.8,
|
| 155 |
+
top_k=10,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# 多模态推理测试
|
| 161 |
+
print("\n" + "="*80)
|
| 162 |
+
print("🧪 开始多模态推理测试")
|
| 163 |
+
print("="*80)
|
| 164 |
+
|
| 165 |
+
def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, num_patches_list=None):
|
| 166 |
+
"""统一的推理测试函数"""
|
| 167 |
+
print(f"\n{'='*60}")
|
| 168 |
+
print(f"🧪 测试: {test_name}")
|
| 169 |
+
print(f"📝 问题: {question}")
|
| 170 |
+
print(f"{'='*60}")
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
# 准备参数
|
| 174 |
+
chat_kwargs = {
|
| 175 |
+
'tokenizer': tokenizer,
|
| 176 |
+
'pixel_values': pixel_values_input,
|
| 177 |
+
'question': question,
|
| 178 |
+
'generation_config': generation_config,
|
| 179 |
+
'verbose': True
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
# 如果有视频数据,添加num_patches_list参数
|
| 183 |
+
if num_patches_list is not None:
|
| 184 |
+
chat_kwargs['num_patches_list'] = num_patches_list
|
| 185 |
+
|
| 186 |
+
# 如果有speech数据,添加speech参数
|
| 187 |
+
if speech_input is not None:
|
| 188 |
+
chat_kwargs.update({
|
| 189 |
+
'speech': speech_input, # mel 谱图,用于 Whisper
|
| 190 |
+
'speech_lengths': speech_lengths_input,
|
| 191 |
+
'speech_wav': speech_wavs, # 原始音频波形,用于 BEATs
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
# 执行推理
|
| 195 |
+
# breakpoint()
|
| 196 |
+
response = model.chat(**chat_kwargs)
|
| 197 |
+
|
| 198 |
+
print(f"✅ 推理成功!")
|
| 199 |
+
print(f"🤖 回复: {response}")
|
| 200 |
+
|
| 201 |
+
return True, response
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"❌ 推理失败: {str(e)}")
|
| 205 |
+
import traceback
|
| 206 |
+
traceback.print_exc()
|
| 207 |
+
return False, str(e)
|
| 208 |
+
|
| 209 |
+
# 测试1: Pure Text (应该正常,使用训练好的InternVL)
|
| 210 |
+
success1, response1 = test_inference(
|
| 211 |
+
test_name="Pure Text",
|
| 212 |
+
question="Hello, who are you? Please introduce yourself briefly.",
|
| 213 |
+
pixel_values_input=None,
|
| 214 |
+
speech_input=None,
|
| 215 |
+
speech_lengths_input=None
|
| 216 |
+
)
|
| 217 |
+
# breakpoint()
|
| 218 |
+
# 测试2: Text & Image - Visual only (应该正常,使用训练好的InternVL)
|
| 219 |
+
success2, response2 = test_inference(
|
| 220 |
+
test_name="Text & Image (Visual only)",
|
| 221 |
+
question="<image>\nPlease describe this image in detail.",
|
| 222 |
+
pixel_values_input=pixel_values,
|
| 223 |
+
speech_input=None,
|
| 224 |
+
speech_lengths_input=None
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
print("\n" + "="*60)
|
| 228 |
+
print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)")
|
| 229 |
+
print("="*60)
|
| 230 |
+
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
| 231 |
+
if bound:
|
| 232 |
+
start, end = bound[0], bound[1]
|
| 233 |
+
else:
|
| 234 |
+
start, end = -100000, 100000
|
| 235 |
+
start_idx = max(first_idx, round(start * fps))
|
| 236 |
+
end_idx = min(round(end * fps), max_frame)
|
| 237 |
+
seg_size = float(end_idx - start_idx) / num_segments
|
| 238 |
+
frame_indices = np.array([
|
| 239 |
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
| 240 |
+
for idx in range(num_segments)
|
| 241 |
+
])
|
| 242 |
+
return frame_indices
|
| 243 |
+
|
| 244 |
+
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
| 245 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
| 246 |
+
max_frame = len(vr) - 1
|
| 247 |
+
fps = float(vr.get_avg_fps())
|
| 248 |
+
|
| 249 |
+
pixel_values_list, num_patches_list = [], []
|
| 250 |
+
transform = build_transform(input_size=input_size)
|
| 251 |
+
frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
|
| 252 |
+
for frame_index in frame_indices:
|
| 253 |
+
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
|
| 254 |
+
img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
| 255 |
+
pixel_values = [transform(tile) for tile in img]
|
| 256 |
+
pixel_values = torch.stack(pixel_values)
|
| 257 |
+
num_patches_list.append(pixel_values.shape[0])
|
| 258 |
+
pixel_values_list.append(pixel_values)
|
| 259 |
+
pixel_values = torch.cat(pixel_values_list)
|
| 260 |
+
return pixel_values, num_patches_list
|
| 261 |
+
|
| 262 |
+
def load_audio(audio_file_name):
|
| 263 |
+
"""
|
| 264 |
+
加载音频文件,使用Ola风格的mel谱图预处理
|
| 265 |
+
这与原始的Ola load_audio函数保持一致
|
| 266 |
+
"""
|
| 267 |
+
speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
|
| 268 |
+
if len(speech_wav.shape) > 1:
|
| 269 |
+
speech_wav = speech_wav[:, 0]
|
| 270 |
+
speech_wav = speech_wav.astype(np.float32)
|
| 271 |
+
CHUNK_LIM = 480000
|
| 272 |
+
SAMPLE_RATE = 16000
|
| 273 |
+
speechs = []
|
| 274 |
+
speech_wavs = []
|
| 275 |
+
|
| 276 |
+
if len(speech_wav) <= CHUNK_LIM:
|
| 277 |
+
speech = whisper.pad_or_trim(speech_wav)
|
| 278 |
+
speech_wav_chunk = whisper.pad_or_trim(speech_wav)
|
| 279 |
+
speechs.append(speech)
|
| 280 |
+
speech_wavs.append(torch.from_numpy(speech_wav_chunk).unsqueeze(0))
|
| 281 |
+
else:
|
| 282 |
+
for i in range(0, len(speech_wav), CHUNK_LIM):
|
| 283 |
+
chunk = speech_wav[i : i + CHUNK_LIM]
|
| 284 |
+
if len(chunk) < CHUNK_LIM:
|
| 285 |
+
chunk = whisper.pad_or_trim(chunk)
|
| 286 |
+
speechs.append(chunk)
|
| 287 |
+
speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
|
| 288 |
+
|
| 289 |
+
# 生成mel谱图
|
| 290 |
+
mels = []
|
| 291 |
+
for chunk in speechs:
|
| 292 |
+
chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
|
| 293 |
+
mels.append(chunk)
|
| 294 |
+
|
| 295 |
+
mels = torch.cat(mels, dim=0)
|
| 296 |
+
speech_wavs = torch.cat(speech_wavs, dim=0)
|
| 297 |
+
if mels.shape[0] > 25:
|
| 298 |
+
mels = mels[:25]
|
| 299 |
+
speech_wavs = speech_wavs[:25]
|
| 300 |
+
|
| 301 |
+
speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
|
| 302 |
+
speech_chunks = torch.LongTensor([mels.shape[0]])
|
| 303 |
+
|
| 304 |
+
return mels, speech_length, speech_chunks, speech_wavs
|
| 305 |
+
|
| 306 |
+
def extract_audio(videos_file_path):
|
| 307 |
+
my_clip = mp.VideoFileClip(videos_file_path)
|
| 308 |
+
return my_clip.audio
|
| 309 |
+
|
| 310 |
+
# 加载视频数据用于视频测试
|
| 311 |
+
print("\n📥 加载视频数据...")
|
| 312 |
+
try:
|
| 313 |
+
video_path = f'{resource_path}red-panda.mp4'
|
| 314 |
+
if os.path.exists(video_path):
|
| 315 |
+
video_pixel_values, video_num_patches_list = load_video(video_path, num_segments=8, max_num=1)
|
| 316 |
+
video_pixel_values = video_pixel_values.to(torch.bfloat16).cuda()
|
| 317 |
+
video_loaded = True
|
| 318 |
+
print(f"✅ 视频加载成功:")
|
| 319 |
+
print(f" - 视频帧数: {len(video_num_patches_list)}")
|
| 320 |
+
print(f" - 视频像素值形状: {video_pixel_values.shape}")
|
| 321 |
+
print(f" - 每帧patch数: {video_num_patches_list}")
|
| 322 |
+
else:
|
| 323 |
+
print(f"⚠️ 视频文件不存在: {video_path}")
|
| 324 |
+
video_loaded = False
|
| 325 |
+
video_pixel_values = None
|
| 326 |
+
video_num_patches_list = None
|
| 327 |
+
except Exception as e:
|
| 328 |
+
print(f"❌ 视频加载失败: {e}")
|
| 329 |
+
video_loaded = False
|
| 330 |
+
video_pixel_values = None
|
| 331 |
+
video_num_patches_list = None
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
audio_path = f'/data1/cxy/dataset/english.mp3'
|
| 336 |
+
|
| 337 |
+
# 加载音频数据用于后续测试
|
| 338 |
+
print("\n📥 加载音频数据...")
|
| 339 |
+
try:
|
| 340 |
+
# 加载音频文件 - 使用Ola风格的mel谱图预处理
|
| 341 |
+
mels, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path)
|
| 342 |
+
print(f"✅ 音频加载成功:")
|
| 343 |
+
print(f" - mel谱图形状: {mels.shape}")
|
| 344 |
+
print(f" - 音频长度: {speech_lengths}")
|
| 345 |
+
print(f" - 音频块数: {speech_chunks}")
|
| 346 |
+
print(f" - 原始音频波形形状: {speech_wavs.shape}")
|
| 347 |
+
|
| 348 |
+
# 将音频数据转换为适当的格式并移到GPU
|
| 349 |
+
mels = mels.to(torch.bfloat16).cuda()
|
| 350 |
+
speech_lengths = speech_lengths.cuda()
|
| 351 |
+
speech_chunks = speech_chunks.cuda()
|
| 352 |
+
speech_wavs = speech_wavs.cuda()
|
| 353 |
+
|
| 354 |
+
audio_loaded = True
|
| 355 |
+
|
| 356 |
+
except Exception as e:
|
| 357 |
+
print(f"❌ 音频加载失败: {e}")
|
| 358 |
+
audio_loaded = False
|
| 359 |
+
mels = None
|
| 360 |
+
speech_lengths = None
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# 测试5: Video + Text (应该正常,使用训练好的InternVL)
|
| 364 |
+
if video_loaded:
|
| 365 |
+
# 构建视频帧前缀
|
| 366 |
+
video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(video_num_patches_list))])
|
| 367 |
+
video_question = video_prefix + 'What is the red panda doing in this video? Please describe the actions and movements you observe.'
|
| 368 |
+
|
| 369 |
+
success5, response5 = test_inference(
|
| 370 |
+
test_name="Video + Text",
|
| 371 |
+
question=video_question,
|
| 372 |
+
pixel_values_input=video_pixel_values,
|
| 373 |
+
speech_input=None,
|
| 374 |
+
speech_lengths_input=None,
|
| 375 |
+
num_patches_list=video_num_patches_list
|
| 376 |
+
)
|
| 377 |
+
else:
|
| 378 |
+
print("⚠️ 跳过Video + Text测试 (视频加载失败)")
|
| 379 |
+
success5 = False
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
print("\n" + "="*80)
|
| 383 |
+
print("📊 多模态推理测试总结")
|
| 384 |
+
print("="*80)
|
| 385 |
+
|
| 386 |
+
test_results = [
|
| 387 |
+
("Pure Text", success1, "PASS", "应该正常 (训练好的InternVL)"),
|
| 388 |
+
("Text & Image", success2, "PASS", "应该正常 (训练好的InternVL)"),
|
| 389 |
+
("Video + Text", success5 if video_loaded else False, "PASS", "应该正常 (训练好的InternVL)"),
|
| 390 |
+
]
|
| 391 |
+
|
| 392 |
+
for test_name, success, expected, note in test_results:
|
| 393 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 394 |
+
print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}")
|
| 395 |
+
|
| 396 |
+
passed = sum(1 for _, success, _, _ in test_results if success)
|
| 397 |
+
total = len(test_results)
|
| 398 |
+
print(f"\n📈 测试统计: {passed}/{total} 通过")
|
| 399 |
+
|
| 400 |
+
if passed >= 2: # 至少pure text、text&image、video+text中的2个应该通过
|
| 401 |
+
print("🎉 基础功能正常,Speech集成架构成功!")
|
| 402 |
+
print("💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练")
|
| 403 |
+
if passed >= 3:
|
| 404 |
+
print("🌟 所有基础模态测试都通过了!")
|
| 405 |
+
else:
|
| 406 |
+
print("⚠️ 基础功能可能存在问题,需要进一步检查")
|
| 407 |
+
|
| 408 |
+
print("\n=== 多模态推理测试完成 ===")
|
| 409 |
+
|
inference/log.txt
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[2025-09-15 09:15:42,098] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
|
| 2 |
+
LOAD_VISION_EARLY is set
|
| 3 |
+
FORCE_NO_DOWNSAMPLE is set
|
| 4 |
+
VIDEO_RESIZE is set as 0x64, 0, 64
|
| 5 |
+
HIGHRES_BASE is set as 0x32, 0, 32
|
| 6 |
+
MAXRES is set as 1536
|
| 7 |
+
MINRES is set as 0
|
| 8 |
+
VIDEO_MAXRES is set as 480
|
| 9 |
+
VIDEO_MINRES is set as 288
|
| 10 |
+
PAD2STRIDE is set
|
| 11 |
+
LOWRES_RESIZE is set as 384x32
|
| 12 |
+
Loading OlaQwen3ForCausalLM model...
|
| 13 |
+
Loading BEATs Model
|
| 14 |
+
Missing keys: ['model.speech_encoder.whisper_model.positional_embedding', 'model.speech_encoder.whisper_model.conv1.weight', 'model.speech_encoder.whisper_model.conv1.bias', 'model.speech_encoder.whisper_model.conv2.weight', 'model.speech_encoder.whisper_model.conv2.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.0.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.0.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.1.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.1.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.2.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.2.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.3.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.3.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.4.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.4.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.5.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.5.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.6.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.6.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.7.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.7.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.8.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.8.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.9.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.9.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.10.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.10.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.11.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.11.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.12.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.12.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.13.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.13.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.14.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.14.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.15.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.15.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.16.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.16.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.17.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.17.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.18.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.18.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.19.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.19.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.20.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.20.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.21.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.21.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.22.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.22.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.23.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.23.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.24.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.24.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.25.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.25.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.26.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.26.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.27.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.27.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.28.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.28.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.29.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.29.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.30.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.30.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.31.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.31.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp_ln.bias', 'model.speech_encoder.whisper_model.ln_post.weight', 'model.speech_encoder.whisper_model.ln_post.bias', 'model.speech_encoder.beats_model.post_extract_proj.weight', 'model.speech_encoder.beats_model.post_extract_proj.bias', 'model.speech_encoder.beats_model.patch_embedding.weight', 'model.speech_encoder.beats_model.encoder.pos_conv.0.bias', 'model.speech_encoder.beats_model.encoder.pos_conv.0.weight_g', 'model.speech_encoder.beats_model.encoder.pos_conv.0.weight_v', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.0.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.0.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.0.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.0.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.0.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.0.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.1.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.1.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.1.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.1.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.1.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.1.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.2.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.2.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.2.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.2.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.2.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.2.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.3.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.3.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.3.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.3.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.3.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.3.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.4.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.4.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.4.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.4.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.4.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.4.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.5.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.5.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.5.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.5.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.5.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.5.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.6.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.6.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.6.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.6.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.6.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.6.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.7.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.7.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.7.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.7.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.7.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.7.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.8.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.8.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.8.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.8.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.8.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.8.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.9.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.9.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.9.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.9.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.9.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.9.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.10.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.10.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.10.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.10.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.10.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.10.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.11.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.11.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.11.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.11.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.11.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.11.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layer_norm.bias', 'model.speech_encoder.beats_model.layer_norm.weight', 'model.speech_encoder.beats_model.layer_norm.bias', 'model.speech_encoder.beats_model.predictor.weight', 'model.speech_encoder.beats_model.predictor.bias', 'model.speech_projector.speech_newline', 'model.speech_projector.speech_begin', 'model.speech_projector.speech_end', 'model.speech_projector.linear1.weight', 'model.speech_projector.linear1.bias', 'model.speech_projector.linear2.weight', 'model.speech_projector.linear2.bias']
|
| 15 |
+
Unexpected keys: []
|
| 16 |
+
Loading vision tower...
|
| 17 |
+
Loading vision tower succeeded.
|
| 18 |
+
User: Hello, who are you?
|
| 19 |
+
Assistant: Hello! I'm PLM-V, an AI assistant created to provide information, answer questions, and help with various tasks. How can I assist you today?
|
| 20 |
+
|
| 21 |
+
================================================================================
|
| 22 |
+
🧪 开始多模态推理测试
|
| 23 |
+
================================================================================
|
| 24 |
+
|
| 25 |
+
============================================================
|
| 26 |
+
🧪 测试: Pure Text
|
| 27 |
+
📝 问题: Hello, who are you? Please introduce yourself briefly.
|
| 28 |
+
============================================================
|
| 29 |
+
<|im_start|>system
|
| 30 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 31 |
+
<|im_start|>user
|
| 32 |
+
Hello, who are you? Please introduce yourself briefly.<|im_end|>
|
| 33 |
+
<|im_start|>assistant
|
| 34 |
+
Hello! I am PLM-V, an intelligent assistant designed to provide you with detailed and accurate information, answer questions, and assist with a wide range of topics. My goal is to support you in a friendly and knowledgeable manner, making your interactions with me as informative and helpful as possible. How can I assist you today?
|
| 35 |
+
✅ 推理成功!
|
| 36 |
+
🤖 回复: Hello! I am PLM-V, an intelligent assistant designed to provide you with detailed and accurate information, answer questions, and assist with a wide range of topics. My goal is to support you in a friendly and knowledgeable manner, making your interactions with me as informative and helpful as possible. How can I assist you today?
|
| 37 |
+
|
| 38 |
+
============================================================
|
| 39 |
+
🧪 测试: Text & Image (Visual only)
|
| 40 |
+
📝 问题: <image>
|
| 41 |
+
Please describe this image in detail.
|
| 42 |
+
============================================================
|
| 43 |
+
dynamic ViT batch size: 13
|
| 44 |
+
<|im_start|>system
|
| 45 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 46 |
+
<|im_start|>user
|
| 47 |
+
<image>
|
| 48 |
+
Please describe this image in detail.<|im_end|>
|
| 49 |
+
<|im_start|>assistant
|
| 50 |
+
The image shows a close-up of a red panda, characterized by its distinctive reddish-brown fur with white markings around its face and back. The red panda appears to be leaning on a wooden structure, possibly a ledge or a part of a platform. The subject has a mix of black and white fur around its ears, eyes, and muzzle, which contrasts with its redder head. The expression on the red panda's face seems calm and curious as it looks directly at the camera. The background features blurred greenery, suggesting that the animal is in an outdoor environment with trees and plants. The setting gives a natural, outdoor feel to the image.
|
| 51 |
+
✅ 推理成功!
|
| 52 |
+
🤖 回复: The image shows a close-up of a red panda, characterized by its distinctive reddish-brown fur with white markings around its face and back. The red panda appears to be leaning on a wooden structure, possibly a ledge or a part of a platform. The subject has a mix of black and white fur around its ears, eyes, and muzzle, which contrasts with its redder head. The expression on the red panda's face seems calm and curious as it looks directly at the camera. The background features blurred greenery, suggesting that the animal is in an outdoor environment with trees and plants. The setting gives a natural, outdoor feel to the image.
|
| 53 |
+
|
| 54 |
+
============================================================
|
| 55 |
+
🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)
|
| 56 |
+
============================================================
|
| 57 |
+
|
| 58 |
+
📥 加载视频数据...
|
| 59 |
+
✅ 视频加载成功:
|
| 60 |
+
- 视频帧数: 8
|
| 61 |
+
- 视频像素值形状: torch.Size([8, 3, 448, 448])
|
| 62 |
+
- 每帧patch数: [1, 1, 1, 1, 1, 1, 1, 1]
|
| 63 |
+
|
| 64 |
+
📥 加载音频数据...
|
| 65 |
+
✅ 音频加载成功:
|
| 66 |
+
- mel谱图形状: torch.Size([2, 3000, 128])
|
| 67 |
+
- 音频长度: tensor([3000, 3000])
|
| 68 |
+
- 音频块数: tensor([2])
|
| 69 |
+
- 原始音频波形形状: torch.Size([2, 480000])
|
| 70 |
+
|
| 71 |
+
============================================================
|
| 72 |
+
🧪 测试: Audio only (预期乱码)
|
| 73 |
+
📝 问题: <speech>
|
| 74 |
+
Please transcribe and summarize what you heard in the audio.
|
| 75 |
+
============================================================
|
| 76 |
+
speech batch size: 2
|
| 77 |
+
<|im_start|>system
|
| 78 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 79 |
+
<|im_start|>user
|
| 80 |
+
<speech>
|
| 81 |
+
Please transcribe and summarize what you heard in the audio.<|im_end|>
|
| 82 |
+
<|im_start|>assistant
|
| 83 |
+
The end
|
| 84 |
+
|
| 85 |
+
The internet
|
| 86 |
+
|
| 87 |
+
or
|
| 88 |
+
|
| 89 |
+
or
|
| 90 |
+
|
| 91 |
+
structure
|
| 92 |
+
|
| 93 |
+
The overcorrection
|
| 94 |
+
|
| 95 |
+
As年人
|
| 96 |
+
|
| 97 |
+
pre-hel theorem is the following: suppose that the other
|
| 98 |
+
There is no one knows that the
|
| 99 |
+
|
| 100 |
+
We consider that the following is the
|
| 101 |
+
It is sufficient that the
|
| 102 |
+
|
| 103 |
+
Consider that the
|
| 104 |
+
|
| 105 |
+
As we see the answer to the question
|
| 106 |
+
|
| 107 |
+
Let us consider the following:
|
| 108 |
+
The initial
|
| 109 |
+
|
| 110 |
+
or, in the beginning of the text
|
| 111 |
+
|
| 112 |
+
We need to consider the following:
|
| 113 |
+
|
| 114 |
+
The answer is
|
| 115 |
+
|
| 116 |
+
Here, we need to consider the following answers
|
| 117 |
+
|
| 118 |
+
However, the question is: What is the
|
| 119 |
+
But the answer to the question is:
|
| 120 |
+
Or the sequence of the answer is:
|
| 121 |
+
|
| 122 |
+
The final answer that is:
|
| 123 |
+
|
| 124 |
+
Actually, the factor, and the issue is: The problem is:
|
| 125 |
+
The problem solved, and the answer is:
|
| 126 |
+
|
| 127 |
+
However, the problem that is: The problem posed is: For the final solution
|
| 128 |
+
|
| 129 |
+
The answer to the question is: For the following question
|
| 130 |
+
|
| 131 |
+
But the question that is solved by: The issue here is: The question we need to solve: The problem we are considering: The question that exists: The problem that we need to address: The question that is:
|
| 132 |
+
|
| 133 |
+
Let us consider the following: The problem that was
|
| 134 |
+
|
| 135 |
+
The answer is: The problem is:
|
| 136 |
+
|
| 137 |
+
We reconsider the question: The question now
|
| 138 |
+
|
| 139 |
+
The problem is: The question we have: The problem that follows: The question is: The final question: The problem that exists: The problem that is: The question posed
|
| 140 |
+
|
| 141 |
+
The final answer is: The problem: The question considered: The question posed: The final answer: The problem that we consider: The problem that was solved: The answer: The problem solved: The problem solved:
|
| 142 |
+
|
| 143 |
+
After a moment of reflection, the question, and the problem is: The question, and the problem is: The answer to the problem: The question and the problem is: The problem that the question is: The problem where the question is
|
| 144 |
+
|
| 145 |
+
However, the initial question, and the answer following these lines:
|
| 146 |
+
|
| 147 |
+
Alternatively, the problem is solved by the following question:
|
| 148 |
+
The sequence of the question and answers
|
| 149 |
+
|
| 150 |
+
The problem posed in the following
|
| 151 |
+
After some thought,the sequence: The problem that the system
|
| 152 |
+
|
| 153 |
+
The actual problem: The actual answer to the question: The actual issue that is: The actual thing that is: The actual fact that is: The actual solution to the problem: The actual moment: The actual thing: The actual purpose: The actual issue: The actual event: The actual time: The actual issue resolved: The actual thing that is: The actual thing that is: The actual moment that is: The actual situation
|
| 154 |
+
|
| 155 |
+
The end of the answer is: The end of the process is: The end of the procedure is
|
| 156 |
+
The end of the process that is:
|
| 157 |
+
|
| 158 |
+
The final answer is: The final step in the process is: The final part of the procedure is:
|
| 159 |
+
The final note that is: The final issue is: The final issue that is:
|
| 160 |
+
The final step is: The final issue solved is: The final part of the question is: The final question is: The final version of the question is: The final part of the structure is: The final structure that is: The final structure of the problem is: The final structure of the consideration is: The final consideration of the structure is: The structure that is: The structure that contains the
|
| 161 |
+
|
| 162 |
+
But the final answer is: The final answer to the question: The final answer that is: The final answer: The final answer: The final answer: The final answer: The final answer:
|
| 163 |
+
|
| 164 |
+
The final issue is: The final answer: The initial issue is:
|
| 165 |
+
|
| 166 |
+
The final answer
|
| 167 |
+
So, the final answer is: The final answer
|
| 168 |
+
|
| 169 |
+
Suppose we consider the final answer: The final answer is: The final answer
|
| 170 |
+
|
| 171 |
+
The final answer is: The final answer
|
| 172 |
+
The final answer
|
| 173 |
+
The final answer
|
| 174 |
+
The final answer
|
| 175 |
+
|
| 176 |
+
The final answer is: The final answer
|
| 177 |
+
|
| 178 |
+
The final answer is: The final answer
|
| 179 |
+
|
| 180 |
+
The final answer is: The final answer: The final answer: The final answer
|
| 181 |
+
But the final answer to the question is: The final answer: The final answer: The final answer:
|
| 182 |
+
The final answer: The final answer
|
| 183 |
+
|
| 184 |
+
The final answer
|
| 185 |
+
|
| 186 |
+
The final answer is: The final answer
|
| 187 |
+
|
| 188 |
+
The final solution is: The final answer
|
| 189 |
+
|
| 190 |
+
The final answer
|
| 191 |
+
|
| 192 |
+
The final answer is: The final answer
|
| 193 |
+
|
| 194 |
+
The final answer
|
| 195 |
+
|
| 196 |
+
The final step is: The final answer
|
| 197 |
+
|
| 198 |
+
The final answer
|
| 199 |
+
|
| 200 |
+
The final answer
|
| 201 |
+
|
| 202 |
+
The final answer is: The final answer
|
| 203 |
+
|
| 204 |
+
The final answer: The final answer
|
| 205 |
+
|
| 206 |
+
The final answer
|
| 207 |
+
|
| 208 |
+
The final answer
|
| 209 |
+
The final answer
|
| 210 |
+
The final answer
|
| 211 |
+
The final answer
|
| 212 |
+
|
| 213 |
+
The final answer (a)
|
| 214 |
+
|
| 215 |
+
But the
|
| 216 |
+
|
| 217 |
+
The final answer
|
| 218 |
+
|
| 219 |
+
The final answer
|
| 220 |
+
|
| 221 |
+
The final answer
|
| 222 |
+
|
| 223 |
+
The final answer
|
| 224 |
+
|
| 225 |
+
The final answer: The final answer
|
| 226 |
+
|
| 227 |
+
The final answer
|
| 228 |
+
|
| 229 |
+
The final answer
|
| 230 |
+
|
| 231 |
+
The final answer
|
| 232 |
+
|
| 233 |
+
The final answer
|
| 234 |
+
|
| 235 |
+
The final answer
|
| 236 |
+
|
| 237 |
+
The final answer
|
| 238 |
+
|
| 239 |
+
The final answer
|
| 240 |
+
|
| 241 |
+
The final answer
|
| 242 |
+
|
| 243 |
+
The final answer
|
| 244 |
+
|
| 245 |
+
The final answer
|
| 246 |
+
✅ 推理成功!
|
| 247 |
+
🤖 回复: The end
|
| 248 |
+
|
| 249 |
+
The internet
|
| 250 |
+
|
| 251 |
+
or
|
| 252 |
+
|
| 253 |
+
or
|
| 254 |
+
|
| 255 |
+
structure
|
| 256 |
+
|
| 257 |
+
The overcorrection
|
| 258 |
+
|
| 259 |
+
As年人
|
| 260 |
+
|
| 261 |
+
pre-hel theorem is the following: suppose that the other
|
| 262 |
+
There is no one knows that the
|
| 263 |
+
|
| 264 |
+
We consider that the following is the
|
| 265 |
+
It is sufficient that the
|
| 266 |
+
|
| 267 |
+
Consider that the
|
| 268 |
+
|
| 269 |
+
As we see the answer to the question
|
| 270 |
+
|
| 271 |
+
Let us consider the following:
|
| 272 |
+
The initial
|
| 273 |
+
|
| 274 |
+
or, in the beginning of the text
|
| 275 |
+
|
| 276 |
+
We need to consider the following:
|
| 277 |
+
|
| 278 |
+
The answer is
|
| 279 |
+
|
| 280 |
+
Here, we need to consider the following answers
|
| 281 |
+
|
| 282 |
+
However, the question is: What is the
|
| 283 |
+
But the answer to the question is:
|
| 284 |
+
Or the sequence of the answer is:
|
| 285 |
+
|
| 286 |
+
The final answer that is:
|
| 287 |
+
|
| 288 |
+
Actually, the factor, and the issue is: The problem is:
|
| 289 |
+
The problem solved, and the answer is:
|
| 290 |
+
|
| 291 |
+
However, the problem that is: The problem posed is: For the final solution
|
| 292 |
+
|
| 293 |
+
The answer to the question is: For the following question
|
| 294 |
+
|
| 295 |
+
But the question that is solved by: The issue here is: The question we need to solve: The problem we are considering: The question that exists: The problem that we need to address: The question that is:
|
| 296 |
+
|
| 297 |
+
Let us consider the following: The problem that was
|
| 298 |
+
|
| 299 |
+
The answer is: The problem is:
|
| 300 |
+
|
| 301 |
+
We reconsider the question: The question now
|
| 302 |
+
|
| 303 |
+
The problem is: The question we have: The problem that follows: The question is: The final question: The problem that exists: The problem that is: The question posed
|
| 304 |
+
|
| 305 |
+
The final answer is: The problem: The question considered: The question posed: The final answer: The problem that we consider: The problem that was solved: The answer: The problem solved: The problem solved:
|
| 306 |
+
|
| 307 |
+
After a moment of reflection, the question, and the problem is: The question, and the problem is: The answer to the problem: The question and the problem is: The problem that the question is: The problem where the question is
|
| 308 |
+
|
| 309 |
+
However, the initial question, and the answer following these lines:
|
| 310 |
+
|
| 311 |
+
Alternatively, the problem is solved by the following question:
|
| 312 |
+
The sequence of the question and answers
|
| 313 |
+
|
| 314 |
+
The problem posed in the following
|
| 315 |
+
After some thought,the sequence: The problem that the system
|
| 316 |
+
|
| 317 |
+
The actual problem: The actual answer to the question: The actual issue that is: The actual thing that is: The actual fact that is: The actual solution to the problem: The actual moment: The actual thing: The actual purpose: The actual issue: The actual event: The actual time: The actual issue resolved: The actual thing that is: The actual thing that is: The actual moment that is: The actual situation
|
| 318 |
+
|
| 319 |
+
The end of the answer is: The end of the process is: The end of the procedure is
|
| 320 |
+
The end of the process that is:
|
| 321 |
+
|
| 322 |
+
The final answer is: The final step in the process is: The final part of the procedure is:
|
| 323 |
+
The final note that is: The final issue is: The final issue that is:
|
| 324 |
+
The final step is: The final issue solved is: The final part of the question is: The final question is: The final version of the question is: The final part of the structure is: The final structure that is: The final structure of the problem is: The final structure of the consideration is: The final consideration of the structure is: The structure that is: The structure that contains the
|
| 325 |
+
|
| 326 |
+
But the final answer is: The final answer to the question: The final answer that is: The final answer: The final answer: The final answer: The final answer: The final answer:
|
| 327 |
+
|
| 328 |
+
The final issue is: The final answer: The initial issue is:
|
| 329 |
+
|
| 330 |
+
The final answer
|
| 331 |
+
So, the final answer is: The final answer
|
| 332 |
+
|
| 333 |
+
Suppose we consider the final answer: The final answer is: The final answer
|
| 334 |
+
|
| 335 |
+
The final answer is: The final answer
|
| 336 |
+
The final answer
|
| 337 |
+
The final answer
|
| 338 |
+
The final answer
|
| 339 |
+
|
| 340 |
+
The final answer is: The final answer
|
| 341 |
+
|
| 342 |
+
The final answer is: The final answer
|
| 343 |
+
|
| 344 |
+
The final answer is: The final answer: The final answer: The final answer
|
| 345 |
+
But the final answer to the question is: The final answer: The final answer: The final answer:
|
| 346 |
+
The final answer: The final answer
|
| 347 |
+
|
| 348 |
+
The final answer
|
| 349 |
+
|
| 350 |
+
The final answer is: The final answer
|
| 351 |
+
|
| 352 |
+
The final solution is: The final answer
|
| 353 |
+
|
| 354 |
+
The final answer
|
| 355 |
+
|
| 356 |
+
The final answer is: The final answer
|
| 357 |
+
|
| 358 |
+
The final answer
|
| 359 |
+
|
| 360 |
+
The final step is: The final answer
|
| 361 |
+
|
| 362 |
+
The final answer
|
| 363 |
+
|
| 364 |
+
The final answer
|
| 365 |
+
|
| 366 |
+
The final answer is: The final answer
|
| 367 |
+
|
| 368 |
+
The final answer: The final answer
|
| 369 |
+
|
| 370 |
+
The final answer
|
| 371 |
+
|
| 372 |
+
The final answer
|
| 373 |
+
The final answer
|
| 374 |
+
The final answer
|
| 375 |
+
The final answer
|
| 376 |
+
|
| 377 |
+
The final answer (a)
|
| 378 |
+
|
| 379 |
+
But the
|
| 380 |
+
|
| 381 |
+
The final answer
|
| 382 |
+
|
| 383 |
+
The final answer
|
| 384 |
+
|
| 385 |
+
The final answer
|
| 386 |
+
|
| 387 |
+
The final answer
|
| 388 |
+
|
| 389 |
+
The final answer: The final answer
|
| 390 |
+
|
| 391 |
+
The final answer
|
| 392 |
+
|
| 393 |
+
The final answer
|
| 394 |
+
|
| 395 |
+
The final answer
|
| 396 |
+
|
| 397 |
+
The final answer
|
| 398 |
+
|
| 399 |
+
The final answer
|
| 400 |
+
|
| 401 |
+
The final answer
|
| 402 |
+
|
| 403 |
+
The final answer
|
| 404 |
+
|
| 405 |
+
The final answer
|
| 406 |
+
|
| 407 |
+
The final answer
|
| 408 |
+
|
| 409 |
+
The final answer
|
| 410 |
+
|
| 411 |
+
============================================================
|
| 412 |
+
🧪 测试: Audio + Image (预期乱码)
|
| 413 |
+
📝 问题: <image>
|
| 414 |
+
User's question in speech: <speech>
|
| 415 |
+
|
| 416 |
+
============================================================
|
| 417 |
+
dynamic ViT batch size: 13
|
| 418 |
+
speech batch size: 2
|
| 419 |
+
<|im_start|>system
|
| 420 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 421 |
+
<|im_start|>user
|
| 422 |
+
<image>
|
| 423 |
+
User's question in speech: <speech>
|
| 424 |
+
<|im_end|>
|
| 425 |
+
<|im_start|>assistant
|
| 426 |
+
Bicytude: 连们在考虑了的时候 - 衋体在当前的是一场关于亚的类型、以通过的是一支集
|
| 427 |
+
在使用后端使用相同的方式的人,可能不知道,或者,其是否在何时进入,而是一次使用相同错误地管理,可能不会发生,或可能是在之前,或者是不是还有其他相关的人,或者,或者,或是否,或,或呢,或是一个,或是一个或,或是
|
| 428 |
+
|
| 429 |
+
I don't understand the given content: 迭取到一张图片,该是关于一个文本文件,内容和类似于的功能,或者是一个关于,或者是一个的人,或若是关于一个关于学习和关于其他主题的上的行为,它是什么原因(或者,或者,或者,或?)还是关于其他主题呢,或在关于其他生物,比如,或者,或者,或者,或呢,或,我需要在使用了,或者,或者,或然呢? - 我需要找到文件中关于“如何改进一个文本文件,以便于更趋时,或是在其他主题,或者是一个关于如何管理一个复杂的数据文件,或是在一个应用数据文件,或者一个关于如何在浏览器的“如何利用Python的自动控制中的“或,或是一个关于一个关于如何使用标签,或是一个关于其他事情,比如,或者一个关于另一个关于如何使用的或是什么,或者,或呢?等,或者??或呢?
|
| 430 |
+
✅ 推理成功!
|
| 431 |
+
🤖 回复: Bicytude: 连们在考虑了的时候 - 衋体在当前的是一场关于亚的类型、以通过的是一支集
|
| 432 |
+
在使用后端使用相同的方式的人,可能不知道,或者,其是否在何时进入,而是一次使用相同错误地管理,可能不会发生,或可能是在之前,或者是不是还有其他相关的人,或者,或者,或是否,或,或呢,或是一个,或是一个或,或是
|
| 433 |
+
|
| 434 |
+
I don't understand the given content: 迭取到一张图片,该是关于一个文本文件,内容和类似于的功能,或者是一个关于,或者是一个的人,或若是关于一个关于学习和关于其他主题的上的行为,它是什么原因(或者,或者,或者,或?)还是关于其他主题呢,或在关于其他生物,比如,或者,或者,或者,或呢,或,我需要在使用了,或者,或者,或然呢? - 我需要找到文件中关于“如何改进一个文本文件,以便于更趋时,或是在其他主题,或者是一个关于如何管理一个复杂的数据文件,或是在一���应用数据文件,或者一个关于如何在浏览器的“如何利用Python的自动控制中的“或,或是一个关于一个关于如何使用标签,或是一个关于其他事情,比如,或者一个关于另一个关于如何使用的或是什么,或者,或呢?等,或者??或呢?
|
| 435 |
+
|
| 436 |
+
============================================================
|
| 437 |
+
🧪 测试: Video + Text
|
| 438 |
+
📝 问题: Frame1: <image>
|
| 439 |
+
Frame2: <image>
|
| 440 |
+
Frame3: <image>
|
| 441 |
+
Frame4: <image>
|
| 442 |
+
Frame5: <image>
|
| 443 |
+
Frame6: <image>
|
| 444 |
+
Frame7: <image>
|
| 445 |
+
Frame8: <image>
|
| 446 |
+
What is the red panda doing in this video? Please describe the actions and movements you observe.
|
| 447 |
+
============================================================
|
| 448 |
+
dynamic ViT batch size: 8
|
| 449 |
+
<|im_start|>system
|
| 450 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 451 |
+
<|im_start|>user
|
| 452 |
+
Frame1: <image>
|
| 453 |
+
Frame2: <image>
|
| 454 |
+
Frame3: <image>
|
| 455 |
+
Frame4: <image>
|
| 456 |
+
Frame5: <image>
|
| 457 |
+
Frame6: <image>
|
| 458 |
+
Frame7: <image>
|
| 459 |
+
Frame8: <image>
|
| 460 |
+
What is the red panda doing in this video? Please describe the actions and movements you observe.<|im_end|>
|
| 461 |
+
<|im_start|>assistant
|
| 462 |
+
In this video, a red panda is climbing up a branch, perched on it while holding something in its mouth, and later sitting on the ground and reaching up towards bamboo sticks suspended from a tree. At one point, one of the red pandas chews on bamboo, and at another point, the blue creature is seen on the grassy ground, looking up towards the red panda. The scene then shows the red panda still perched on the branch, holding something in its mouth, and another red panda is perched on the ground, reaching up towards the bamboo on the tree. After a few moments, the panda on the ground finishes its activity and sits down on the grassy ground.
|
| 463 |
+
✅ 推理成功!
|
| 464 |
+
🤖 回复: In this video, a red panda is climbing up a branch, perched on it while holding something in its mouth, and later sitting on the ground and reaching up towards bamboo sticks suspended from a tree. At one point, one of the red pandas chews on bamboo, and at another point, the blue creature is seen on the grassy ground, looking up towards the red panda. The scene then shows the red panda still perched on the branch, holding something in its mouth, and another red panda is perched on the ground, reaching up towards the bamboo on the tree. After a few moments, the panda on the ground finishes its activity and sits down on the grassy ground.
|
| 465 |
+
|
| 466 |
+
================================================================================
|
| 467 |
+
📊 多模态推理测试总结
|
| 468 |
+
================================================================================
|
| 469 |
+
✅ PASS Pure Text (预期: PASS ) - 应该正常 (训练好的InternVL)
|
| 470 |
+
✅ PASS Text & Image (预期: PASS ) - 应该正常 (训练好的InternVL)
|
| 471 |
+
✅ PASS Video + Text (预期: PASS ) - 应该正常 (训练好的InternVL)
|
| 472 |
+
✅ PASS Audio only (预期: GARBLED ) - 可能乱码 (speech未训练)
|
| 473 |
+
✅ PASS Audio + Image (预期: GARBLED ) - 可能乱码 (speech未训练)
|
| 474 |
+
|
| 475 |
+
📈 测试统计: 5/5 通过
|
| 476 |
+
🎉 基础功能正常,Speech集成架构成功!
|
| 477 |
+
💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练
|
| 478 |
+
🌟 所有基础模态测试都通过了!
|
| 479 |
+
|
| 480 |
+
=== 多模态推理测试完成 ===
|
inference/log1.txt
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[2025-09-15 09:26:52,568] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
|
| 2 |
+
LOAD_VISION_EARLY is set
|
| 3 |
+
FORCE_NO_DOWNSAMPLE is set
|
| 4 |
+
VIDEO_RESIZE is set as 0x64, 0, 64
|
| 5 |
+
HIGHRES_BASE is set as 0x32, 0, 32
|
| 6 |
+
MAXRES is set as 1536
|
| 7 |
+
MINRES is set as 0
|
| 8 |
+
VIDEO_MAXRES is set as 480
|
| 9 |
+
VIDEO_MINRES is set as 288
|
| 10 |
+
PAD2STRIDE is set
|
| 11 |
+
LOWRES_RESIZE is set as 384x32
|
| 12 |
+
Loading OlaQwen3ForCausalLM model...
|
| 13 |
+
Loading BEATs Model
|
| 14 |
+
Missing keys: ['model.speech_encoder.whisper_model.positional_embedding', 'model.speech_encoder.whisper_model.conv1.weight', 'model.speech_encoder.whisper_model.conv1.bias', 'model.speech_encoder.whisper_model.conv2.weight', 'model.speech_encoder.whisper_model.conv2.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.0.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.0.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.1.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.1.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.2.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.2.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.3.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.3.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.4.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.4.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.5.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.5.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.6.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.6.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.7.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.7.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.8.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.8.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.9.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.9.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.10.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.10.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.11.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.11.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.12.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.12.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.13.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.13.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.14.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.14.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.15.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.15.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.16.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.16.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.17.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.17.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.18.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.18.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.19.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.19.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.20.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.20.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.21.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.21.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.22.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.22.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.23.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.23.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.24.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.24.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.25.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.25.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.26.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.26.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.27.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.27.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.28.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.28.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.29.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.29.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.30.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.30.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.31.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.31.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp_ln.bias', 'model.speech_encoder.whisper_model.ln_post.weight', 'model.speech_encoder.whisper_model.ln_post.bias', 'model.speech_encoder.beats_model.post_extract_proj.weight', 'model.speech_encoder.beats_model.post_extract_proj.bias', 'model.speech_encoder.beats_model.patch_embedding.weight', 'model.speech_encoder.beats_model.encoder.pos_conv.0.bias', 'model.speech_encoder.beats_model.encoder.pos_conv.0.weight_g', 'model.speech_encoder.beats_model.encoder.pos_conv.0.weight_v', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.0.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.0.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.0.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.0.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.0.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.0.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.1.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.1.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.1.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.1.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.1.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.1.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.2.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.2.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.2.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.2.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.2.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.2.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.3.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.3.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.3.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.3.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.3.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.3.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.4.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.4.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.4.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.4.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.4.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.4.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.5.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.5.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.5.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.5.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.5.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.5.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.6.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.6.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.6.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.6.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.6.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.6.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.7.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.7.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.7.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.7.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.7.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.7.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.8.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.8.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.8.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.8.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.8.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.8.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.9.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.9.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.9.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.9.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.9.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.9.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.10.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.10.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.10.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.10.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.10.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.10.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.11.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.11.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.11.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.11.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.11.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.11.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layer_norm.bias', 'model.speech_encoder.beats_model.layer_norm.weight', 'model.speech_encoder.beats_model.layer_norm.bias', 'model.speech_encoder.beats_model.predictor.weight', 'model.speech_encoder.beats_model.predictor.bias', 'model.speech_projector.speech_newline', 'model.speech_projector.speech_begin', 'model.speech_projector.speech_end', 'model.speech_projector.linear1.weight', 'model.speech_projector.linear1.bias', 'model.speech_projector.linear2.weight', 'model.speech_projector.linear2.bias']
|
| 15 |
+
Unexpected keys: []
|
| 16 |
+
Loading vision tower...
|
| 17 |
+
Loading vision tower succeeded.
|
| 18 |
+
User: Hello, who are you?
|
| 19 |
+
Assistant: Hello! I am an AI assistant called PLM-V. How can I help you today?
|
| 20 |
+
|
| 21 |
+
================================================================================
|
| 22 |
+
🧪 开始多模态推理测试
|
| 23 |
+
================================================================================
|
| 24 |
+
|
| 25 |
+
============================================================
|
| 26 |
+
🧪 测试: Pure Text
|
| 27 |
+
📝 问题: Hello, who are you? Please introduce yourself briefly.
|
| 28 |
+
============================================================
|
| 29 |
+
<|im_start|>system
|
| 30 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 31 |
+
<|im_start|>user
|
| 32 |
+
Hello, who are you? Please introduce yourself briefly.<|im_end|>
|
| 33 |
+
<|im_start|>assistant
|
| 34 |
+
Hello! I am PLM-V, a language model created by the platform OpenAI. My primary function is to assist users by providing information, answering questions, and engaging in conversation. I aim to be helpful, accurate, and respectful in all interactions. How can I assist you today?
|
| 35 |
+
✅ 推理成功!
|
| 36 |
+
🤖 回复: Hello! I am PLM-V, a language model created by the platform OpenAI. My primary function is to assist users by providing information, answering questions, and engaging in conversation. I aim to be helpful, accurate, and respectful in all interactions. How can I assist you today?
|
| 37 |
+
|
| 38 |
+
============================================================
|
| 39 |
+
🧪 测试: Text & Image (Visual only)
|
| 40 |
+
📝 问题: <image>
|
| 41 |
+
Please describe this image in detail.
|
| 42 |
+
============================================================
|
| 43 |
+
dynamic ViT batch size: 13
|
| 44 |
+
<|im_start|>system
|
| 45 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 46 |
+
<|im_start|>user
|
| 47 |
+
<image>
|
| 48 |
+
Please describe this image in detail.<|im_end|>
|
| 49 |
+
<|im_start|>assistant
|
| 50 |
+
The image shows a cute red panda sitting on a wooden platform. This reddish-brown animal has distinctive black and white markings: a white face with black stripes on its cheeks and around its nose. Its fur appears soft and fluffy. The red panda has large, dark eyes and white whiskers, adding to its endearing appearance. It is resting close to a tree trunk, with its black ears perked up on either side. The background is filled with blurred green foliage, suggesting this red panda is in a natural or outdoor setting, possibly a zoo or wildlife sanctuary. The wooden platform appears to be part of a structure designed for the animal to rest or climb on comfortably.
|
| 51 |
+
✅ 推理成功!
|
| 52 |
+
🤖 回复: The image shows a cute red panda sitting on a wooden platform. This reddish-brown animal has distinctive black and white markings: a white face with black stripes on its cheeks and around its nose. Its fur appears soft and fluffy. The red panda has large, dark eyes and white whiskers, adding to its endearing appearance. It is resting close to a tree trunk, with its black ears perked up on either side. The background is filled with blurred green foliage, suggesting this red panda is in a natural or outdoor setting, possibly a zoo or wildlife sanctuary. The wooden platform appears to be part of a structure designed for the animal to rest or climb on comfortably.
|
| 53 |
+
|
| 54 |
+
============================================================
|
| 55 |
+
🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)
|
| 56 |
+
============================================================
|
| 57 |
+
|
| 58 |
+
📥 加载视频数据...
|
| 59 |
+
✅ 视频加载成功:
|
| 60 |
+
- 视频帧数: 8
|
| 61 |
+
- 视频像素值形状: torch.Size([8, 3, 448, 448])
|
| 62 |
+
- 每帧patch数: [1, 1, 1, 1, 1, 1, 1, 1]
|
| 63 |
+
|
| 64 |
+
📥 加载音频数据...
|
| 65 |
+
✅ 音频加载成功:
|
| 66 |
+
- mel谱图形状: torch.Size([2, 3000, 128])
|
| 67 |
+
- 音频长度: tensor([3000, 3000])
|
| 68 |
+
- 音频块数: tensor([2])
|
| 69 |
+
- 原始音频波形形状: torch.Size([2, 480000])
|
| 70 |
+
|
| 71 |
+
============================================================
|
| 72 |
+
🧪 测试: Audio only (预期乱码)
|
| 73 |
+
📝 问题: <speech>
|
| 74 |
+
Please transcribe and summarize what you heard in the audio.
|
| 75 |
+
============================================================
|
| 76 |
+
speech batch size: 2
|
| 77 |
+
<|im_start|>system
|
| 78 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 79 |
+
<|im_start|>user
|
| 80 |
+
<speech>
|
| 81 |
+
Please transcribe and summarize what you heard in the audio.<|im_end|>
|
| 82 |
+
<|im_start|>assistant
|
| 83 |
+
The秘密 is currently notPowered should be valid.
|
| 84 |
+
|
| 85 |
+
It's possible that's a complex combination of unsub,$
|
| 86 |
+
|
| 87 |
+
Here's an unexpected task is to assess the����是有效的。
|
| 88 |
+
|
| 89 |
+
Now, the秘密
|
| 90 |
+
|
| 91 |
+
### 払下秘密是从加密的
|
| 92 |
+
|
| 93 |
+
Here的秘密
|
| 94 |
+
|
| 95 |
+
It seems the encrypted message has been truncated. It's a confusing combination of unworking! We're not seeing the message from a valid IP的操作未完成. We're here to assist."
|
| 96 |
+
|
| 97 |
+
It's a sequence of 10.21.48. Here we can assist to encrypting the message. This is the result of a combination of unworking. If you're trying to determine the message that the code for the current message. It seems we're not seeing the message that the network can't work. We're too broad.
|
| 98 |
+
|
| 99 |
+
Therefore, the final answer is that we can't help you. We're currently working to inform you. We're now able to help you.
|
| 100 |
+
|
| 101 |
+
What did you saw?"
|
| 102 |
+
|
| 103 |
+
It seems like the message is invalid. In fact, the message is to decrypt the message that the network can't work. Here's an encryption that's confusing. We're not really saying that the combination is not possible. We're unable to help you.
|
| 104 |
+
|
| 105 |
+
However, it seems that the message is returning to its initial message. We can assist in determining the message that the network can't help. We're unable to assist with this message.
|
| 106 |
+
|
| 107 |
+
To assist in the current message, we need to decrypt the message that the network is not working. We'll now assist with the message that the network can't help.
|
| 108 |
+
|
| 109 |
+
However, it seems that the message is not working. We can't do that.
|
| 110 |
+
|
| 111 |
+
To break the message, we need to help you with the message that the network is not working. We'll now assist with the message that the network can't assist.
|
| 112 |
+
|
| 113 |
+
Based on the message, the network seems to be having a difficulty. We're unable to do that. We can assist in determining the message that the network is not working. We're unable to assist with the message that the network is not working.
|
| 114 |
+
|
| 115 |
+
It seems like the message is not working. We can't assist with the message that the network is not allowed to help you.
|
| 116 |
+
|
| 117 |
+
Please, however, we can assist with the message that the network is not working. We're currently unable to help you.
|
| 118 |
+
|
| 119 |
+
Are you able to help with the message that the network is not working? We can't help you with the message that the network is not available to assist with the message.
|
| 120 |
+
|
| 121 |
+
But the message is not doing that. We can't assist with the message that the network can't assist you.
|
| 122 |
+
|
| 123 |
+
In this message, we need to help you with the message that the network is not working. We'll help the message that the network can't assist with the message.
|
| 124 |
+
|
| 125 |
+
It seems like the message is not working. We're unable to help you with the message that the network is not working.
|
| 126 |
+
|
| 127 |
+
Let's focus on the message that the network is not working. We cannot assist with the message that the network is not available to assist with the message.
|
| 128 |
+
|
| 129 |
+
Given that the message is not working, we'll now assist with the message that the network can't assist.
|
| 130 |
+
|
| 131 |
+
However, it seems that the message cannot be provided based on the message that the network is not working.
|
| 132 |
+
|
| 133 |
+
Let's continue to assist with the message that the network can't assist.
|
| 134 |
+
|
| 135 |
+
It seems like the message is not possible for the message to assist with the message that the network can't assist.
|
| 136 |
+
|
| 137 |
+
We are unable to assist with the message that the network is not working.
|
| 138 |
+
|
| 139 |
+
Please, however, we are unable to assist with the message that the network is not working.
|
| 140 |
+
|
| 141 |
+
The message is not working. We're currently unable to assist with the message that the network is not working.
|
| 142 |
+
|
| 143 |
+
We are unable to assist with the message that the network is not working.
|
| 144 |
+
|
| 145 |
+
The message seems to be unreadable. We're unable to help with the message that the network is not working.
|
| 146 |
+
|
| 147 |
+
Please, we are unable to help with the message that the message is not working.
|
| 148 |
+
|
| 149 |
+
The task is asking for help with the message that the network is not working.
|
| 150 |
+
|
| 151 |
+
It seems that the message is not possible for the message to assist with the message that the network is not working.
|
| 152 |
+
|
| 153 |
+
Let's assume the message is not working. We can't assist with the message that the message is not working.
|
| 154 |
+
|
| 155 |
+
The message is not possible to assist with the message that the network is not working.
|
| 156 |
+
|
| 157 |
+
We're unable to help with the message that the network is not working.
|
| 158 |
+
|
| 159 |
+
Here's the message that the network is not working.
|
| 160 |
+
|
| 161 |
+
The message is not possible to help with the message that the network is not working.
|
| 162 |
+
|
| 163 |
+
The message is a mystery. We're unable to assist with the message that the network is not working.
|
| 164 |
+
|
| 165 |
+
It seems that the message is not working. We can't assist with the message that the network is not working.
|
| 166 |
+
|
| 167 |
+
We're unable to help with the message that the
|
| 168 |
+
✅ 推理成功!
|
| 169 |
+
🤖 回复: The秘密 is currently notPowered should be valid.
|
| 170 |
+
|
| 171 |
+
It's possible that's a complex combination of unsub,$
|
| 172 |
+
|
| 173 |
+
Here's an unexpected task is to assess the����是有效的。
|
| 174 |
+
|
| 175 |
+
Now, the秘密
|
| 176 |
+
|
| 177 |
+
### 払下秘密是从加密的
|
| 178 |
+
|
| 179 |
+
Here的秘密
|
| 180 |
+
|
| 181 |
+
It seems the encrypted message has been truncated. It's a confusing combination of unworking! We're not seeing the message from a valid IP的操作未完成. We're here to assist."
|
| 182 |
+
|
| 183 |
+
It's a sequence of 10.21.48. Here we can assist to encrypting the message. This is the result of a combination of unworking. If you're trying to determine the message that the code for the current message. It seems we're not seeing the message that the network can't work. We're too broad.
|
| 184 |
+
|
| 185 |
+
Therefore, the final answer is that we can't help you. We're currently working to inform you. We're now able to help you.
|
| 186 |
+
|
| 187 |
+
What did you saw?"
|
| 188 |
+
|
| 189 |
+
It seems like the message is invalid. In fact, the message is to decrypt the message that the network can't work. Here's an encryption that's confusing. We're not really saying that the combination is not possible. We're unable to help you.
|
| 190 |
+
|
| 191 |
+
However, it seems that the message is returning to its initial message. We can assist in determining the message that the network can't help. We're unable to assist with this message.
|
| 192 |
+
|
| 193 |
+
To assist in the current message, we need to decrypt the message that the network is not working. We'll now assist with the message that the network can't help.
|
| 194 |
+
|
| 195 |
+
However, it seems that the message is not working. We can't do that.
|
| 196 |
+
|
| 197 |
+
To break the message, we need to help you with the message that the network is not working. We'll now assist with the message that the network can't assist.
|
| 198 |
+
|
| 199 |
+
Based on the message, the network seems to be having a difficulty. We're unable to do that. We can assist in determining the message that the network is not working. We're unable to assist with the message that the network is not working.
|
| 200 |
+
|
| 201 |
+
It seems like the message is not working. We can't assist with the message that the network is not allowed to help you.
|
| 202 |
+
|
| 203 |
+
Please, however, we can assist with the message that the network is not working. We're currently unable to help you.
|
| 204 |
+
|
| 205 |
+
Are you able to help with the message that the network is not working? We can't help you with the message that the network is not available to assist with the message.
|
| 206 |
+
|
| 207 |
+
But the message is not doing that. We can't assist with the message that the network can't assist you.
|
| 208 |
+
|
| 209 |
+
In this message, we need to help you with the message that the network is not working. We'll help the message that the network can't assist with the message.
|
| 210 |
+
|
| 211 |
+
It seems like the message is not working. We're unable to help you with the message that the network is not working.
|
| 212 |
+
|
| 213 |
+
Let's focus on the message that the network is not working. We cannot assist with the message that the network is not available to assist with the message.
|
| 214 |
+
|
| 215 |
+
Given that the message is not working, we'll now assist with the message that the network can't assist.
|
| 216 |
+
|
| 217 |
+
However, it seems that the message cannot be provided based on the message that the network is not working.
|
| 218 |
+
|
| 219 |
+
Let's continue to assist with the message that the network can't assist.
|
| 220 |
+
|
| 221 |
+
It seems like the message is not possible for the message to assist with the message that the network can't assist.
|
| 222 |
+
|
| 223 |
+
We are unable to assist with the message that the network is not working.
|
| 224 |
+
|
| 225 |
+
Please, however, we are unable to assist with the message that the network is not working.
|
| 226 |
+
|
| 227 |
+
The message is not working. We're currently unable to assist with the message that the network is not working.
|
| 228 |
+
|
| 229 |
+
We are unable to assist with the message that the network is not working.
|
| 230 |
+
|
| 231 |
+
The message seems to be unreadable. We're unable to help with the message that the network is not working.
|
| 232 |
+
|
| 233 |
+
Please, we are unable to help with the message that the message is not working.
|
| 234 |
+
|
| 235 |
+
The task is asking for help with the message that the network is not working.
|
| 236 |
+
|
| 237 |
+
It seems that the message is not possible for the message to assist with the message that the network is not working.
|
| 238 |
+
|
| 239 |
+
Let's assume the message is not working. We can't assist with the message that the message is not working.
|
| 240 |
+
|
| 241 |
+
The message is not possible to assist with the message that the network is not working.
|
| 242 |
+
|
| 243 |
+
We're unable to help with the message that the network is not working.
|
| 244 |
+
|
| 245 |
+
Here's the message that the network is not working.
|
| 246 |
+
|
| 247 |
+
The message is not possible to help with the message that the network is not working.
|
| 248 |
+
|
| 249 |
+
The message is a mystery. We're unable to assist with the message that the network is not working.
|
| 250 |
+
|
| 251 |
+
It seems that the message is not working. We can't assist with the message that the network is not working.
|
| 252 |
+
|
| 253 |
+
We're unable to help with the message that the
|
| 254 |
+
|
| 255 |
+
============================================================
|
| 256 |
+
🧪 测试: Audio + Image (预期乱码)
|
| 257 |
+
📝 问题: <image>
|
| 258 |
+
User's question in speech: <speech>
|
| 259 |
+
|
| 260 |
+
============================================================
|
| 261 |
+
dynamic ViT batch size: 13
|
| 262 |
+
speech batch size: 2
|
| 263 |
+
<|im_start|>system
|
| 264 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 265 |
+
<|im_start|>user
|
| 266 |
+
<image>
|
| 267 |
+
User's question in speech: <speech>
|
| 268 |
+
<|im_end|>
|
| 269 |
+
<|im_start|>assistant
|
| 270 |
+
In the previous text 10个
|
| 271 |
+
|
| 272 |
+
在是在一个建筑中,一下的人们在使用前一张“我们来学习的是该地区的植物学10月5月的浏览器:这是一个关于在是在一个大的是一张复杂的、我们无法想象的,我需要找出数组中的是否是什么,如何使用异,这个代码看起来不太的人群在该地区进行了8年,然后在使用了4个的用户,我需要将一个的是一张更大的网站,以便确保我的眼睛在没有的,我需要在进行中的人群中,我的网络里,以确保我的眼睛,然后将帮助我理解一下,以便确保我在学习中,需要考虑在使用中的使用的人群之间出现的获取的代码,以便确保,我会帮助您确认,我的代码,或者我需要确认,我没有在不使用的,因为我总是不知道,因为我们需要在的,因为我的困惑,或者因为,因为我的困惑感到不安)等等,我需要帮助控制,或者因为个人在,我需要确保我的眼睛,或者因为我求如何在进行了8年,我需要帮助我的人在或其他方面,我需要确保,我可以尝试以其他的,确保我能提供或不确定,我需要确保,或者为了确保我的眼睛,确保我的眼睛,或者我不确定,我不计算的,或因为我的眼睛,因为我经常在的用户,或者如果使用其他方法,但我需要确保,我需要帮助我,或者因为某些原因,我如何处理使用与我使用Python编程语言使用一个示的用户,我需要确认,我,我、确保,或者在什么时候要使用不同的技巧,但可能,或者我想要确认,或者因为,我需要确保,我无法确保,我需要确保,或者因为,我需要确认,或者我不会知道更多信息,或者我需要确定,或者我需要,我没有,或者可能,或者因为,或者我需要确保,我需要确认,我没有,或者我认为我是否需要,我需要确保,我需要确保,或者我需要帮助,我需要找到,或者我需要在使用了“inferred - �国我需要确保,或者我需要确认,我需要考虑,或者我需要帮助我,我需要知道,我没有,或者我需要,我用一下如何提高或我需要做到的,我需要确认,或者我需要确认,或者我需要,或者我没有,或可能通过这些代码或其他方法,我需要帮助,或者我需要知道,或我需要知道,或我需要确保,我需要,或我需要确保,我需要知道,或者我需要确保,或者其他类似任何事情,或者我需要确保,等等。我需要帮助我如何确保,或者我需要知道,我需要确保,我需要确定,或者我需要处理,或者我需要帮助,或者我不会知道,或者我需要确保,或者我需要帮助,或者我需要,我需要确认,或我需要,我需要确如何确保,我需要确保,或我需要确保,或者我需要保证,我需要确保我需要,或我需要,即使,我需要,我需要确保,或可能需要帮助,我需要确保,或者我需要确保,我需要确保我需要确保,我需要确保,或我需要确定,或我需要确保,或者我需要,我需要确保,我需要,我需要确保,或者我需要其他方式,或者我需要,或我需要,或者我需要,或我需要,或者我我需要,或我我需要,或我需要,我需要,我需要,或我需要,或我需要,或其他,我需要——,或我需要,我需要确认,或者我需要,我需要确保,我需要,或我需要,或我需要,或我在使用了,我需要,但即使,我需要,或我需要,或者我需要,或我需要,或我在通过使用于,或我需要,或者我需要,我需要,或我需要,我需要,或我需要,但即使,我需要,或者我需要,或我需要,或者我的担心,我需要,或我需要,或我需要,我们,或我需要,或者我需要,或我需要,或我我需要,我需要,或我需要,或我需要,或我在使用了,在我需要,或我需要,且我需要,或者我需要,我需要,我需要,或我需要,或者,我需要,让我们需要,或我需要,或我需要,或我需要,或我需要,那么我需要,或者我需要,或我是否需要,或我需要,或我需要,或我需要,你能够帮助我,我需要,或我需要,我需要,或我需要,或我需要是否,或我需要,我需要,让我需要,或我需要
|
| 273 |
+
✅ 推理成功!
|
| 274 |
+
🤖 回复: In the previous text 10个
|
| 275 |
+
|
| 276 |
+
在是在一个建筑中,一下的人们在使用前一张“我们来学习的是该地区的植物学10月5月的浏览器:这是一个关于在是在一个大的是一张复杂的、我们无法想象的,我需要找出数组中的是否是什么,如何使用异,这个代码看起来不太的人群在该地区进行了8年,然后在使用了4个的用户,我需要将一个的是一张更大的网站,以便确保我的眼睛在没有的,我需要在进行中的人群中,我的网络里,以确保我的眼睛,然后将帮助我理解一下,以便确保我在学习中,需要考虑在使用中的使用的人群之间出现的获取的代码,以便确保,我会帮助您确认,我的代码,或者我需要确认,我没有在不使用的,因为我总是不知道,因为我们需要在的,因为我的困惑,或者因为,因为我的困惑感到不安)等等,我需要帮助控制,或者因为个人在,我需要确保我的眼睛,或者因为我求如何在进行了8年,我需要帮助我的人在或其他方面,我需要确保,我可以尝试以其他的,确保我能提供或不确定,我需要确保,或者为了确保我的眼睛,确保我的眼睛,或者我不确定,我不计算的,或因为我的眼睛,因为我经常在的用户,或者如果使用其他方法,但我需要确保,我需要帮助我,或者因为某些原因,我如何处理使用与我使用Python编程语言使用一个示的用户,我需要确认,我,我、确保,或者在什么时候要使用不同的技巧,但可能,或者我想要确认,或者因为,我需要确保,我无法确保,我需要确保,或者因为,我需要确认,或者我不会知道更多信息,或者我需要确定,或者我需要,我没有,或者可能,或者因为,或者我需要确保,我需要确认,我没有,或者我认为我是否需要,我需要确保,我需要确保,或者我需要帮助,我需要找到,或者我需要在使用了“inferred - �国我需要确保,或者我需要确认,我需要考虑,或者我需要帮助我,我需要知道,我没有,或者我需要,我用一下如何提高或我需要做到的,我需要确认,或者我需要确认,或者我需要,或者我没有,或可能通过这些代码或其他方法,我需要帮助,或者我需要知道,或我需要知道,或我需要确保,我需要,或我需要确保,我需要知道,或者我需要确保,或者其他类似任何事情,或者我需要确保,等等。我需要帮助我如何确保,或者我需要知道,我需要确保,我需要确定,或者我需要处理,或者我需要帮助,或者我不会知道,或者我需要确保,或者我需要帮助,或者我需要,我需要确认,或我需要,我需要确如何确保,我需要确保,或我需要确保,或者我需要保证,我需要确保我需要,或我需要,即使,我需要,我需要确保,或可能需要帮助��我需要确保,或者我需要确保,我需要确保我需要确保,我需要确保,或我需要确定,或我需要确保,或者我需要,我需要确保,我需要,我需要确保,或者我需要其他方式,或者我需要,或我需要,或者我需要,或我需要,或者我我需要,或我我需要,或我需要,我需要,我需要,或我需要,或我需要,或其他,我需要——,或我需要,我需要确认,或者我需要,我需要确保,我需要,或我需要,或我需要,或我在使用了,我需要,但即使,我需要,或我需要,或者我需要,或我需要,或我在通过使用于,或我需要,或者我需要,我需要,或我需要,我需要,或我需要,但即使,我需要,或者我需要,或我需要,或者我的担心,我需要,或我需要,或我需要,我们,或我需要,或者我需要,或我需要,或我我需要,我需要,或我需要,或我需要,或我在使用了,在我需要,或我需要,且我需要,或者我需要,我需要,我需要,或我需要,或者,我需要,让我们需要,或我需要,或我需要,或我需要,或我需要,那么我需要,或者我需要,或我是否需要,或我需要,或我需要,或我需要,你能够帮助我,我需要,或我需要,我需要,或我需要,或我需要是否,或我需要,我需要,让我需要,或我需要
|
| 277 |
+
|
| 278 |
+
============================================================
|
| 279 |
+
🧪 测试: Video + Text
|
| 280 |
+
📝 问题: Frame1: <image>
|
| 281 |
+
Frame2: <image>
|
| 282 |
+
Frame3: <image>
|
| 283 |
+
Frame4: <image>
|
| 284 |
+
Frame5: <image>
|
| 285 |
+
Frame6: <image>
|
| 286 |
+
Frame7: <image>
|
| 287 |
+
Frame8: <image>
|
| 288 |
+
What is the red panda doing in this video? Please describe the actions and movements you observe.
|
| 289 |
+
============================================================
|
| 290 |
+
dynamic ViT batch size: 8
|
| 291 |
+
<|im_start|>system
|
| 292 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 293 |
+
<|im_start|>user
|
| 294 |
+
Frame1: <image>
|
| 295 |
+
Frame2: <image>
|
| 296 |
+
Frame3: <image>
|
| 297 |
+
Frame4: <image>
|
| 298 |
+
Frame5: <image>
|
| 299 |
+
Frame6: <image>
|
| 300 |
+
Frame7: <image>
|
| 301 |
+
Frame8: <image>
|
| 302 |
+
What is the red panda doing in this video? Please describe the actions and movements you observe.<|im_end|>
|
| 303 |
+
<|im_start|>assistant
|
| 304 |
+
The red panda in this video is shown eating bamboo and holding a piece of bamboo. In the beginning, the red panda is eating bamboo from the other end of the structure while the baby in front of it is reaching up to eat. They move to the right, and the adult red panda is eating bamboo while the baby continues to reach up. Towards the end, the baby starts eating a piece of bamboo while the adult is eating bamboo from the structure above.
|
| 305 |
+
✅ 推理成功!
|
| 306 |
+
🤖 回复: The red panda in this video is shown eating bamboo and holding a piece of bamboo. In the beginning, the red panda is eating bamboo from the other end of the structure while the baby in front of it is reaching up to eat. They move to the right, and the adult red panda is eating bamboo while the baby continues to reach up. Towards the end, the baby starts eating a piece of bamboo while the adult is eating bamboo from the structure above.
|
| 307 |
+
|
| 308 |
+
============================================================
|
| 309 |
+
🧪 测试: Video + Audio (预期乱码)
|
| 310 |
+
📝 问题: <speech><image>
|
| 311 |
+
Describe what you hear and see in this content.
|
| 312 |
+
============================================================
|
| 313 |
+
dynamic ViT batch size: 13
|
| 314 |
+
speech batch size: 2
|
| 315 |
+
<|im_start|>system
|
| 316 |
+
You are PLM-V, a helpful assistant.<|im_end|>
|
| 317 |
+
<|im_start|>user
|
| 318 |
+
<speech><image>
|
| 319 |
+
Describe what you hear and see in this content.<|im_end|>
|
| 320 |
+
<|im_start|>assistant
|
| 321 |
+
The first step is to determine the 301445560526679335.928824157.5.5, 12.496688788434,49505993735.3846390994.45546936.455539779387". Which of the 201863188435.3179473.943878315546397753.3976.388571.7538856.569376.454694876.56974376.937831.63687.48876.388571.7764387938094875.3880948178093961.7831809387546.9378318878315.776383187673876.83187648780937883187525.3856839375469783188787831875831879838378763831831875469839188767383187583187876383187546983783187145935.429563977878318763831875831868781763831887583187638388093787831887831875469837831887831875831876383831875469837831878318783187546983783187831875831876383831875831876383831875831878318763831875831878318763831875831876.26.33.73.176.1235.32.14.1235.23.48.987659934.763839376.38.7966.8376.393776.23.876.37635.917.53.4687836.75.9356.58378788.7678393783188786.388876.4388876.4387.46.8839376.3937809487.438763831887831875.88317831.776.94699378831875938.783187583188783187583187638318758318763838318758318763831878318758318783187638318758318766.38876.3876.43876.4387839378.3187583187583187638318758318763838
|
| 322 |
+
✅ 推理成功!
|
| 323 |
+
🤖 回复: The first step is to determine the 301445560526679335.928824157.5.5, 12.496688788434,49505993735.3846390994.45546936.455539779387". Which of the 201863188435.3179473.943878315546397753.3976.388571.7538856.569376.454694876.56974376.937831.63687.48876.388571.7764387938094875.3880948178093961.7831809387546.9378318878315.776383187673876.83187648780937883187525.3856839375469783188787831875831879838378763831831875469839188767383187583187876383187546983783187145935.429563977878318763831875831868781763831887583187638388093787831887831875469837831887831875831876383831875469837831878318783187546983783187831875831876383831875831876383831875831878318763831875831878318763831875831876.26.33.73.176.1235.32.14.1235.23.48.987659934.763839376.38.7966.8376.393776.23.876.37635.917.53.4687836.75.9356.58378788.7678393783188786.388876.4388876.4387.46.8839376.3937809487.438763831887831875.88317831.776.94699378831875938.783187583188783187583187638318758318763838318758318763831878318758318783187638318758318766.38876.3876.43876.4387839378.3187583187583187638318758318763838
|
| 324 |
+
|
| 325 |
+
================================================================================
|
| 326 |
+
📊 多模态推理测试总结
|
| 327 |
+
================================================================================
|
| 328 |
+
✅ PASS Pure Text (预期: PASS ) - 应该正常 (训练好的InternVL)
|
| 329 |
+
✅ PASS Text & Image (预期: PASS ) - 应该正常 (训练好的InternVL)
|
| 330 |
+
✅ PASS Video + Text (预期: PASS ) - 应该正常 (训练好的InternVL)
|
| 331 |
+
✅ PASS Audio only (预期: GARBLED ) - 可能乱码 (speech未训练)
|
| 332 |
+
✅ PASS Audio + Image (预期: GARBLED ) - 可能乱码 (speech未训练)
|
| 333 |
+
|
| 334 |
+
📈 测试统计: 5/5 通过
|
| 335 |
+
🎉 基础功能正常,Speech集成架构成功!
|
| 336 |
+
💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练
|
| 337 |
+
🌟 所有基础模态测试都通过了!
|
| 338 |
+
|
| 339 |
+
=== 多模态推理测试完成 ===
|
ola.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: ola
|
| 3 |
+
Version: 1.0.0
|
| 4 |
+
Summary: Omni-Modal Language Model
|
| 5 |
+
Classifier: Programming Language :: Python :: 3
|
| 6 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 7 |
+
Requires-Python: >=3.10
|
| 8 |
+
Description-Content-Type: text/markdown
|
| 9 |
+
License-File: LICENSE
|
| 10 |
+
Requires-Dist: torch==2.1.2
|
| 11 |
+
Requires-Dist: torchvision==0.16.2
|
| 12 |
+
Requires-Dist: torchaudio==2.1.2
|
| 13 |
+
Requires-Dist: transformers==4.43.4
|
| 14 |
+
Requires-Dist: tokenizers==0.19.1
|
| 15 |
+
Requires-Dist: sentencepiece==0.1.99
|
| 16 |
+
Requires-Dist: shortuuid
|
| 17 |
+
Requires-Dist: accelerate==0.33.0
|
| 18 |
+
Requires-Dist: peft==0.11.1
|
| 19 |
+
Requires-Dist: bitsandbytes==0.43.1
|
| 20 |
+
Requires-Dist: pydantic
|
| 21 |
+
Requires-Dist: markdown2[all]
|
| 22 |
+
Requires-Dist: numpy
|
| 23 |
+
Requires-Dist: scikit-learn==1.2.2
|
| 24 |
+
Requires-Dist: gradio==4.43.0
|
| 25 |
+
Requires-Dist: gradio_client==1.3.0
|
| 26 |
+
Requires-Dist: requests
|
| 27 |
+
Requires-Dist: httpx==0.27.2
|
| 28 |
+
Requires-Dist: uvicorn
|
| 29 |
+
Requires-Dist: fastapi
|
| 30 |
+
Requires-Dist: soundfile
|
| 31 |
+
Requires-Dist: einops==0.6.1
|
| 32 |
+
Requires-Dist: einops-exts==0.0.4
|
| 33 |
+
Requires-Dist: timm==0.9.16
|
| 34 |
+
Requires-Dist: openai-whisper
|
| 35 |
+
Requires-Dist: setuptools==59.5.0
|
| 36 |
+
Requires-Dist: omegaconf==2.0.6
|
| 37 |
+
Requires-Dist: loguru
|
| 38 |
+
Requires-Dist: av
|
| 39 |
+
Requires-Dist: librosa
|
| 40 |
+
Provides-Extra: train
|
| 41 |
+
Requires-Dist: deepspeed==0.12.6; extra == "train"
|
| 42 |
+
Requires-Dist: ninja; extra == "train"
|
| 43 |
+
Requires-Dist: wandb; extra == "train"
|
| 44 |
+
Requires-Dist: tensorboardX; extra == "train"
|
| 45 |
+
Provides-Extra: build
|
| 46 |
+
Requires-Dist: build; extra == "build"
|
| 47 |
+
Requires-Dist: twine; extra == "build"
|
| 48 |
+
Dynamic: license-file
|
| 49 |
+
|
| 50 |
+
<p align="center" width="100%">
|
| 51 |
+
<img src="https://ola-omni.github.io/static/images/ola-icon.png" alt="967023137dff29e65b21544e7620e0f7.webp" width=60%>
|
| 52 |
+
</p>
|
| 53 |
+
<div>
|
| 54 |
+
|
| 55 |
+
## Ola: Pushing the Frontiers of Omni-Modal Language Model
|
| 56 |
+
|
| 57 |
+
<p align="left">
|
| 58 |
+
<a href='https://github.com/liuzuyan' target='_blank'>Zuyan Liu<sup>*,1,2</sup></a> 
|
| 59 |
+
<a href='https://github.com/dongyh20/' target='_blank'>Yuhao Dong<sup>*,2,3</sup></a> 
|
| 60 |
+
Jiahui Wang<sup>1</sup></a> <br>
|
| 61 |
+
<a href='https://liuziwei7.github.io/' target='_blank'>Ziwei Liu<sup>3</sup></a> 
|
| 62 |
+
Winston Hu<sup>2</sup></a> 
|
| 63 |
+
<a href='https://scholar.google.com/citations?user=TN8uDQoAAAAJ' target='_blank'>Jiwen Lu<sup>1,✉</sup></a> 
|
| 64 |
+
<a href='https://raoyongming.github.io/' target='_blank'>Yongming Rao<sup>2,1,✉</sup></a> 
|
| 65 |
+
</p>
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
<p align="left"><sup>1</sup>Tsinghua University   <sup>2</sup>Tencent Hunyuan Research  <sup>3</sup>S-Lab, NTU  </p>
|
| 69 |
+
|
| 70 |
+
<p align="left"><sup>*</sup> Equal Contribution<sup>  ✉</sup> Corresponding Author</p>
|
| 71 |
+
|
| 72 |
+
[-blue)](https://rank.opencompass.org.cn/leaderboard-multimodal/?m=REALTIME) [](https://video-mme.github.io/home_page.html#leaderboard)
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
**Project Page:** [](https://ola-omni.github.io)
|
| 77 |
+
|
| 78 |
+
**Weights in Huggingface:** [](https://huggingface.co/THUdyh/Ola-7b) [](https://huggingface.co/THUdyh/Ola-Image) [](https://huggingface.co/THUdyh/Ola-Video)
|
| 79 |
+
|
| 80 |
+
**arXiv Paper:** [](https://arxiv.org/abs/2502.04328)
|
| 81 |
+
|
| 82 |
+
**Demo by Gradio:** [](https://huggingface.co/spaces/THUdyh/Ola)
|
| 83 |
+
|
| 84 |
+
**Training Data:** [](https://huggingface.co/datasets/THUdyh/Ola-Data)
|
| 85 |
+
|
| 86 |
+
**中文解读**: [](https://mp.weixin.qq.com/s/N4bjcHOejJudtxTFZVAXmg)
|
| 87 |
+
|
| 88 |
+
Contact: Leave an issue or contact liuzuyan19@gmail.com . We are on call to respond.
|
| 89 |
+
|
| 90 |
+
## 📢 News
|
| 91 |
+
|
| 92 |
+
- 🔥[28/2/2025] We release the intermediate model, Ola-Image and Ola-Video, try building your own omni-modal models!
|
| 93 |
+
|
| 94 |
+
- 🚀[19/2/2025] We release the huggingface demo of Ola, try the advanced omni-modal model on your own!
|
| 95 |
+
|
| 96 |
+
- 🔥[18/2/2025] The training data, training script for Ola-7b is released!
|
| 97 |
+
|
| 98 |
+
- 🎉[07/2/2025] The Ola is released! Check our [project page](https://ola-omni.github.io), [model weights](https://huggingface.co/THUdyh/Ola-7b), [arXiv paper](https://arxiv.org/pdf/2502.04328) for the strong omni-modal understanding model!
|
| 99 |
+
|
| 100 |
+
- 🔥[06/2/2025] [Ola-7b](https://huggingface.co/THUdyh/Ola-7b) achieves **Rank #1** on the OpenCompass Multi-modal Leaderboard among all the models under 15B parameters with average score of **72.6**. Check the impressive results [here](https://rank.opencompass.org.cn/leaderboard-multimodal/?m=REALTIME)!
|
| 101 |
+
|
| 102 |
+
## 🚀Coming Soon
|
| 103 |
+
|
| 104 |
+
- [x] Evaluation code on omni-modal benchmarks
|
| 105 |
+
- [x] Gradio Demo
|
| 106 |
+
- [x] Training Data (Video, Audio, Cross-Modality)
|
| 107 |
+
|
| 108 |
+
## 🌟 Introduction
|
| 109 |
+
|
| 110 |
+
### Roads to Ola
|
| 111 |
+
|
| 112 |
+
<p align="center" width="100%">
|
| 113 |
+
<img src="https://ola-omni.github.io/static/images/road.png" alt="road.png" width=100%>
|
| 114 |
+
</p>
|
| 115 |
+
<div>
|
| 116 |
+
|
| 117 |
+
**Ola** is an Omni-modal language model that achieves competitive performance across image, video, and audio understanding compared to specialized counterparts. We conduct a comprehensive exploration of architectural design, data curation, and training strategies essential for building a robust omni-modal model.
|
| 118 |
+
|
| 119 |
+
<p align="center" width="100%">
|
| 120 |
+
<img src="https://ola-omni.github.io/static/images/teaser.png" alt="teaser.png" width=100%>
|
| 121 |
+
</p>
|
| 122 |
+
<div>
|
| 123 |
+
|
| 124 |
+
### Architecture
|
| 125 |
+
|
| 126 |
+
<p align="center" width="100%">
|
| 127 |
+
<img src="https://ola-omni.github.io/static/images/method.png" alt="method.png" width=100%>
|
| 128 |
+
</p>
|
| 129 |
+
<div>
|
| 130 |
+
|
| 131 |
+
Ola supports omni-modal inputs including text, image, video, and audio, capable of processing the inputs simultaneously with competitive performance on understanding tasks for all these modalities. Meanwhile, Ola supports user-friendly real-time streaming decoding for texts and speeches thanks to the text detokenizer and the speech decoder.
|
| 132 |
+
|
| 133 |
+
### Training Strategies
|
| 134 |
+
|
| 135 |
+
<p align="center" width="100%">
|
| 136 |
+
<img src="https://ola-omni.github.io/static/images/training.png" alt="training.png" width=100%>
|
| 137 |
+
</p>
|
| 138 |
+
<div>
|
| 139 |
+
|
| 140 |
+
We visualize the relationships among modalities in the left part. Speech acts as the connection between language and audio knowledge, while video constructs the bridge with highly relevant visual and audio information. Therefore, we design the progressive alignment training strategy from primary to periphery. Furthermore, we design the cross-modality video-audio data to better capture the relationships among modalities.
|
| 141 |
+
|
| 142 |
+
### Performance
|
| 143 |
+
|
| 144 |
+
<p align="center" width="100%">
|
| 145 |
+
<img src="https://ola-omni.github.io/static/images/results.png" alt="results.png" width=100%>
|
| 146 |
+
</p>
|
| 147 |
+
<div>
|
| 148 |
+
|
| 149 |
+
Ola achieves competitive performance across major multi-modal benchmarks when compared to state-of-the-art specialist-modal LLMs.
|
| 150 |
+
|
| 151 |
+
## Installation
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
#### 1. Clone this repository:
|
| 155 |
+
```bash
|
| 156 |
+
git clone https://github.com/Ola-Omni/Ola
|
| 157 |
+
cd Ola
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
#### 2. Install the required package:
|
| 161 |
+
```bash
|
| 162 |
+
conda create -n ola python=3.10 -y
|
| 163 |
+
conda activate ola
|
| 164 |
+
pip install --upgrade pip
|
| 165 |
+
pip install -e .
|
| 166 |
+
```
|
| 167 |
+
#### 3.Install additional packages for training cases
|
| 168 |
+
|
| 169 |
+
```bash
|
| 170 |
+
pip install -e ".[train]"
|
| 171 |
+
pip install flash-attn --no-build-isolation
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## Model Zoo
|
| 175 |
+
|
| 176 |
+
We provide our checkpoints at [Huggingface](https://huggingface.co/collections/THUdyh/ola-67b8220eb93406ec87aeec37)
|
| 177 |
+
|
| 178 |
+
| Model | Link | Size | Modal |
|
| 179 |
+
|:---:|:---:|:---:|:---:|
|
| 180 |
+
|Ola-7b | [Huggingface](https://huggingface.co/THUdyh/Ola-7b) | 7B | Text, Image, Video, Audio |
|
| 181 |
+
|Ola-Image | [Huggingface](https://huggingface.co/THUdyh/Ola-Image) | 7B | Text, Image |
|
| 182 |
+
|Ola-Video | [Huggingface](https://huggingface.co/THUdyh/Ola-Video) | 7B | Text, Image, Video |
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
## Quick Start
|
| 186 |
+
|
| 187 |
+
1. Download `Ola-7b` from [Huggingface](https://huggingface.co/THUdyh/Ola-7b) or skip the step to using the online weights directly.
|
| 188 |
+
|
| 189 |
+
2. Download audio encoder from [Huggingface](https://huggingface.co/THUdyh/Ola_speech_encoders/tree/main) and put the weights `large-v3.pt` and `BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt` under repo directory `path/to/Ola/`
|
| 190 |
+
|
| 191 |
+
3. Run `inference/infer.py`
|
| 192 |
+
|
| 193 |
+
- Text & Image Understanding
|
| 194 |
+
|
| 195 |
+
```
|
| 196 |
+
python3 inference/infer.py --image_path *.png,jpg --text user_instruction
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
- Text & Video Understanding
|
| 200 |
+
|
| 201 |
+
```
|
| 202 |
+
python3 inference/infer.py --video_path *.mp4 --text user_instruction
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
- Text & Audio Understanding
|
| 206 |
+
|
| 207 |
+
```
|
| 208 |
+
python3 inference/infer.py --audio_path *.wav,mp3 --text user_instruction
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
- Audio & Image Understanding
|
| 212 |
+
|
| 213 |
+
```
|
| 214 |
+
python3 inference/infer.py --audio_path *.png,jpg --audio_path *.wav,mp3
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
## Evaluation
|
| 218 |
+
|
| 219 |
+
You can evaluate Ola model with [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval).
|
| 220 |
+
|
| 221 |
+
## Training
|
| 222 |
+
|
| 223 |
+
### Data Preparation
|
| 224 |
+
|
| 225 |
+
Please refer to [DATA.md](https://github.com/Ola-Omni/Ola/blob/main/DATA.md) for instructions of customized finetuning or using the provided datasets.
|
| 226 |
+
|
| 227 |
+
### Start Training
|
| 228 |
+
|
| 229 |
+
Please follow the script below to start training. Make sure you have created the correct datasets for fine-tuning.
|
| 230 |
+
|
| 231 |
+
1. Finetuning Ola-7b Model:
|
| 232 |
+
|
| 233 |
+
```
|
| 234 |
+
bash ./scripts/finetune_ola.sh
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
2. Finetuning Ola-Image Model (Ola Stage1 or Stage2)
|
| 238 |
+
|
| 239 |
+
```
|
| 240 |
+
bash ./scripts/finetune_ola_image.sh
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
3. Finetuning Ola-Video Model (Ola Stage3):
|
| 244 |
+
|
| 245 |
+
```
|
| 246 |
+
bash ./scripts/finetune_ola_video.sh
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
## Citation
|
| 250 |
+
|
| 251 |
+
If you find it useful for your research and applications, please cite our paper using this BibTeX:
|
| 252 |
+
```bibtex
|
| 253 |
+
@article{liu2025ola,
|
| 254 |
+
title={Ola: Pushing the Frontiers of Omni-Modal Language Model with Progressive Modality Alignment},
|
| 255 |
+
author={Liu, Zuyan and Dong, Yuhao and Wang, Jiahui and Liu, Ziwei and Hu, Winston and Lu, Jiwen and Rao, Yongming},
|
| 256 |
+
journal={arXiv preprint arXiv:2502.04328},
|
| 257 |
+
year={2025}
|
| 258 |
+
}
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
## Acknowledgement
|
| 262 |
+
|
| 263 |
+
- Our codebase is conducted on [LLaVA](https://github.com/LLaVA-VL/LLaVA-NeXT)
|
| 264 |
+
|
| 265 |
+
- Thanks [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) team for the evaluation system!
|
ola.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
inference/infer.py
|
| 5 |
+
ola/arguments.py
|
| 6 |
+
ola/constants.py
|
| 7 |
+
ola/conversation.py
|
| 8 |
+
ola/mm_utils.py
|
| 9 |
+
ola/utils.py
|
| 10 |
+
ola.egg-info/PKG-INFO
|
| 11 |
+
ola.egg-info/SOURCES.txt
|
| 12 |
+
ola.egg-info/dependency_links.txt
|
| 13 |
+
ola.egg-info/requires.txt
|
| 14 |
+
ola.egg-info/top_level.txt
|
| 15 |
+
ola/datasets/__init__.py
|
| 16 |
+
ola/datasets/preprocess.py
|
| 17 |
+
ola/model/__init__.py
|
| 18 |
+
ola/model/builder.py
|
| 19 |
+
ola/model/ola_arch.py
|
| 20 |
+
ola/model/language_model/ola_qwen.py
|
| 21 |
+
ola/model/multimodal_encoder/builder.py
|
| 22 |
+
ola/model/multimodal_encoder/oryx_vit.py
|
| 23 |
+
ola/model/multimodal_projector/builder.py
|
| 24 |
+
ola/model/multimodal_projector/pooler_projector.py
|
| 25 |
+
ola/model/multimodal_resampler/builder.py
|
| 26 |
+
ola/model/speech_encoder/builder.py
|
| 27 |
+
ola/model/speech_encoder/speech_encoder.py
|
| 28 |
+
ola/model/speech_encoder/beats/BEATs.py
|
| 29 |
+
ola/model/speech_encoder/beats/Tokenizers.py
|
| 30 |
+
ola/model/speech_encoder/beats/__init__.py
|
| 31 |
+
ola/model/speech_encoder/beats/backbone.py
|
| 32 |
+
ola/model/speech_encoder/beats/kaldi.py
|
| 33 |
+
ola/model/speech_encoder/beats/modules.py
|
| 34 |
+
ola/model/speech_encoder/beats/quantizer.py
|
| 35 |
+
ola/model/speech_projector/builder.py
|
| 36 |
+
ola/model/speech_projector/speech_projector.py
|
| 37 |
+
ola/serve/__init__.py
|
| 38 |
+
ola/serve/controller.py
|
| 39 |
+
ola/serve/gradio_web_server.py
|
| 40 |
+
ola/serve/model_worker.py
|
| 41 |
+
ola/train/ola_trainer.py
|
| 42 |
+
ola/train/train.py
|
| 43 |
+
tools/convert_mp4_wav.py
|
| 44 |
+
tools/create_patch.py
|
ola.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
ola.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.1.2
|
| 2 |
+
torchvision==0.16.2
|
| 3 |
+
torchaudio==2.1.2
|
| 4 |
+
transformers==4.43.4
|
| 5 |
+
tokenizers==0.19.1
|
| 6 |
+
sentencepiece==0.1.99
|
| 7 |
+
shortuuid
|
| 8 |
+
accelerate==0.33.0
|
| 9 |
+
peft==0.11.1
|
| 10 |
+
bitsandbytes==0.43.1
|
| 11 |
+
pydantic
|
| 12 |
+
markdown2[all]
|
| 13 |
+
numpy
|
| 14 |
+
scikit-learn==1.2.2
|
| 15 |
+
gradio==4.43.0
|
| 16 |
+
gradio_client==1.3.0
|
| 17 |
+
requests
|
| 18 |
+
httpx==0.27.2
|
| 19 |
+
uvicorn
|
| 20 |
+
fastapi
|
| 21 |
+
soundfile
|
| 22 |
+
einops==0.6.1
|
| 23 |
+
einops-exts==0.0.4
|
| 24 |
+
timm==0.9.16
|
| 25 |
+
openai-whisper
|
| 26 |
+
setuptools==59.5.0
|
| 27 |
+
omegaconf==2.0.6
|
| 28 |
+
loguru
|
| 29 |
+
av
|
| 30 |
+
librosa
|
| 31 |
+
|
| 32 |
+
[build]
|
| 33 |
+
build
|
| 34 |
+
twine
|
| 35 |
+
|
| 36 |
+
[train]
|
| 37 |
+
deepspeed==0.12.6
|
| 38 |
+
ninja
|
| 39 |
+
wandb
|
| 40 |
+
tensorboardX
|
ola.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
inference
|
| 2 |
+
ola
|
| 3 |
+
scripts
|
| 4 |
+
tools
|
ola/__pycache__/arguments.cpython-312.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
ola/__pycache__/constants.cpython-312.pyc
ADDED
|
Binary file (586 Bytes). View file
|
|
|
ola/__pycache__/conversation.cpython-312.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
ola/__pycache__/mm_utils.cpython-312.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
ola/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
ola/arguments.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import transformers
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ModelArguments:
|
| 9 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
| 10 |
+
version: Optional[str] = field(default="v0")
|
| 11 |
+
freeze_backbone: bool = field(default=False)
|
| 12 |
+
tune_speech_projector: bool = field(default=False)
|
| 13 |
+
tune_speech_encoder: bool = field(default=False)
|
| 14 |
+
tune_speech_generator_only: bool = field(default=False)
|
| 15 |
+
speech_encoder_type: Optional[str] = field(default=None)
|
| 16 |
+
speech_encoder: Optional[str] = field(default=None)
|
| 17 |
+
pretrain_speech_projector: Optional[str] = field(default=None)
|
| 18 |
+
speech_projector_type: Optional[str] = field(default='linear')
|
| 19 |
+
speech_encoder_ds_rate: int = 5
|
| 20 |
+
speech_encoder_hidden_size: int = 1280
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class DataArguments:
|
| 25 |
+
data_path: str = field(default=None,
|
| 26 |
+
metadata={"help": "Path to the training data."})
|
| 27 |
+
is_multimodal: bool = False
|
| 28 |
+
input_type: str = field(default="mel")
|
| 29 |
+
speech_normalize: bool = False
|
| 30 |
+
mel_size: int = 128
|
| 31 |
+
has_tgt_units: bool = False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 36 |
+
cache_dir: Optional[str] = field(default=None)
|
| 37 |
+
optim: str = field(default="adamw_torch")
|
| 38 |
+
freeze_speech_projector: bool = field(default=False)
|
| 39 |
+
model_max_length: int = field(
|
| 40 |
+
default=512,
|
| 41 |
+
metadata={
|
| 42 |
+
"help":
|
| 43 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 44 |
+
},
|
| 45 |
+
)
|
| 46 |
+
double_quant: bool = field(
|
| 47 |
+
default=True,
|
| 48 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
| 49 |
+
)
|
| 50 |
+
quant_type: str = field(
|
| 51 |
+
default="nf4",
|
| 52 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
| 53 |
+
)
|
| 54 |
+
bits: int = field(
|
| 55 |
+
default=16,
|
| 56 |
+
metadata={"help": "How many bits to use."}
|
| 57 |
+
)
|
| 58 |
+
lora_enable: bool = False
|
| 59 |
+
lora_r: int = 64
|
| 60 |
+
lora_alpha: int = 16
|
| 61 |
+
lora_dropout: float = 0.05
|
| 62 |
+
lora_weight_path: str = ""
|
| 63 |
+
lora_bias: str = "none"
|
| 64 |
+
speech_projector_lr: Optional[float] = None
|
| 65 |
+
group_by_modality_length: bool = field(default=False)
|
ola/constants.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 3 |
+
|
| 4 |
+
LOGDIR = "."
|
| 5 |
+
|
| 6 |
+
# Model Constants
|
| 7 |
+
IGNORE_INDEX = -100
|
| 8 |
+
SPEECH_TOKEN_INDEX = -200
|
| 9 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
| 10 |
+
IMAGE_TOKEN_INDEX= -300
|
| 11 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 12 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 13 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 14 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
ola/conversation.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from enum import auto, Enum
|
| 3 |
+
from typing import List, Any, Union, Tuple
|
| 4 |
+
import base64
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SeparatorStyle(Enum):
|
| 10 |
+
"""Different separator style."""
|
| 11 |
+
TWO = auto()
|
| 12 |
+
PLAIN = auto()
|
| 13 |
+
CHATML = auto()
|
| 14 |
+
LLAMA_2 = auto()
|
| 15 |
+
LLAMA_3 = auto()
|
| 16 |
+
QWEN2 = auto()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclasses.dataclass
|
| 20 |
+
class Conversation:
|
| 21 |
+
"""A class that keeps all conversation history."""
|
| 22 |
+
system: str
|
| 23 |
+
roles: List[str]
|
| 24 |
+
messages: List[List[str]]
|
| 25 |
+
offset: int
|
| 26 |
+
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
|
| 27 |
+
sep: str = "###"
|
| 28 |
+
sep2: str = None
|
| 29 |
+
version: str = "Unknown"
|
| 30 |
+
|
| 31 |
+
tokenizer_id: str = ""
|
| 32 |
+
tokenizer: Any = None
|
| 33 |
+
# Stop criteria (the default one is EOS token)
|
| 34 |
+
stop_str: Union[str, List[str]] = None
|
| 35 |
+
# Stops generation if meeting any token in this list
|
| 36 |
+
stop_token_ids: List[int] = None
|
| 37 |
+
|
| 38 |
+
skip_next: bool = False
|
| 39 |
+
|
| 40 |
+
def get_prompt(self):
|
| 41 |
+
messages = self.messages
|
| 42 |
+
|
| 43 |
+
if self.sep_style == SeparatorStyle.TWO:
|
| 44 |
+
seps = [self.sep, self.sep2]
|
| 45 |
+
ret = self.system + seps[0]
|
| 46 |
+
for i, (role, message) in enumerate(messages):
|
| 47 |
+
if message:
|
| 48 |
+
if type(message) is tuple:
|
| 49 |
+
message = message[0]
|
| 50 |
+
ret += role + ": " + message + seps[i % 2]
|
| 51 |
+
else:
|
| 52 |
+
ret += role + ":"
|
| 53 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
| 54 |
+
wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
|
| 55 |
+
ret = "<|begin_of_text|>" + wrap_sys(self.system)
|
| 56 |
+
for i, (role, message) in enumerate(messages):
|
| 57 |
+
if message:
|
| 58 |
+
if type(message) is tuple:
|
| 59 |
+
message = message[0]
|
| 60 |
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
| 61 |
+
ret += message.strip() + self.sep2
|
| 62 |
+
else:
|
| 63 |
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
| 64 |
+
return ret
|
| 65 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
| 66 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
| 67 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
| 68 |
+
ret = ""
|
| 69 |
+
|
| 70 |
+
for i, (role, message) in enumerate(messages):
|
| 71 |
+
if i == 0:
|
| 72 |
+
assert message, "first message should not be none"
|
| 73 |
+
assert role == self.roles[0], "first message should come from user"
|
| 74 |
+
if message:
|
| 75 |
+
if type(message) is tuple:
|
| 76 |
+
message, _, _ = message
|
| 77 |
+
if i == 0:
|
| 78 |
+
message = wrap_sys(self.system) + message
|
| 79 |
+
if i % 2 == 0:
|
| 80 |
+
message = wrap_inst(message)
|
| 81 |
+
ret += self.sep + message
|
| 82 |
+
else:
|
| 83 |
+
ret += " " + message + " " + self.sep2
|
| 84 |
+
else:
|
| 85 |
+
ret += ""
|
| 86 |
+
ret = ret.lstrip(self.sep)
|
| 87 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 88 |
+
seps = [self.sep, self.sep2]
|
| 89 |
+
ret = self.system
|
| 90 |
+
for i, (role, message) in enumerate(messages):
|
| 91 |
+
if message:
|
| 92 |
+
if type(message) is tuple:
|
| 93 |
+
message, _, _ = message
|
| 94 |
+
ret += message + seps[i % 2]
|
| 95 |
+
else:
|
| 96 |
+
ret += ""
|
| 97 |
+
|
| 98 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
| 99 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
| 100 |
+
for role, message in messages:
|
| 101 |
+
if message:
|
| 102 |
+
if type(message) is tuple:
|
| 103 |
+
raise ValueError("Tuple not supported in CHATML")
|
| 104 |
+
message, images = message
|
| 105 |
+
message = "<speech>" * len(images) + message
|
| 106 |
+
ret += role + "\n" + message + self.sep + "\n"
|
| 107 |
+
else:
|
| 108 |
+
ret += role + "\n"
|
| 109 |
+
return ret
|
| 110 |
+
elif self.sep_style == SeparatorStyle.QWEN2:
|
| 111 |
+
start = '<|im_start|>'
|
| 112 |
+
end = '<|im_end|>\n'
|
| 113 |
+
ret = start + 'system\n' + self.system + end
|
| 114 |
+
for i, (role, message) in enumerate(messages):
|
| 115 |
+
if message:
|
| 116 |
+
if type(message) is tuple:
|
| 117 |
+
message, _, _ = message
|
| 118 |
+
|
| 119 |
+
if message.endswith('<|endoftext|>'):
|
| 120 |
+
message = message.replace('<|endoftext|>', '')
|
| 121 |
+
ret += start + role + "\n" + message + end + '<|endoftext|>'
|
| 122 |
+
else:
|
| 123 |
+
assert not '<|endoftext|>' in message, f"Invalid message: {message}"
|
| 124 |
+
ret += start + role + "\n" + message + end
|
| 125 |
+
else:
|
| 126 |
+
ret += start + role + "\n"
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 129 |
+
|
| 130 |
+
return ret
|
| 131 |
+
|
| 132 |
+
def append_message(self, role, message):
|
| 133 |
+
self.messages.append([role, message])
|
| 134 |
+
|
| 135 |
+
def to_gradio_chatbot(self):
|
| 136 |
+
ret = []
|
| 137 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
| 138 |
+
if i % 2 == 0:
|
| 139 |
+
if type(msg) is tuple:
|
| 140 |
+
msg, speech = msg
|
| 141 |
+
ret.append([msg, None])
|
| 142 |
+
else:
|
| 143 |
+
ret.append([msg, None])
|
| 144 |
+
else:
|
| 145 |
+
ret[-1][-1] = msg
|
| 146 |
+
return ret
|
| 147 |
+
|
| 148 |
+
def copy(self):
|
| 149 |
+
return Conversation(
|
| 150 |
+
system=self.system,
|
| 151 |
+
roles=self.roles,
|
| 152 |
+
messages=[[x, y] for x, y in self.messages],
|
| 153 |
+
offset=self.offset,
|
| 154 |
+
sep_style=self.sep_style,
|
| 155 |
+
sep=self.sep,
|
| 156 |
+
sep2=self.sep2,
|
| 157 |
+
version=self.version)
|
| 158 |
+
|
| 159 |
+
def dict(self):
|
| 160 |
+
if len(self.get_images()) > 0:
|
| 161 |
+
return {
|
| 162 |
+
"system": self.system,
|
| 163 |
+
"roles": self.roles,
|
| 164 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
| 165 |
+
"offset": self.offset,
|
| 166 |
+
"sep": self.sep,
|
| 167 |
+
"sep2": self.sep2,
|
| 168 |
+
}
|
| 169 |
+
return {
|
| 170 |
+
"system": self.system,
|
| 171 |
+
"roles": self.roles,
|
| 172 |
+
"messages": self.messages,
|
| 173 |
+
"offset": self.offset,
|
| 174 |
+
"sep": self.sep,
|
| 175 |
+
"sep2": self.sep2,
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
conv_vicuna_v1 = Conversation(
|
| 179 |
+
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 180 |
+
roles=("USER", "ASSISTANT"),
|
| 181 |
+
version="v1",
|
| 182 |
+
messages=[],
|
| 183 |
+
offset=0,
|
| 184 |
+
sep_style=SeparatorStyle.TWO,
|
| 185 |
+
sep=" ",
|
| 186 |
+
sep2="</s>",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
conv_llama_2 = Conversation(
|
| 190 |
+
system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
| 191 |
+
roles=("USER", "ASSISTANT"),
|
| 192 |
+
version="llama_v2",
|
| 193 |
+
messages=[],
|
| 194 |
+
offset=0,
|
| 195 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
| 196 |
+
sep="<s>",
|
| 197 |
+
sep2="</s>",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
conv_llama_3 = Conversation(
|
| 201 |
+
system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
| 202 |
+
roles=("user", "assistant"),
|
| 203 |
+
version="llama_v3",
|
| 204 |
+
messages=[],
|
| 205 |
+
offset=0,
|
| 206 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
| 207 |
+
sep="",
|
| 208 |
+
sep2="<|eot_id|>"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
conv_qwen_v1 = Conversation(
|
| 213 |
+
system="You are a helpful assistant.",
|
| 214 |
+
roles=("user", "assistant"),
|
| 215 |
+
version="v1",
|
| 216 |
+
messages=(),
|
| 217 |
+
offset=0,
|
| 218 |
+
sep_style=SeparatorStyle.QWEN2,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
conv_plain = Conversation(
|
| 222 |
+
system="",
|
| 223 |
+
roles=("", ""),
|
| 224 |
+
messages=(
|
| 225 |
+
),
|
| 226 |
+
offset=0,
|
| 227 |
+
sep_style=SeparatorStyle.PLAIN,
|
| 228 |
+
sep="</s>",
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
conv_qwen = Conversation(
|
| 232 |
+
system="""<|im_start|>system
|
| 233 |
+
You are a helpful assistant.""",
|
| 234 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
| 235 |
+
version="qwen",
|
| 236 |
+
messages=[],
|
| 237 |
+
offset=0,
|
| 238 |
+
sep_style=SeparatorStyle.CHATML,
|
| 239 |
+
sep="<|im_end|>",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
conv_plmv = Conversation(
|
| 243 |
+
system="""<|im_start|>system
|
| 244 |
+
You are PLM-V, developed by PLM-Team, a helpful assistant.""",
|
| 245 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
| 246 |
+
version="plm_v",
|
| 247 |
+
messages=[],
|
| 248 |
+
offset=0,
|
| 249 |
+
sep_style=SeparatorStyle.CHATML,
|
| 250 |
+
sep="<|im_end|>",
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
default_conversation = conv_plmv
|
| 254 |
+
conv_templates = {
|
| 255 |
+
"v1": conv_vicuna_v1,
|
| 256 |
+
"plain": conv_plain,
|
| 257 |
+
"llama_2": conv_llama_2,
|
| 258 |
+
"llama_3": conv_llama_3,
|
| 259 |
+
'v1_qwen2': conv_qwen_v1,
|
| 260 |
+
"qwen_1_5": conv_qwen,
|
| 261 |
+
"plm_v": conv_plmv,
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
print(default_conversation.get_prompt())
|
ola/datasets/__init__.py
ADDED
|
File without changes
|
ola/datasets/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (151 Bytes). View file
|
|
|
ola/datasets/__pycache__/preprocess.cpython-312.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
ola/datasets/preprocess.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import torch
|
| 3 |
+
import transformers
|
| 4 |
+
import tokenizers
|
| 5 |
+
|
| 6 |
+
from typing import Dict, Sequence
|
| 7 |
+
|
| 8 |
+
from ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX
|
| 9 |
+
from ola import conversation as conversation_lib
|
| 10 |
+
from ola.model import *
|
| 11 |
+
from ola.arguments import DataArguments
|
| 12 |
+
from ola.constants import SPEECH_TOKEN_INDEX
|
| 13 |
+
|
| 14 |
+
from packaging import version
|
| 15 |
+
|
| 16 |
+
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
|
| 20 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
|
| 21 |
+
|
| 22 |
+
def insert_separator(X, sep):
|
| 23 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 24 |
+
|
| 25 |
+
input_ids = []
|
| 26 |
+
offset = 0
|
| 27 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 28 |
+
offset = 1
|
| 29 |
+
input_ids.append(prompt_chunks[0][0])
|
| 30 |
+
|
| 31 |
+
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
|
| 32 |
+
input_ids.extend(x[offset:])
|
| 33 |
+
|
| 34 |
+
if return_tensors is not None:
|
| 35 |
+
if return_tensors == 'pt':
|
| 36 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 37 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 38 |
+
return input_ids
|
| 39 |
+
|
| 40 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
| 41 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
| 42 |
+
|
| 43 |
+
def insert_separator(X, sep):
|
| 44 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 45 |
+
|
| 46 |
+
input_ids = []
|
| 47 |
+
offset = 0
|
| 48 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 49 |
+
offset = 1
|
| 50 |
+
input_ids.append(prompt_chunks[0][0])
|
| 51 |
+
|
| 52 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 53 |
+
input_ids.extend(x[offset:])
|
| 54 |
+
|
| 55 |
+
if return_tensors is not None:
|
| 56 |
+
if return_tensors == 'pt':
|
| 57 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 58 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 59 |
+
return input_ids
|
| 60 |
+
|
| 61 |
+
def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
|
| 62 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech><image>')]
|
| 63 |
+
|
| 64 |
+
def insert_separator(X, sep):
|
| 65 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 66 |
+
|
| 67 |
+
input_ids = []
|
| 68 |
+
offset = 0
|
| 69 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 70 |
+
offset = 1
|
| 71 |
+
input_ids.append(prompt_chunks[0][0])
|
| 72 |
+
|
| 73 |
+
for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)):
|
| 74 |
+
input_ids.extend(x[offset:])
|
| 75 |
+
|
| 76 |
+
if return_tensors is not None:
|
| 77 |
+
if return_tensors == 'pt':
|
| 78 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 79 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 80 |
+
return input_ids
|
| 81 |
+
|
| 82 |
+
def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None):
|
| 83 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>\nUser's question in speech: <speech>\n")]
|
| 84 |
+
|
| 85 |
+
def insert_separator(X, sep):
|
| 86 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 87 |
+
|
| 88 |
+
input_ids = []
|
| 89 |
+
offset = 0
|
| 90 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 91 |
+
offset = 1
|
| 92 |
+
input_ids.append(prompt_chunks[0][0])
|
| 93 |
+
|
| 94 |
+
nl_tokens = tokenizer("\n").input_ids
|
| 95 |
+
special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens]
|
| 96 |
+
|
| 97 |
+
for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)):
|
| 98 |
+
input_ids.extend(x[offset:])
|
| 99 |
+
|
| 100 |
+
if return_tensors is not None:
|
| 101 |
+
if return_tensors == 'pt':
|
| 102 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 103 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 104 |
+
return input_ids
|
| 105 |
+
|
| 106 |
+
def preprocess_v1(
|
| 107 |
+
sources,
|
| 108 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 109 |
+
has_speech: bool = False
|
| 110 |
+
) -> Dict:
|
| 111 |
+
conv = conversation_lib.default_conversation.copy()
|
| 112 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 113 |
+
|
| 114 |
+
# Apply prompt templates
|
| 115 |
+
conversations = []
|
| 116 |
+
for i, source in enumerate(sources):
|
| 117 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 118 |
+
# Skip the first one if it is not from human
|
| 119 |
+
source = source[1:]
|
| 120 |
+
|
| 121 |
+
conv.messages = []
|
| 122 |
+
for j, sentence in enumerate(source):
|
| 123 |
+
role = roles[sentence["from"]]
|
| 124 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 125 |
+
conv.append_message(role, sentence["value"])
|
| 126 |
+
conversations.append(conv.get_prompt())
|
| 127 |
+
|
| 128 |
+
# Tokenize conversations
|
| 129 |
+
|
| 130 |
+
if has_speech:
|
| 131 |
+
input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
| 132 |
+
else:
|
| 133 |
+
input_ids = tokenizer(
|
| 134 |
+
conversations,
|
| 135 |
+
return_tensors="pt",
|
| 136 |
+
padding="longest",
|
| 137 |
+
max_length=tokenizer.model_max_length,
|
| 138 |
+
truncation=True,
|
| 139 |
+
).input_ids
|
| 140 |
+
|
| 141 |
+
targets = input_ids.clone()
|
| 142 |
+
|
| 143 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
| 144 |
+
|
| 145 |
+
# Mask targets
|
| 146 |
+
sep = conv.sep + conv.roles[1] + ": "
|
| 147 |
+
for conversation, target in zip(conversations, targets):
|
| 148 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 149 |
+
|
| 150 |
+
rounds = conversation.split(conv.sep2)
|
| 151 |
+
cur_len = 1
|
| 152 |
+
target[:cur_len] = IGNORE_INDEX
|
| 153 |
+
for i, rou in enumerate(rounds):
|
| 154 |
+
if rou == "":
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
parts = rou.split(sep)
|
| 158 |
+
if len(parts) != 2:
|
| 159 |
+
break
|
| 160 |
+
parts[0] += sep
|
| 161 |
+
|
| 162 |
+
if has_speech:
|
| 163 |
+
round_len = len(tokenizer_speech_token(rou, tokenizer))
|
| 164 |
+
instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2
|
| 165 |
+
else:
|
| 166 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 167 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 168 |
+
|
| 169 |
+
# FIXME: tokenizer bug
|
| 170 |
+
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
| 171 |
+
round_len -= 1
|
| 172 |
+
instruction_len -= 1
|
| 173 |
+
|
| 174 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
| 175 |
+
|
| 176 |
+
cur_len += round_len
|
| 177 |
+
target[cur_len:] = IGNORE_INDEX
|
| 178 |
+
|
| 179 |
+
if cur_len < tokenizer.model_max_length:
|
| 180 |
+
if cur_len != total_len:
|
| 181 |
+
target[:] = IGNORE_INDEX
|
| 182 |
+
print(
|
| 183 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 184 |
+
f" (ignored)"
|
| 185 |
+
)
|
| 186 |
+
print(f"Debug - Conversation: {conversation[:200]}...")
|
| 187 |
+
print(f"Debug - Target shape: {target.shape}")
|
| 188 |
+
print(f"Debug - All labels are IGNORE_INDEX: {(target == IGNORE_INDEX).all().item()}")
|
| 189 |
+
|
| 190 |
+
return dict(
|
| 191 |
+
input_ids=input_ids,
|
| 192 |
+
labels=targets,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def preprocess_plain(
|
| 197 |
+
sources: Sequence[str],
|
| 198 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 199 |
+
) -> Dict:
|
| 200 |
+
# add end signal and concatenate together
|
| 201 |
+
conversations = []
|
| 202 |
+
for source in sources:
|
| 203 |
+
assert len(source) == 2
|
| 204 |
+
assert DEFAULT_SPEECH_TOKEN in source[0]['value']
|
| 205 |
+
source[0]['value'] = DEFAULT_SPEECH_TOKEN
|
| 206 |
+
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
| 207 |
+
conversations.append(conversation)
|
| 208 |
+
# tokenize conversations
|
| 209 |
+
input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
| 210 |
+
targets = copy.deepcopy(input_ids)
|
| 211 |
+
for target, source in zip(targets, sources):
|
| 212 |
+
tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer))
|
| 213 |
+
target[:tokenized_len] = IGNORE_INDEX
|
| 214 |
+
|
| 215 |
+
return dict(input_ids=input_ids, labels=targets)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def preprocess(
|
| 219 |
+
sources: Sequence[str],
|
| 220 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 221 |
+
has_speech: bool = False
|
| 222 |
+
) -> Dict:
|
| 223 |
+
"""
|
| 224 |
+
Given a list of sources, each is a conversation list. This transform:
|
| 225 |
+
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
| 226 |
+
2. Concatenate conversations together;
|
| 227 |
+
3. Tokenize the concatenated conversation;
|
| 228 |
+
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
| 229 |
+
"""
|
| 230 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
| 231 |
+
return preprocess_plain(sources, tokenizer)
|
| 232 |
+
if conversation_lib.default_conversation.version.startswith("v1"):
|
| 233 |
+
return preprocess_v1(sources, tokenizer, has_speech=has_speech)
|
| 234 |
+
raise NotImplementedError
|
ola/mm_utils.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import base64
|
| 3 |
+
import math
|
| 4 |
+
import ast
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import StoppingCriteria
|
| 8 |
+
import os
|
| 9 |
+
import io
|
| 10 |
+
|
| 11 |
+
if 'VIDEO_RESIZE' in os.environ:
|
| 12 |
+
# highresxpatch
|
| 13 |
+
VIDEO_RESIZE = os.environ['VIDEO_RESIZE']
|
| 14 |
+
video_base, video_ps = VIDEO_RESIZE.split('x')
|
| 15 |
+
video_base = int(video_base)
|
| 16 |
+
video_ps = int(video_ps)
|
| 17 |
+
print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}")
|
| 18 |
+
else:
|
| 19 |
+
HIGHRES_BASE = None
|
| 20 |
+
|
| 21 |
+
if 'HIGHRES_BASE' in os.environ:
|
| 22 |
+
# highresxpatch
|
| 23 |
+
HIGHRES_BASE = os.environ['HIGHRES_BASE']
|
| 24 |
+
highres_base, highres_ps = HIGHRES_BASE.split('x')
|
| 25 |
+
highres_base = int(highres_base)
|
| 26 |
+
highres_ps = int(highres_ps)
|
| 27 |
+
print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}")
|
| 28 |
+
else:
|
| 29 |
+
HIGHRES_BASE = None
|
| 30 |
+
|
| 31 |
+
if 'MAXRES' in os.environ:
|
| 32 |
+
# highresxpatch
|
| 33 |
+
MAXRES = int(os.environ['MAXRES'])
|
| 34 |
+
print(f"MAXRES is set as {MAXRES}")
|
| 35 |
+
else:
|
| 36 |
+
MAXRES = 1536
|
| 37 |
+
|
| 38 |
+
if 'MINRES' in os.environ:
|
| 39 |
+
# highresxpatch
|
| 40 |
+
MINRES = int(os.environ['MINRES'])
|
| 41 |
+
print(f"MINRES is set as {MINRES}")
|
| 42 |
+
else:
|
| 43 |
+
MINRES = 0
|
| 44 |
+
|
| 45 |
+
if 'VIDEO_MAXRES' in os.environ:
|
| 46 |
+
# highresxpatch
|
| 47 |
+
VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES'])
|
| 48 |
+
print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}")
|
| 49 |
+
else:
|
| 50 |
+
VIDEO_MAXRES = 1536
|
| 51 |
+
|
| 52 |
+
if 'VIDEO_MINRES' in os.environ:
|
| 53 |
+
# highresxpatch
|
| 54 |
+
VIDEO_MINRES = int(os.environ['VIDEO_MINRES'])
|
| 55 |
+
print(f"VIDEO_MINRES is set as {VIDEO_MINRES}")
|
| 56 |
+
else:
|
| 57 |
+
MINRES = 0
|
| 58 |
+
|
| 59 |
+
if 'PAD2STRIDE' in os.environ:
|
| 60 |
+
# highresxpatch
|
| 61 |
+
PAD2STRIDE = True
|
| 62 |
+
print(f"PAD2STRIDE is set")
|
| 63 |
+
else:
|
| 64 |
+
PAD2STRIDE = False
|
| 65 |
+
|
| 66 |
+
if 'LOWRES_RESIZE' in os.environ:
|
| 67 |
+
LOWRES_RESIZE = os.environ['LOWRES_RESIZE']
|
| 68 |
+
print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}")
|
| 69 |
+
if 'x' in LOWRES_RESIZE:
|
| 70 |
+
size, ps = LOWRES_RESIZE.split('x')
|
| 71 |
+
size = int(size)
|
| 72 |
+
ps = int(ps)
|
| 73 |
+
LOWRES_RESIZE = (size, ps)
|
| 74 |
+
else:
|
| 75 |
+
LOWRES_RESIZE = int(LOWRES_RESIZE)
|
| 76 |
+
else:
|
| 77 |
+
LOWRES_RESIZE = None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def pad_image(image, target_resolution, value=0):
|
| 81 |
+
"""
|
| 82 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
image (PIL.Image.Image): The input image.
|
| 86 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
PIL.Image.Image: The resized and padded image.
|
| 90 |
+
"""
|
| 91 |
+
original_width, original_height = image.size
|
| 92 |
+
target_width, target_height = target_resolution
|
| 93 |
+
# Create a new image with the target size and paste the resized image onto it
|
| 94 |
+
new_image = Image.new('RGB', (target_width, target_height), (value, value, value))
|
| 95 |
+
paste_x = (target_width - original_width) // 2
|
| 96 |
+
paste_y = (target_height - original_height) // 2
|
| 97 |
+
new_image.paste(image, (paste_x, paste_y))
|
| 98 |
+
return new_image
|
| 99 |
+
|
| 100 |
+
def resize_images(image, patch_size=14, base_size=896):
|
| 101 |
+
h, w = image.size
|
| 102 |
+
if base_size == 0:
|
| 103 |
+
if h * w > MAXRES * MAXRES:
|
| 104 |
+
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
|
| 105 |
+
scale = MAXRES * MAXRES / (h * w)
|
| 106 |
+
scale = math.sqrt(scale)
|
| 107 |
+
elif h * w < MINRES * MINRES:
|
| 108 |
+
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
|
| 109 |
+
scale = MINRES * MINRES / (h * w)
|
| 110 |
+
scale = math.sqrt(scale)
|
| 111 |
+
else:
|
| 112 |
+
scale = None
|
| 113 |
+
else:
|
| 114 |
+
scale = base_size * base_size / (h * w)
|
| 115 |
+
scale = math.sqrt(scale)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if scale is not None:
|
| 119 |
+
new_h = int(h * scale / patch_size) * patch_size
|
| 120 |
+
new_w = int(w * scale / patch_size) * patch_size
|
| 121 |
+
new_h = max(new_h, patch_size)
|
| 122 |
+
new_w = max(new_w, patch_size)
|
| 123 |
+
image = image.resize((new_h, new_w))
|
| 124 |
+
elif PAD2STRIDE:
|
| 125 |
+
if h % patch_size == 0:
|
| 126 |
+
new_h = h
|
| 127 |
+
else:
|
| 128 |
+
new_h = (h // patch_size + 1) * patch_size
|
| 129 |
+
|
| 130 |
+
if w % patch_size == 0:
|
| 131 |
+
new_w = w
|
| 132 |
+
else:
|
| 133 |
+
new_w = (w // patch_size + 1) * patch_size
|
| 134 |
+
image = pad_image(image, (new_h, new_w), value=127)
|
| 135 |
+
else:
|
| 136 |
+
scale = 1.0
|
| 137 |
+
new_h = int(h * scale / patch_size) * patch_size
|
| 138 |
+
new_w = int(w * scale / patch_size) * patch_size
|
| 139 |
+
new_h = max(new_h, patch_size)
|
| 140 |
+
new_w = max(new_w, patch_size)
|
| 141 |
+
image = image.resize((new_h, new_w))
|
| 142 |
+
|
| 143 |
+
return image
|
| 144 |
+
|
| 145 |
+
def resize_video(image, patch_size=14, base_size=896):
|
| 146 |
+
h, w = image.size
|
| 147 |
+
if base_size == 0:
|
| 148 |
+
if h * w > VIDEO_MAXRES * VIDEO_MAXRES:
|
| 149 |
+
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
|
| 150 |
+
scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w)
|
| 151 |
+
scale = math.sqrt(scale)
|
| 152 |
+
elif h * w < VIDEO_MINRES * VIDEO_MINRES:
|
| 153 |
+
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
|
| 154 |
+
scale = VIDEO_MINRES * VIDEO_MINRES / (h * w)
|
| 155 |
+
scale = math.sqrt(scale)
|
| 156 |
+
else:
|
| 157 |
+
scale = None
|
| 158 |
+
else:
|
| 159 |
+
scale = base_size * base_size / (h * w)
|
| 160 |
+
scale = math.sqrt(scale)
|
| 161 |
+
|
| 162 |
+
if scale is not None:
|
| 163 |
+
new_h = int(h * scale / patch_size) * patch_size
|
| 164 |
+
new_w = int(w * scale / patch_size) * patch_size
|
| 165 |
+
image = image.resize((new_h, new_w))
|
| 166 |
+
elif PAD2STRIDE:
|
| 167 |
+
if h % patch_size == 0:
|
| 168 |
+
new_h = h
|
| 169 |
+
else:
|
| 170 |
+
new_h = (h // patch_size + 1) * patch_size
|
| 171 |
+
|
| 172 |
+
if w % patch_size == 0:
|
| 173 |
+
new_w = w
|
| 174 |
+
else:
|
| 175 |
+
new_w = (w // patch_size + 1) * patch_size
|
| 176 |
+
image = pad_image(image, (new_h, new_w), value=127)
|
| 177 |
+
else:
|
| 178 |
+
scale = 1.0
|
| 179 |
+
new_h = int(h * scale / patch_size) * patch_size
|
| 180 |
+
new_w = int(w * scale / patch_size) * patch_size
|
| 181 |
+
image = image.resize((new_h, new_w))
|
| 182 |
+
|
| 183 |
+
return image
|
| 184 |
+
|
| 185 |
+
def process_anyres_video(image, processor):
|
| 186 |
+
if VIDEO_RESIZE is not None:
|
| 187 |
+
image = resize_video(image, patch_size=video_ps, base_size=video_base)
|
| 188 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
| 189 |
+
return image.unsqueeze(0)
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError("VIDEO_RESIZE is not set")
|
| 192 |
+
|
| 193 |
+
def process_anyres_highres_image(image, processor):
|
| 194 |
+
processor2 = None
|
| 195 |
+
if type(processor) is tuple:
|
| 196 |
+
processor, processor2 = processor[0], processor[1]
|
| 197 |
+
|
| 198 |
+
if HIGHRES_BASE is not None:
|
| 199 |
+
image = resize_images(image, patch_size=highres_ps, base_size=highres_base)
|
| 200 |
+
|
| 201 |
+
if processor2 is not None:
|
| 202 |
+
image_original_resize = image.resize((processor2.size['shortest_edge'], processor.size['shortest_edge']))
|
| 203 |
+
image_patches = [image_original_resize] + [image_original_resize]
|
| 204 |
+
image_patches = [processor2.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
| 205 |
+
for image_patch in image_patches]
|
| 206 |
+
else:
|
| 207 |
+
if LOWRES_RESIZE is not None:
|
| 208 |
+
if type(LOWRES_RESIZE) is int:
|
| 209 |
+
image_original_resize = resize_images(image, patch_size=14, base_size=LOWRES_RESIZE)
|
| 210 |
+
else:
|
| 211 |
+
image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0])
|
| 212 |
+
else:
|
| 213 |
+
image_original_resize = image.resize((336, 336))
|
| 214 |
+
image_patches = [image_original_resize]
|
| 215 |
+
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
| 216 |
+
for image_patch in image_patches]
|
| 217 |
+
image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
| 218 |
+
return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0)
|
| 219 |
+
|
| 220 |
+
def read_image_patch(patch_info):
|
| 221 |
+
if 'img_path' in patch_info.keys():
|
| 222 |
+
image = Image.open(patch_info['img_path']).convert('RGB')
|
| 223 |
+
else:
|
| 224 |
+
if 'image_encoing' in patch_info.keys():
|
| 225 |
+
patch_info['image_encoding'] = patch_info['image_encoing']
|
| 226 |
+
image_file_name = patch_info['patch']
|
| 227 |
+
start_bytes = int(patch_info['start_num'])
|
| 228 |
+
file_size = int(patch_info['size'])
|
| 229 |
+
|
| 230 |
+
with open(image_file_name, 'rb') as f:
|
| 231 |
+
f.seek(start_bytes)
|
| 232 |
+
if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
|
| 233 |
+
image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
|
| 234 |
+
else:
|
| 235 |
+
image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
|
| 236 |
+
return image
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def get_model_name_from_path(model_path):
|
| 240 |
+
model_path = model_path.strip("/")
|
| 241 |
+
model_paths = model_path.split("/")
|
| 242 |
+
if model_paths[-1].startswith('checkpoint-'):
|
| 243 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 244 |
+
else:
|
| 245 |
+
return model_paths[-1]
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 249 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 250 |
+
self.keywords = keywords
|
| 251 |
+
self.keyword_ids = []
|
| 252 |
+
for keyword in keywords:
|
| 253 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 254 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 255 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 256 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 257 |
+
self.tokenizer = tokenizer
|
| 258 |
+
self.start_len = input_ids.shape[1]
|
| 259 |
+
|
| 260 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 261 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
| 262 |
+
offset = min(output_ids.shape[1] - self.start_len, 3)
|
| 263 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 264 |
+
for keyword_id in self.keyword_ids:
|
| 265 |
+
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
|
| 266 |
+
return True
|
| 267 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 268 |
+
for keyword in self.keywords:
|
| 269 |
+
if keyword in outputs:
|
| 270 |
+
return True
|
| 271 |
+
return False
|
ola/model/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen
|
| 2 |
+
from .language_model.ola_qwen3 import OlaQwen3ForCausalLM, OlaConfigQwen3
|
ola/model/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (338 Bytes). View file
|
|
|
ola/model/__pycache__/builder.cpython-312.pyc
ADDED
|
Binary file (5.67 kB). View file
|
|
|
ola/model/__pycache__/ola_arch.cpython-312.pyc
ADDED
|
Binary file (36.1 kB). View file
|
|
|
ola/model/builder.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
import shutil
|
| 4 |
+
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, AutoProcessor
|
| 6 |
+
import torch
|
| 7 |
+
from ola.model import *
|
| 8 |
+
from ola.model.speech_encoder.builder import build_speech_encoder
|
| 9 |
+
|
| 10 |
+
# 过滤掉 PyTorch 的 meta parameter 警告
|
| 11 |
+
warnings.filterwarnings("ignore", message=".*copying from a non-meta parameter in the checkpoint to a meta parameter.*")
|
| 12 |
+
|
| 13 |
+
def load_pretrained_model(model_path, model_type, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs):
|
| 14 |
+
device = "cuda"
|
| 15 |
+
if load_8bit:
|
| 16 |
+
kwargs['load_in_8bit'] = True
|
| 17 |
+
elif load_4bit:
|
| 18 |
+
kwargs['load_in_4bit'] = True
|
| 19 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
| 20 |
+
load_in_4bit=True,
|
| 21 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 22 |
+
bnb_4bit_use_double_quant=True,
|
| 23 |
+
bnb_4bit_quant_type='nf4'
|
| 24 |
+
)
|
| 25 |
+
else:
|
| 26 |
+
kwargs['torch_dtype'] = torch.bfloat16
|
| 27 |
+
|
| 28 |
+
if use_flash_attn:
|
| 29 |
+
kwargs['attn_implementation'] = 'flash_attention_2'
|
| 30 |
+
|
| 31 |
+
if model_type == 'ola_internvl':
|
| 32 |
+
model_cls = OlaQwen3ForCausalLM
|
| 33 |
+
print('Loading OlaQwen3ForCausalLM model...')
|
| 34 |
+
else:
|
| 35 |
+
model_cls = OlaQwenForCausalLM
|
| 36 |
+
|
| 37 |
+
# Load Ola model
|
| 38 |
+
if is_lora:
|
| 39 |
+
assert model_base is not None, "model_base is required for LoRA models."
|
| 40 |
+
from ola.model.language_model.ola_qwen import OlaConfigQwen
|
| 41 |
+
lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path)
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
| 43 |
+
print('Loading Ola from base model...')
|
| 44 |
+
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
|
| 45 |
+
print('Loading additional Ola weights...')
|
| 46 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
| 47 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
| 48 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
| 49 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
| 50 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
| 51 |
+
model.load_state_dict(non_lora_trainables, strict=False, assign=True)
|
| 52 |
+
|
| 53 |
+
from peft import PeftModel
|
| 54 |
+
print('Loading LoRA weights...')
|
| 55 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 56 |
+
print('Merging LoRA weights...')
|
| 57 |
+
model = model.merge_and_unload()
|
| 58 |
+
print('Model is loaded...')
|
| 59 |
+
elif model_base is not None:
|
| 60 |
+
print('Loading Ola from base model...')
|
| 61 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
| 62 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
| 63 |
+
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
|
| 64 |
+
|
| 65 |
+
speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu')
|
| 66 |
+
speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
|
| 67 |
+
model.load_state_dict(speech_projector_weights, strict=False, assign=True)
|
| 68 |
+
model = model.to(device=device)
|
| 69 |
+
else:
|
| 70 |
+
# model_path = "/data1/cxy/plm-v/modeling/plm_internvl3_5_ola"
|
| 71 |
+
model_path = "/data1/cxy/plm-v/modeling/ckpt/ola_audio_8_8gpu/checkpoint-120"
|
| 72 |
+
tokernizer_path = "/data1/cxy/plm-v/modeling/internvl3_5-2B"
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(tokernizer_path, use_fast=False, trust_remote_code=True)
|
| 74 |
+
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 75 |
+
with torch.device("cuda"):
|
| 76 |
+
model = model_cls.from_pretrained(
|
| 77 |
+
model_path,
|
| 78 |
+
trust_remote_code=True,
|
| 79 |
+
config=cfg,
|
| 80 |
+
# device_map="auto",
|
| 81 |
+
**kwargs,
|
| 82 |
+
)
|
| 83 |
+
model = model.to(device=device)
|
| 84 |
+
# breakpoint()
|
| 85 |
+
image_processor = None
|
| 86 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 87 |
+
# breakpoint()
|
| 88 |
+
print("Loading vision tower...")
|
| 89 |
+
print("Loading vision tower succeeded.")
|
| 90 |
+
|
| 91 |
+
if hasattr(model.config, "max_sequence_length"):
|
| 92 |
+
context_len = model.config.max_sequence_length
|
| 93 |
+
else:
|
| 94 |
+
context_len = 16384
|
| 95 |
+
image_processor = AutoProcessor.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B-HF")
|
| 96 |
+
|
| 97 |
+
return tokenizer, model, image_processor, context_len
|
ola/model/builder_back.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
import shutil
|
| 4 |
+
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, AutoProcessor
|
| 6 |
+
import torch
|
| 7 |
+
from ola.model import *
|
| 8 |
+
from ola.model.speech_encoder.builder import build_speech_encoder
|
| 9 |
+
|
| 10 |
+
def load_pretrained_model(model_path, model_type, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs):
|
| 11 |
+
device = "cuda"
|
| 12 |
+
if load_8bit:
|
| 13 |
+
kwargs['load_in_8bit'] = True
|
| 14 |
+
elif load_4bit:
|
| 15 |
+
kwargs['load_in_4bit'] = True
|
| 16 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
| 17 |
+
load_in_4bit=True,
|
| 18 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 19 |
+
bnb_4bit_use_double_quant=True,
|
| 20 |
+
bnb_4bit_quant_type='nf4'
|
| 21 |
+
)
|
| 22 |
+
else:
|
| 23 |
+
kwargs['torch_dtype'] = torch.bfloat16
|
| 24 |
+
|
| 25 |
+
if use_flash_attn:
|
| 26 |
+
kwargs['attn_implementation'] = 'flash_attention_2'
|
| 27 |
+
|
| 28 |
+
if model_type == 'ola_internvl':
|
| 29 |
+
model_cls = OlaQwen3ForCausalLM
|
| 30 |
+
print('Loading OlaQwen3ForCausalLM model...')
|
| 31 |
+
else:
|
| 32 |
+
model_cls = OlaQwenForCausalLM
|
| 33 |
+
|
| 34 |
+
# Load Ola model
|
| 35 |
+
if is_lora:
|
| 36 |
+
assert model_base is not None, "model_base is required for LoRA models."
|
| 37 |
+
from ola.model.language_model.ola_qwen import OlaConfigQwen
|
| 38 |
+
lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path)
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
| 40 |
+
print('Loading Ola from base model...')
|
| 41 |
+
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
|
| 42 |
+
print('Loading additional Ola weights...')
|
| 43 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
| 44 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
| 45 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
| 46 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
| 47 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
| 48 |
+
model.load_state_dict(non_lora_trainables, strict=False, assign=True)
|
| 49 |
+
|
| 50 |
+
from peft import PeftModel
|
| 51 |
+
print('Loading LoRA weights...')
|
| 52 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 53 |
+
print('Merging LoRA weights...')
|
| 54 |
+
model = model.merge_and_unload()
|
| 55 |
+
print('Model is loaded...')
|
| 56 |
+
elif model_base is not None:
|
| 57 |
+
print('Loading Ola from base model...')
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
| 59 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
| 60 |
+
model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
|
| 61 |
+
|
| 62 |
+
speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu')
|
| 63 |
+
speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
|
| 64 |
+
model.load_state_dict(speech_projector_weights, strict=False, assign=True)
|
| 65 |
+
model = model.to(device=device)
|
| 66 |
+
elif model_type == 'ola_internvl':
|
| 67 |
+
cfg = AutoConfig.from_pretrained("/data1/cxy/plm-v/modeling/old_ola", trust_remote_code=True)
|
| 68 |
+
# breakpoint()
|
| 69 |
+
tokenizer = AutoTokenizer.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B", use_fast=False)
|
| 70 |
+
with torch.device("cpu"):
|
| 71 |
+
# model = model_cls.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B", low_cpu_mem_usage=False, attn_implementation="eager", config=cfg, **kwargs)
|
| 72 |
+
# model = model_cls.from_config(config=cfg)
|
| 73 |
+
model = model_cls(cfg)
|
| 74 |
+
# breakpoint()
|
| 75 |
+
# model.model.layers[1].self_attn.q_proj.weight
|
| 76 |
+
else:
|
| 77 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 78 |
+
with torch.device("cpu"):
|
| 79 |
+
model = model_cls.from_pretrained(
|
| 80 |
+
model_path,
|
| 81 |
+
**kwargs,
|
| 82 |
+
)
|
| 83 |
+
model = model.to(device=device)
|
| 84 |
+
# model.resize_token_embeddings(len(tokenizer))
|
| 85 |
+
from safetensors.torch import load_file
|
| 86 |
+
partial_state_dict = load_file(f"/data1/cxy/plm-v/modeling/internvl3_5-2B/model.safetensors") # 替换为你的部分权重路径
|
| 87 |
+
mapping = {
|
| 88 |
+
"mlp1.0.weight": "model.mm_projector.layer_norm.weight",
|
| 89 |
+
"mlp1.0.bias": "model.mm_projector.layer_norm.bias",
|
| 90 |
+
"mlp1.1.weight": "model.mm_projector.linear_1.weight",
|
| 91 |
+
"mlp1.1.bias": "model.mm_projector.linear_1.bias",
|
| 92 |
+
"mlp1.3.weight": "model.mm_projector.linear_2.weight",
|
| 93 |
+
"mlp1.3.bias": "model.mm_projector.linear_2.bias",
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# 遍历 state_dict 并重命名
|
| 97 |
+
def remap_keys(state_dict, mapping):
|
| 98 |
+
new_state_dict = {}
|
| 99 |
+
for k, v in state_dict.items():
|
| 100 |
+
if k in mapping:
|
| 101 |
+
new_state_dict[mapping[k]] = v
|
| 102 |
+
else:
|
| 103 |
+
new_state_dict[k] = v
|
| 104 |
+
return new_state_dict
|
| 105 |
+
# merged_state_dict = {**partial_state_dict, **partial_state_dict2}
|
| 106 |
+
# 2. 重命名 key:multi_modal_projector -> mm_projector
|
| 107 |
+
# breakpoint()
|
| 108 |
+
rename_dict = {}
|
| 109 |
+
for k in list(partial_state_dict.keys()):
|
| 110 |
+
if k.startswith("language_model"):
|
| 111 |
+
new_k = k.replace("language_model.", "", 1)
|
| 112 |
+
rename_dict[k] = new_k
|
| 113 |
+
if k.startswith("vision_model"):
|
| 114 |
+
new_k = k.replace("vision_model", "model.vision_tower", 1)
|
| 115 |
+
rename_dict[k] = new_k
|
| 116 |
+
|
| 117 |
+
# 应用重命名
|
| 118 |
+
for old_k, new_k in rename_dict.items():
|
| 119 |
+
partial_state_dict[new_k] = partial_state_dict.pop(old_k)
|
| 120 |
+
partial_state_dict = remap_keys(partial_state_dict, mapping)
|
| 121 |
+
|
| 122 |
+
whisper_state_dict = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location='cpu')
|
| 123 |
+
# breakpoint()
|
| 124 |
+
whisper_state_dict = whisper_state_dict["model_state_dict"]
|
| 125 |
+
|
| 126 |
+
# Filter to keep only encoder weights
|
| 127 |
+
whisper_encoder_dict = {}
|
| 128 |
+
for key, value in whisper_state_dict.items():
|
| 129 |
+
if key.startswith('encoder.'):
|
| 130 |
+
whisper_encoder_dict[key] = value
|
| 131 |
+
|
| 132 |
+
print(f"Original Whisper keys: {len(whisper_state_dict)}")
|
| 133 |
+
print(f"Filtered encoder keys: {len(whisper_encoder_dict)}")
|
| 134 |
+
print("Sample encoder keys:")
|
| 135 |
+
for i, key in enumerate(list(whisper_encoder_dict.keys())[:5]):
|
| 136 |
+
print(f" {key}")
|
| 137 |
+
|
| 138 |
+
# Create mapping for Whisper parameters to OLA format
|
| 139 |
+
def create_whisper_mapping():
|
| 140 |
+
mapping = {}
|
| 141 |
+
|
| 142 |
+
# Base encoder components
|
| 143 |
+
base_mappings = {
|
| 144 |
+
'encoder.positional_embedding': 'model.speech_encoder.whisper_model.positional_embedding',
|
| 145 |
+
'encoder.conv1.weight': 'model.speech_encoder.whisper_model.conv1.weight',
|
| 146 |
+
'encoder.conv1.bias': 'model.speech_encoder.whisper_model.conv1.bias',
|
| 147 |
+
'encoder.conv2.weight': 'model.speech_encoder.whisper_model.conv2.weight',
|
| 148 |
+
'encoder.conv2.bias': 'model.speech_encoder.whisper_model.conv2.bias',
|
| 149 |
+
'encoder.ln_post.weight': 'model.speech_encoder.whisper_model.ln_post.weight',
|
| 150 |
+
'encoder.ln_post.bias': 'model.speech_encoder.whisper_model.ln_post.bias',
|
| 151 |
+
}
|
| 152 |
+
mapping.update(base_mappings)
|
| 153 |
+
|
| 154 |
+
# Encoder blocks (32 blocks: 0-31)
|
| 155 |
+
for block_idx in range(32):
|
| 156 |
+
# Attention components
|
| 157 |
+
attn_components = [
|
| 158 |
+
'attn.query.weight', 'attn.query.bias',
|
| 159 |
+
'attn.key.weight', 'attn.key.bias',
|
| 160 |
+
'attn.value.weight', 'attn.value.bias',
|
| 161 |
+
'attn.out.weight', 'attn.out.bias',
|
| 162 |
+
'attn_ln.weight', 'attn_ln.bias'
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
for component in attn_components:
|
| 166 |
+
source_key = f'encoder.blocks.{block_idx}.{component}'
|
| 167 |
+
target_key = f'model.speech_encoder.whisper_model.blocks.{block_idx}.{component}'
|
| 168 |
+
mapping[source_key] = target_key
|
| 169 |
+
|
| 170 |
+
# MLP components
|
| 171 |
+
mlp_components = [
|
| 172 |
+
'mlp.0.weight', 'mlp.0.bias',
|
| 173 |
+
'mlp.2.weight', 'mlp.2.bias',
|
| 174 |
+
'mlp_ln.weight', 'mlp_ln.bias'
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
for component in mlp_components:
|
| 178 |
+
source_key = f'encoder.blocks.{block_idx}.{component}'
|
| 179 |
+
target_key = f'model.speech_encoder.whisper_model.blocks.{block_idx}.{component}'
|
| 180 |
+
mapping[source_key] = target_key
|
| 181 |
+
|
| 182 |
+
return mapping
|
| 183 |
+
|
| 184 |
+
# Apply mapping to whisper_encoder_dict
|
| 185 |
+
whisper_mapping = create_whisper_mapping()
|
| 186 |
+
mapped_whisper_dict = {}
|
| 187 |
+
unmapped_whisper_keys = []
|
| 188 |
+
|
| 189 |
+
for key, value in whisper_encoder_dict.items():
|
| 190 |
+
if key in whisper_mapping:
|
| 191 |
+
mapped_key = whisper_mapping[key]
|
| 192 |
+
mapped_whisper_dict[mapped_key] = value
|
| 193 |
+
else:
|
| 194 |
+
unmapped_whisper_keys.append(key)
|
| 195 |
+
print(f"Warning: No mapping found for Whisper encoder key '{key}'")
|
| 196 |
+
|
| 197 |
+
if unmapped_whisper_keys:
|
| 198 |
+
print(f"Total unmapped Whisper encoder keys: {len(unmapped_whisper_keys)}")
|
| 199 |
+
print("First 10 unmapped Whisper encoder keys:")
|
| 200 |
+
for key in unmapped_whisper_keys[:10]:
|
| 201 |
+
print(f" {key}")
|
| 202 |
+
|
| 203 |
+
print(f"Successfully mapped {len(mapped_whisper_dict)} encoder parameters")
|
| 204 |
+
|
| 205 |
+
beat_state_dict = torch.load("/data1/cxy/model/THUdyh/Ola-7b//BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt", map_location='cpu')
|
| 206 |
+
beat_state_dict = beat_state_dict['model']
|
| 207 |
+
beat_state_dict = {"model.speech_encoder.beats_model."+k: v for k, v in beat_state_dict.items()}
|
| 208 |
+
|
| 209 |
+
# 处理 BEATs 模型中的参数化权重映射 (先pop后添加)
|
| 210 |
+
keys_to_process = list(beat_state_dict.keys())
|
| 211 |
+
breakpoint()
|
| 212 |
+
processed_count = 0
|
| 213 |
+
|
| 214 |
+
# for key in keys_to_process:
|
| 215 |
+
# if 'weight_g' in key:
|
| 216 |
+
# # pop 原始权重并添加为 weight_g
|
| 217 |
+
# weight_tensor = beat_state_dict.pop(key)
|
| 218 |
+
# new_key = key.replace('weight_g','parametrizations.weight.original0')
|
| 219 |
+
# beat_state_dict[new_key] = weight_tensor
|
| 220 |
+
# processed_count += 1
|
| 221 |
+
# elif 'weight_v' in key:
|
| 222 |
+
# # pop 原始权重并添加为 weight_v
|
| 223 |
+
# weight_tensor = beat_state_dict.pop(key)
|
| 224 |
+
# new_key = key.replace('weight_v', 'parametrizations.weight.original1')
|
| 225 |
+
# beat_state_dict[new_key] = weight_tensor
|
| 226 |
+
# processed_count += 1
|
| 227 |
+
|
| 228 |
+
print(f"Processed {processed_count} parametrized weight keys in BEATs model (pop and add)")
|
| 229 |
+
breakpoint()
|
| 230 |
+
# breakpoint()
|
| 231 |
+
partial_state_dict = {**partial_state_dict, **mapped_whisper_dict, **beat_state_dict}
|
| 232 |
+
|
| 233 |
+
# Ensure all tensors in the state dict are on CPU and have proper device information
|
| 234 |
+
print("Moving all state dict tensors to CPU...")
|
| 235 |
+
for key, tensor in partial_state_dict.items():
|
| 236 |
+
if torch.is_tensor(tensor):
|
| 237 |
+
# Ensure tensor has device information and move to CPU
|
| 238 |
+
if not tensor.device.type:
|
| 239 |
+
print(f"Warning: Tensor {key} has no device, creating on CPU")
|
| 240 |
+
partial_state_dict[key] = torch.tensor(tensor.detach().numpy()).cpu()
|
| 241 |
+
else:
|
| 242 |
+
partial_state_dict[key] = tensor.cpu()
|
| 243 |
+
|
| 244 |
+
# Ensure model is on CPU before loading state dict to avoid device mismatches
|
| 245 |
+
print("Moving model to CPU before loading state dict...")
|
| 246 |
+
model = model.cpu()
|
| 247 |
+
|
| 248 |
+
print("Loading state dict...")
|
| 249 |
+
breakpoint()
|
| 250 |
+
missing, unexpected = model.load_state_dict(partial_state_dict, strict=False, assign=True)
|
| 251 |
+
|
| 252 |
+
print("Missing keys:", missing)
|
| 253 |
+
print("Unexpected keys:", unexpected)
|
| 254 |
+
|
| 255 |
+
# Convert model to bfloat16 before saving
|
| 256 |
+
print("Converting model to bfloat16...")
|
| 257 |
+
model = model.to(torch.bfloat16)
|
| 258 |
+
model = model.to("cpu")
|
| 259 |
+
|
| 260 |
+
# Save model in bfloat16 format
|
| 261 |
+
print("Saving model in bfloat16 format...")
|
| 262 |
+
model.save_pretrained("/data1/cxy/plm-v/modeling/plm_internvl3_ola", safe_serialization=False, torch_dtype=torch.bfloat16)
|
| 263 |
+
print("Model saved successfully in bfloat16 format!")
|
| 264 |
+
breakpoint()
|
| 265 |
+
# model.model.mm_projector.linear_1.weight:-0.0106 multi_modal_projector.linear_1.weight model.mm_projector.linear_2.bias
|
| 266 |
+
# model.vision_tower.encoder.layers.7.attn.proj.bias
|
| 267 |
+
# model.model.vision_tower.encoder.layers[0].attn.qkv.weight: -6.5613e-03 dui
|
| 268 |
+
#
|
| 269 |
+
# breakpoint()
|
| 270 |
+
# model.get_model().speech_encoder.load_model("")
|
| 271 |
+
# language_model.model.layers.9.mlp.up_proj.weight vision_model.encoder.layers
|
| 272 |
+
# model.layers.14.self_attn.q_proj.weight model.vision_tower.encoder.layers.23.attn.proj.bias
|
| 273 |
+
# model.get_model().speech_encoder = build_speech_encoder(model.config)
|
| 274 |
+
# model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
|
| 275 |
+
image_processor = None
|
| 276 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 277 |
+
vision_tower = model.get_vision_tower()
|
| 278 |
+
print("Loading vision tower...")
|
| 279 |
+
# if not vision_tower.is_loaded:
|
| 280 |
+
# vision_tower.load_model(device_map=device)
|
| 281 |
+
# if device != "auto":
|
| 282 |
+
# vision_tower.to(device="cuda", dtype=torch.bfloat16)
|
| 283 |
+
# else:
|
| 284 |
+
# vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
|
| 285 |
+
# image_processor = vision_tower.image_processor
|
| 286 |
+
print("Loading vision tower succeeded.")
|
| 287 |
+
|
| 288 |
+
if hasattr(model.config, "max_sequence_length"):
|
| 289 |
+
context_len = model.config.max_sequence_length
|
| 290 |
+
else:
|
| 291 |
+
context_len = 16384
|
| 292 |
+
image_processor = AutoProcessor.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B-HF")
|
| 293 |
+
# breakpoint()
|
| 294 |
+
return tokenizer, model, image_processor, context_len
|
ola/model/language_model/__pycache__/conversation.cpython-312.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
ola/model/language_model/__pycache__/ola_qwen.cpython-312.pyc
ADDED
|
Binary file (8.59 kB). View file
|
|
|
ola/model/language_model/__pycache__/ola_qwen3.cpython-312.pyc
ADDED
|
Binary file (24.7 kB). View file
|
|
|
ola/model/language_model/conversation.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Conversation prompt templates.
|
| 3 |
+
|
| 4 |
+
We kindly request that you import fastchat instead of copying this file if you wish to use it.
|
| 5 |
+
If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
|
| 6 |
+
|
| 7 |
+
Modified from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import dataclasses
|
| 11 |
+
from enum import IntEnum, auto
|
| 12 |
+
from typing import Dict, List, Tuple, Union
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SeparatorStyle(IntEnum):
|
| 16 |
+
"""Separator styles."""
|
| 17 |
+
|
| 18 |
+
ADD_COLON_SINGLE = auto()
|
| 19 |
+
ADD_COLON_TWO = auto()
|
| 20 |
+
ADD_COLON_SPACE_SINGLE = auto()
|
| 21 |
+
NO_COLON_SINGLE = auto()
|
| 22 |
+
NO_COLON_TWO = auto()
|
| 23 |
+
ADD_NEW_LINE_SINGLE = auto()
|
| 24 |
+
LLAMA2 = auto()
|
| 25 |
+
CHATGLM = auto()
|
| 26 |
+
CHATML = auto()
|
| 27 |
+
CHATINTERN = auto()
|
| 28 |
+
DOLLY = auto()
|
| 29 |
+
RWKV = auto()
|
| 30 |
+
PHOENIX = auto()
|
| 31 |
+
ROBIN = auto()
|
| 32 |
+
FALCON_CHAT = auto()
|
| 33 |
+
CHATGLM3 = auto()
|
| 34 |
+
INTERNVL_ZH = auto()
|
| 35 |
+
MPT = auto()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclasses.dataclass
|
| 39 |
+
class Conversation:
|
| 40 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
| 41 |
+
|
| 42 |
+
# The name of this template
|
| 43 |
+
name: str
|
| 44 |
+
# The template of the system prompt
|
| 45 |
+
system_template: str = '{system_message}'
|
| 46 |
+
# The system message
|
| 47 |
+
system_message: str = ''
|
| 48 |
+
# The names of two roles
|
| 49 |
+
roles: Tuple[str] = ('USER', 'ASSISTANT')
|
| 50 |
+
# All messages. Each item is (role, message).
|
| 51 |
+
messages: List[List[str]] = ()
|
| 52 |
+
# The number of few shot examples
|
| 53 |
+
offset: int = 0
|
| 54 |
+
# The separator style and configurations
|
| 55 |
+
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
| 56 |
+
sep: str = '\n'
|
| 57 |
+
sep2: str = None
|
| 58 |
+
# Stop criteria (the default one is EOS token)
|
| 59 |
+
stop_str: Union[str, List[str]] = None
|
| 60 |
+
# Stops generation if meeting any token in this list
|
| 61 |
+
stop_token_ids: List[int] = None
|
| 62 |
+
|
| 63 |
+
def get_prompt(self) -> str:
|
| 64 |
+
"""Get the prompt for generation."""
|
| 65 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
| 66 |
+
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
| 67 |
+
ret = system_prompt + self.sep
|
| 68 |
+
for role, message in self.messages:
|
| 69 |
+
if message:
|
| 70 |
+
ret += role + ': ' + message + self.sep
|
| 71 |
+
else:
|
| 72 |
+
ret += role + ':'
|
| 73 |
+
return ret
|
| 74 |
+
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
| 75 |
+
seps = [self.sep, self.sep2]
|
| 76 |
+
ret = system_prompt + seps[0]
|
| 77 |
+
for i, (role, message) in enumerate(self.messages):
|
| 78 |
+
if message:
|
| 79 |
+
ret += role + ': ' + message + seps[i % 2]
|
| 80 |
+
else:
|
| 81 |
+
ret += role + ':'
|
| 82 |
+
return ret
|
| 83 |
+
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
| 84 |
+
ret = system_prompt + self.sep
|
| 85 |
+
for role, message in self.messages:
|
| 86 |
+
if message:
|
| 87 |
+
ret += role + ': ' + message + self.sep
|
| 88 |
+
else:
|
| 89 |
+
ret += role + ': ' # must be end with a space
|
| 90 |
+
return ret
|
| 91 |
+
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
| 92 |
+
ret = '' if system_prompt == '' else system_prompt + self.sep
|
| 93 |
+
for role, message in self.messages:
|
| 94 |
+
if message:
|
| 95 |
+
ret += role + '\n' + message + self.sep
|
| 96 |
+
else:
|
| 97 |
+
ret += role + '\n'
|
| 98 |
+
return ret
|
| 99 |
+
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
| 100 |
+
ret = system_prompt
|
| 101 |
+
for role, message in self.messages:
|
| 102 |
+
if message:
|
| 103 |
+
ret += role + message + self.sep
|
| 104 |
+
else:
|
| 105 |
+
ret += role
|
| 106 |
+
return ret
|
| 107 |
+
elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
| 108 |
+
seps = [self.sep, self.sep2]
|
| 109 |
+
ret = system_prompt
|
| 110 |
+
for i, (role, message) in enumerate(self.messages):
|
| 111 |
+
if message:
|
| 112 |
+
ret += role + message + seps[i % 2]
|
| 113 |
+
else:
|
| 114 |
+
ret += role
|
| 115 |
+
return ret
|
| 116 |
+
elif self.sep_style == SeparatorStyle.RWKV:
|
| 117 |
+
ret = system_prompt
|
| 118 |
+
for i, (role, message) in enumerate(self.messages):
|
| 119 |
+
if message:
|
| 120 |
+
ret += (
|
| 121 |
+
role
|
| 122 |
+
+ ': '
|
| 123 |
+
+ message.replace('\r\n', '\n').replace('\n\n', '\n')
|
| 124 |
+
)
|
| 125 |
+
ret += '\n\n'
|
| 126 |
+
else:
|
| 127 |
+
ret += role + ':'
|
| 128 |
+
return ret
|
| 129 |
+
elif self.sep_style == SeparatorStyle.LLAMA2:
|
| 130 |
+
seps = [self.sep, self.sep2]
|
| 131 |
+
if self.system_message:
|
| 132 |
+
ret = system_prompt
|
| 133 |
+
else:
|
| 134 |
+
ret = '[INST] '
|
| 135 |
+
for i, (role, message) in enumerate(self.messages):
|
| 136 |
+
tag = self.roles[i % 2]
|
| 137 |
+
if message:
|
| 138 |
+
if i == 0:
|
| 139 |
+
ret += message + ' '
|
| 140 |
+
else:
|
| 141 |
+
ret += tag + ' ' + message + seps[i % 2]
|
| 142 |
+
else:
|
| 143 |
+
ret += tag
|
| 144 |
+
return ret
|
| 145 |
+
elif self.sep_style == SeparatorStyle.CHATGLM:
|
| 146 |
+
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
| 147 |
+
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
| 148 |
+
round_add_n = 1 if self.name == 'chatglm2' else 0
|
| 149 |
+
if system_prompt:
|
| 150 |
+
ret = system_prompt + self.sep
|
| 151 |
+
else:
|
| 152 |
+
ret = ''
|
| 153 |
+
|
| 154 |
+
for i, (role, message) in enumerate(self.messages):
|
| 155 |
+
if i % 2 == 0:
|
| 156 |
+
ret += f'[Round {i//2 + round_add_n}]{self.sep}'
|
| 157 |
+
|
| 158 |
+
if message:
|
| 159 |
+
ret += f'{role}:{message}{self.sep}'
|
| 160 |
+
else:
|
| 161 |
+
ret += f'{role}:'
|
| 162 |
+
return ret
|
| 163 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
| 164 |
+
ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
|
| 165 |
+
for role, message in self.messages:
|
| 166 |
+
if message:
|
| 167 |
+
ret += role + '\n' + message + self.sep + '\n'
|
| 168 |
+
else:
|
| 169 |
+
ret += role + '\n'
|
| 170 |
+
return ret
|
| 171 |
+
elif self.sep_style == SeparatorStyle.CHATGLM3:
|
| 172 |
+
ret = ''
|
| 173 |
+
if self.system_message:
|
| 174 |
+
ret += system_prompt
|
| 175 |
+
for role, message in self.messages:
|
| 176 |
+
if message:
|
| 177 |
+
ret += role + '\n' + ' ' + message
|
| 178 |
+
else:
|
| 179 |
+
ret += role
|
| 180 |
+
return ret
|
| 181 |
+
elif self.sep_style == SeparatorStyle.CHATINTERN:
|
| 182 |
+
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
| 183 |
+
seps = [self.sep, self.sep2]
|
| 184 |
+
ret = system_prompt
|
| 185 |
+
for i, (role, message) in enumerate(self.messages):
|
| 186 |
+
# if i % 2 == 0:
|
| 187 |
+
# ret += "<s>"
|
| 188 |
+
if message:
|
| 189 |
+
ret += role + ':' + message + seps[i % 2] + '\n'
|
| 190 |
+
else:
|
| 191 |
+
ret += role + ':'
|
| 192 |
+
return ret
|
| 193 |
+
elif self.sep_style == SeparatorStyle.DOLLY:
|
| 194 |
+
seps = [self.sep, self.sep2]
|
| 195 |
+
ret = system_prompt
|
| 196 |
+
for i, (role, message) in enumerate(self.messages):
|
| 197 |
+
if message:
|
| 198 |
+
ret += role + ':\n' + message + seps[i % 2]
|
| 199 |
+
if i % 2 == 1:
|
| 200 |
+
ret += '\n\n'
|
| 201 |
+
else:
|
| 202 |
+
ret += role + ':\n'
|
| 203 |
+
return ret
|
| 204 |
+
elif self.sep_style == SeparatorStyle.PHOENIX:
|
| 205 |
+
ret = system_prompt
|
| 206 |
+
for role, message in self.messages:
|
| 207 |
+
if message:
|
| 208 |
+
ret += role + ': ' + '<s>' + message + '</s>'
|
| 209 |
+
else:
|
| 210 |
+
ret += role + ': ' + '<s>'
|
| 211 |
+
return ret
|
| 212 |
+
elif self.sep_style == SeparatorStyle.ROBIN:
|
| 213 |
+
ret = system_prompt + self.sep
|
| 214 |
+
for role, message in self.messages:
|
| 215 |
+
if message:
|
| 216 |
+
ret += role + ':\n' + message + self.sep
|
| 217 |
+
else:
|
| 218 |
+
ret += role + ':\n'
|
| 219 |
+
return ret
|
| 220 |
+
elif self.sep_style == SeparatorStyle.FALCON_CHAT:
|
| 221 |
+
ret = ''
|
| 222 |
+
if self.system_message:
|
| 223 |
+
ret += system_prompt + self.sep
|
| 224 |
+
for role, message in self.messages:
|
| 225 |
+
if message:
|
| 226 |
+
ret += role + ': ' + message + self.sep
|
| 227 |
+
else:
|
| 228 |
+
ret += role + ':'
|
| 229 |
+
|
| 230 |
+
return ret
|
| 231 |
+
elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
|
| 232 |
+
seps = [self.sep, self.sep2]
|
| 233 |
+
ret = self.system_message + seps[0]
|
| 234 |
+
for i, (role, message) in enumerate(self.messages):
|
| 235 |
+
if message:
|
| 236 |
+
ret += role + ': ' + message + seps[i % 2]
|
| 237 |
+
else:
|
| 238 |
+
ret += role + ':'
|
| 239 |
+
return ret
|
| 240 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
| 241 |
+
ret = system_prompt + self.sep
|
| 242 |
+
for role, message in self.messages:
|
| 243 |
+
if message:
|
| 244 |
+
if type(message) is tuple:
|
| 245 |
+
message, _, _ = message
|
| 246 |
+
ret += role + message + self.sep
|
| 247 |
+
else:
|
| 248 |
+
ret += role
|
| 249 |
+
return ret
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError(f'Invalid style: {self.sep_style}')
|
| 252 |
+
|
| 253 |
+
def set_system_message(self, system_message: str):
|
| 254 |
+
"""Set the system message."""
|
| 255 |
+
self.system_message = system_message
|
| 256 |
+
|
| 257 |
+
def append_message(self, role: str, message: str):
|
| 258 |
+
"""Append a new message."""
|
| 259 |
+
self.messages.append([role, message])
|
| 260 |
+
|
| 261 |
+
def update_last_message(self, message: str):
|
| 262 |
+
"""Update the last output.
|
| 263 |
+
|
| 264 |
+
The last message is typically set to be None when constructing the prompt,
|
| 265 |
+
so we need to update it in-place after getting the response from a model.
|
| 266 |
+
"""
|
| 267 |
+
self.messages[-1][1] = message
|
| 268 |
+
|
| 269 |
+
def to_gradio_chatbot(self):
|
| 270 |
+
"""Convert the conversation to gradio chatbot format."""
|
| 271 |
+
ret = []
|
| 272 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 273 |
+
if i % 2 == 0:
|
| 274 |
+
ret.append([msg, None])
|
| 275 |
+
else:
|
| 276 |
+
ret[-1][-1] = msg
|
| 277 |
+
return ret
|
| 278 |
+
|
| 279 |
+
def to_openai_api_messages(self):
|
| 280 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
| 281 |
+
ret = [{'role': 'system', 'content': self.system_message}]
|
| 282 |
+
|
| 283 |
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
| 284 |
+
if i % 2 == 0:
|
| 285 |
+
ret.append({'role': 'user', 'content': msg})
|
| 286 |
+
else:
|
| 287 |
+
if msg is not None:
|
| 288 |
+
ret.append({'role': 'assistant', 'content': msg})
|
| 289 |
+
return ret
|
| 290 |
+
|
| 291 |
+
def copy(self):
|
| 292 |
+
return Conversation(
|
| 293 |
+
name=self.name,
|
| 294 |
+
system_template=self.system_template,
|
| 295 |
+
system_message=self.system_message,
|
| 296 |
+
roles=self.roles,
|
| 297 |
+
messages=[[x, y] for x, y in self.messages],
|
| 298 |
+
offset=self.offset,
|
| 299 |
+
sep_style=self.sep_style,
|
| 300 |
+
sep=self.sep,
|
| 301 |
+
sep2=self.sep2,
|
| 302 |
+
stop_str=self.stop_str,
|
| 303 |
+
stop_token_ids=self.stop_token_ids,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def dict(self):
|
| 307 |
+
return {
|
| 308 |
+
'template_name': self.name,
|
| 309 |
+
'system_message': self.system_message,
|
| 310 |
+
'roles': self.roles,
|
| 311 |
+
'messages': self.messages,
|
| 312 |
+
'offset': self.offset,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# A global registry for all conversation templates
|
| 317 |
+
conv_templates: Dict[str, Conversation] = {}
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
| 321 |
+
"""Register a new conversation template."""
|
| 322 |
+
if not override:
|
| 323 |
+
assert (
|
| 324 |
+
template.name not in conv_templates
|
| 325 |
+
), f'{template.name} has been registered.'
|
| 326 |
+
|
| 327 |
+
conv_templates[template.name] = template
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def get_conv_template(name: str) -> Conversation:
|
| 331 |
+
"""Get a conversation template."""
|
| 332 |
+
# breakpoint()
|
| 333 |
+
return conv_templates[name].copy()
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference
|
| 337 |
+
# is that during training, the preprocessing function for the Hermes-2 template doesn't add
|
| 338 |
+
# <s> at the beginning of the tokenized sequence, while the internlm2-chat template does.
|
| 339 |
+
# Therefore, they are completely equivalent during inference.
|
| 340 |
+
register_conv_template(
|
| 341 |
+
Conversation(
|
| 342 |
+
name='Hermes-2',
|
| 343 |
+
system_template='<|im_start|>system\n{system_message}',
|
| 344 |
+
# note: The new system prompt was not used here to avoid changes in benchmark performance.
|
| 345 |
+
# system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
|
| 346 |
+
system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
|
| 347 |
+
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
| 348 |
+
sep_style=SeparatorStyle.MPT,
|
| 349 |
+
sep='<|im_end|>',
|
| 350 |
+
stop_str='<|endoftext|>',
|
| 351 |
+
)
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
register_conv_template(
|
| 356 |
+
Conversation(
|
| 357 |
+
name='internlm2-chat',
|
| 358 |
+
system_template='<|im_start|>system\n{system_message}',
|
| 359 |
+
# note: The new system prompt was not used here to avoid changes in benchmark performance.
|
| 360 |
+
# system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
|
| 361 |
+
system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
|
| 362 |
+
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
| 363 |
+
sep_style=SeparatorStyle.MPT,
|
| 364 |
+
sep='<|im_end|>',
|
| 365 |
+
)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
register_conv_template(
|
| 370 |
+
Conversation(
|
| 371 |
+
name='phi3-chat',
|
| 372 |
+
system_template='<|system|>\n{system_message}',
|
| 373 |
+
# note: The new system prompt was not used here to avoid changes in benchmark performance.
|
| 374 |
+
# system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
|
| 375 |
+
system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
|
| 376 |
+
roles=('<|user|>\n', '<|assistant|>\n'),
|
| 377 |
+
sep_style=SeparatorStyle.MPT,
|
| 378 |
+
sep='<|end|>',
|
| 379 |
+
)
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
register_conv_template(
|
| 384 |
+
Conversation(
|
| 385 |
+
name='internvl2_5',
|
| 386 |
+
system_template='<|im_start|>system\n{system_message}',
|
| 387 |
+
system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
|
| 388 |
+
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
| 389 |
+
sep_style=SeparatorStyle.MPT,
|
| 390 |
+
sep='<|im_end|>\n',
|
| 391 |
+
)
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
register_conv_template(
|
| 395 |
+
Conversation(
|
| 396 |
+
name='plm_v',
|
| 397 |
+
system_template='<|im_start|>system\n{system_message}',
|
| 398 |
+
system_message='You are PLM-V, developed by PLM-Team, a helpful assistant.',
|
| 399 |
+
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
| 400 |
+
sep_style=SeparatorStyle.MPT,
|
| 401 |
+
sep='<|im_end|>\n',
|
| 402 |
+
)
|
| 403 |
+
)
|
ola/model/language_model/ola_qwen.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
import transformers
|
| 7 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 11 |
+
from transformers.generation.utils import GenerateOutput
|
| 12 |
+
|
| 13 |
+
from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM
|
| 14 |
+
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class OlaConfigQwen(Qwen2Config):
|
| 18 |
+
model_type = "ola_qwen"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class OlaQwenModel(OlaMetaModel, Qwen2Model):
|
| 22 |
+
config_class = OlaConfigQwen
|
| 23 |
+
|
| 24 |
+
def __init__(self, config: Qwen2Config):
|
| 25 |
+
super(OlaQwenModel, self).__init__(config)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM):
|
| 29 |
+
config_class = OlaConfigQwen
|
| 30 |
+
|
| 31 |
+
def __init__(self, config):
|
| 32 |
+
super(Qwen2ForCausalLM, self).__init__(config)
|
| 33 |
+
|
| 34 |
+
config.rope_scaling = None
|
| 35 |
+
self.model = OlaQwenModel(config)
|
| 36 |
+
self.vocab_size = config.vocab_size
|
| 37 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 38 |
+
|
| 39 |
+
# Initialize weights and apply final processing
|
| 40 |
+
self.post_init()
|
| 41 |
+
|
| 42 |
+
def get_model(self):
|
| 43 |
+
return self.model
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self,
|
| 47 |
+
input_ids: torch.LongTensor = None,
|
| 48 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 49 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 50 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 51 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 52 |
+
labels: Optional[torch.LongTensor] = None,
|
| 53 |
+
use_cache: Optional[bool] = None,
|
| 54 |
+
output_attentions: Optional[bool] = None,
|
| 55 |
+
output_hidden_states: Optional[bool] = None,
|
| 56 |
+
speech: Optional[torch.FloatTensor] = None,
|
| 57 |
+
speech_lengths: Optional[torch.LongTensor] = None,
|
| 58 |
+
speech_chunks: Optional[torch.LongTensor] = None,
|
| 59 |
+
speech_wav: Optional[torch.FloatTensor] = None,
|
| 60 |
+
images: Optional[torch.FloatTensor] = None,
|
| 61 |
+
images_highres: Optional[List[torch.FloatTensor]] = None,
|
| 62 |
+
image_sizes: Optional[List[List[int]]] = None,
|
| 63 |
+
modalities: Optional[List[str]] = ["image"],
|
| 64 |
+
return_dict: Optional[bool] = None,
|
| 65 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 66 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 67 |
+
|
| 68 |
+
if inputs_embeds is None:
|
| 69 |
+
(
|
| 70 |
+
input_ids,
|
| 71 |
+
position_ids,
|
| 72 |
+
attention_mask,
|
| 73 |
+
past_key_values,
|
| 74 |
+
inputs_embeds,
|
| 75 |
+
labels
|
| 76 |
+
) = self.prepare_inputs_labels_for_speech_vision_text(
|
| 77 |
+
input_ids,
|
| 78 |
+
position_ids,
|
| 79 |
+
attention_mask,
|
| 80 |
+
past_key_values,
|
| 81 |
+
labels,
|
| 82 |
+
speech,
|
| 83 |
+
speech_lengths,
|
| 84 |
+
speech_chunks,
|
| 85 |
+
speech_wav,
|
| 86 |
+
images,
|
| 87 |
+
modalities,
|
| 88 |
+
image_sizes,
|
| 89 |
+
images_highres
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if labels is None:
|
| 93 |
+
return super().forward(
|
| 94 |
+
input_ids=input_ids,
|
| 95 |
+
attention_mask=attention_mask,
|
| 96 |
+
position_ids=position_ids,
|
| 97 |
+
past_key_values=past_key_values,
|
| 98 |
+
inputs_embeds=inputs_embeds,
|
| 99 |
+
use_cache=use_cache,
|
| 100 |
+
output_attentions=output_attentions,
|
| 101 |
+
output_hidden_states=output_hidden_states,
|
| 102 |
+
return_dict=return_dict
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
return self.forward_llm_efficient(
|
| 106 |
+
input_ids=input_ids,
|
| 107 |
+
attention_mask=attention_mask,
|
| 108 |
+
position_ids=position_ids,
|
| 109 |
+
past_key_values=past_key_values,
|
| 110 |
+
inputs_embeds=inputs_embeds,
|
| 111 |
+
labels=labels,
|
| 112 |
+
use_cache=use_cache,
|
| 113 |
+
output_attentions=output_attentions,
|
| 114 |
+
output_hidden_states=output_hidden_states,
|
| 115 |
+
return_dict=return_dict
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
|
| 120 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 121 |
+
output_hidden_states = (
|
| 122 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 123 |
+
)
|
| 124 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 125 |
+
|
| 126 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 127 |
+
outputs = self.model(
|
| 128 |
+
input_ids=input_ids,
|
| 129 |
+
attention_mask=attention_mask,
|
| 130 |
+
position_ids=position_ids,
|
| 131 |
+
past_key_values=past_key_values,
|
| 132 |
+
inputs_embeds=inputs_embeds,
|
| 133 |
+
use_cache=use_cache,
|
| 134 |
+
output_attentions=output_attentions,
|
| 135 |
+
output_hidden_states=output_hidden_states,
|
| 136 |
+
return_dict=return_dict,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
hidden_states = outputs[0]
|
| 140 |
+
hidden_dim = hidden_states.size(-1)
|
| 141 |
+
shift_labels = labels[..., 1:].contiguous().reshape(-1)
|
| 142 |
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
|
| 143 |
+
assert shift_labels.size(0) == shift_hidden_states.size(0)
|
| 144 |
+
mask = shift_labels > -1
|
| 145 |
+
assert mask.float().sum() > 0
|
| 146 |
+
shift_labels = shift_labels[mask]
|
| 147 |
+
shift_hidden_states = shift_hidden_states[mask, :]
|
| 148 |
+
logits = self.lm_head(shift_hidden_states)
|
| 149 |
+
logits = logits.float()
|
| 150 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 151 |
+
loss = loss_fct(logits, shift_labels)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if not return_dict:
|
| 155 |
+
output = (logits,) + outputs[1:]
|
| 156 |
+
return (loss,) + output if loss is not None else output
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
return CausalLMOutputWithPast(
|
| 160 |
+
loss=loss,
|
| 161 |
+
logits=logits,
|
| 162 |
+
past_key_values=outputs.past_key_values,
|
| 163 |
+
hidden_states=outputs.hidden_states,
|
| 164 |
+
attentions=outputs.attentions,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def generate(
|
| 169 |
+
self,
|
| 170 |
+
inputs: Optional[torch.Tensor] = None,
|
| 171 |
+
speech: Optional[torch.Tensor] = None,
|
| 172 |
+
speech_lengths: Optional[torch.Tensor] = None,
|
| 173 |
+
speech_chunks: Optional[torch.Tensor] = None,
|
| 174 |
+
speech_wav: Optional[torch.FloatTensor] = None,
|
| 175 |
+
images: Optional[torch.Tensor] = None,
|
| 176 |
+
images_highres: Optional[List[torch.FloatTensor]] = None,
|
| 177 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 178 |
+
modalities: Optional[List[str]] = ["image"],
|
| 179 |
+
**kwargs,
|
| 180 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 181 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 182 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 183 |
+
if "inputs_embeds" in kwargs:
|
| 184 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 185 |
+
|
| 186 |
+
(
|
| 187 |
+
inputs,
|
| 188 |
+
position_ids,
|
| 189 |
+
attention_mask,
|
| 190 |
+
_,
|
| 191 |
+
inputs_embeds,
|
| 192 |
+
_
|
| 193 |
+
) = self.prepare_inputs_labels_for_speech_vision_text(
|
| 194 |
+
inputs,
|
| 195 |
+
position_ids,
|
| 196 |
+
attention_mask,
|
| 197 |
+
None,
|
| 198 |
+
None,
|
| 199 |
+
speech,
|
| 200 |
+
speech_lengths,
|
| 201 |
+
speech_chunks,
|
| 202 |
+
speech_wav,
|
| 203 |
+
images,
|
| 204 |
+
modalities,
|
| 205 |
+
image_sizes,
|
| 206 |
+
images_highres
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
return super().generate(
|
| 210 |
+
position_ids=position_ids,
|
| 211 |
+
attention_mask=attention_mask,
|
| 212 |
+
inputs_embeds=inputs_embeds,
|
| 213 |
+
**kwargs
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
| 217 |
+
inputs_embeds=None, **kwargs):
|
| 218 |
+
speech = kwargs.pop("speech", None)
|
| 219 |
+
speech_lengths = kwargs.pop("speech_lengths", None)
|
| 220 |
+
speech_chunks = kwargs.pop("speech_chunks", None)
|
| 221 |
+
images = kwargs.pop("images", None)
|
| 222 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
| 223 |
+
inputs = super().prepare_inputs_for_generation(
|
| 224 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
| 225 |
+
)
|
| 226 |
+
if speech is not None:
|
| 227 |
+
inputs['speech'] = speech
|
| 228 |
+
inputs['speech_lengths'] = speech_lengths
|
| 229 |
+
inputs['speech_chunks'] = speech_chunks
|
| 230 |
+
if images is not None:
|
| 231 |
+
inputs["images"] = images
|
| 232 |
+
if image_sizes is not None:
|
| 233 |
+
inputs["image_sizes"] = image_sizes
|
| 234 |
+
return inputs
|
| 235 |
+
|
| 236 |
+
AutoConfig.register("ola_qwen", OlaConfigQwen)
|
| 237 |
+
AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM)
|
ola/model/language_model/ola_qwen3.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
import transformers
|
| 7 |
+
from transformers import GenerationConfig
|
| 8 |
+
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
|
| 9 |
+
SPEECH_TOKEN_INDEX = -200
|
| 10 |
+
|
| 11 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 12 |
+
from transformers.generation.utils import GenerateOutput
|
| 13 |
+
|
| 14 |
+
from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM
|
| 15 |
+
from transformers import Qwen3Config, Qwen3Model, Qwen3ForCausalLM
|
| 16 |
+
from .conversation import get_conv_template
|
| 17 |
+
from ola.constants import IGNORE_INDEX
|
| 18 |
+
|
| 19 |
+
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
|
| 20 |
+
"""Tokenize prompt with speech tokens, similar to OLA's implementation"""
|
| 21 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
|
| 22 |
+
|
| 23 |
+
def insert_separator(X, sep):
|
| 24 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 25 |
+
|
| 26 |
+
input_ids = []
|
| 27 |
+
offset = 0
|
| 28 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 29 |
+
offset = 1
|
| 30 |
+
input_ids.append(prompt_chunks[0][0])
|
| 31 |
+
|
| 32 |
+
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
|
| 33 |
+
input_ids.extend(x[offset:])
|
| 34 |
+
|
| 35 |
+
if return_tensors is not None:
|
| 36 |
+
if return_tensors == 'pt':
|
| 37 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 38 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 39 |
+
return input_ids
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Qwen3Model(Qwen3Model):
|
| 43 |
+
def __init__(self, config: Qwen3Config, llm_config: Qwen3Config):
|
| 44 |
+
# breakpoint()
|
| 45 |
+
super(Qwen3Model, self).__init__(llm_config)
|
| 46 |
+
|
| 47 |
+
class OlaConfigQwen3(Qwen3Config, PretrainedConfig):
|
| 48 |
+
model_type = "ola_internvl"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class OlaQwen3Model(OlaMetaModel, Qwen3Model):
|
| 52 |
+
config_class = OlaConfigQwen3
|
| 53 |
+
|
| 54 |
+
def __init__(self, config: Qwen3Config):
|
| 55 |
+
|
| 56 |
+
super(OlaQwen3Model, self).__init__(config, config.llm_config)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class OlaQwen3ForCausalLM(Qwen3ForCausalLM, OlaMetaForCausalLM):
|
| 60 |
+
config_class = OlaConfigQwen3
|
| 61 |
+
# 从零初始化时不需要 checkpoint conversion mapping
|
| 62 |
+
# _checkpoint_conversion_mapping = {
|
| 63 |
+
# "^language_model.lm_head": "lm_head",
|
| 64 |
+
# "^language_model.model": "model.model",
|
| 65 |
+
# "^vision_model": "model.vision_tower",
|
| 66 |
+
# }
|
| 67 |
+
# model.model.embed_tokens:
|
| 68 |
+
def __init__(self, config):
|
| 69 |
+
super(Qwen3ForCausalLM, self).__init__(config)
|
| 70 |
+
|
| 71 |
+
config.rope_scaling = None
|
| 72 |
+
# breakpoint()
|
| 73 |
+
self.model = OlaQwen3Model(config)
|
| 74 |
+
self.vocab_size = config.vocab_size
|
| 75 |
+
# breakpoint()
|
| 76 |
+
self.ps_version = config.ps_version
|
| 77 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 78 |
+
self.template = "plm_v"
|
| 79 |
+
self.select_layer = config.select_layer
|
| 80 |
+
self.conv_template = get_conv_template(self.template)
|
| 81 |
+
self.system_message = self.conv_template.system_message
|
| 82 |
+
self.num_image_token = int((config.vision_config.image_size // config.vision_config.patch_size) ** 2 * (config.downsample_ratio ** 2))
|
| 83 |
+
self.downsample_ratio = config.downsample_ratio
|
| 84 |
+
# Initialize weights and apply final processing
|
| 85 |
+
self.post_init()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_model(self):
|
| 89 |
+
return self.model
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
input_ids: torch.LongTensor = None,
|
| 94 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 95 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 96 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 97 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 98 |
+
labels: Optional[torch.LongTensor] = None,
|
| 99 |
+
use_cache: Optional[bool] = None,
|
| 100 |
+
output_attentions: Optional[bool] = None,
|
| 101 |
+
output_hidden_states: Optional[bool] = None,
|
| 102 |
+
speech: Optional[torch.FloatTensor] = None,
|
| 103 |
+
speech_lengths: Optional[torch.LongTensor] = None,
|
| 104 |
+
speech_chunks: Optional[torch.LongTensor] = None,
|
| 105 |
+
speech_wav: Optional[torch.FloatTensor] = None,
|
| 106 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 107 |
+
images_highres: Optional[List[torch.FloatTensor]] = None,
|
| 108 |
+
image_sizes: Optional[List[List[int]]] = None,
|
| 109 |
+
modalities: Optional[List[str]] = ["image"],
|
| 110 |
+
image_flags: Optional[torch.LongTensor] = None,
|
| 111 |
+
return_dict: Optional[bool] = None,
|
| 112 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 113 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 114 |
+
# breakpoint()
|
| 115 |
+
if inputs_embeds is None:
|
| 116 |
+
(
|
| 117 |
+
input_ids,
|
| 118 |
+
position_ids,
|
| 119 |
+
attention_mask,
|
| 120 |
+
past_key_values,
|
| 121 |
+
inputs_embeds,
|
| 122 |
+
labels
|
| 123 |
+
) = self.prepare_inputs_labels_for_speech_text_for_internvl(
|
| 124 |
+
input_ids,
|
| 125 |
+
position_ids,
|
| 126 |
+
attention_mask,
|
| 127 |
+
past_key_values,
|
| 128 |
+
labels,
|
| 129 |
+
speech,
|
| 130 |
+
speech_lengths,
|
| 131 |
+
speech_chunks,
|
| 132 |
+
speech_wav,
|
| 133 |
+
modalities,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if labels is None:
|
| 137 |
+
return super().forward(
|
| 138 |
+
input_ids=input_ids,
|
| 139 |
+
attention_mask=attention_mask,
|
| 140 |
+
position_ids=position_ids,
|
| 141 |
+
past_key_values=past_key_values,
|
| 142 |
+
inputs_embeds=inputs_embeds,
|
| 143 |
+
use_cache=use_cache,
|
| 144 |
+
output_attentions=output_attentions,
|
| 145 |
+
output_hidden_states=output_hidden_states,
|
| 146 |
+
return_dict=return_dict
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
return self.forward_llm_efficient(
|
| 150 |
+
input_ids=input_ids,
|
| 151 |
+
attention_mask=attention_mask,
|
| 152 |
+
position_ids=position_ids,
|
| 153 |
+
past_key_values=past_key_values,
|
| 154 |
+
inputs_embeds=inputs_embeds,
|
| 155 |
+
labels=labels,
|
| 156 |
+
use_cache=use_cache,
|
| 157 |
+
output_attentions=output_attentions,
|
| 158 |
+
output_hidden_states=output_hidden_states,
|
| 159 |
+
return_dict=return_dict
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
|
| 164 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 165 |
+
output_hidden_states = (
|
| 166 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 167 |
+
)
|
| 168 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 169 |
+
|
| 170 |
+
# Check inputs before model forward
|
| 171 |
+
print(f"Debug - Input embeddings range: {inputs_embeds.min().item()} to {inputs_embeds.max().item()}")
|
| 172 |
+
print(f"Debug - Input embeddings has nan: {torch.isnan(inputs_embeds).any().item()}")
|
| 173 |
+
print(f"Debug - Input embeddings has inf: {torch.isinf(inputs_embeds).any().item()}")
|
| 174 |
+
|
| 175 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 176 |
+
outputs = self.model(
|
| 177 |
+
input_ids=input_ids,
|
| 178 |
+
attention_mask=attention_mask,
|
| 179 |
+
position_ids=position_ids,
|
| 180 |
+
past_key_values=past_key_values,
|
| 181 |
+
inputs_embeds=inputs_embeds,
|
| 182 |
+
use_cache=use_cache,
|
| 183 |
+
output_attentions=output_attentions,
|
| 184 |
+
output_hidden_states=output_hidden_states,
|
| 185 |
+
return_dict=return_dict,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
hidden_states = outputs[0]
|
| 189 |
+
|
| 190 |
+
# Check hidden states immediately after model forward
|
| 191 |
+
print(f"Debug - Raw hidden states range: {hidden_states.min().item()} to {hidden_states.max().item()}")
|
| 192 |
+
print(f"Debug - Raw hidden states has nan: {torch.isnan(hidden_states).any().item()}")
|
| 193 |
+
print(f"Debug - Raw hidden states has inf: {torch.isinf(hidden_states).any().item()}")
|
| 194 |
+
hidden_dim = hidden_states.size(-1)
|
| 195 |
+
shift_labels = labels[..., 1:].contiguous().reshape(-1)
|
| 196 |
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
|
| 197 |
+
assert shift_labels.size(0) == shift_hidden_states.size(0)
|
| 198 |
+
mask = shift_labels != IGNORE_INDEX
|
| 199 |
+
|
| 200 |
+
# Debug logging
|
| 201 |
+
print(f"Debug - Total tokens: {shift_labels.size(0)}")
|
| 202 |
+
print(f"Debug - Valid tokens: {mask.float().sum().item()}")
|
| 203 |
+
print(f"Debug - Ignored tokens: {(~mask).float().sum().item()}")
|
| 204 |
+
print(f"Debug - Label range: {shift_labels.min().item()} to {shift_labels.max().item()}")
|
| 205 |
+
|
| 206 |
+
assert mask.float().sum() > 0, f"No valid tokens found! Total: {shift_labels.size(0)}, Valid: {mask.float().sum().item()}"
|
| 207 |
+
shift_labels = shift_labels[mask]
|
| 208 |
+
shift_hidden_states = shift_hidden_states[mask, :]
|
| 209 |
+
|
| 210 |
+
print(f"Debug - After filtering: {shift_labels.size(0)} tokens")
|
| 211 |
+
print(f"Debug - Hidden states shape: {shift_hidden_states.shape}")
|
| 212 |
+
print(f"Debug - Hidden states range: {shift_hidden_states.min().item()} to {shift_hidden_states.max().item()}")
|
| 213 |
+
print(f"Debug - Hidden states has nan: {torch.isnan(shift_hidden_states).any().item()}")
|
| 214 |
+
print(f"Debug - Hidden states has inf: {torch.isinf(shift_hidden_states).any().item()}")
|
| 215 |
+
|
| 216 |
+
# Check lm_head weights
|
| 217 |
+
print(f"Debug - lm_head weight shape: {self.lm_head.weight.shape}")
|
| 218 |
+
print(f"Debug - lm_head weight range: {self.lm_head.weight.min().item()} to {self.lm_head.weight.max().item()}")
|
| 219 |
+
print(f"Debug - lm_head weight has nan: {torch.isnan(self.lm_head.weight).any().item()}")
|
| 220 |
+
print(f"Debug - lm_head weight has inf: {torch.isinf(self.lm_head.weight).any().item()}")
|
| 221 |
+
|
| 222 |
+
logits = self.lm_head(shift_hidden_states)
|
| 223 |
+
logits = logits.float()
|
| 224 |
+
|
| 225 |
+
print(f"Debug - Logits shape: {logits.shape}")
|
| 226 |
+
print(f"Debug - Logits range: {logits.min().item()} to {logits.max().item()}")
|
| 227 |
+
print(f"Debug - Logits has nan: {torch.isnan(logits).any().item()}")
|
| 228 |
+
print(f"Debug - Logits has inf: {torch.isinf(logits).any().item()}")
|
| 229 |
+
|
| 230 |
+
# Fix nan values in logits
|
| 231 |
+
if torch.isnan(logits).any():
|
| 232 |
+
print("WARNING: Found nan values in logits, replacing with zeros")
|
| 233 |
+
logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits)
|
| 234 |
+
|
| 235 |
+
# Fix inf values in logits
|
| 236 |
+
if torch.isinf(logits).any():
|
| 237 |
+
print("WARNING: Found inf values in logits, clamping to finite range")
|
| 238 |
+
logits = torch.clamp(logits, min=-1e4, max=1e4)
|
| 239 |
+
|
| 240 |
+
# Additional check: if logits are still problematic, use a fallback
|
| 241 |
+
if torch.isnan(logits).any() or torch.isinf(logits).any():
|
| 242 |
+
print("ERROR: Logits still contain nan/inf after fixing, using fallback")
|
| 243 |
+
logits = torch.zeros_like(logits)
|
| 244 |
+
|
| 245 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 246 |
+
loss = loss_fct(logits, shift_labels)
|
| 247 |
+
|
| 248 |
+
print(f"Debug - Loss: {loss.item()}")
|
| 249 |
+
print(f"Debug - Loss has nan: {torch.isnan(loss).item()}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if not return_dict:
|
| 253 |
+
output = (logits,) + outputs[1:]
|
| 254 |
+
return (loss,) + output if loss is not None else output
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
return CausalLMOutputWithPast(
|
| 258 |
+
loss=loss,
|
| 259 |
+
logits=logits,
|
| 260 |
+
past_key_values=outputs.past_key_values,
|
| 261 |
+
hidden_states=outputs.hidden_states,
|
| 262 |
+
attentions=outputs.attentions,
|
| 263 |
+
)
|
| 264 |
+
def pixel_shuffle(self, x, scale_factor=0.5):
|
| 265 |
+
n, w, h, c = x.size()
|
| 266 |
+
# N, W, H, C --> N, W, H * scale, C // scale
|
| 267 |
+
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
| 268 |
+
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
| 269 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 270 |
+
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
| 271 |
+
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
| 272 |
+
int(c / (scale_factor * scale_factor)))
|
| 273 |
+
if self.ps_version == 'v1':
|
| 274 |
+
warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
|
| 275 |
+
'which results in a transposed image.')
|
| 276 |
+
else:
|
| 277 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 278 |
+
return x
|
| 279 |
+
|
| 280 |
+
def extract_feature(self, pixel_values):
|
| 281 |
+
if self.select_layer == -1:
|
| 282 |
+
# breakpoint()
|
| 283 |
+
vit_embeds = self.get_vision_tower()(
|
| 284 |
+
pixel_values=pixel_values,
|
| 285 |
+
output_hidden_states=False,
|
| 286 |
+
return_dict=True).last_hidden_state
|
| 287 |
+
else:
|
| 288 |
+
vit_embeds = self.get_vision_tower()(
|
| 289 |
+
pixel_values=pixel_values,
|
| 290 |
+
output_hidden_states=True,
|
| 291 |
+
return_dict=True).hidden_states[self.select_layer]
|
| 292 |
+
vit_embeds = vit_embeds[:, 1:, :]
|
| 293 |
+
|
| 294 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
| 295 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
| 296 |
+
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
| 297 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
| 298 |
+
# breakpoint()
|
| 299 |
+
vit_embeds = self.get_vision_projector()(vit_embeds)
|
| 300 |
+
return vit_embeds
|
| 301 |
+
@torch.no_grad()
|
| 302 |
+
def generate(
|
| 303 |
+
self,
|
| 304 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 305 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
| 306 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 307 |
+
visual_features: Optional[torch.FloatTensor] = None,
|
| 308 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 309 |
+
output_hidden_states: Optional[bool] = None,
|
| 310 |
+
speech: Optional[torch.FloatTensor] = None,
|
| 311 |
+
speech_lengths: Optional[torch.LongTensor] = None,
|
| 312 |
+
speech_chunks: Optional[torch.LongTensor] = None,
|
| 313 |
+
speech_wav: Optional[torch.FloatTensor] = None,
|
| 314 |
+
modalities: Optional[List[str]] = ["image"],
|
| 315 |
+
**kwargs,
|
| 316 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 317 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 318 |
+
|
| 319 |
+
if speech is not None:
|
| 320 |
+
(
|
| 321 |
+
_,
|
| 322 |
+
position_ids,
|
| 323 |
+
attention_mask,
|
| 324 |
+
_,
|
| 325 |
+
input_embeds,
|
| 326 |
+
_
|
| 327 |
+
) = self.prepare_inputs_labels_for_speech_text_for_internvl(
|
| 328 |
+
input_ids,
|
| 329 |
+
position_ids,
|
| 330 |
+
attention_mask,
|
| 331 |
+
None,
|
| 332 |
+
None, # labels
|
| 333 |
+
speech,
|
| 334 |
+
speech_lengths,
|
| 335 |
+
speech_chunks,
|
| 336 |
+
speech_wav,
|
| 337 |
+
modalities,
|
| 338 |
+
)
|
| 339 |
+
else:
|
| 340 |
+
# internvl
|
| 341 |
+
assert self.img_context_token_id is not None
|
| 342 |
+
if pixel_values is not None:
|
| 343 |
+
if visual_features is not None:
|
| 344 |
+
vit_embeds = visual_features
|
| 345 |
+
else:
|
| 346 |
+
vit_embeds = self.extract_feature(pixel_values)
|
| 347 |
+
input_embeds = self.get_model().get_input_embeddings()(input_ids)
|
| 348 |
+
B, N, C = input_embeds.shape
|
| 349 |
+
input_embeds = input_embeds.reshape(B * N, C)
|
| 350 |
+
input_ids = input_ids.reshape(B * N)
|
| 351 |
+
selected = (input_ids == self.img_context_token_id)
|
| 352 |
+
assert selected.sum() != 0
|
| 353 |
+
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
| 354 |
+
input_embeds = input_embeds.reshape(B, N, C)
|
| 355 |
+
else:
|
| 356 |
+
input_embeds = self.get_model().get_input_embeddings()(input_ids)
|
| 357 |
+
return super().generate(
|
| 358 |
+
inputs_embeds=input_embeds,
|
| 359 |
+
attention_mask=attention_mask,
|
| 360 |
+
generation_config=generation_config,
|
| 361 |
+
output_hidden_states=output_hidden_states,
|
| 362 |
+
use_cache=True,
|
| 363 |
+
**kwargs,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
|
| 368 |
+
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
|
| 369 |
+
verbose=False, speech=None, speech_lengths=None, speech_wav=None, speech_chunks=None):
|
| 370 |
+
if history is None and pixel_values is not None and '<image>' not in question:
|
| 371 |
+
question = '<image>\n' + question
|
| 372 |
+
|
| 373 |
+
if num_patches_list is None:
|
| 374 |
+
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
|
| 375 |
+
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
|
| 376 |
+
|
| 377 |
+
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
| 378 |
+
self.img_context_token_id = img_context_token_id
|
| 379 |
+
|
| 380 |
+
template = get_conv_template(self.template)
|
| 381 |
+
template.system_message = self.system_message
|
| 382 |
+
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
|
| 383 |
+
|
| 384 |
+
history = [] if history is None else history
|
| 385 |
+
for (old_question, old_answer) in history:
|
| 386 |
+
template.append_message(template.roles[0], old_question)
|
| 387 |
+
template.append_message(template.roles[1], old_answer)
|
| 388 |
+
template.append_message(template.roles[0], question)
|
| 389 |
+
template.append_message(template.roles[1], None)
|
| 390 |
+
query = template.get_prompt()
|
| 391 |
+
|
| 392 |
+
if verbose and pixel_values is not None:
|
| 393 |
+
image_bs = pixel_values.shape[0]
|
| 394 |
+
print(f'dynamic ViT batch size: {image_bs}')
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# Replace image tokens
|
| 398 |
+
for num_patches in num_patches_list:
|
| 399 |
+
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
| 400 |
+
query = query.replace('<image>', image_tokens, 1)
|
| 401 |
+
from ola.conversation import conv_templates, SeparatorStyle
|
| 402 |
+
from ola.mm_utils import KeywordsStoppingCriteria
|
| 403 |
+
conv_mode = "plm_v"
|
| 404 |
+
conv = conv_templates[conv_mode].copy()
|
| 405 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 406 |
+
keywords = [stop_str]
|
| 407 |
+
|
| 408 |
+
# Use OLA-style tokenization for speech inputs
|
| 409 |
+
if speech is not None and '<speech>' in query:
|
| 410 |
+
# Use OLA-style tokenization directly with <speech> tokens
|
| 411 |
+
input_ids = tokenizer_speech_token(query, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device)
|
| 412 |
+
# Handle case where pad_token_id might be None
|
| 413 |
+
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 151643
|
| 414 |
+
attention_mask = input_ids.ne(pad_token_id).long().to(self.device)
|
| 415 |
+
|
| 416 |
+
else:
|
| 417 |
+
model_inputs = tokenizer(query, return_tensors='pt')
|
| 418 |
+
input_ids = model_inputs['input_ids'].to(self.device)
|
| 419 |
+
attention_mask = model_inputs['attention_mask'].to(self.device)
|
| 420 |
+
generation_config['eos_token_id'] = eos_token_id
|
| 421 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 422 |
+
# generation_config["stopping_criteria"] = stopping_criteria
|
| 423 |
+
generation_output = self.generate(
|
| 424 |
+
pixel_values=pixel_values,
|
| 425 |
+
input_ids=input_ids,
|
| 426 |
+
attention_mask=attention_mask,
|
| 427 |
+
speech=speech,
|
| 428 |
+
speech_lengths=speech_lengths,
|
| 429 |
+
speech_chunks=speech_chunks,
|
| 430 |
+
speech_wav=speech_wav,
|
| 431 |
+
**generation_config
|
| 432 |
+
)
|
| 433 |
+
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
| 434 |
+
response = response.split(template.sep.strip())[0].strip()
|
| 435 |
+
history.append((question, response))
|
| 436 |
+
if return_history:
|
| 437 |
+
return response, history
|
| 438 |
+
else:
|
| 439 |
+
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
|
| 440 |
+
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
|
| 441 |
+
if verbose:
|
| 442 |
+
print(query_to_print, response)
|
| 443 |
+
return response
|
| 444 |
+
|
| 445 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
| 446 |
+
inputs_embeds=None, **kwargs):
|
| 447 |
+
speech = kwargs.pop("speech", None)
|
| 448 |
+
speech_lengths = kwargs.pop("speech_lengths", None)
|
| 449 |
+
speech_chunks = kwargs.pop("speech_chunks", None)
|
| 450 |
+
images = kwargs.pop("images", None)
|
| 451 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
| 452 |
+
inputs = super().prepare_inputs_for_generation(
|
| 453 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
| 454 |
+
)
|
| 455 |
+
if speech is not None:
|
| 456 |
+
inputs['speech'] = speech
|
| 457 |
+
inputs['speech_lengths'] = speech_lengths
|
| 458 |
+
inputs['speech_chunks'] = speech_chunks
|
| 459 |
+
if images is not None:
|
| 460 |
+
inputs["images"] = images
|
| 461 |
+
if image_sizes is not None:
|
| 462 |
+
inputs["image_sizes"] = image_sizes
|
| 463 |
+
return inputs
|
| 464 |
+
|
| 465 |
+
AutoConfig.register("ola_internvl", OlaConfigQwen3)
|
| 466 |
+
AutoModelForCausalLM.register(OlaConfigQwen3, OlaQwen3ForCausalLM)
|
ola/model/multimodal_encoder/__pycache__/builder.cpython-312.pyc
ADDED
|
Binary file (960 Bytes). View file
|
|
|
ola/model/multimodal_encoder/__pycache__/configuration_intern_vit.cpython-312.pyc
ADDED
|
Binary file (5.7 kB). View file
|
|
|
ola/model/multimodal_encoder/__pycache__/internvl_vit.cpython-312.pyc
ADDED
|
Binary file (26.2 kB). View file
|
|
|
ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-312.pyc
ADDED
|
Binary file (46.7 kB). View file
|
|
|
ola/model/multimodal_encoder/builder.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from .oryx_vit import SigLIPViTAnysizeWrapper
|
| 3 |
+
from .internvl_vit import InternVisionModel
|
| 4 |
+
|
| 5 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
| 6 |
+
# breakpoint()
|
| 7 |
+
if vision_tower_cfg.model_type == 'intern_vit_6b':
|
| 8 |
+
vision_tower = InternVisionModel(vision_tower_cfg)
|
| 9 |
+
# breakpoint()
|
| 10 |
+
return vision_tower
|
| 11 |
+
else:
|
| 12 |
+
vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None))
|
| 13 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
| 14 |
+
print(f"Buiding OryxViTWrapper from {vision_tower}...")
|
| 15 |
+
# path = vision_tower.split(":")[1]
|
| 16 |
+
return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs)
|
ola/model/multimodal_encoder/configuration_intern_vit.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2024 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class InternVisionConfig(PretrainedConfig):
|
| 16 |
+
r"""
|
| 17 |
+
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
| 18 |
+
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
| 19 |
+
|
| 20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 21 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 25 |
+
Number of color channels in the input images (e.g., 3 for RGB).
|
| 26 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 27 |
+
The size (resolution) of each patch.
|
| 28 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 29 |
+
The size (resolution) of each image.
|
| 30 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
| 31 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
| 32 |
+
hidden_size (`int`, *optional*, defaults to 3200):
|
| 33 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 34 |
+
num_attention_heads (`int`, *optional*, defaults to 25):
|
| 35 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 36 |
+
intermediate_size (`int`, *optional*, defaults to 12800):
|
| 37 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 38 |
+
qk_normalization (`bool`, *optional*, defaults to `True`):
|
| 39 |
+
Whether to normalize the queries and keys in the self-attention layers.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
| 41 |
+
Number of hidden layers in the Transformer encoder.
|
| 42 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
| 43 |
+
Whether to use flash attention mechanism.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 46 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
| 47 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
| 48 |
+
The epsilon used by the layer normalization layers.
|
| 49 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 51 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
| 52 |
+
Dropout rate for stochastic depth.
|
| 53 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 54 |
+
The dropout ratio for the attention probabilities.
|
| 55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 57 |
+
initializer_factor (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
A factor for layer scale.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
model_type = 'intern_vit_6b'
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
num_channels=3,
|
| 66 |
+
patch_size=14,
|
| 67 |
+
image_size=224,
|
| 68 |
+
qkv_bias=False,
|
| 69 |
+
hidden_size=3200,
|
| 70 |
+
num_attention_heads=25,
|
| 71 |
+
intermediate_size=12800,
|
| 72 |
+
qk_normalization=True,
|
| 73 |
+
num_hidden_layers=48,
|
| 74 |
+
use_flash_attn=True,
|
| 75 |
+
hidden_act='gelu',
|
| 76 |
+
norm_type='rms_norm',
|
| 77 |
+
layer_norm_eps=1e-6,
|
| 78 |
+
dropout=0.0,
|
| 79 |
+
drop_path_rate=0.0,
|
| 80 |
+
attention_dropout=0.0,
|
| 81 |
+
initializer_range=0.02,
|
| 82 |
+
initializer_factor=0.1,
|
| 83 |
+
**kwargs,
|
| 84 |
+
):
|
| 85 |
+
super().__init__(**kwargs)
|
| 86 |
+
|
| 87 |
+
self.hidden_size = hidden_size
|
| 88 |
+
self.intermediate_size = intermediate_size
|
| 89 |
+
self.dropout = dropout
|
| 90 |
+
self.drop_path_rate = drop_path_rate
|
| 91 |
+
self.num_hidden_layers = num_hidden_layers
|
| 92 |
+
self.num_attention_heads = num_attention_heads
|
| 93 |
+
self.num_channels = num_channels
|
| 94 |
+
self.patch_size = patch_size
|
| 95 |
+
self.image_size = image_size
|
| 96 |
+
self.initializer_range = initializer_range
|
| 97 |
+
self.initializer_factor = initializer_factor
|
| 98 |
+
self.attention_dropout = attention_dropout
|
| 99 |
+
self.layer_norm_eps = layer_norm_eps
|
| 100 |
+
self.hidden_act = hidden_act
|
| 101 |
+
self.norm_type = norm_type
|
| 102 |
+
self.qkv_bias = qkv_bias
|
| 103 |
+
self.qk_normalization = qk_normalization
|
| 104 |
+
self.use_flash_attn = use_flash_attn
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
|
| 108 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 109 |
+
|
| 110 |
+
if 'vision_config' in config_dict:
|
| 111 |
+
config_dict = config_dict['vision_config']
|
| 112 |
+
|
| 113 |
+
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
|
| 114 |
+
logger.warning(
|
| 115 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 116 |
+
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return cls.from_dict(config_dict, **kwargs)
|
ola/model/multimodal_encoder/internvl_vit.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2024 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.utils.checkpoint
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from timm.layers import DropPath
|
| 14 |
+
from torch import nn
|
| 15 |
+
from transformers.activations import ACT2FN
|
| 16 |
+
from transformers.modeling_outputs import (BaseModelOutput,
|
| 17 |
+
BaseModelOutputWithPooling)
|
| 18 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
+
from transformers.utils import logging
|
| 20 |
+
|
| 21 |
+
from .configuration_intern_vit import InternVisionConfig
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 25 |
+
from flash_attn.flash_attn_interface import \
|
| 26 |
+
flash_attn_varlen_qkvpacked_func
|
| 27 |
+
has_flash_attn = True
|
| 28 |
+
except:
|
| 29 |
+
print('FlashAttention2 is not installed.')
|
| 30 |
+
has_flash_attn = False
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class FlashAttention(nn.Module):
|
| 36 |
+
"""Implement the scaled dot product attention with softmax.
|
| 37 |
+
Arguments
|
| 38 |
+
---------
|
| 39 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 40 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 41 |
+
runtime)
|
| 42 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 43 |
+
(default: 0.0)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.softmax_scale = softmax_scale
|
| 49 |
+
self.dropout_p = attention_dropout
|
| 50 |
+
|
| 51 |
+
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
| 52 |
+
max_s=None, need_weights=False):
|
| 53 |
+
"""Implements the multihead softmax attention.
|
| 54 |
+
Arguments
|
| 55 |
+
---------
|
| 56 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
| 57 |
+
if unpadded: (nnz, 3, h, d)
|
| 58 |
+
key_padding_mask: a bool tensor of shape (B, S)
|
| 59 |
+
"""
|
| 60 |
+
assert not need_weights
|
| 61 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
| 62 |
+
assert qkv.is_cuda
|
| 63 |
+
|
| 64 |
+
if cu_seqlens is None:
|
| 65 |
+
batch_size = qkv.shape[0]
|
| 66 |
+
seqlen = qkv.shape[1]
|
| 67 |
+
if key_padding_mask is None:
|
| 68 |
+
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
| 69 |
+
max_s = seqlen
|
| 70 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
| 71 |
+
device=qkv.device)
|
| 72 |
+
output = flash_attn_varlen_qkvpacked_func(
|
| 73 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 74 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 75 |
+
)
|
| 76 |
+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
| 77 |
+
else:
|
| 78 |
+
nheads = qkv.shape[-2]
|
| 79 |
+
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
| 80 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
| 81 |
+
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
| 82 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
| 83 |
+
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 84 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 85 |
+
)
|
| 86 |
+
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
| 87 |
+
indices, batch_size, seqlen),
|
| 88 |
+
'b s (h d) -> b s h d', h=nheads)
|
| 89 |
+
else:
|
| 90 |
+
assert max_s is not None
|
| 91 |
+
output = flash_attn_varlen_qkvpacked_func(
|
| 92 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 93 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return output, None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class InternRMSNorm(nn.Module):
|
| 100 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 103 |
+
self.variance_epsilon = eps
|
| 104 |
+
|
| 105 |
+
def forward(self, hidden_states):
|
| 106 |
+
input_dtype = hidden_states.dtype
|
| 107 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 108 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 109 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 110 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
from apex.normalization import FusedRMSNorm
|
| 115 |
+
|
| 116 |
+
InternRMSNorm = FusedRMSNorm # noqa
|
| 117 |
+
|
| 118 |
+
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
|
| 119 |
+
except ImportError:
|
| 120 |
+
# using the normal InternRMSNorm
|
| 121 |
+
pass
|
| 122 |
+
except Exception:
|
| 123 |
+
logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
NORM2FN = {
|
| 128 |
+
'rms_norm': InternRMSNorm,
|
| 129 |
+
'layer_norm': nn.LayerNorm,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class InternVisionEmbeddings(nn.Module):
|
| 134 |
+
def __init__(self, config: InternVisionConfig):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.config = config
|
| 137 |
+
self.embed_dim = config.hidden_size
|
| 138 |
+
self.image_size = config.image_size
|
| 139 |
+
self.patch_size = config.patch_size
|
| 140 |
+
|
| 141 |
+
self.class_embedding = nn.Parameter(
|
| 142 |
+
torch.randn(1, 1, self.embed_dim),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.patch_embedding = nn.Conv2d(
|
| 146 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 150 |
+
self.num_positions = self.num_patches + 1
|
| 151 |
+
|
| 152 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 153 |
+
|
| 154 |
+
def _get_pos_embed(self, pos_embed, H, W):
|
| 155 |
+
target_dtype = pos_embed.dtype
|
| 156 |
+
pos_embed = pos_embed.float().reshape(
|
| 157 |
+
1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
|
| 158 |
+
pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
|
| 159 |
+
reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
|
| 160 |
+
return pos_embed
|
| 161 |
+
|
| 162 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 163 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 164 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
|
| 165 |
+
batch_size, _, height, width = patch_embeds.shape
|
| 166 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 167 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 168 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 169 |
+
position_embedding = torch.cat([
|
| 170 |
+
self.position_embedding[:, :1, :],
|
| 171 |
+
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
|
| 172 |
+
], dim=1)
|
| 173 |
+
embeddings = embeddings + position_embedding.to(target_dtype)
|
| 174 |
+
return embeddings
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class InternAttention(nn.Module):
|
| 178 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, config: InternVisionConfig):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.config = config
|
| 183 |
+
self.embed_dim = config.hidden_size
|
| 184 |
+
self.num_heads = config.num_attention_heads
|
| 185 |
+
self.use_flash_attn = config.use_flash_attn and has_flash_attn
|
| 186 |
+
if config.use_flash_attn and not has_flash_attn:
|
| 187 |
+
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
|
| 188 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 189 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
| 192 |
+
f' {self.num_heads}).'
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
self.scale = self.head_dim ** -0.5
|
| 196 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
| 197 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
| 198 |
+
self.proj_drop = nn.Dropout(config.dropout)
|
| 199 |
+
|
| 200 |
+
self.qk_normalization = config.qk_normalization
|
| 201 |
+
|
| 202 |
+
if self.qk_normalization:
|
| 203 |
+
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 204 |
+
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 205 |
+
|
| 206 |
+
if self.use_flash_attn:
|
| 207 |
+
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
|
| 208 |
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 209 |
+
|
| 210 |
+
def _naive_attn(self, x):
|
| 211 |
+
B, N, C = x.shape
|
| 212 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 213 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 214 |
+
|
| 215 |
+
if self.qk_normalization:
|
| 216 |
+
B_, H_, N_, D_ = q.shape
|
| 217 |
+
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
| 218 |
+
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
| 219 |
+
|
| 220 |
+
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
| 221 |
+
attn = attn.softmax(dim=-1)
|
| 222 |
+
attn = self.attn_drop(attn)
|
| 223 |
+
|
| 224 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 225 |
+
x = self.proj(x)
|
| 226 |
+
x = self.proj_drop(x)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
| 230 |
+
qkv = self.qkv(x)
|
| 231 |
+
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
| 232 |
+
|
| 233 |
+
if self.qk_normalization:
|
| 234 |
+
q, k, v = qkv.unbind(2)
|
| 235 |
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
| 236 |
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
| 237 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 238 |
+
|
| 239 |
+
context, _ = self.inner_attn(
|
| 240 |
+
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
|
| 241 |
+
)
|
| 242 |
+
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
|
| 243 |
+
outs = self.proj_drop(outs)
|
| 244 |
+
return outs
|
| 245 |
+
|
| 246 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 247 |
+
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class InternMLP(nn.Module):
|
| 252 |
+
def __init__(self, config: InternVisionConfig):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.config = config
|
| 255 |
+
self.act = ACT2FN[config.hidden_act]
|
| 256 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 257 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 258 |
+
|
| 259 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 260 |
+
hidden_states = self.fc1(hidden_states)
|
| 261 |
+
hidden_states = self.act(hidden_states)
|
| 262 |
+
hidden_states = self.fc2(hidden_states)
|
| 263 |
+
return hidden_states
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class InternVisionEncoderLayer(nn.Module):
|
| 267 |
+
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.embed_dim = config.hidden_size
|
| 270 |
+
self.intermediate_size = config.intermediate_size
|
| 271 |
+
self.norm_type = config.norm_type
|
| 272 |
+
|
| 273 |
+
self.attn = InternAttention(config)
|
| 274 |
+
self.mlp = InternMLP(config)
|
| 275 |
+
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
| 276 |
+
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
| 277 |
+
|
| 278 |
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 279 |
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 280 |
+
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 281 |
+
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 282 |
+
|
| 283 |
+
def forward(
|
| 284 |
+
self,
|
| 285 |
+
hidden_states: torch.Tensor,
|
| 286 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
| 287 |
+
"""
|
| 288 |
+
Args:
|
| 289 |
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 290 |
+
"""
|
| 291 |
+
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
|
| 292 |
+
|
| 293 |
+
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
|
| 294 |
+
|
| 295 |
+
return hidden_states
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class InternVisionEncoder(nn.Module):
|
| 299 |
+
"""
|
| 300 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 301 |
+
[`InternEncoderLayer`].
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
config (`InternConfig`):
|
| 305 |
+
The corresponding vision configuration for the `InternEncoder`.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
def __init__(self, config: InternVisionConfig):
|
| 309 |
+
super().__init__()
|
| 310 |
+
self.config = config
|
| 311 |
+
# stochastic depth decay rule
|
| 312 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
| 313 |
+
self.layers = nn.ModuleList([
|
| 314 |
+
InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
|
| 315 |
+
self.gradient_checkpointing = True
|
| 316 |
+
|
| 317 |
+
def forward(
|
| 318 |
+
self,
|
| 319 |
+
inputs_embeds,
|
| 320 |
+
output_hidden_states: Optional[bool] = None,
|
| 321 |
+
return_dict: Optional[bool] = None,
|
| 322 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 323 |
+
r"""
|
| 324 |
+
Args:
|
| 325 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 326 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
| 327 |
+
output_hidden_states (`bool`, *optional*):
|
| 328 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 329 |
+
for more detail.
|
| 330 |
+
return_dict (`bool`, *optional*):
|
| 331 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 332 |
+
"""
|
| 333 |
+
output_hidden_states = (
|
| 334 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 335 |
+
)
|
| 336 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 337 |
+
|
| 338 |
+
encoder_states = () if output_hidden_states else None
|
| 339 |
+
hidden_states = inputs_embeds
|
| 340 |
+
|
| 341 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 342 |
+
if output_hidden_states:
|
| 343 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 344 |
+
if self.gradient_checkpointing and self.training:
|
| 345 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 346 |
+
encoder_layer,
|
| 347 |
+
hidden_states)
|
| 348 |
+
else:
|
| 349 |
+
layer_outputs = encoder_layer(
|
| 350 |
+
hidden_states,
|
| 351 |
+
)
|
| 352 |
+
hidden_states = layer_outputs
|
| 353 |
+
|
| 354 |
+
if output_hidden_states:
|
| 355 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 356 |
+
|
| 357 |
+
if not return_dict:
|
| 358 |
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
| 359 |
+
return BaseModelOutput(
|
| 360 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class InternVisionModel(PreTrainedModel):
|
| 365 |
+
main_input_name = 'pixel_values'
|
| 366 |
+
# _supports_flash_attn_2 = True
|
| 367 |
+
supports_gradient_checkpointing = True
|
| 368 |
+
config_class = InternVisionConfig
|
| 369 |
+
_no_split_modules = ['InternVisionEncoderLayer']
|
| 370 |
+
# support transformers 4.51.+
|
| 371 |
+
_tp_plan = ''
|
| 372 |
+
|
| 373 |
+
def __init__(self, config: InternVisionConfig):
|
| 374 |
+
super().__init__(config)
|
| 375 |
+
self.config = config
|
| 376 |
+
# Force eager attention implementation to avoid scaled_dot_product_attention error
|
| 377 |
+
# self._attn_implementation = "eager"
|
| 378 |
+
|
| 379 |
+
self.embeddings = InternVisionEmbeddings(config)
|
| 380 |
+
self.encoder = InternVisionEncoder(config)
|
| 381 |
+
|
| 382 |
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
| 383 |
+
pos_emb = self.embeddings.position_embedding
|
| 384 |
+
_, num_positions, embed_dim = pos_emb.shape
|
| 385 |
+
cls_emb = pos_emb[:, :1, :]
|
| 386 |
+
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
|
| 387 |
+
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
|
| 388 |
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
| 389 |
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
| 390 |
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
| 391 |
+
self.embeddings.image_size = new_size
|
| 392 |
+
logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
|
| 393 |
+
|
| 394 |
+
def get_input_embeddings(self):
|
| 395 |
+
return self.embeddings
|
| 396 |
+
|
| 397 |
+
def forward(
|
| 398 |
+
self,
|
| 399 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 400 |
+
output_hidden_states: Optional[bool] = None,
|
| 401 |
+
return_dict: Optional[bool] = None,
|
| 402 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
| 403 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 404 |
+
output_hidden_states = (
|
| 405 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 406 |
+
)
|
| 407 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 408 |
+
|
| 409 |
+
if pixel_values is None and pixel_embeds is None:
|
| 410 |
+
raise ValueError('You have to specify pixel_values or pixel_embeds')
|
| 411 |
+
|
| 412 |
+
if pixel_embeds is not None:
|
| 413 |
+
hidden_states = pixel_embeds
|
| 414 |
+
else:
|
| 415 |
+
if len(pixel_values.shape) == 4:
|
| 416 |
+
hidden_states = self.embeddings(pixel_values)
|
| 417 |
+
else:
|
| 418 |
+
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
|
| 419 |
+
encoder_outputs = self.encoder(
|
| 420 |
+
inputs_embeds=hidden_states,
|
| 421 |
+
output_hidden_states=output_hidden_states,
|
| 422 |
+
return_dict=return_dict,
|
| 423 |
+
)
|
| 424 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
| 425 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 426 |
+
|
| 427 |
+
if not return_dict:
|
| 428 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 429 |
+
|
| 430 |
+
return BaseModelOutputWithPooling(
|
| 431 |
+
last_hidden_state=last_hidden_state,
|
| 432 |
+
pooler_output=pooled_output,
|
| 433 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 434 |
+
attentions=encoder_outputs.attentions,
|
| 435 |
+
)
|
ola/model/multimodal_encoder/oryx_vit.py
ADDED
|
@@ -0,0 +1,1075 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import (
|
| 6 |
+
Callable,
|
| 7 |
+
Dict,
|
| 8 |
+
Final,
|
| 9 |
+
List,
|
| 10 |
+
Literal,
|
| 11 |
+
Optional,
|
| 12 |
+
Sequence,
|
| 13 |
+
Set,
|
| 14 |
+
Tuple,
|
| 15 |
+
Type,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from torch.utils.checkpoint import checkpoint
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
try:
|
| 24 |
+
from timm.layers import (
|
| 25 |
+
AttentionPoolLatent,
|
| 26 |
+
DropPath,
|
| 27 |
+
LayerType,
|
| 28 |
+
Mlp,
|
| 29 |
+
PatchDropout,
|
| 30 |
+
PatchEmbed,
|
| 31 |
+
resample_abs_pos_embed,
|
| 32 |
+
)
|
| 33 |
+
from timm.models._manipulate import checkpoint_seq, named_apply
|
| 34 |
+
except:
|
| 35 |
+
print('Wrong timm version')
|
| 36 |
+
|
| 37 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 38 |
+
|
| 39 |
+
from typing import Optional
|
| 40 |
+
|
| 41 |
+
import logging
|
| 42 |
+
import torch
|
| 43 |
+
import torch.nn as nn
|
| 44 |
+
import torch.nn.functional as F
|
| 45 |
+
|
| 46 |
+
import deepspeed
|
| 47 |
+
import os
|
| 48 |
+
|
| 49 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 50 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 51 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 52 |
+
def norm_cdf(x):
|
| 53 |
+
# Computes standard normal cumulative distribution function
|
| 54 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
| 55 |
+
|
| 56 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 57 |
+
warnings.warn(
|
| 58 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 59 |
+
"The distribution of values may be incorrect.",
|
| 60 |
+
stacklevel=2,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
# Values are generated by using a truncated uniform distribution and
|
| 65 |
+
# then using the inverse CDF for the normal distribution.
|
| 66 |
+
# Get upper and lower cdf values
|
| 67 |
+
l = norm_cdf((a - mean) / std) # noqa: E741
|
| 68 |
+
u = norm_cdf((b - mean) / std)
|
| 69 |
+
|
| 70 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 71 |
+
# [2l-1, 2u-1].
|
| 72 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 73 |
+
|
| 74 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 75 |
+
# standard normal
|
| 76 |
+
tensor.erfinv_()
|
| 77 |
+
|
| 78 |
+
# Transform to proper mean, std
|
| 79 |
+
tensor.mul_(std * math.sqrt(2.0))
|
| 80 |
+
tensor.add_(mean)
|
| 81 |
+
|
| 82 |
+
# Clamp to ensure it's in the proper range
|
| 83 |
+
tensor.clamp_(min=a, max=b)
|
| 84 |
+
return tensor
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
| 88 |
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
| 89 |
+
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
|
| 90 |
+
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
|
| 91 |
+
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
|
| 92 |
+
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 93 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 94 |
+
the bounds. The method used for generating the random values works
|
| 95 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
| 96 |
+
Args:
|
| 97 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 98 |
+
mean: the mean of the normal distribution
|
| 99 |
+
std: the standard deviation of the normal distribution
|
| 100 |
+
a: the minimum cutoff value
|
| 101 |
+
b: the maximum cutoff value
|
| 102 |
+
Examples:
|
| 103 |
+
>>> w = torch.empty(3, 5)
|
| 104 |
+
>>> nn.init.trunc_normal_(w)
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
dtype = tensor.dtype
|
| 109 |
+
tensor_fp32 = tensor.float()
|
| 110 |
+
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
|
| 111 |
+
tensor_dtype = tensor_fp32.to(dtype=dtype)
|
| 112 |
+
tensor.copy_(tensor_dtype)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def init_weights(self):
|
| 116 |
+
if self.pos_embed is not None:
|
| 117 |
+
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
| 118 |
+
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
|
| 122 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 123 |
+
if isinstance(module, nn.Linear):
|
| 124 |
+
trunc_normal_(module.weight, std=0.02)
|
| 125 |
+
if module.bias is not None:
|
| 126 |
+
nn.init.zeros_(module.bias)
|
| 127 |
+
elif hasattr(module, "init_weights"):
|
| 128 |
+
module.init_weights()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Attention(nn.Module):
|
| 132 |
+
fused_attn: Final[bool]
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
dim: int,
|
| 137 |
+
num_heads: int = 8,
|
| 138 |
+
qkv_bias: bool = False,
|
| 139 |
+
qk_norm: bool = False,
|
| 140 |
+
attn_drop: float = 0.0,
|
| 141 |
+
proj_drop: float = 0.0,
|
| 142 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 143 |
+
) -> None:
|
| 144 |
+
super().__init__()
|
| 145 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 146 |
+
self.num_heads = num_heads
|
| 147 |
+
self.head_dim = dim // num_heads
|
| 148 |
+
self.scale = self.head_dim**-0.5
|
| 149 |
+
# self.fused_attn = use_fused_attn()
|
| 150 |
+
self.fused_attn = True
|
| 151 |
+
|
| 152 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 153 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 154 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 155 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 156 |
+
self.proj = nn.Linear(dim, dim)
|
| 157 |
+
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
|
| 160 |
+
B, N, C = x.shape
|
| 161 |
+
qkv = (
|
| 162 |
+
self.qkv(x)
|
| 163 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
| 164 |
+
.permute(2, 0, 3, 1, 4)
|
| 165 |
+
)
|
| 166 |
+
q, k, v = qkv.unbind(0)
|
| 167 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 168 |
+
|
| 169 |
+
if cu_slens is not None:
|
| 170 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
| 171 |
+
k = k.permute(0, 2, 1, 3)
|
| 172 |
+
v = v.permute(0, 2, 1, 3)
|
| 173 |
+
max_seqlen = torch.max(cu_slens[1:] - cu_slens[:-1]).item()
|
| 174 |
+
x = flash_attn_varlen_func(
|
| 175 |
+
q.squeeze(0),
|
| 176 |
+
k.squeeze(0),
|
| 177 |
+
v.squeeze(0),
|
| 178 |
+
cu_seqlens_q=cu_slens,
|
| 179 |
+
cu_seqlens_k=cu_slens,
|
| 180 |
+
max_seqlen_q=max_seqlen,
|
| 181 |
+
max_seqlen_k=max_seqlen,
|
| 182 |
+
softmax_scale=self.scale,
|
| 183 |
+
causal=False,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
x = x.reshape(B, N, -1)
|
| 187 |
+
x = self.proj(x)
|
| 188 |
+
x = self.proj_drop(x)
|
| 189 |
+
|
| 190 |
+
else:
|
| 191 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
| 192 |
+
k = k.permute(0, 2, 1, 3)
|
| 193 |
+
v = v.permute(0, 2, 1, 3)
|
| 194 |
+
x = flash_attn_func(q, k, v, softmax_scale=self.scale) # -> b, n, h, c
|
| 195 |
+
|
| 196 |
+
x = x.reshape(B, N, -1)
|
| 197 |
+
x = self.proj(x)
|
| 198 |
+
x = self.proj_drop(x)
|
| 199 |
+
# if self.fused_attn:
|
| 200 |
+
# x = F.scaled_dot_product_attention(
|
| 201 |
+
# q,
|
| 202 |
+
# k,
|
| 203 |
+
# v,
|
| 204 |
+
# dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 205 |
+
# )
|
| 206 |
+
# else:
|
| 207 |
+
# q = q * self.scale
|
| 208 |
+
# attn = q @ k.transpose(-2, -1)
|
| 209 |
+
# attn = attn.softmax(dim=-1)
|
| 210 |
+
# attn = self.attn_drop(attn)
|
| 211 |
+
# x = attn @ v
|
| 212 |
+
|
| 213 |
+
# x = x.transpose(1, 2).reshape(B, N, C)
|
| 214 |
+
# x = self.proj(x)
|
| 215 |
+
# x = self.proj_drop(x)
|
| 216 |
+
return x
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class LayerScale(nn.Module):
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
dim: int,
|
| 223 |
+
init_values: float = 1e-5,
|
| 224 |
+
inplace: bool = False,
|
| 225 |
+
) -> None:
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.inplace = inplace
|
| 228 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 229 |
+
|
| 230 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 231 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class Block(nn.Module):
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
dim: int,
|
| 238 |
+
num_heads: int,
|
| 239 |
+
mlp_ratio: float = 4.0,
|
| 240 |
+
qkv_bias: bool = False,
|
| 241 |
+
qk_norm: bool = False,
|
| 242 |
+
proj_drop: float = 0.0,
|
| 243 |
+
attn_drop: float = 0.0,
|
| 244 |
+
init_values: Optional[float] = None,
|
| 245 |
+
drop_path: float = 0.0,
|
| 246 |
+
act_layer: nn.Module = nn.GELU,
|
| 247 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 248 |
+
mlp_layer: nn.Module = Mlp,
|
| 249 |
+
) -> None:
|
| 250 |
+
super().__init__()
|
| 251 |
+
self.norm1 = norm_layer(dim)
|
| 252 |
+
self.attn = Attention(
|
| 253 |
+
dim,
|
| 254 |
+
num_heads=num_heads,
|
| 255 |
+
qkv_bias=qkv_bias,
|
| 256 |
+
qk_norm=qk_norm,
|
| 257 |
+
attn_drop=attn_drop,
|
| 258 |
+
proj_drop=proj_drop,
|
| 259 |
+
norm_layer=norm_layer,
|
| 260 |
+
)
|
| 261 |
+
self.ls1 = (
|
| 262 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 263 |
+
)
|
| 264 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 265 |
+
|
| 266 |
+
self.norm2 = norm_layer(dim)
|
| 267 |
+
self.mlp = mlp_layer(
|
| 268 |
+
in_features=dim,
|
| 269 |
+
hidden_features=int(dim * mlp_ratio),
|
| 270 |
+
act_layer=act_layer,
|
| 271 |
+
drop=proj_drop,
|
| 272 |
+
)
|
| 273 |
+
self.ls2 = (
|
| 274 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 275 |
+
)
|
| 276 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 277 |
+
|
| 278 |
+
def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
|
| 279 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_slens=cu_slens)))
|
| 280 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class VisionTransformer(nn.Module):
|
| 285 |
+
"""Vision Transformer
|
| 286 |
+
|
| 287 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
| 288 |
+
- https://arxiv.org/abs/2010.11929
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
dynamic_img_size: Final[bool]
|
| 292 |
+
|
| 293 |
+
def __init__(
|
| 294 |
+
self,
|
| 295 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 296 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 297 |
+
in_chans: int = 3,
|
| 298 |
+
num_classes: int = 1000,
|
| 299 |
+
global_pool: Literal["", "avg", "token", "map"] = "token",
|
| 300 |
+
embed_dim: int = 768,
|
| 301 |
+
depth: int = 12,
|
| 302 |
+
num_heads: int = 12,
|
| 303 |
+
mlp_ratio: float = 4.0,
|
| 304 |
+
qkv_bias: bool = True,
|
| 305 |
+
qk_norm: bool = False,
|
| 306 |
+
init_values: Optional[float] = None,
|
| 307 |
+
class_token: bool = True,
|
| 308 |
+
no_embed_class: bool = False,
|
| 309 |
+
reg_tokens: int = 0,
|
| 310 |
+
pre_norm: bool = False,
|
| 311 |
+
fc_norm: Optional[bool] = None,
|
| 312 |
+
dynamic_img_size: bool = False,
|
| 313 |
+
dynamic_img_pad: bool = False,
|
| 314 |
+
drop_rate: float = 0.0,
|
| 315 |
+
pos_drop_rate: float = 0.0,
|
| 316 |
+
patch_drop_rate: float = 0.0,
|
| 317 |
+
proj_drop_rate: float = 0.0,
|
| 318 |
+
attn_drop_rate: float = 0.0,
|
| 319 |
+
drop_path_rate: float = 0.0,
|
| 320 |
+
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
|
| 321 |
+
embed_layer: Callable = PatchEmbed,
|
| 322 |
+
norm_layer: Optional[LayerType] = None,
|
| 323 |
+
act_layer: Optional[LayerType] = None,
|
| 324 |
+
strict_img_size: bool = False,
|
| 325 |
+
block_fn: Type[nn.Module] = Block,
|
| 326 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
| 327 |
+
ignore_head: bool = False,
|
| 328 |
+
add_patch2x2: bool = False,
|
| 329 |
+
) -> None:
|
| 330 |
+
"""
|
| 331 |
+
Args:
|
| 332 |
+
img_size: Input image size.
|
| 333 |
+
patch_size: Patch size.
|
| 334 |
+
in_chans: Number of image input channels.
|
| 335 |
+
num_classes: Mumber of classes for classification head.
|
| 336 |
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
| 337 |
+
embed_dim: Transformer embedding dimension.
|
| 338 |
+
depth: Depth of transformer.
|
| 339 |
+
num_heads: Number of attention heads.
|
| 340 |
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
| 341 |
+
qkv_bias: Enable bias for qkv projections if True.
|
| 342 |
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
| 343 |
+
class_token: Use class token.
|
| 344 |
+
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
| 345 |
+
reg_tokens: Number of register tokens.
|
| 346 |
+
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
| 347 |
+
drop_rate: Head dropout rate.
|
| 348 |
+
pos_drop_rate: Position embedding dropout rate.
|
| 349 |
+
attn_drop_rate: Attention dropout rate.
|
| 350 |
+
drop_path_rate: Stochastic depth rate.
|
| 351 |
+
weight_init: Weight initialization scheme.
|
| 352 |
+
embed_layer: Patch embedding layer.
|
| 353 |
+
norm_layer: Normalization layer.
|
| 354 |
+
act_layer: MLP activation layer.
|
| 355 |
+
block_fn: Transformer block layer.
|
| 356 |
+
"""
|
| 357 |
+
super().__init__()
|
| 358 |
+
assert global_pool in ("", "avg", "token", "map")
|
| 359 |
+
assert class_token or global_pool != "token"
|
| 360 |
+
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
| 361 |
+
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
| 362 |
+
# act_layer = get_act_layer(act_layer) or nn.GELU
|
| 363 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 364 |
+
act_layer = nn.GELU
|
| 365 |
+
|
| 366 |
+
self.num_classes = num_classes
|
| 367 |
+
self.global_pool = global_pool
|
| 368 |
+
self.num_features = self.embed_dim = (
|
| 369 |
+
embed_dim # num_features for consistency with other models
|
| 370 |
+
)
|
| 371 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
| 372 |
+
self.num_prefix_tokens += reg_tokens
|
| 373 |
+
self.num_reg_tokens = reg_tokens
|
| 374 |
+
self.has_class_token = class_token
|
| 375 |
+
self.no_embed_class = (
|
| 376 |
+
no_embed_class # don't embed prefix positions (includes reg)
|
| 377 |
+
)
|
| 378 |
+
self.dynamic_img_size = dynamic_img_size
|
| 379 |
+
self.grad_checkpointing = False
|
| 380 |
+
self.ignore_head = ignore_head
|
| 381 |
+
|
| 382 |
+
embed_args = {}
|
| 383 |
+
if dynamic_img_size:
|
| 384 |
+
# flatten deferred until after pos embed
|
| 385 |
+
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
|
| 386 |
+
self.patch_embed = embed_layer(
|
| 387 |
+
img_size=img_size,
|
| 388 |
+
patch_size=patch_size,
|
| 389 |
+
in_chans=in_chans,
|
| 390 |
+
embed_dim=embed_dim,
|
| 391 |
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
| 392 |
+
dynamic_img_pad=dynamic_img_pad,
|
| 393 |
+
strict_img_size=strict_img_size,
|
| 394 |
+
**embed_args,
|
| 395 |
+
)
|
| 396 |
+
num_patches = self.patch_embed.num_patches
|
| 397 |
+
|
| 398 |
+
self.cls_token = (
|
| 399 |
+
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
| 400 |
+
)
|
| 401 |
+
self.reg_token = (
|
| 402 |
+
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
| 403 |
+
)
|
| 404 |
+
embed_len = (
|
| 405 |
+
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
| 406 |
+
)
|
| 407 |
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# deepspeed.zero.register_external_parameter(self, self.pos_embed)
|
| 411 |
+
# deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.weight)
|
| 412 |
+
# deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.bias)
|
| 413 |
+
# print(self.patch_embed.state_dict().keys())
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
| 417 |
+
if patch_drop_rate > 0:
|
| 418 |
+
self.patch_drop = PatchDropout(
|
| 419 |
+
patch_drop_rate,
|
| 420 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
self.patch_drop = nn.Identity()
|
| 424 |
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
| 425 |
+
|
| 426 |
+
dpr = [
|
| 427 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 428 |
+
] # stochastic depth decay rule
|
| 429 |
+
self.blocks = nn.Sequential(
|
| 430 |
+
*[
|
| 431 |
+
block_fn(
|
| 432 |
+
dim=embed_dim,
|
| 433 |
+
num_heads=num_heads,
|
| 434 |
+
mlp_ratio=mlp_ratio,
|
| 435 |
+
qkv_bias=qkv_bias,
|
| 436 |
+
qk_norm=qk_norm,
|
| 437 |
+
init_values=init_values,
|
| 438 |
+
proj_drop=proj_drop_rate,
|
| 439 |
+
attn_drop=attn_drop_rate,
|
| 440 |
+
drop_path=dpr[i],
|
| 441 |
+
norm_layer=norm_layer,
|
| 442 |
+
act_layer=act_layer,
|
| 443 |
+
mlp_layer=mlp_layer,
|
| 444 |
+
)
|
| 445 |
+
for i in range(depth)
|
| 446 |
+
]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
if add_patch2x2:
|
| 451 |
+
if add_patch2x2 == 'v2':
|
| 452 |
+
self.downsample = nn.Sequential(
|
| 453 |
+
nn.Conv2d(embed_dim, embed_dim*2, kernel_size=2, stride=2),
|
| 454 |
+
nn.GELU(),
|
| 455 |
+
nn.Conv2d(embed_dim*2, embed_dim*4, 1)
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
mid_dim = embed_dim * 2
|
| 459 |
+
self.downsample = nn.Sequential(
|
| 460 |
+
nn.Conv2d(embed_dim, mid_dim, kernel_size=2, stride=2),
|
| 461 |
+
nn.GELU(),
|
| 462 |
+
nn.Conv2d(mid_dim, mid_dim, 1)
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
else:
|
| 466 |
+
self.downsample = None
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
# self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
| 470 |
+
|
| 471 |
+
# # Classifier Head
|
| 472 |
+
# if global_pool == "map":
|
| 473 |
+
# AttentionPoolLatent.init_weights = init_weights
|
| 474 |
+
# self.attn_pool = AttentionPoolLatent(
|
| 475 |
+
# self.embed_dim,
|
| 476 |
+
# num_heads=num_heads,
|
| 477 |
+
# mlp_ratio=mlp_ratio,
|
| 478 |
+
# norm_layer=norm_layer,
|
| 479 |
+
# )
|
| 480 |
+
# else:
|
| 481 |
+
# self.attn_pool = None
|
| 482 |
+
# self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
| 483 |
+
# self.head_drop = nn.Dropout(drop_rate)
|
| 484 |
+
# self.head = (
|
| 485 |
+
# nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 486 |
+
# )
|
| 487 |
+
|
| 488 |
+
# if weight_init != "skip":
|
| 489 |
+
# self.init_weights(weight_init)
|
| 490 |
+
|
| 491 |
+
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
|
| 492 |
+
assert mode in ("jax", "jax_nlhb", "moco", "")
|
| 493 |
+
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
| 494 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 495 |
+
if self.cls_token is not None:
|
| 496 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 497 |
+
named_apply(init_weights_vit_timm, self)
|
| 498 |
+
|
| 499 |
+
@torch.jit.ignore
|
| 500 |
+
def no_weight_decay(self) -> Set:
|
| 501 |
+
return {"pos_embed", "cls_token", "dist_token"}
|
| 502 |
+
|
| 503 |
+
@torch.jit.ignore
|
| 504 |
+
def group_matcher(self, coarse: bool = False) -> Dict:
|
| 505 |
+
return dict(
|
| 506 |
+
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
|
| 507 |
+
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
@torch.jit.ignore
|
| 511 |
+
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
| 512 |
+
self.grad_checkpointing = enable
|
| 513 |
+
|
| 514 |
+
@torch.jit.ignore
|
| 515 |
+
def get_classifier(self) -> nn.Module:
|
| 516 |
+
return self.head
|
| 517 |
+
|
| 518 |
+
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
| 519 |
+
self.num_classes = num_classes
|
| 520 |
+
if global_pool is not None:
|
| 521 |
+
assert global_pool in ("", "avg", "token", "map")
|
| 522 |
+
if global_pool == "map" and self.attn_pool is None:
|
| 523 |
+
assert (
|
| 524 |
+
False
|
| 525 |
+
), "Cannot currently add attention pooling in reset_classifier()."
|
| 526 |
+
elif global_pool != "map " and self.attn_pool is not None:
|
| 527 |
+
self.attn_pool = None # remove attention pooling
|
| 528 |
+
self.global_pool = global_pool
|
| 529 |
+
self.head = (
|
| 530 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
def rescale_positional_embedding(self, out_size):
|
| 534 |
+
h, w = out_size
|
| 535 |
+
pos_embed_shape = int((self.pos_embed.shape[1]) ** 0.5)
|
| 536 |
+
if (h, w) == (pos_embed_shape, pos_embed_shape):
|
| 537 |
+
return self.pos_embed
|
| 538 |
+
rescaled_positional_embedding = \
|
| 539 |
+
self.pos_embed.new_zeros(1, h*w, self.pos_embed.shape[2])
|
| 540 |
+
pe_2d = self.pos_embed[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
|
| 541 |
+
pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
|
| 542 |
+
rescaled_positional_embedding[0] = pe_2d.T.contiguous()
|
| 543 |
+
return rescaled_positional_embedding
|
| 544 |
+
|
| 545 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
| 546 |
+
if self.dynamic_img_size:
|
| 547 |
+
B, H, W, C = x.shape
|
| 548 |
+
pos_embed = resample_abs_pos_embed(
|
| 549 |
+
self.pos_embed,
|
| 550 |
+
(H, W),
|
| 551 |
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
| 552 |
+
)
|
| 553 |
+
x = x.view(B, -1, C)
|
| 554 |
+
else:
|
| 555 |
+
pos_embed = self.pos_embed
|
| 556 |
+
|
| 557 |
+
to_cat = []
|
| 558 |
+
if self.cls_token is not None:
|
| 559 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
| 560 |
+
if self.reg_token is not None:
|
| 561 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
| 562 |
+
|
| 563 |
+
if self.no_embed_class:
|
| 564 |
+
# deit-3, updated JAX (big vision)
|
| 565 |
+
# position embedding does not overlap with class token, add then concat
|
| 566 |
+
x = x + pos_embed
|
| 567 |
+
if to_cat:
|
| 568 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 569 |
+
else:
|
| 570 |
+
# original timm, JAX, and deit vit impl
|
| 571 |
+
# pos_embed has entry for class token, concat then add
|
| 572 |
+
if to_cat:
|
| 573 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 574 |
+
x = x + pos_embed
|
| 575 |
+
|
| 576 |
+
return self.pos_drop(x)
|
| 577 |
+
|
| 578 |
+
def _intermediate_layers(
|
| 579 |
+
self,
|
| 580 |
+
x: torch.Tensor,
|
| 581 |
+
n: Union[int, Sequence] = 1,
|
| 582 |
+
) -> List[torch.Tensor]:
|
| 583 |
+
outputs, num_blocks = [], len(self.blocks)
|
| 584 |
+
take_indices = set(
|
| 585 |
+
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# forward pass
|
| 589 |
+
x = self.patch_embed(x)
|
| 590 |
+
x = self._pos_embed(x)
|
| 591 |
+
x = self.patch_drop(x)
|
| 592 |
+
x = self.norm_pre(x)
|
| 593 |
+
for i, blk in enumerate(self.blocks):
|
| 594 |
+
x = blk(x)
|
| 595 |
+
if i in take_indices:
|
| 596 |
+
outputs.append(x)
|
| 597 |
+
|
| 598 |
+
return outputs
|
| 599 |
+
|
| 600 |
+
def get_intermediate_layers(
|
| 601 |
+
self,
|
| 602 |
+
x: torch.Tensor,
|
| 603 |
+
n: Union[int, Sequence] = 1,
|
| 604 |
+
reshape: bool = False,
|
| 605 |
+
return_prefix_tokens: bool = False,
|
| 606 |
+
norm: bool = False,
|
| 607 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 608 |
+
"""Intermediate layer accessor (NOTE: This is a WIP experiment).
|
| 609 |
+
Inspired by DINO / DINOv2 interface
|
| 610 |
+
"""
|
| 611 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
| 612 |
+
outputs = self._intermediate_layers(x, n)
|
| 613 |
+
if norm:
|
| 614 |
+
outputs = [self.norm(out) for out in outputs]
|
| 615 |
+
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
|
| 616 |
+
outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
|
| 617 |
+
|
| 618 |
+
if reshape:
|
| 619 |
+
grid_size = self.patch_embed.grid_size
|
| 620 |
+
outputs = [
|
| 621 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
|
| 622 |
+
.permute(0, 3, 1, 2)
|
| 623 |
+
.contiguous()
|
| 624 |
+
for out in outputs
|
| 625 |
+
]
|
| 626 |
+
|
| 627 |
+
if return_prefix_tokens:
|
| 628 |
+
return tuple(zip(outputs, prefix_tokens))
|
| 629 |
+
return tuple(outputs)
|
| 630 |
+
|
| 631 |
+
def forward_features_list(self, x_list):
|
| 632 |
+
x_all = []
|
| 633 |
+
image_sizes = []
|
| 634 |
+
for x in x_list:
|
| 635 |
+
bs, _, h, w = x.shape
|
| 636 |
+
|
| 637 |
+
# fix patch size=14 in datasets
|
| 638 |
+
pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0]
|
| 639 |
+
pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1]
|
| 640 |
+
x = F.pad(x, (0, pad_w, 0, pad_h))
|
| 641 |
+
|
| 642 |
+
bs, _, h, w = x.shape
|
| 643 |
+
|
| 644 |
+
h = h // self.patch_embed.patch_size[0]
|
| 645 |
+
w = w // self.patch_embed.patch_size[1]
|
| 646 |
+
|
| 647 |
+
x = self.patch_embed(x)
|
| 648 |
+
# x = self._pos_embed(x)
|
| 649 |
+
x = x + self.rescale_positional_embedding(out_size=(h, w))
|
| 650 |
+
x = self.patch_drop(x)
|
| 651 |
+
x = self.norm_pre(x)
|
| 652 |
+
x_all.append(x)
|
| 653 |
+
image_sizes.append((h, w))
|
| 654 |
+
|
| 655 |
+
slen = [xi.size(1) for xi in x_all]
|
| 656 |
+
x = torch.cat(x_all, dim=1)
|
| 657 |
+
|
| 658 |
+
cu_indices = [0, ]
|
| 659 |
+
for i in slen:
|
| 660 |
+
cu_indices.append(cu_indices[-1] + i)
|
| 661 |
+
|
| 662 |
+
cu_slens = torch.tensor(cu_indices, dtype=torch.int32).to(x.device)
|
| 663 |
+
for idx, blk in enumerate(self.blocks):
|
| 664 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 665 |
+
x = checkpoint(blk, x, cu_slens, use_reentrant=True)
|
| 666 |
+
else:
|
| 667 |
+
x = blk(x, cu_slens=cu_slens)
|
| 668 |
+
feats = x.split(slen, dim=1) #[(1, slen, c)]
|
| 669 |
+
|
| 670 |
+
if self.downsample is not None:
|
| 671 |
+
new_feats = []
|
| 672 |
+
new_sizes = []
|
| 673 |
+
for f, s in zip(feats, image_sizes):
|
| 674 |
+
h, w = s
|
| 675 |
+
b, n, c = f.size()
|
| 676 |
+
f = f.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
| 677 |
+
f = self.downsample(f)
|
| 678 |
+
b, c, h, w = f.size()
|
| 679 |
+
f = f.permute(0, 2, 3, 1).reshape(b, h*w, c)
|
| 680 |
+
new_feats.append(f)
|
| 681 |
+
new_sizes.append((h, w))
|
| 682 |
+
return new_feats, new_sizes
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
return feats, image_sizes
|
| 686 |
+
|
| 687 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 688 |
+
bs, _, h, w = x.shape
|
| 689 |
+
h = h // self.patch_embed.patch_size[0]
|
| 690 |
+
w = w // self.patch_embed.patch_size[1]
|
| 691 |
+
|
| 692 |
+
x = self.patch_embed(x)
|
| 693 |
+
# x = self._pos_embed(x)
|
| 694 |
+
x = x + self.rescale_positional_embedding(out_size=(h, w))
|
| 695 |
+
x = self.patch_drop(x)
|
| 696 |
+
x = self.norm_pre(x)
|
| 697 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 698 |
+
x = checkpoint_seq(self.blocks, x)
|
| 699 |
+
else:
|
| 700 |
+
x = self.blocks(x)
|
| 701 |
+
|
| 702 |
+
if self.downsample is not None:
|
| 703 |
+
b, n, c = x.size()
|
| 704 |
+
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
| 705 |
+
x = self.downsample(x)
|
| 706 |
+
b, c, h, w = x.size()
|
| 707 |
+
x = x.permute(0, 2, 3, 1).reshape(b, h*w, c)
|
| 708 |
+
new_feats = x
|
| 709 |
+
new_sizes = (h, w)
|
| 710 |
+
return new_feats, new_sizes
|
| 711 |
+
|
| 712 |
+
return x, (h, w)
|
| 713 |
+
|
| 714 |
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
| 715 |
+
x = self.norm(x)
|
| 716 |
+
if self.attn_pool is not None:
|
| 717 |
+
x = self.attn_pool(x)
|
| 718 |
+
elif self.global_pool == "avg":
|
| 719 |
+
x = x[:, self.num_prefix_tokens :].mean(dim=1)
|
| 720 |
+
elif self.global_pool:
|
| 721 |
+
x = x[:, 0] # class token
|
| 722 |
+
x = self.fc_norm(x)
|
| 723 |
+
x = self.head_drop(x)
|
| 724 |
+
return x if pre_logits else self.head(x)
|
| 725 |
+
|
| 726 |
+
def forward(self, x, cal_attn_pool=False):
|
| 727 |
+
# import pdb;pdb.set_trace()
|
| 728 |
+
if type(x) is list:
|
| 729 |
+
x, image_sizes = self.forward_features_list(x)
|
| 730 |
+
return x, image_sizes, None
|
| 731 |
+
else:
|
| 732 |
+
x, image_sizes = self.forward_features(x)
|
| 733 |
+
return x, image_sizes, None
|
| 734 |
+
|
| 735 |
+
@dataclass
|
| 736 |
+
class SigLIPVisionCfg:
|
| 737 |
+
width: int = 1152
|
| 738 |
+
layers: Union[Tuple[int, int, int, int], int] = 27
|
| 739 |
+
heads: int = 16
|
| 740 |
+
patch_size: int = 14
|
| 741 |
+
image_size: Union[Tuple[int, int], int] = 336
|
| 742 |
+
global_pool: str = "map"
|
| 743 |
+
mlp_ratio: float = 3.7362
|
| 744 |
+
class_token: bool = False
|
| 745 |
+
num_classes: int = 0
|
| 746 |
+
use_checkpoint: bool = False
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
SigLIP_MODEL_CONFIG = {
|
| 750 |
+
"siglip_so400m_patch14_384": {
|
| 751 |
+
"image_size": 384,
|
| 752 |
+
"patch_size": 14,
|
| 753 |
+
"width": 1152,
|
| 754 |
+
"layers": 27,
|
| 755 |
+
"heads": 16,
|
| 756 |
+
"mlp_ratio": 3.7362,
|
| 757 |
+
"global_pool": "map",
|
| 758 |
+
"use_checkpoint": False,
|
| 759 |
+
},
|
| 760 |
+
"siglip_so400m_patch16_384": {
|
| 761 |
+
"image_size": 384,
|
| 762 |
+
"patch_size": 16,
|
| 763 |
+
"width": 1152,
|
| 764 |
+
"layers": 27,
|
| 765 |
+
"heads": 16,
|
| 766 |
+
"mlp_ratio": 3.7362,
|
| 767 |
+
"global_pool": "map",
|
| 768 |
+
"use_checkpoint": False,
|
| 769 |
+
},
|
| 770 |
+
"siglip_so400m_patch14_224": {
|
| 771 |
+
"image_size": 224,
|
| 772 |
+
"patch_size": 14,
|
| 773 |
+
"width": 1152,
|
| 774 |
+
"layers": 27,
|
| 775 |
+
"heads": 16,
|
| 776 |
+
"mlp_ratio": 3.7362,
|
| 777 |
+
"global_pool": "map",
|
| 778 |
+
"use_checkpoint": False,
|
| 779 |
+
},
|
| 780 |
+
"siglip_large_patch16_384": {
|
| 781 |
+
"image_size": 384,
|
| 782 |
+
"patch_size": 16,
|
| 783 |
+
"width": 1024,
|
| 784 |
+
"layers": 24,
|
| 785 |
+
"heads": 16,
|
| 786 |
+
"mlp_ratio": 4,
|
| 787 |
+
"global_pool": "map",
|
| 788 |
+
"use_checkpoint": False,
|
| 789 |
+
},
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bicubic'):
|
| 794 |
+
# interpolate position embedding
|
| 795 |
+
orig_size = 24
|
| 796 |
+
new_size = 128
|
| 797 |
+
pos_tokens = model.pos_embed
|
| 798 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, model.embed_dim).permute(0, 3, 1, 2)
|
| 799 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 800 |
+
pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False)
|
| 801 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 802 |
+
model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True)
|
| 803 |
+
return model
|
| 804 |
+
|
| 805 |
+
def create_siglip_vit(
|
| 806 |
+
model_name: str = "siglip_so400m_patch14_384",
|
| 807 |
+
image_size: int = 384,
|
| 808 |
+
select_layer: int = -1,
|
| 809 |
+
path: str = "",
|
| 810 |
+
gradient_checkpointing: bool = False,
|
| 811 |
+
**kwargs,
|
| 812 |
+
):
|
| 813 |
+
assert (
|
| 814 |
+
model_name in SigLIP_MODEL_CONFIG.keys()
|
| 815 |
+
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
| 816 |
+
|
| 817 |
+
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
| 818 |
+
|
| 819 |
+
if select_layer <= 0:
|
| 820 |
+
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
| 821 |
+
else:
|
| 822 |
+
layers = min(vision_cfg.layers, select_layer)
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
if 'patch2x2' or 'patch4x4' in path:
|
| 827 |
+
add_patch2x2 = True
|
| 828 |
+
else:
|
| 829 |
+
add_patch2x2 = False
|
| 830 |
+
|
| 831 |
+
if 'patch4x4pool' in path or 'patch2x2from4x4' in path:
|
| 832 |
+
add_patch2x2 = 'v2'
|
| 833 |
+
|
| 834 |
+
if FORCE_NO_DOWNSAMPLE:
|
| 835 |
+
add_patch2x2 = False
|
| 836 |
+
|
| 837 |
+
model = VisionTransformer(
|
| 838 |
+
img_size=2048,
|
| 839 |
+
patch_size=16,
|
| 840 |
+
embed_dim=vision_cfg.width,
|
| 841 |
+
depth=layers,
|
| 842 |
+
num_heads=vision_cfg.heads,
|
| 843 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
| 844 |
+
class_token=vision_cfg.class_token,
|
| 845 |
+
global_pool=vision_cfg.global_pool,
|
| 846 |
+
dynamic_img_pad=False,
|
| 847 |
+
strict_img_size=False,
|
| 848 |
+
ignore_head=kwargs.get("ignore_head", False),
|
| 849 |
+
weight_init=kwargs.get("weight_init", "skip"),
|
| 850 |
+
num_classes=0,
|
| 851 |
+
add_patch2x2=add_patch2x2
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
if gradient_checkpointing:
|
| 855 |
+
model.set_grad_checkpointing(True)
|
| 856 |
+
return model
|
| 857 |
+
|
| 858 |
+
import os
|
| 859 |
+
if 'LOAD_VISION_EARLY' in os.environ:
|
| 860 |
+
print("LOAD_VISION_EARLY is set")
|
| 861 |
+
LOAD_VISION_EARLY = True
|
| 862 |
+
else:
|
| 863 |
+
LOAD_VISION_EARLY = False
|
| 864 |
+
|
| 865 |
+
if 'VIT_WITH_GRAD' in os.environ:
|
| 866 |
+
print("VIT_WITH_GRAD is set")
|
| 867 |
+
VIT_WITH_GRAD = True
|
| 868 |
+
else:
|
| 869 |
+
VIT_WITH_GRAD = False
|
| 870 |
+
|
| 871 |
+
if 'FIX_SIZE' in os.environ:
|
| 872 |
+
print("FIX_SIZE is set")
|
| 873 |
+
FIX_SIZE = True
|
| 874 |
+
else:
|
| 875 |
+
FIX_SIZE = False
|
| 876 |
+
|
| 877 |
+
if 'ANYRES_SPLIT' in os.environ:
|
| 878 |
+
ANYRES_SPLIT = int(os.environ['ANYRES_SPLIT'])
|
| 879 |
+
print(f"ANYRES_SPLIT is set as {ANYRES_SPLIT}")
|
| 880 |
+
else:
|
| 881 |
+
ANYRES_SPLIT = None
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
if 'FORCE_NO_DOWNSAMPLE' in os.environ:
|
| 885 |
+
print("FORCE_NO_DOWNSAMPLE is set")
|
| 886 |
+
FORCE_NO_DOWNSAMPLE = True
|
| 887 |
+
else:
|
| 888 |
+
FORCE_NO_DOWNSAMPLE = False
|
| 889 |
+
|
| 890 |
+
from transformers import CLIPImageProcessor
|
| 891 |
+
import torch.distributed as dist
|
| 892 |
+
|
| 893 |
+
class SigLIPViTAnysizeWrapper(nn.Module):
|
| 894 |
+
def __init__(self, vision_tower, path, args, delay_load=False):
|
| 895 |
+
super().__init__()
|
| 896 |
+
|
| 897 |
+
self.is_loaded = False
|
| 898 |
+
|
| 899 |
+
self.vision_tower_name = vision_tower
|
| 900 |
+
self.args = args
|
| 901 |
+
self.path = path
|
| 902 |
+
|
| 903 |
+
self.select_layer = -1
|
| 904 |
+
if self.select_layer < -1: self.select_layer += 1
|
| 905 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 906 |
+
|
| 907 |
+
self.output_dim = 1152
|
| 908 |
+
if not FORCE_NO_DOWNSAMPLE:
|
| 909 |
+
if 'patch2x2' or 'patch4x4' in path:
|
| 910 |
+
self.output_dim = 1152*2
|
| 911 |
+
|
| 912 |
+
if 'patch4x4pool' in path or 'patch2x2from4x4' in path:
|
| 913 |
+
self.output_dim = 1152*4
|
| 914 |
+
|
| 915 |
+
if not delay_load or LOAD_VISION_EARLY:
|
| 916 |
+
self.load_model()
|
| 917 |
+
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
| 918 |
+
# TODO: better detector is needed.
|
| 919 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
| 920 |
+
self.load_model()
|
| 921 |
+
|
| 922 |
+
def load_model(self, device_map=None):
|
| 923 |
+
if self.is_loaded:
|
| 924 |
+
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
| 925 |
+
return
|
| 926 |
+
|
| 927 |
+
self.image_processor = CLIPImageProcessor.from_pretrained("/data1/cxy/model/openai/clip-vit-large-patch14")
|
| 928 |
+
if self.args.mm_projector_type == "conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp_woconv":
|
| 929 |
+
self.image_processor.crop_size['height'] = 384
|
| 930 |
+
self.image_processor.crop_size['width'] = 384
|
| 931 |
+
self.image_processor.size['shortest_edge'] = 384
|
| 932 |
+
print("Resizeing clip processor to 384...")
|
| 933 |
+
self.image_processor.image_mean = [0.5, 0.5, 0.5]
|
| 934 |
+
self.image_processor.image_std = [0.5, 0.5, 0.5]
|
| 935 |
+
print("Loading vision model...")
|
| 936 |
+
if VIT_WITH_GRAD:
|
| 937 |
+
self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384',
|
| 938 |
+
gradient_checkpointing=True)
|
| 939 |
+
self.vision_tower.train()
|
| 940 |
+
else:
|
| 941 |
+
self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384',
|
| 942 |
+
gradient_checkpointing=False)
|
| 943 |
+
for p in self.vision_tower.parameters():
|
| 944 |
+
p.requires_grad = False
|
| 945 |
+
self.vision_tower.eval()
|
| 946 |
+
self.is_loaded = True
|
| 947 |
+
|
| 948 |
+
def train(self, mode = True):
|
| 949 |
+
self.training = mode
|
| 950 |
+
|
| 951 |
+
if self.is_loaded and not VIT_WITH_GRAD:
|
| 952 |
+
self.vision_tower.eval()
|
| 953 |
+
|
| 954 |
+
def split_images(self, images, split_res=512, base_size=32):
|
| 955 |
+
split_images = []
|
| 956 |
+
sub_images_info = []
|
| 957 |
+
for image in images:
|
| 958 |
+
now_sub_images = []
|
| 959 |
+
_, c, h, w = image.shape
|
| 960 |
+
if h * w <= split_res * split_res:
|
| 961 |
+
split_images.append(image)
|
| 962 |
+
sub_images_info.append(
|
| 963 |
+
(
|
| 964 |
+
1, 1, 1, h // base_size, w // base_size, [(0, h // base_size, 0, w // base_size)]
|
| 965 |
+
)
|
| 966 |
+
)
|
| 967 |
+
continue
|
| 968 |
+
nsplit_h = math.ceil(h / split_res)
|
| 969 |
+
nsplit_w = math.ceil(w / split_res)
|
| 970 |
+
sub_h = int(h / nsplit_h / base_size) * base_size
|
| 971 |
+
sub_w = int(w / nsplit_w / base_size) * base_size
|
| 972 |
+
crop_infos = []
|
| 973 |
+
for i in range(nsplit_h):
|
| 974 |
+
for j in range(nsplit_w):
|
| 975 |
+
begin_h = i * sub_h
|
| 976 |
+
begin_w = j * sub_w
|
| 977 |
+
|
| 978 |
+
if i == nsplit_h - 1:
|
| 979 |
+
end_h = h
|
| 980 |
+
else:
|
| 981 |
+
end_h = (i + 1) * sub_h
|
| 982 |
+
|
| 983 |
+
if j == nsplit_w - 1:
|
| 984 |
+
end_w = w
|
| 985 |
+
else:
|
| 986 |
+
end_w = (j + 1) * sub_w
|
| 987 |
+
|
| 988 |
+
assert (end_h - begin_h) % base_size == 0 and (end_w - begin_w) % base_size == 0
|
| 989 |
+
|
| 990 |
+
sub_image = image[:, :, begin_h:end_h, begin_w:end_w]
|
| 991 |
+
now_sub_images.append(sub_image)
|
| 992 |
+
crop_infos.append(
|
| 993 |
+
(begin_h // base_size, end_h // base_size, begin_w // base_size, end_w // base_size)
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
split_images += now_sub_images
|
| 997 |
+
sub_images_info.append(
|
| 998 |
+
(
|
| 999 |
+
len(now_sub_images), nsplit_h, nsplit_w, h // base_size, w // base_size, crop_infos
|
| 1000 |
+
)
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
return split_images, sub_images_info
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def unsplit_images(self, features, sizes, sub_images_info):
|
| 1007 |
+
new_features = []
|
| 1008 |
+
for feature, size in zip(features, sizes):
|
| 1009 |
+
h, w = size
|
| 1010 |
+
new_features.append(
|
| 1011 |
+
feature.reshape(1, h, w, -1)
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
fused_images = []
|
| 1015 |
+
images_sizes = []
|
| 1016 |
+
sub_count = 0
|
| 1017 |
+
for n_split, nsplit_h, nsplit_w, total_h, total_w, crop_infos in sub_images_info:
|
| 1018 |
+
sub_features = new_features[sub_count:sub_count+n_split]
|
| 1019 |
+
sub_count += n_split
|
| 1020 |
+
|
| 1021 |
+
total_feature = new_features[0].new_zeros(1, total_h, total_w, self.hidden_size)
|
| 1022 |
+
for feature, (begin_h, end_h, begin_w, end_w) in zip(sub_features, crop_infos):
|
| 1023 |
+
total_feature[:, begin_h:end_h, begin_w:end_w] += feature
|
| 1024 |
+
|
| 1025 |
+
fused_images.append(total_feature.reshape(1, total_h * total_w, self.hidden_size))
|
| 1026 |
+
images_sizes.append((total_h, total_w))
|
| 1027 |
+
|
| 1028 |
+
return fused_images, images_sizes
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
def forward_func(self, images, force_fix_size=False, cal_attn_pool=False):
|
| 1033 |
+
if type(images) is list:
|
| 1034 |
+
xs = [x.to(self.dtype) for x in images]
|
| 1035 |
+
image_features, img_size, cls_token = self.vision_tower(xs, cal_attn_pool=cal_attn_pool)
|
| 1036 |
+
image_features = [x.to(images[0].dtype) for x in image_features]
|
| 1037 |
+
|
| 1038 |
+
else:
|
| 1039 |
+
image_forward_outs, img_size, cls_token = self.vision_tower(images.to(self.dtype), cal_attn_pool=cal_attn_pool)
|
| 1040 |
+
image_features = image_forward_outs.to(images.dtype)
|
| 1041 |
+
|
| 1042 |
+
return image_features, img_size, cls_token
|
| 1043 |
+
|
| 1044 |
+
def forward(self, images, cal_attn_pool=False):
|
| 1045 |
+
if VIT_WITH_GRAD:
|
| 1046 |
+
image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool)
|
| 1047 |
+
return image_features, img_size
|
| 1048 |
+
else:
|
| 1049 |
+
with torch.no_grad():
|
| 1050 |
+
image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool)
|
| 1051 |
+
return image_features, img_size
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
@property
|
| 1055 |
+
def dummy_feature(self):
|
| 1056 |
+
return torch.zeros(1, 1152, device=self.device, dtype=self.dtype)
|
| 1057 |
+
|
| 1058 |
+
@property
|
| 1059 |
+
def dtype(self):
|
| 1060 |
+
return self.vision_tower.pos_embed.dtype
|
| 1061 |
+
|
| 1062 |
+
@property
|
| 1063 |
+
def device(self):
|
| 1064 |
+
return self.vision_tower.pos_embed.device
|
| 1065 |
+
|
| 1066 |
+
@property
|
| 1067 |
+
def hidden_size(self):
|
| 1068 |
+
return self.output_dim
|
| 1069 |
+
|
| 1070 |
+
@property
|
| 1071 |
+
def config(self):
|
| 1072 |
+
return type('LLaVAConfigWrapper', (), {
|
| 1073 |
+
# 'image_size': 224,
|
| 1074 |
+
'patch_size': 16,
|
| 1075 |
+
})()
|
ola/model/multimodal_projector/__pycache__/builder.cpython-312.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
ola/model/multimodal_projector/__pycache__/internvl_projector.cpython-312.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-312.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
ola/model/multimodal_projector/builder.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
from .pooler_projector import NormalizedDwPooler
|
| 8 |
+
from .internvl_projector import InternVLMultiModalProjector
|
| 9 |
+
import os
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
class IdentityMap(nn.Module):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
def forward(self, x, *args, **kwargs):
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def config(self):
|
| 21 |
+
return {"mm_projector_type": 'identity'}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SimpleResBlock(nn.Module):
|
| 25 |
+
def __init__(self, channels):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.pre_norm = nn.LayerNorm(channels)
|
| 28 |
+
|
| 29 |
+
self.proj = nn.Sequential(
|
| 30 |
+
nn.Linear(channels, channels),
|
| 31 |
+
nn.GELU(),
|
| 32 |
+
nn.Linear(channels, channels)
|
| 33 |
+
)
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = self.pre_norm(x)
|
| 36 |
+
return x + self.proj(x)
|
| 37 |
+
|
| 38 |
+
class OlaMLP(nn.Module):
|
| 39 |
+
def __init__(self, in_channels, out_channels, twoview=False):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.proj1 = nn.Linear(in_channels, out_channels)
|
| 43 |
+
self.proj2 = nn.Linear(out_channels, out_channels)
|
| 44 |
+
self.act = nn.GELU()
|
| 45 |
+
self.pooler = NormalizedDwPooler(out_channels)
|
| 46 |
+
|
| 47 |
+
embed_std = 1 / math.sqrt(out_channels)
|
| 48 |
+
self.image_newline = nn.Parameter(
|
| 49 |
+
torch.randn(out_channels) * embed_std
|
| 50 |
+
)
|
| 51 |
+
self.image_begin = nn.Parameter(
|
| 52 |
+
torch.randn(out_channels) * embed_std
|
| 53 |
+
)
|
| 54 |
+
self.image_end = nn.Parameter(
|
| 55 |
+
torch.randn(out_channels) * embed_std
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if twoview:
|
| 59 |
+
self.image_sep = nn.Parameter(
|
| 60 |
+
torch.randn(out_channels) * embed_std
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'):
|
| 64 |
+
|
| 65 |
+
if modalities in ['image', 'text']:
|
| 66 |
+
h, w = size
|
| 67 |
+
dtype = x.dtype
|
| 68 |
+
x = x.reshape(x.shape[0], h, w, -1)
|
| 69 |
+
x = self.proj1(x)
|
| 70 |
+
x = self.pooler(x, forward_type='2x')
|
| 71 |
+
x = self.act(x)
|
| 72 |
+
x = self.proj2(x)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
b, h, w, c = x.shape
|
| 76 |
+
x = torch.cat([
|
| 77 |
+
x,
|
| 78 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
|
| 79 |
+
], dim=2)
|
| 80 |
+
x = x.reshape(b, -1, c)
|
| 81 |
+
|
| 82 |
+
if x2 is not None:
|
| 83 |
+
h2, w2 = size2
|
| 84 |
+
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
| 85 |
+
x2 = self.proj1(x2)
|
| 86 |
+
x2 = self.pooler(x2, forward_type='2x')
|
| 87 |
+
x2 = self.act(x2)
|
| 88 |
+
x2 = self.proj2(x2)
|
| 89 |
+
|
| 90 |
+
b2, h2, w2, c2 = x2.shape
|
| 91 |
+
x2 = torch.cat([
|
| 92 |
+
x2,
|
| 93 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
|
| 94 |
+
], dim=2)
|
| 95 |
+
x2 = x2.reshape(b, -1, c)
|
| 96 |
+
sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
|
| 97 |
+
x = torch.cat([x, sep, x2], dim=1)
|
| 98 |
+
|
| 99 |
+
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
| 100 |
+
end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
| 101 |
+
x = torch.cat([begin, x, end], dim=1)
|
| 102 |
+
return x
|
| 103 |
+
elif modalities in ['video']:
|
| 104 |
+
# x2 is the true feature, ignore x
|
| 105 |
+
h, w = size
|
| 106 |
+
dtype = x.dtype
|
| 107 |
+
x = x.reshape(x.shape[0], h, w, -1)
|
| 108 |
+
x1 = self.proj1(x)
|
| 109 |
+
x1 = self.pooler(x1, forward_type='2x')
|
| 110 |
+
x1 = self.proj2(x1).mean() * 0.0
|
| 111 |
+
|
| 112 |
+
h2, w2 = size2
|
| 113 |
+
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
| 114 |
+
x2 = self.proj1(x2)
|
| 115 |
+
x2 = self.pooler(x2, forward_type='2x')
|
| 116 |
+
x2 = self.act(x2)
|
| 117 |
+
x2 = self.proj2(x2)
|
| 118 |
+
|
| 119 |
+
b2, h2, w2, c = x2.shape
|
| 120 |
+
x2 = torch.cat([
|
| 121 |
+
x2,
|
| 122 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype)
|
| 123 |
+
], dim=2)
|
| 124 |
+
|
| 125 |
+
x2 = x2.reshape(b2, -1, c)
|
| 126 |
+
|
| 127 |
+
sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype)
|
| 128 |
+
x2 = torch.cat([x2, sep], dim=1)
|
| 129 |
+
|
| 130 |
+
x2 = x2.flatten(0, 1)
|
| 131 |
+
|
| 132 |
+
begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype)
|
| 133 |
+
end = self.image_end.reshape(1, -1).expand(1, c).to(dtype)
|
| 134 |
+
x2 = torch.cat([begin, x2, end], dim=0)
|
| 135 |
+
x2 = x2.unsqueeze(0)
|
| 136 |
+
return x2
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f'Unknown modalities: {modalities}')
|
| 139 |
+
|
| 140 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
| 141 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
| 142 |
+
|
| 143 |
+
if projector_type == 'linear':
|
| 144 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
| 145 |
+
|
| 146 |
+
elif projector_type == 'ola_mlp':
|
| 147 |
+
return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True)
|
| 148 |
+
|
| 149 |
+
elif projector_type == 'ola_internvl':
|
| 150 |
+
# breakpoint()
|
| 151 |
+
return InternVLMultiModalProjector(config)
|
| 152 |
+
|
| 153 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
| 154 |
+
if mlp_gelu_match:
|
| 155 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
| 156 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
| 157 |
+
for _ in range(1, mlp_depth):
|
| 158 |
+
modules.append(nn.GELU())
|
| 159 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 160 |
+
return nn.Sequential(*modules)
|
| 161 |
+
|
| 162 |
+
mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type)
|
| 163 |
+
if mlp_gelu_resnet_match:
|
| 164 |
+
mlp_depth = int(mlp_gelu_resnet_match.group(1))
|
| 165 |
+
res_depth = int(mlp_gelu_resnet_match.group(2))
|
| 166 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
| 167 |
+
for _ in range(1, mlp_depth):
|
| 168 |
+
modules.append(nn.GELU())
|
| 169 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 170 |
+
for _ in range(res_depth):
|
| 171 |
+
modules.append(SimpleResBlock(config.hidden_size))
|
| 172 |
+
return nn.Sequential(*modules)
|
| 173 |
+
|
| 174 |
+
if projector_type == 'identity':
|
| 175 |
+
return IdentityMap()
|
| 176 |
+
|
| 177 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|