Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- phivenv/Lib/site-packages/transformers/models/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_flax_roformer.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_roformer.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_tf_roformer.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_roformer.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_roformer_fast.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_utils.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__init__.py +33 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr_resnet.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr_fast.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr_resnet.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modular_rt_detr.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/configuration_rt_detr.py +372 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py +114 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py +1103 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/image_processing_rt_detr_fast.py +590 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/modeling_rt_detr.py +2013 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/modeling_rt_detr_resnet.py +399 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr/modular_rt_detr.py +365 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__init__.py +29 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/configuration_rt_detr_v2.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/modeling_rt_detr_v2.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/modular_rt_detr_v2.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +387 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +1998 -0
- phivenv/Lib/site-packages/transformers/models/rt_detr_v2/modular_rt_detr_v2.py +636 -0
- phivenv/Lib/site-packages/transformers/models/rwkv/__init__.py +27 -0
- phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/configuration_rwkv.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/modeling_rwkv.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/rwkv/configuration_rwkv.py +120 -0
- phivenv/Lib/site-packages/transformers/models/rwkv/modeling_rwkv.py +798 -0
- phivenv/Lib/site-packages/transformers/models/sam/__init__.py +31 -0
- phivenv/Lib/site-packages/transformers/models/sam/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/sam/__pycache__/configuration_sam.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/sam/__pycache__/image_processing_sam.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/sam/__pycache__/image_processing_sam_fast.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/sam/__pycache__/modeling_sam.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/sam/__pycache__/modeling_tf_sam.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/sam/__pycache__/processing_sam.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/sam/configuration_sam.py +337 -0
- phivenv/Lib/site-packages/transformers/models/sam/image_processing_sam.py +1499 -0
- phivenv/Lib/site-packages/transformers/models/sam/image_processing_sam_fast.py +829 -0
- phivenv/Lib/site-packages/transformers/models/sam/modeling_sam.py +1368 -0
phivenv/Lib/site-packages/transformers/models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (7.91 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (671 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_flax_roformer.cpython-39.pyc
ADDED
|
Binary file (28.7 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_roformer.cpython-39.pyc
ADDED
|
Binary file (41.7 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_tf_roformer.cpython-39.pyc
ADDED
|
Binary file (47 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_roformer.cpython-39.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_roformer_fast.cpython-39.pyc
ADDED
|
Binary file (4.66 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_utils.cpython-39.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
from ...utils import _LazyModule
|
| 19 |
+
from ...utils.import_utils import define_import_structure
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from .configuration_rt_detr import *
|
| 24 |
+
from .configuration_rt_detr_resnet import *
|
| 25 |
+
from .image_processing_rt_detr import *
|
| 26 |
+
from .image_processing_rt_detr_fast import *
|
| 27 |
+
from .modeling_rt_detr import *
|
| 28 |
+
from .modeling_rt_detr_resnet import *
|
| 29 |
+
else:
|
| 30 |
+
import sys
|
| 31 |
+
|
| 32 |
+
_file = globals()["__file__"]
|
| 33 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (683 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr.cpython-39.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr_resnet.cpython-39.pyc
ADDED
|
Binary file (5.02 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr.cpython-39.pyc
ADDED
|
Binary file (39.6 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr_fast.cpython-39.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr.cpython-39.pyc
ADDED
|
Binary file (64.1 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr_resnet.cpython-39.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modular_rt_detr.cpython-39.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr/configuration_rt_detr.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""RT-DETR model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 20 |
+
from ..auto import CONFIG_MAPPING
|
| 21 |
+
from .configuration_rt_detr_resnet import RTDetrResNetConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RTDetrConfig(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a
|
| 30 |
+
RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 31 |
+
with the defaults will yield a similar configuration to that of the RT-DETR
|
| 32 |
+
[PekingU/rtdetr_r50vd](https://huggingface.co/PekingU/rtdetr_r50vd) architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
initializer_range (`float`, *optional*, defaults to 0.01):
|
| 39 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 40 |
+
initializer_bias_prior_prob (`float`, *optional*):
|
| 41 |
+
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
| 42 |
+
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
| 43 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 44 |
+
The epsilon used by the layer normalization layers.
|
| 45 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 46 |
+
The epsilon used by the batch normalization layers.
|
| 47 |
+
backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
|
| 48 |
+
The configuration of the backbone model.
|
| 49 |
+
backbone (`str`, *optional*):
|
| 50 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 51 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 52 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 53 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 54 |
+
Whether to use pretrained weights for the backbone.
|
| 55 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 56 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 57 |
+
library.
|
| 58 |
+
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether to freeze the batch normalization layers in the backbone.
|
| 60 |
+
backbone_kwargs (`dict`, *optional*):
|
| 61 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 62 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 63 |
+
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 64 |
+
Dimension of the layers in hybrid encoder.
|
| 65 |
+
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
|
| 66 |
+
Multi level features input for encoder.
|
| 67 |
+
feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
|
| 68 |
+
Strides used in each feature map.
|
| 69 |
+
encoder_layers (`int`, *optional*, defaults to 1):
|
| 70 |
+
Total of layers to be used by the encoder.
|
| 71 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 72 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 73 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 74 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 75 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 76 |
+
The ratio for all dropout layers.
|
| 77 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 78 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 79 |
+
encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
|
| 80 |
+
Indexes of the projected layers to be used in the encoder.
|
| 81 |
+
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
| 82 |
+
The temperature parameter used to create the positional encodings.
|
| 83 |
+
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
|
| 84 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 85 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 86 |
+
activation_function (`str`, *optional*, defaults to `"silu"`):
|
| 87 |
+
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
|
| 88 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 89 |
+
eval_size (`tuple[int, int]`, *optional*):
|
| 90 |
+
Height and width used to computes the effective height and width of the position embeddings after taking
|
| 91 |
+
into account the stride.
|
| 92 |
+
normalize_before (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
|
| 94 |
+
feed-forward modules.
|
| 95 |
+
hidden_expansion (`float`, *optional*, defaults to 1.0):
|
| 96 |
+
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
|
| 97 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 98 |
+
Dimension of the layers exclude hybrid encoder.
|
| 99 |
+
num_queries (`int`, *optional*, defaults to 300):
|
| 100 |
+
Number of object queries.
|
| 101 |
+
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
|
| 102 |
+
Multi level features dimension for decoder
|
| 103 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 104 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 105 |
+
num_feature_levels (`int`, *optional*, defaults to 3):
|
| 106 |
+
The number of input feature levels.
|
| 107 |
+
decoder_n_points (`int`, *optional*, defaults to 4):
|
| 108 |
+
The number of sampled keys in each feature level for each attention head in the decoder.
|
| 109 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 110 |
+
Number of decoder layers.
|
| 111 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 112 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 113 |
+
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
|
| 114 |
+
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
|
| 115 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 116 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 117 |
+
The dropout ratio for the attention probabilities.
|
| 118 |
+
num_denoising (`int`, *optional*, defaults to 100):
|
| 119 |
+
The total number of denoising tasks or queries to be used for contrastive denoising.
|
| 120 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 121 |
+
The fraction of denoising labels to which random noise should be added.
|
| 122 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 123 |
+
Scale or magnitude of noise to be added to the bounding boxes.
|
| 124 |
+
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
| 125 |
+
Indicates whether the initial query embeddings for the decoder should be learned during training
|
| 126 |
+
anchor_image_size (`tuple[int, int]`, *optional*):
|
| 127 |
+
Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
|
| 128 |
+
disable_custom_kernels (`bool`, *optional*, defaults to `True`):
|
| 129 |
+
Whether to disable custom kernels.
|
| 130 |
+
with_box_refine (`bool`, *optional*, defaults to `True`):
|
| 131 |
+
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
| 132 |
+
based on the predictions from the previous layer.
|
| 133 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 134 |
+
Whether the architecture has an encoder decoder structure.
|
| 135 |
+
matcher_alpha (`float`, *optional*, defaults to 0.25):
|
| 136 |
+
Parameter alpha used by the Hungarian Matcher.
|
| 137 |
+
matcher_gamma (`float`, *optional*, defaults to 2.0):
|
| 138 |
+
Parameter gamma used by the Hungarian Matcher.
|
| 139 |
+
matcher_class_cost (`float`, *optional*, defaults to 2.0):
|
| 140 |
+
The relative weight of the class loss used by the Hungarian Matcher.
|
| 141 |
+
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
|
| 142 |
+
The relative weight of the bounding box loss used by the Hungarian Matcher.
|
| 143 |
+
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
|
| 144 |
+
The relative weight of the giou loss of used by the Hungarian Matcher.
|
| 145 |
+
use_focal_loss (`bool`, *optional*, defaults to `True`):
|
| 146 |
+
Parameter informing if focal focal should be used.
|
| 147 |
+
auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
| 148 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 149 |
+
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
|
| 150 |
+
Parameter alpha used to compute the focal loss.
|
| 151 |
+
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
|
| 152 |
+
Parameter gamma used to compute the focal loss.
|
| 153 |
+
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
|
| 154 |
+
Relative weight of the varifocal loss in the object detection loss.
|
| 155 |
+
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
|
| 156 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 157 |
+
weight_loss_giou (`float`, *optional*, defaults to 2.0):
|
| 158 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 159 |
+
eos_coefficient (`float`, *optional*, defaults to 0.0001):
|
| 160 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 161 |
+
|
| 162 |
+
Examples:
|
| 163 |
+
|
| 164 |
+
```python
|
| 165 |
+
>>> from transformers import RTDetrConfig, RTDetrModel
|
| 166 |
+
|
| 167 |
+
>>> # Initializing a RT-DETR configuration
|
| 168 |
+
>>> configuration = RTDetrConfig()
|
| 169 |
+
|
| 170 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 171 |
+
>>> model = RTDetrModel(configuration)
|
| 172 |
+
|
| 173 |
+
>>> # Accessing the model configuration
|
| 174 |
+
>>> configuration = model.config
|
| 175 |
+
```"""
|
| 176 |
+
|
| 177 |
+
model_type = "rt_detr"
|
| 178 |
+
layer_types = ["basic", "bottleneck"]
|
| 179 |
+
attribute_map = {
|
| 180 |
+
"hidden_size": "d_model",
|
| 181 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
initializer_range=0.01,
|
| 187 |
+
initializer_bias_prior_prob=None,
|
| 188 |
+
layer_norm_eps=1e-5,
|
| 189 |
+
batch_norm_eps=1e-5,
|
| 190 |
+
# backbone
|
| 191 |
+
backbone_config=None,
|
| 192 |
+
backbone=None,
|
| 193 |
+
use_pretrained_backbone=False,
|
| 194 |
+
use_timm_backbone=False,
|
| 195 |
+
freeze_backbone_batch_norms=True,
|
| 196 |
+
backbone_kwargs=None,
|
| 197 |
+
# encoder HybridEncoder
|
| 198 |
+
encoder_hidden_dim=256,
|
| 199 |
+
encoder_in_channels=[512, 1024, 2048],
|
| 200 |
+
feat_strides=[8, 16, 32],
|
| 201 |
+
encoder_layers=1,
|
| 202 |
+
encoder_ffn_dim=1024,
|
| 203 |
+
encoder_attention_heads=8,
|
| 204 |
+
dropout=0.0,
|
| 205 |
+
activation_dropout=0.0,
|
| 206 |
+
encode_proj_layers=[2],
|
| 207 |
+
positional_encoding_temperature=10000,
|
| 208 |
+
encoder_activation_function="gelu",
|
| 209 |
+
activation_function="silu",
|
| 210 |
+
eval_size=None,
|
| 211 |
+
normalize_before=False,
|
| 212 |
+
hidden_expansion=1.0,
|
| 213 |
+
# decoder RTDetrTransformer
|
| 214 |
+
d_model=256,
|
| 215 |
+
num_queries=300,
|
| 216 |
+
decoder_in_channels=[256, 256, 256],
|
| 217 |
+
decoder_ffn_dim=1024,
|
| 218 |
+
num_feature_levels=3,
|
| 219 |
+
decoder_n_points=4,
|
| 220 |
+
decoder_layers=6,
|
| 221 |
+
decoder_attention_heads=8,
|
| 222 |
+
decoder_activation_function="relu",
|
| 223 |
+
attention_dropout=0.0,
|
| 224 |
+
num_denoising=100,
|
| 225 |
+
label_noise_ratio=0.5,
|
| 226 |
+
box_noise_scale=1.0,
|
| 227 |
+
learn_initial_query=False,
|
| 228 |
+
anchor_image_size=None,
|
| 229 |
+
disable_custom_kernels=True,
|
| 230 |
+
with_box_refine=True,
|
| 231 |
+
is_encoder_decoder=True,
|
| 232 |
+
# Loss
|
| 233 |
+
matcher_alpha=0.25,
|
| 234 |
+
matcher_gamma=2.0,
|
| 235 |
+
matcher_class_cost=2.0,
|
| 236 |
+
matcher_bbox_cost=5.0,
|
| 237 |
+
matcher_giou_cost=2.0,
|
| 238 |
+
use_focal_loss=True,
|
| 239 |
+
auxiliary_loss=True,
|
| 240 |
+
focal_loss_alpha=0.75,
|
| 241 |
+
focal_loss_gamma=2.0,
|
| 242 |
+
weight_loss_vfl=1.0,
|
| 243 |
+
weight_loss_bbox=5.0,
|
| 244 |
+
weight_loss_giou=2.0,
|
| 245 |
+
eos_coefficient=1e-4,
|
| 246 |
+
**kwargs,
|
| 247 |
+
):
|
| 248 |
+
self.initializer_range = initializer_range
|
| 249 |
+
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
| 250 |
+
self.layer_norm_eps = layer_norm_eps
|
| 251 |
+
self.batch_norm_eps = batch_norm_eps
|
| 252 |
+
# backbone
|
| 253 |
+
if backbone_config is None and backbone is None:
|
| 254 |
+
logger.info(
|
| 255 |
+
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone."
|
| 256 |
+
)
|
| 257 |
+
backbone_config = RTDetrResNetConfig(
|
| 258 |
+
num_channels=3,
|
| 259 |
+
embedding_size=64,
|
| 260 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 261 |
+
depths=[3, 4, 6, 3],
|
| 262 |
+
layer_type="bottleneck",
|
| 263 |
+
hidden_act="relu",
|
| 264 |
+
downsample_in_first_stage=False,
|
| 265 |
+
downsample_in_bottleneck=False,
|
| 266 |
+
out_features=None,
|
| 267 |
+
out_indices=[2, 3, 4],
|
| 268 |
+
)
|
| 269 |
+
elif isinstance(backbone_config, dict):
|
| 270 |
+
backbone_model_type = backbone_config.pop("model_type")
|
| 271 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 272 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 273 |
+
|
| 274 |
+
verify_backbone_config_arguments(
|
| 275 |
+
use_timm_backbone=use_timm_backbone,
|
| 276 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 277 |
+
backbone=backbone,
|
| 278 |
+
backbone_config=backbone_config,
|
| 279 |
+
backbone_kwargs=backbone_kwargs,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
self.backbone_config = backbone_config
|
| 283 |
+
self.backbone = backbone
|
| 284 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 285 |
+
self.use_timm_backbone = use_timm_backbone
|
| 286 |
+
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
|
| 287 |
+
self.backbone_kwargs = backbone_kwargs
|
| 288 |
+
# encoder
|
| 289 |
+
self.encoder_hidden_dim = encoder_hidden_dim
|
| 290 |
+
self.encoder_in_channels = encoder_in_channels
|
| 291 |
+
self.feat_strides = feat_strides
|
| 292 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 293 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 294 |
+
self.dropout = dropout
|
| 295 |
+
self.activation_dropout = activation_dropout
|
| 296 |
+
self.encode_proj_layers = encode_proj_layers
|
| 297 |
+
self.encoder_layers = encoder_layers
|
| 298 |
+
self.positional_encoding_temperature = positional_encoding_temperature
|
| 299 |
+
self.eval_size = eval_size
|
| 300 |
+
self.normalize_before = normalize_before
|
| 301 |
+
self.encoder_activation_function = encoder_activation_function
|
| 302 |
+
self.activation_function = activation_function
|
| 303 |
+
self.hidden_expansion = hidden_expansion
|
| 304 |
+
# decoder
|
| 305 |
+
self.d_model = d_model
|
| 306 |
+
self.num_queries = num_queries
|
| 307 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 308 |
+
self.decoder_in_channels = decoder_in_channels
|
| 309 |
+
self.num_feature_levels = num_feature_levels
|
| 310 |
+
self.decoder_n_points = decoder_n_points
|
| 311 |
+
self.decoder_layers = decoder_layers
|
| 312 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 313 |
+
self.decoder_activation_function = decoder_activation_function
|
| 314 |
+
self.attention_dropout = attention_dropout
|
| 315 |
+
self.num_denoising = num_denoising
|
| 316 |
+
self.label_noise_ratio = label_noise_ratio
|
| 317 |
+
self.box_noise_scale = box_noise_scale
|
| 318 |
+
self.learn_initial_query = learn_initial_query
|
| 319 |
+
self.anchor_image_size = anchor_image_size
|
| 320 |
+
self.auxiliary_loss = auxiliary_loss
|
| 321 |
+
self.disable_custom_kernels = disable_custom_kernels
|
| 322 |
+
self.with_box_refine = with_box_refine
|
| 323 |
+
# Loss
|
| 324 |
+
self.matcher_alpha = matcher_alpha
|
| 325 |
+
self.matcher_gamma = matcher_gamma
|
| 326 |
+
self.matcher_class_cost = matcher_class_cost
|
| 327 |
+
self.matcher_bbox_cost = matcher_bbox_cost
|
| 328 |
+
self.matcher_giou_cost = matcher_giou_cost
|
| 329 |
+
self.use_focal_loss = use_focal_loss
|
| 330 |
+
self.focal_loss_alpha = focal_loss_alpha
|
| 331 |
+
self.focal_loss_gamma = focal_loss_gamma
|
| 332 |
+
self.weight_loss_vfl = weight_loss_vfl
|
| 333 |
+
self.weight_loss_bbox = weight_loss_bbox
|
| 334 |
+
self.weight_loss_giou = weight_loss_giou
|
| 335 |
+
self.eos_coefficient = eos_coefficient
|
| 336 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def num_attention_heads(self) -> int:
|
| 340 |
+
return self.encoder_attention_heads
|
| 341 |
+
|
| 342 |
+
@property
|
| 343 |
+
def hidden_size(self) -> int:
|
| 344 |
+
return self.d_model
|
| 345 |
+
|
| 346 |
+
@property
|
| 347 |
+
def sub_configs(self):
|
| 348 |
+
return (
|
| 349 |
+
{"backbone_config": type(self.backbone_config)}
|
| 350 |
+
if getattr(self, "backbone_config", None) is not None
|
| 351 |
+
else {}
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
@classmethod
|
| 355 |
+
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 356 |
+
"""Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
| 357 |
+
configuration.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
backbone_config ([`PretrainedConfig`]):
|
| 361 |
+
The backbone configuration.
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
[`RTDetrConfig`]: An instance of a configuration object
|
| 365 |
+
"""
|
| 366 |
+
return cls(
|
| 367 |
+
backbone_config=backbone_config,
|
| 368 |
+
**kwargs,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
__all__ = ["RTDetrConfig"]
|
phivenv/Lib/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""RT-DETR ResNet model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an
|
| 28 |
+
ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 29 |
+
with the defaults will yield a similar configuration to that of the ResNet
|
| 30 |
+
[microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 37 |
+
The number of input channels.
|
| 38 |
+
embedding_size (`int`, *optional*, defaults to 64):
|
| 39 |
+
Dimensionality (hidden size) for the embedding layer.
|
| 40 |
+
hidden_sizes (`list[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
|
| 41 |
+
Dimensionality (hidden size) at each stage.
|
| 42 |
+
depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
|
| 43 |
+
Depth (number of layers) for each stage.
|
| 44 |
+
layer_type (`str`, *optional*, defaults to `"bottleneck"`):
|
| 45 |
+
The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
|
| 46 |
+
`"bottleneck"` (used for larger models like resnet-50 and above).
|
| 47 |
+
hidden_act (`str`, *optional*, defaults to `"relu"`):
|
| 48 |
+
The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
|
| 49 |
+
are supported.
|
| 50 |
+
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
|
| 51 |
+
If `True`, the first stage will downsample the inputs using a `stride` of 2.
|
| 52 |
+
downsample_in_bottleneck (`bool`, *optional*, defaults to `False`):
|
| 53 |
+
If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2.
|
| 54 |
+
out_features (`list[str]`, *optional*):
|
| 55 |
+
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
| 56 |
+
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
| 57 |
+
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
| 58 |
+
same order as defined in the `stage_names` attribute.
|
| 59 |
+
out_indices (`list[int]`, *optional*):
|
| 60 |
+
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
| 61 |
+
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
| 62 |
+
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
| 63 |
+
same order as defined in the `stage_names` attribute.
|
| 64 |
+
|
| 65 |
+
Example:
|
| 66 |
+
```python
|
| 67 |
+
>>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone
|
| 68 |
+
|
| 69 |
+
>>> # Initializing a ResNet resnet-50 style configuration
|
| 70 |
+
>>> configuration = RTDetrResNetConfig()
|
| 71 |
+
|
| 72 |
+
>>> # Initializing a model (with random weights) from the resnet-50 style configuration
|
| 73 |
+
>>> model = RTDetrResnetBackbone(configuration)
|
| 74 |
+
|
| 75 |
+
>>> # Accessing the model configuration
|
| 76 |
+
>>> configuration = model.config
|
| 77 |
+
```
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
model_type = "rt_detr_resnet"
|
| 81 |
+
layer_types = ["basic", "bottleneck"]
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
num_channels=3,
|
| 86 |
+
embedding_size=64,
|
| 87 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 88 |
+
depths=[3, 4, 6, 3],
|
| 89 |
+
layer_type="bottleneck",
|
| 90 |
+
hidden_act="relu",
|
| 91 |
+
downsample_in_first_stage=False,
|
| 92 |
+
downsample_in_bottleneck=False,
|
| 93 |
+
out_features=None,
|
| 94 |
+
out_indices=None,
|
| 95 |
+
**kwargs,
|
| 96 |
+
):
|
| 97 |
+
super().__init__(**kwargs)
|
| 98 |
+
if layer_type not in self.layer_types:
|
| 99 |
+
raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
|
| 100 |
+
self.num_channels = num_channels
|
| 101 |
+
self.embedding_size = embedding_size
|
| 102 |
+
self.hidden_sizes = hidden_sizes
|
| 103 |
+
self.depths = depths
|
| 104 |
+
self.layer_type = layer_type
|
| 105 |
+
self.hidden_act = hidden_act
|
| 106 |
+
self.downsample_in_first_stage = downsample_in_first_stage
|
| 107 |
+
self.downsample_in_bottleneck = downsample_in_bottleneck
|
| 108 |
+
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
| 109 |
+
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
| 110 |
+
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
__all__ = ["RTDetrResNetConfig"]
|
phivenv/Lib/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py
ADDED
|
@@ -0,0 +1,1103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for RT-DETR."""
|
| 16 |
+
|
| 17 |
+
import pathlib
|
| 18 |
+
from collections.abc import Iterable
|
| 19 |
+
from typing import Any, Callable, Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from ...feature_extraction_utils import BatchFeature
|
| 24 |
+
from ...image_processing_utils import BaseImageProcessor, get_size_dict
|
| 25 |
+
from ...image_transforms import (
|
| 26 |
+
PaddingMode,
|
| 27 |
+
center_to_corners_format,
|
| 28 |
+
corners_to_center_format,
|
| 29 |
+
pad,
|
| 30 |
+
rescale,
|
| 31 |
+
resize,
|
| 32 |
+
to_channel_dimension_format,
|
| 33 |
+
)
|
| 34 |
+
from ...image_utils import (
|
| 35 |
+
IMAGENET_DEFAULT_MEAN,
|
| 36 |
+
IMAGENET_DEFAULT_STD,
|
| 37 |
+
AnnotationFormat,
|
| 38 |
+
AnnotationType,
|
| 39 |
+
ChannelDimension,
|
| 40 |
+
ImageInput,
|
| 41 |
+
PILImageResampling,
|
| 42 |
+
get_image_size,
|
| 43 |
+
infer_channel_dimension_format,
|
| 44 |
+
is_scaled_image,
|
| 45 |
+
make_list_of_images,
|
| 46 |
+
to_numpy_array,
|
| 47 |
+
valid_images,
|
| 48 |
+
validate_annotations,
|
| 49 |
+
validate_preprocess_arguments,
|
| 50 |
+
)
|
| 51 |
+
from ...utils import (
|
| 52 |
+
filter_out_non_signature_kwargs,
|
| 53 |
+
is_flax_available,
|
| 54 |
+
is_jax_tensor,
|
| 55 |
+
is_tf_available,
|
| 56 |
+
is_tf_tensor,
|
| 57 |
+
is_torch_available,
|
| 58 |
+
is_torch_tensor,
|
| 59 |
+
logging,
|
| 60 |
+
requires_backends,
|
| 61 |
+
)
|
| 62 |
+
from ...utils.generic import TensorType
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if is_torch_available():
|
| 66 |
+
import torch
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 70 |
+
|
| 71 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
|
| 75 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
|
| 76 |
+
"""
|
| 77 |
+
Computes the output image size given the input image size and the desired output size.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
image_size (`tuple[int, int]`):
|
| 81 |
+
The input image size.
|
| 82 |
+
size (`int`):
|
| 83 |
+
The desired output size.
|
| 84 |
+
max_size (`int`, *optional*):
|
| 85 |
+
The maximum allowed output size.
|
| 86 |
+
"""
|
| 87 |
+
height, width = image_size
|
| 88 |
+
raw_size = None
|
| 89 |
+
if max_size is not None:
|
| 90 |
+
min_original_size = float(min((height, width)))
|
| 91 |
+
max_original_size = float(max((height, width)))
|
| 92 |
+
if max_original_size / min_original_size * size > max_size:
|
| 93 |
+
raw_size = max_size * min_original_size / max_original_size
|
| 94 |
+
size = int(round(raw_size))
|
| 95 |
+
|
| 96 |
+
if (height <= width and height == size) or (width <= height and width == size):
|
| 97 |
+
oh, ow = height, width
|
| 98 |
+
elif width < height:
|
| 99 |
+
ow = size
|
| 100 |
+
if max_size is not None and raw_size is not None:
|
| 101 |
+
oh = int(raw_size * height / width)
|
| 102 |
+
else:
|
| 103 |
+
oh = int(size * height / width)
|
| 104 |
+
else:
|
| 105 |
+
oh = size
|
| 106 |
+
if max_size is not None and raw_size is not None:
|
| 107 |
+
ow = int(raw_size * width / height)
|
| 108 |
+
else:
|
| 109 |
+
ow = int(size * width / height)
|
| 110 |
+
|
| 111 |
+
return (oh, ow)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
|
| 115 |
+
def get_resize_output_image_size(
|
| 116 |
+
input_image: np.ndarray,
|
| 117 |
+
size: Union[int, tuple[int, int], list[int]],
|
| 118 |
+
max_size: Optional[int] = None,
|
| 119 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 120 |
+
) -> tuple[int, int]:
|
| 121 |
+
"""
|
| 122 |
+
Computes the output image size given the input image size and the desired output size. If the desired output size
|
| 123 |
+
is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
|
| 124 |
+
image size is computed by keeping the aspect ratio of the input image size.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
input_image (`np.ndarray`):
|
| 128 |
+
The image to resize.
|
| 129 |
+
size (`int` or `tuple[int, int]` or `list[int]`):
|
| 130 |
+
The desired output size.
|
| 131 |
+
max_size (`int`, *optional*):
|
| 132 |
+
The maximum allowed output size.
|
| 133 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 134 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 135 |
+
"""
|
| 136 |
+
image_size = get_image_size(input_image, input_data_format)
|
| 137 |
+
if isinstance(size, (list, tuple)):
|
| 138 |
+
return size
|
| 139 |
+
|
| 140 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width
|
| 144 |
+
def get_image_size_for_max_height_width(
|
| 145 |
+
input_image: np.ndarray,
|
| 146 |
+
max_height: int,
|
| 147 |
+
max_width: int,
|
| 148 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 149 |
+
) -> tuple[int, int]:
|
| 150 |
+
"""
|
| 151 |
+
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
|
| 152 |
+
Important, even if image_height < max_height and image_width < max_width, the image will be resized
|
| 153 |
+
to at least one of the edges be equal to max_height or max_width.
|
| 154 |
+
For example:
|
| 155 |
+
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
|
| 156 |
+
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
|
| 157 |
+
Args:
|
| 158 |
+
input_image (`np.ndarray`):
|
| 159 |
+
The image to resize.
|
| 160 |
+
max_height (`int`):
|
| 161 |
+
The maximum allowed height.
|
| 162 |
+
max_width (`int`):
|
| 163 |
+
The maximum allowed width.
|
| 164 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 165 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 166 |
+
"""
|
| 167 |
+
image_size = get_image_size(input_image, input_data_format)
|
| 168 |
+
height, width = image_size
|
| 169 |
+
height_scale = max_height / height
|
| 170 |
+
width_scale = max_width / width
|
| 171 |
+
min_scale = min(height_scale, width_scale)
|
| 172 |
+
new_height = int(height * min_scale)
|
| 173 |
+
new_width = int(width * min_scale)
|
| 174 |
+
return new_height, new_width
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn
|
| 178 |
+
def get_numpy_to_framework_fn(arr) -> Callable:
|
| 179 |
+
"""
|
| 180 |
+
Returns a function that converts a numpy array to the framework of the input array.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
arr (`np.ndarray`): The array to convert.
|
| 184 |
+
"""
|
| 185 |
+
if isinstance(arr, np.ndarray):
|
| 186 |
+
return np.array
|
| 187 |
+
if is_tf_available() and is_tf_tensor(arr):
|
| 188 |
+
import tensorflow as tf
|
| 189 |
+
|
| 190 |
+
return tf.convert_to_tensor
|
| 191 |
+
if is_torch_available() and is_torch_tensor(arr):
|
| 192 |
+
import torch
|
| 193 |
+
|
| 194 |
+
return torch.tensor
|
| 195 |
+
if is_flax_available() and is_jax_tensor(arr):
|
| 196 |
+
import jax.numpy as jnp
|
| 197 |
+
|
| 198 |
+
return jnp.array
|
| 199 |
+
raise ValueError(f"Cannot convert arrays of type {type(arr)}")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# Copied from transformers.models.detr.image_processing_detr.safe_squeeze
|
| 203 |
+
def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
|
| 204 |
+
"""
|
| 205 |
+
Squeezes an array, but only if the axis specified has dim 1.
|
| 206 |
+
"""
|
| 207 |
+
if axis is None:
|
| 208 |
+
return arr.squeeze()
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
return arr.squeeze(axis=axis)
|
| 212 |
+
except ValueError:
|
| 213 |
+
return arr
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# Copied from transformers.models.detr.image_processing_detr.normalize_annotation
|
| 217 |
+
def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict:
|
| 218 |
+
image_height, image_width = image_size
|
| 219 |
+
norm_annotation = {}
|
| 220 |
+
for key, value in annotation.items():
|
| 221 |
+
if key == "boxes":
|
| 222 |
+
boxes = value
|
| 223 |
+
boxes = corners_to_center_format(boxes)
|
| 224 |
+
boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
|
| 225 |
+
norm_annotation[key] = boxes
|
| 226 |
+
else:
|
| 227 |
+
norm_annotation[key] = value
|
| 228 |
+
return norm_annotation
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
|
| 232 |
+
def max_across_indices(values: Iterable[Any]) -> list[Any]:
|
| 233 |
+
"""
|
| 234 |
+
Return the maximum value across all indices of an iterable of values.
|
| 235 |
+
"""
|
| 236 |
+
return [max(values_i) for values_i in zip(*values)]
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
| 240 |
+
def get_max_height_width(
|
| 241 |
+
images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
| 242 |
+
) -> list[int]:
|
| 243 |
+
"""
|
| 244 |
+
Get the maximum height and width across all images in a batch.
|
| 245 |
+
"""
|
| 246 |
+
if input_data_format is None:
|
| 247 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 248 |
+
|
| 249 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 250 |
+
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
| 251 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 252 |
+
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
| 255 |
+
return (max_height, max_width)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
| 259 |
+
def make_pixel_mask(
|
| 260 |
+
image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
| 261 |
+
) -> np.ndarray:
|
| 262 |
+
"""
|
| 263 |
+
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
image (`np.ndarray`):
|
| 267 |
+
Image to make the pixel mask for.
|
| 268 |
+
output_size (`tuple[int, int]`):
|
| 269 |
+
Output size of the mask.
|
| 270 |
+
"""
|
| 271 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 272 |
+
mask = np.zeros(output_size, dtype=np.int64)
|
| 273 |
+
mask[:input_height, :input_width] = 1
|
| 274 |
+
return mask
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def prepare_coco_detection_annotation(
|
| 278 |
+
image,
|
| 279 |
+
target,
|
| 280 |
+
return_segmentation_masks: bool = False,
|
| 281 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 282 |
+
):
|
| 283 |
+
"""
|
| 284 |
+
Convert the target in COCO format into the format expected by RTDETR.
|
| 285 |
+
"""
|
| 286 |
+
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
| 287 |
+
|
| 288 |
+
image_id = target["image_id"]
|
| 289 |
+
image_id = np.asarray([image_id], dtype=np.int64)
|
| 290 |
+
|
| 291 |
+
# Get all COCO annotations for the given image.
|
| 292 |
+
annotations = target["annotations"]
|
| 293 |
+
annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
| 294 |
+
|
| 295 |
+
classes = [obj["category_id"] for obj in annotations]
|
| 296 |
+
classes = np.asarray(classes, dtype=np.int64)
|
| 297 |
+
|
| 298 |
+
# for conversion to coco api
|
| 299 |
+
area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
|
| 300 |
+
iscrowd = np.asarray([obj.get("iscrowd", 0) for obj in annotations], dtype=np.int64)
|
| 301 |
+
|
| 302 |
+
boxes = [obj["bbox"] for obj in annotations]
|
| 303 |
+
# guard against no boxes via resizing
|
| 304 |
+
boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
|
| 305 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 306 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 307 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 308 |
+
|
| 309 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 310 |
+
|
| 311 |
+
new_target = {}
|
| 312 |
+
new_target["image_id"] = image_id
|
| 313 |
+
new_target["class_labels"] = classes[keep]
|
| 314 |
+
new_target["boxes"] = boxes[keep]
|
| 315 |
+
new_target["area"] = area[keep]
|
| 316 |
+
new_target["iscrowd"] = iscrowd[keep]
|
| 317 |
+
new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
|
| 318 |
+
|
| 319 |
+
if annotations and "keypoints" in annotations[0]:
|
| 320 |
+
keypoints = [obj["keypoints"] for obj in annotations]
|
| 321 |
+
# Converting the filtered keypoints list to a numpy array
|
| 322 |
+
keypoints = np.asarray(keypoints, dtype=np.float32)
|
| 323 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 324 |
+
keypoints = keypoints[keep]
|
| 325 |
+
num_keypoints = keypoints.shape[0]
|
| 326 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 327 |
+
new_target["keypoints"] = keypoints
|
| 328 |
+
|
| 329 |
+
return new_target
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Copied from transformers.models.detr.image_processing_detr.resize_annotation
|
| 333 |
+
def resize_annotation(
|
| 334 |
+
annotation: dict[str, Any],
|
| 335 |
+
orig_size: tuple[int, int],
|
| 336 |
+
target_size: tuple[int, int],
|
| 337 |
+
threshold: float = 0.5,
|
| 338 |
+
resample: PILImageResampling = PILImageResampling.NEAREST,
|
| 339 |
+
):
|
| 340 |
+
"""
|
| 341 |
+
Resizes an annotation to a target size.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
annotation (`dict[str, Any]`):
|
| 345 |
+
The annotation dictionary.
|
| 346 |
+
orig_size (`tuple[int, int]`):
|
| 347 |
+
The original size of the input image.
|
| 348 |
+
target_size (`tuple[int, int]`):
|
| 349 |
+
The target size of the image, as returned by the preprocessing `resize` step.
|
| 350 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 351 |
+
The threshold used to binarize the segmentation masks.
|
| 352 |
+
resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
|
| 353 |
+
The resampling filter to use when resizing the masks.
|
| 354 |
+
"""
|
| 355 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
|
| 356 |
+
ratio_height, ratio_width = ratios
|
| 357 |
+
|
| 358 |
+
new_annotation = {}
|
| 359 |
+
new_annotation["size"] = target_size
|
| 360 |
+
|
| 361 |
+
for key, value in annotation.items():
|
| 362 |
+
if key == "boxes":
|
| 363 |
+
boxes = value
|
| 364 |
+
scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
|
| 365 |
+
new_annotation["boxes"] = scaled_boxes
|
| 366 |
+
elif key == "area":
|
| 367 |
+
area = value
|
| 368 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 369 |
+
new_annotation["area"] = scaled_area
|
| 370 |
+
elif key == "masks":
|
| 371 |
+
masks = value[:, None]
|
| 372 |
+
masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
|
| 373 |
+
masks = masks.astype(np.float32)
|
| 374 |
+
masks = masks[:, 0] > threshold
|
| 375 |
+
new_annotation["masks"] = masks
|
| 376 |
+
elif key == "size":
|
| 377 |
+
new_annotation["size"] = target_size
|
| 378 |
+
else:
|
| 379 |
+
new_annotation[key] = value
|
| 380 |
+
|
| 381 |
+
return new_annotation
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class RTDetrImageProcessor(BaseImageProcessor):
|
| 385 |
+
r"""
|
| 386 |
+
Constructs a RT-DETR image processor.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 390 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 391 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 392 |
+
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be
|
| 393 |
+
overridden by the `do_resize` parameter in the `preprocess` method.
|
| 394 |
+
size (`dict[str, int]` *optional*, defaults to `{"height": 640, "width": 640}`):
|
| 395 |
+
Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
|
| 396 |
+
in the `preprocess` method. Available options are:
|
| 397 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 398 |
+
Do NOT keep the aspect ratio.
|
| 399 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 400 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 401 |
+
less or equal to `longest_edge`.
|
| 402 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 403 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 404 |
+
`max_width`.
|
| 405 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 406 |
+
Resampling filter to use if resizing the image.
|
| 407 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 408 |
+
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
| 409 |
+
`do_rescale` parameter in the `preprocess` method.
|
| 410 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 411 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 412 |
+
`preprocess` method.
|
| 413 |
+
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
|
| 414 |
+
`preprocess` method.
|
| 415 |
+
do_normalize (`bool`, *optional*, defaults to `False`):
|
| 416 |
+
Whether to normalize the image.
|
| 417 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
|
| 418 |
+
Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
|
| 419 |
+
channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 420 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
| 421 |
+
Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
|
| 422 |
+
for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 423 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 424 |
+
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
|
| 425 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 426 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 427 |
+
do_pad (`bool`, *optional*, defaults to `False`):
|
| 428 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 429 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 430 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 431 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 432 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 433 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 434 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 435 |
+
height and width in the batch.
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 439 |
+
|
| 440 |
+
def __init__(
|
| 441 |
+
self,
|
| 442 |
+
format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
|
| 443 |
+
do_resize: bool = True,
|
| 444 |
+
size: Optional[dict[str, int]] = None,
|
| 445 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 446 |
+
do_rescale: bool = True,
|
| 447 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 448 |
+
do_normalize: bool = False,
|
| 449 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 450 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 451 |
+
do_convert_annotations: bool = True,
|
| 452 |
+
do_pad: bool = False,
|
| 453 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 454 |
+
**kwargs,
|
| 455 |
+
) -> None:
|
| 456 |
+
size = size if size is not None else {"height": 640, "width": 640}
|
| 457 |
+
size = get_size_dict(size, default_to_square=False)
|
| 458 |
+
|
| 459 |
+
if do_convert_annotations is None:
|
| 460 |
+
do_convert_annotations = do_normalize
|
| 461 |
+
|
| 462 |
+
super().__init__(**kwargs)
|
| 463 |
+
self.format = format
|
| 464 |
+
self.do_resize = do_resize
|
| 465 |
+
self.size = size
|
| 466 |
+
self.resample = resample
|
| 467 |
+
self.do_rescale = do_rescale
|
| 468 |
+
self.rescale_factor = rescale_factor
|
| 469 |
+
self.do_normalize = do_normalize
|
| 470 |
+
self.do_convert_annotations = do_convert_annotations
|
| 471 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 472 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 473 |
+
self.do_pad = do_pad
|
| 474 |
+
self.pad_size = pad_size
|
| 475 |
+
|
| 476 |
+
def prepare_annotation(
|
| 477 |
+
self,
|
| 478 |
+
image: np.ndarray,
|
| 479 |
+
target: dict,
|
| 480 |
+
format: Optional[AnnotationFormat] = None,
|
| 481 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 482 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 483 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 484 |
+
) -> dict:
|
| 485 |
+
"""
|
| 486 |
+
Prepare an annotation for feeding into RTDETR model.
|
| 487 |
+
"""
|
| 488 |
+
format = format if format is not None else self.format
|
| 489 |
+
|
| 490 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 491 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 492 |
+
target = prepare_coco_detection_annotation(
|
| 493 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 494 |
+
)
|
| 495 |
+
else:
|
| 496 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 497 |
+
return target
|
| 498 |
+
|
| 499 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize
|
| 500 |
+
def resize(
|
| 501 |
+
self,
|
| 502 |
+
image: np.ndarray,
|
| 503 |
+
size: dict[str, int],
|
| 504 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 505 |
+
data_format: Optional[ChannelDimension] = None,
|
| 506 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 507 |
+
**kwargs,
|
| 508 |
+
) -> np.ndarray:
|
| 509 |
+
"""
|
| 510 |
+
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
| 511 |
+
int, smaller edge of the image will be matched to this number.
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
image (`np.ndarray`):
|
| 515 |
+
Image to resize.
|
| 516 |
+
size (`dict[str, int]`):
|
| 517 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 518 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 519 |
+
Do NOT keep the aspect ratio.
|
| 520 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 521 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 522 |
+
less or equal to `longest_edge`.
|
| 523 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 524 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 525 |
+
`max_width`.
|
| 526 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 527 |
+
Resampling filter to use if resizing the image.
|
| 528 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 529 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 530 |
+
image is used.
|
| 531 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 532 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 533 |
+
"""
|
| 534 |
+
if "max_size" in kwargs:
|
| 535 |
+
logger.warning_once(
|
| 536 |
+
"The `max_size` parameter is deprecated and will be removed in v4.26. "
|
| 537 |
+
"Please specify in `size['longest_edge'] instead`.",
|
| 538 |
+
)
|
| 539 |
+
max_size = kwargs.pop("max_size")
|
| 540 |
+
else:
|
| 541 |
+
max_size = None
|
| 542 |
+
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
| 543 |
+
if "shortest_edge" in size and "longest_edge" in size:
|
| 544 |
+
new_size = get_resize_output_image_size(
|
| 545 |
+
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
|
| 546 |
+
)
|
| 547 |
+
elif "max_height" in size and "max_width" in size:
|
| 548 |
+
new_size = get_image_size_for_max_height_width(
|
| 549 |
+
image, size["max_height"], size["max_width"], input_data_format=input_data_format
|
| 550 |
+
)
|
| 551 |
+
elif "height" in size and "width" in size:
|
| 552 |
+
new_size = (size["height"], size["width"])
|
| 553 |
+
else:
|
| 554 |
+
raise ValueError(
|
| 555 |
+
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
| 556 |
+
f" {size.keys()}."
|
| 557 |
+
)
|
| 558 |
+
image = resize(
|
| 559 |
+
image,
|
| 560 |
+
size=new_size,
|
| 561 |
+
resample=resample,
|
| 562 |
+
data_format=data_format,
|
| 563 |
+
input_data_format=input_data_format,
|
| 564 |
+
**kwargs,
|
| 565 |
+
)
|
| 566 |
+
return image
|
| 567 |
+
|
| 568 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
|
| 569 |
+
def resize_annotation(
|
| 570 |
+
self,
|
| 571 |
+
annotation,
|
| 572 |
+
orig_size,
|
| 573 |
+
size,
|
| 574 |
+
resample: PILImageResampling = PILImageResampling.NEAREST,
|
| 575 |
+
) -> dict:
|
| 576 |
+
"""
|
| 577 |
+
Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
|
| 578 |
+
to this number.
|
| 579 |
+
"""
|
| 580 |
+
return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
|
| 581 |
+
|
| 582 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
| 583 |
+
def rescale(
|
| 584 |
+
self,
|
| 585 |
+
image: np.ndarray,
|
| 586 |
+
rescale_factor: float,
|
| 587 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 588 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 589 |
+
) -> np.ndarray:
|
| 590 |
+
"""
|
| 591 |
+
Rescale the image by the given factor. image = image * rescale_factor.
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
image (`np.ndarray`):
|
| 595 |
+
Image to rescale.
|
| 596 |
+
rescale_factor (`float`):
|
| 597 |
+
The value to use for rescaling.
|
| 598 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 599 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 600 |
+
image is used. Can be one of:
|
| 601 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 602 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 603 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 604 |
+
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
| 605 |
+
one of:
|
| 606 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 607 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 608 |
+
"""
|
| 609 |
+
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
| 610 |
+
|
| 611 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
|
| 612 |
+
def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
|
| 613 |
+
"""
|
| 614 |
+
Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
|
| 615 |
+
`[center_x, center_y, width, height]` format and from absolute to relative pixel values.
|
| 616 |
+
"""
|
| 617 |
+
return normalize_annotation(annotation, image_size=image_size)
|
| 618 |
+
|
| 619 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image
|
| 620 |
+
def _update_annotation_for_padded_image(
|
| 621 |
+
self,
|
| 622 |
+
annotation: dict,
|
| 623 |
+
input_image_size: tuple[int, int],
|
| 624 |
+
output_image_size: tuple[int, int],
|
| 625 |
+
padding,
|
| 626 |
+
update_bboxes,
|
| 627 |
+
) -> dict:
|
| 628 |
+
"""
|
| 629 |
+
Update the annotation for a padded image.
|
| 630 |
+
"""
|
| 631 |
+
new_annotation = {}
|
| 632 |
+
new_annotation["size"] = output_image_size
|
| 633 |
+
|
| 634 |
+
for key, value in annotation.items():
|
| 635 |
+
if key == "masks":
|
| 636 |
+
masks = value
|
| 637 |
+
masks = pad(
|
| 638 |
+
masks,
|
| 639 |
+
padding,
|
| 640 |
+
mode=PaddingMode.CONSTANT,
|
| 641 |
+
constant_values=0,
|
| 642 |
+
input_data_format=ChannelDimension.FIRST,
|
| 643 |
+
)
|
| 644 |
+
masks = safe_squeeze(masks, 1)
|
| 645 |
+
new_annotation["masks"] = masks
|
| 646 |
+
elif key == "boxes" and update_bboxes:
|
| 647 |
+
boxes = value
|
| 648 |
+
boxes *= np.asarray(
|
| 649 |
+
[
|
| 650 |
+
input_image_size[1] / output_image_size[1],
|
| 651 |
+
input_image_size[0] / output_image_size[0],
|
| 652 |
+
input_image_size[1] / output_image_size[1],
|
| 653 |
+
input_image_size[0] / output_image_size[0],
|
| 654 |
+
]
|
| 655 |
+
)
|
| 656 |
+
new_annotation["boxes"] = boxes
|
| 657 |
+
elif key == "size":
|
| 658 |
+
new_annotation["size"] = output_image_size
|
| 659 |
+
else:
|
| 660 |
+
new_annotation[key] = value
|
| 661 |
+
return new_annotation
|
| 662 |
+
|
| 663 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
|
| 664 |
+
def _pad_image(
|
| 665 |
+
self,
|
| 666 |
+
image: np.ndarray,
|
| 667 |
+
output_size: tuple[int, int],
|
| 668 |
+
annotation: Optional[dict[str, Any]] = None,
|
| 669 |
+
constant_values: Union[float, Iterable[float]] = 0,
|
| 670 |
+
data_format: Optional[ChannelDimension] = None,
|
| 671 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 672 |
+
update_bboxes: bool = True,
|
| 673 |
+
) -> np.ndarray:
|
| 674 |
+
"""
|
| 675 |
+
Pad an image with zeros to the given size.
|
| 676 |
+
"""
|
| 677 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 678 |
+
output_height, output_width = output_size
|
| 679 |
+
|
| 680 |
+
pad_bottom = output_height - input_height
|
| 681 |
+
pad_right = output_width - input_width
|
| 682 |
+
padding = ((0, pad_bottom), (0, pad_right))
|
| 683 |
+
padded_image = pad(
|
| 684 |
+
image,
|
| 685 |
+
padding,
|
| 686 |
+
mode=PaddingMode.CONSTANT,
|
| 687 |
+
constant_values=constant_values,
|
| 688 |
+
data_format=data_format,
|
| 689 |
+
input_data_format=input_data_format,
|
| 690 |
+
)
|
| 691 |
+
if annotation is not None:
|
| 692 |
+
annotation = self._update_annotation_for_padded_image(
|
| 693 |
+
annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
|
| 694 |
+
)
|
| 695 |
+
return padded_image, annotation
|
| 696 |
+
|
| 697 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
|
| 698 |
+
def pad(
|
| 699 |
+
self,
|
| 700 |
+
images: list[np.ndarray],
|
| 701 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
|
| 702 |
+
constant_values: Union[float, Iterable[float]] = 0,
|
| 703 |
+
return_pixel_mask: bool = True,
|
| 704 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 705 |
+
data_format: Optional[ChannelDimension] = None,
|
| 706 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 707 |
+
update_bboxes: bool = True,
|
| 708 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 709 |
+
) -> BatchFeature:
|
| 710 |
+
"""
|
| 711 |
+
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
| 712 |
+
in the batch and optionally returns their corresponding pixel mask.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
images (list[`np.ndarray`]):
|
| 716 |
+
Images to pad.
|
| 717 |
+
annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
|
| 718 |
+
Annotations to transform according to the padding that is applied to the images.
|
| 719 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 720 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 721 |
+
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
| 722 |
+
Whether to return a pixel mask.
|
| 723 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 724 |
+
The type of tensors to return. Can be one of:
|
| 725 |
+
- Unset: Return a list of `np.ndarray`.
|
| 726 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 727 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 728 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 729 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 730 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 731 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 732 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 733 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 734 |
+
update_bboxes (`bool`, *optional*, defaults to `True`):
|
| 735 |
+
Whether to update the bounding boxes in the annotations to match the padded images. If the
|
| 736 |
+
bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
|
| 737 |
+
format, the bounding boxes will not be updated.
|
| 738 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 739 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 740 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 741 |
+
height and width in the batch.
|
| 742 |
+
"""
|
| 743 |
+
pad_size = pad_size if pad_size is not None else self.pad_size
|
| 744 |
+
if pad_size is not None:
|
| 745 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 746 |
+
else:
|
| 747 |
+
padded_size = get_max_height_width(images, input_data_format=input_data_format)
|
| 748 |
+
|
| 749 |
+
annotation_list = annotations if annotations is not None else [None] * len(images)
|
| 750 |
+
padded_images = []
|
| 751 |
+
padded_annotations = []
|
| 752 |
+
for image, annotation in zip(images, annotation_list):
|
| 753 |
+
padded_image, padded_annotation = self._pad_image(
|
| 754 |
+
image,
|
| 755 |
+
padded_size,
|
| 756 |
+
annotation,
|
| 757 |
+
constant_values=constant_values,
|
| 758 |
+
data_format=data_format,
|
| 759 |
+
input_data_format=input_data_format,
|
| 760 |
+
update_bboxes=update_bboxes,
|
| 761 |
+
)
|
| 762 |
+
padded_images.append(padded_image)
|
| 763 |
+
padded_annotations.append(padded_annotation)
|
| 764 |
+
|
| 765 |
+
data = {"pixel_values": padded_images}
|
| 766 |
+
|
| 767 |
+
if return_pixel_mask:
|
| 768 |
+
masks = [
|
| 769 |
+
make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
|
| 770 |
+
for image in images
|
| 771 |
+
]
|
| 772 |
+
data["pixel_mask"] = masks
|
| 773 |
+
|
| 774 |
+
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
| 775 |
+
|
| 776 |
+
if annotations is not None:
|
| 777 |
+
encoded_inputs["labels"] = [
|
| 778 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
|
| 779 |
+
]
|
| 780 |
+
|
| 781 |
+
return encoded_inputs
|
| 782 |
+
|
| 783 |
+
@filter_out_non_signature_kwargs()
|
| 784 |
+
def preprocess(
|
| 785 |
+
self,
|
| 786 |
+
images: ImageInput,
|
| 787 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
|
| 788 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 789 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 790 |
+
do_resize: Optional[bool] = None,
|
| 791 |
+
size: Optional[dict[str, int]] = None,
|
| 792 |
+
resample=None, # PILImageResampling
|
| 793 |
+
do_rescale: Optional[bool] = None,
|
| 794 |
+
rescale_factor: Optional[Union[int, float]] = None,
|
| 795 |
+
do_normalize: Optional[bool] = None,
|
| 796 |
+
do_convert_annotations: Optional[bool] = None,
|
| 797 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 798 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 799 |
+
do_pad: Optional[bool] = None,
|
| 800 |
+
format: Optional[Union[str, AnnotationFormat]] = None,
|
| 801 |
+
return_tensors: Optional[Union[TensorType, str]] = None,
|
| 802 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 803 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 804 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 805 |
+
) -> BatchFeature:
|
| 806 |
+
"""
|
| 807 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
images (`ImageInput`):
|
| 811 |
+
Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
|
| 812 |
+
from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 813 |
+
annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
|
| 814 |
+
List of annotations associated with the image or batch of images. If annotation is for object
|
| 815 |
+
detection, the annotations should be a dictionary with the following keys:
|
| 816 |
+
- "image_id" (`int`): The image id.
|
| 817 |
+
- "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
|
| 818 |
+
dictionary. An image can have no annotations, in which case the list should be empty.
|
| 819 |
+
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
| 820 |
+
- "image_id" (`int`): The image id.
|
| 821 |
+
- "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
| 822 |
+
An image can have no segments, in which case the list should be empty.
|
| 823 |
+
- "file_name" (`str`): The file name of the image.
|
| 824 |
+
return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
|
| 825 |
+
Whether to return segmentation masks.
|
| 826 |
+
masks_path (`str` or `pathlib.Path`, *optional*):
|
| 827 |
+
Path to the directory containing the segmentation masks.
|
| 828 |
+
do_resize (`bool`, *optional*, defaults to self.do_resize):
|
| 829 |
+
Whether to resize the image.
|
| 830 |
+
size (`dict[str, int]`, *optional*, defaults to self.size):
|
| 831 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 832 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 833 |
+
Do NOT keep the aspect ratio.
|
| 834 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 835 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 836 |
+
less or equal to `longest_edge`.
|
| 837 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 838 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 839 |
+
`max_width`.
|
| 840 |
+
resample (`PILImageResampling`, *optional*, defaults to self.resample):
|
| 841 |
+
Resampling filter to use when resizing the image.
|
| 842 |
+
do_rescale (`bool`, *optional*, defaults to self.do_rescale):
|
| 843 |
+
Whether to rescale the image.
|
| 844 |
+
rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
|
| 845 |
+
Rescale factor to use when rescaling the image.
|
| 846 |
+
do_normalize (`bool`, *optional*, defaults to self.do_normalize):
|
| 847 |
+
Whether to normalize the image.
|
| 848 |
+
do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
|
| 849 |
+
Whether to convert the annotations to the format expected by the model. Converts the bounding
|
| 850 |
+
boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
|
| 851 |
+
and in relative coordinates.
|
| 852 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean):
|
| 853 |
+
Mean to use when normalizing the image.
|
| 854 |
+
image_std (`float` or `list[float]`, *optional*, defaults to self.image_std):
|
| 855 |
+
Standard deviation to use when normalizing the image.
|
| 856 |
+
do_pad (`bool`, *optional*, defaults to self.do_pad):
|
| 857 |
+
Whether to pad the image. If `True`, padding will be applied to the bottom and right of
|
| 858 |
+
the image with zeros. If `pad_size` is provided, the image will be padded to the specified
|
| 859 |
+
dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 860 |
+
format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
|
| 861 |
+
Format of the annotations.
|
| 862 |
+
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
|
| 863 |
+
Type of tensors to return. If `None`, will return the list of images.
|
| 864 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 865 |
+
The channel dimension format for the output image. Can be one of:
|
| 866 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 867 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 868 |
+
- Unset: Use the channel dimension format of the input image.
|
| 869 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 870 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 871 |
+
from the input image. Can be one of:
|
| 872 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 873 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 874 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 875 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 876 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 877 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 878 |
+
height and width in the batch.
|
| 879 |
+
"""
|
| 880 |
+
do_resize = self.do_resize if do_resize is None else do_resize
|
| 881 |
+
size = self.size if size is None else size
|
| 882 |
+
size = get_size_dict(size=size, default_to_square=True)
|
| 883 |
+
resample = self.resample if resample is None else resample
|
| 884 |
+
do_rescale = self.do_rescale if do_rescale is None else do_rescale
|
| 885 |
+
rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
|
| 886 |
+
do_normalize = self.do_normalize if do_normalize is None else do_normalize
|
| 887 |
+
image_mean = self.image_mean if image_mean is None else image_mean
|
| 888 |
+
image_std = self.image_std if image_std is None else image_std
|
| 889 |
+
do_convert_annotations = (
|
| 890 |
+
self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
|
| 891 |
+
)
|
| 892 |
+
do_pad = self.do_pad if do_pad is None else do_pad
|
| 893 |
+
pad_size = self.pad_size if pad_size is None else pad_size
|
| 894 |
+
format = self.format if format is None else format
|
| 895 |
+
|
| 896 |
+
images = make_list_of_images(images)
|
| 897 |
+
|
| 898 |
+
if not valid_images(images):
|
| 899 |
+
raise ValueError(
|
| 900 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 901 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
# Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
|
| 905 |
+
|
| 906 |
+
validate_preprocess_arguments(
|
| 907 |
+
do_rescale=do_rescale,
|
| 908 |
+
rescale_factor=rescale_factor,
|
| 909 |
+
do_normalize=do_normalize,
|
| 910 |
+
image_mean=image_mean,
|
| 911 |
+
image_std=image_std,
|
| 912 |
+
do_resize=do_resize,
|
| 913 |
+
size=size,
|
| 914 |
+
resample=resample,
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 918 |
+
annotations = [annotations]
|
| 919 |
+
|
| 920 |
+
if annotations is not None and len(images) != len(annotations):
|
| 921 |
+
raise ValueError(
|
| 922 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
format = AnnotationFormat(format)
|
| 926 |
+
if annotations is not None:
|
| 927 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 928 |
+
|
| 929 |
+
images = make_list_of_images(images)
|
| 930 |
+
if not valid_images(images):
|
| 931 |
+
raise ValueError(
|
| 932 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 933 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
# All transformations expect numpy arrays
|
| 937 |
+
images = [to_numpy_array(image) for image in images]
|
| 938 |
+
|
| 939 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 940 |
+
logger.warning_once(
|
| 941 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 942 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
if input_data_format is None:
|
| 946 |
+
# We assume that all images have the same channel dimension format.
|
| 947 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 948 |
+
|
| 949 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 950 |
+
if annotations is not None:
|
| 951 |
+
prepared_images = []
|
| 952 |
+
prepared_annotations = []
|
| 953 |
+
for image, target in zip(images, annotations):
|
| 954 |
+
target = self.prepare_annotation(
|
| 955 |
+
image,
|
| 956 |
+
target,
|
| 957 |
+
format,
|
| 958 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 959 |
+
masks_path=masks_path,
|
| 960 |
+
input_data_format=input_data_format,
|
| 961 |
+
)
|
| 962 |
+
prepared_images.append(image)
|
| 963 |
+
prepared_annotations.append(target)
|
| 964 |
+
images = prepared_images
|
| 965 |
+
annotations = prepared_annotations
|
| 966 |
+
del prepared_images, prepared_annotations
|
| 967 |
+
|
| 968 |
+
# transformations
|
| 969 |
+
if do_resize:
|
| 970 |
+
if annotations is not None:
|
| 971 |
+
resized_images, resized_annotations = [], []
|
| 972 |
+
for image, target in zip(images, annotations):
|
| 973 |
+
orig_size = get_image_size(image, input_data_format)
|
| 974 |
+
resized_image = self.resize(
|
| 975 |
+
image, size=size, resample=resample, input_data_format=input_data_format
|
| 976 |
+
)
|
| 977 |
+
resized_annotation = self.resize_annotation(
|
| 978 |
+
target, orig_size, get_image_size(resized_image, input_data_format)
|
| 979 |
+
)
|
| 980 |
+
resized_images.append(resized_image)
|
| 981 |
+
resized_annotations.append(resized_annotation)
|
| 982 |
+
images = resized_images
|
| 983 |
+
annotations = resized_annotations
|
| 984 |
+
del resized_images, resized_annotations
|
| 985 |
+
else:
|
| 986 |
+
images = [
|
| 987 |
+
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
| 988 |
+
for image in images
|
| 989 |
+
]
|
| 990 |
+
|
| 991 |
+
if do_rescale:
|
| 992 |
+
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
| 993 |
+
|
| 994 |
+
if do_normalize:
|
| 995 |
+
images = [
|
| 996 |
+
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
| 997 |
+
]
|
| 998 |
+
|
| 999 |
+
if do_convert_annotations and annotations is not None:
|
| 1000 |
+
annotations = [
|
| 1001 |
+
self.normalize_annotation(annotation, get_image_size(image, input_data_format))
|
| 1002 |
+
for annotation, image in zip(annotations, images)
|
| 1003 |
+
]
|
| 1004 |
+
|
| 1005 |
+
if do_pad:
|
| 1006 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 1007 |
+
encoded_inputs = self.pad(
|
| 1008 |
+
images,
|
| 1009 |
+
annotations=annotations,
|
| 1010 |
+
return_pixel_mask=True,
|
| 1011 |
+
data_format=data_format,
|
| 1012 |
+
input_data_format=input_data_format,
|
| 1013 |
+
update_bboxes=do_convert_annotations,
|
| 1014 |
+
return_tensors=return_tensors,
|
| 1015 |
+
pad_size=pad_size,
|
| 1016 |
+
)
|
| 1017 |
+
else:
|
| 1018 |
+
images = [
|
| 1019 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 1020 |
+
for image in images
|
| 1021 |
+
]
|
| 1022 |
+
encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
| 1023 |
+
if annotations is not None:
|
| 1024 |
+
encoded_inputs["labels"] = [
|
| 1025 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 1026 |
+
]
|
| 1027 |
+
|
| 1028 |
+
return encoded_inputs
|
| 1029 |
+
|
| 1030 |
+
def post_process_object_detection(
|
| 1031 |
+
self,
|
| 1032 |
+
outputs,
|
| 1033 |
+
threshold: float = 0.5,
|
| 1034 |
+
target_sizes: Union[TensorType, list[tuple]] = None,
|
| 1035 |
+
use_focal_loss: bool = True,
|
| 1036 |
+
):
|
| 1037 |
+
"""
|
| 1038 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 1039 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 1040 |
+
|
| 1041 |
+
Args:
|
| 1042 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 1043 |
+
Raw outputs of the model.
|
| 1044 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1045 |
+
Score threshold to keep object detection predictions.
|
| 1046 |
+
target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
|
| 1047 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
|
| 1048 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 1049 |
+
use_focal_loss (`bool` defaults to `True`):
|
| 1050 |
+
Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
|
| 1051 |
+
to compute the scores of each detection, otherwise, a softmax function is used.
|
| 1052 |
+
|
| 1053 |
+
Returns:
|
| 1054 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 1055 |
+
in the batch as predicted by the model.
|
| 1056 |
+
"""
|
| 1057 |
+
requires_backends(self, ["torch"])
|
| 1058 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 1059 |
+
# convert from relative cxcywh to absolute xyxy
|
| 1060 |
+
boxes = center_to_corners_format(out_bbox)
|
| 1061 |
+
if target_sizes is not None:
|
| 1062 |
+
if len(out_logits) != len(target_sizes):
|
| 1063 |
+
raise ValueError(
|
| 1064 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1065 |
+
)
|
| 1066 |
+
if isinstance(target_sizes, list):
|
| 1067 |
+
img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
|
| 1068 |
+
else:
|
| 1069 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 1070 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 1071 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 1072 |
+
|
| 1073 |
+
num_top_queries = out_logits.shape[1]
|
| 1074 |
+
num_classes = out_logits.shape[2]
|
| 1075 |
+
|
| 1076 |
+
if use_focal_loss:
|
| 1077 |
+
scores = torch.nn.functional.sigmoid(out_logits)
|
| 1078 |
+
scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
|
| 1079 |
+
labels = index % num_classes
|
| 1080 |
+
index = index // num_classes
|
| 1081 |
+
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
| 1082 |
+
else:
|
| 1083 |
+
scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
|
| 1084 |
+
scores, labels = scores.max(dim=-1)
|
| 1085 |
+
if scores.shape[1] > num_top_queries:
|
| 1086 |
+
scores, index = torch.topk(scores, num_top_queries, dim=-1)
|
| 1087 |
+
labels = torch.gather(labels, dim=1, index=index)
|
| 1088 |
+
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
|
| 1089 |
+
|
| 1090 |
+
results = []
|
| 1091 |
+
for score, label, box in zip(scores, labels, boxes):
|
| 1092 |
+
results.append(
|
| 1093 |
+
{
|
| 1094 |
+
"scores": score[score > threshold],
|
| 1095 |
+
"labels": label[score > threshold],
|
| 1096 |
+
"boxes": box[score > threshold],
|
| 1097 |
+
}
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
return results
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
__all__ = ["RTDetrImageProcessor"]
|
phivenv/Lib/site-packages/transformers/models/rt_detr/image_processing_rt_detr_fast.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/rt_detr/modular_rt_detr.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_rt_detr.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
import pathlib
|
| 8 |
+
from typing import Any, Optional, Union
|
| 9 |
+
|
| 10 |
+
from ...image_processing_utils import BatchFeature
|
| 11 |
+
from ...image_processing_utils_fast import (
|
| 12 |
+
BaseImageProcessorFast,
|
| 13 |
+
DefaultFastImageProcessorKwargs,
|
| 14 |
+
SizeDict,
|
| 15 |
+
get_image_size_for_max_height_width,
|
| 16 |
+
get_max_height_width,
|
| 17 |
+
safe_squeeze,
|
| 18 |
+
)
|
| 19 |
+
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
| 20 |
+
from ...image_utils import (
|
| 21 |
+
IMAGENET_DEFAULT_MEAN,
|
| 22 |
+
IMAGENET_DEFAULT_STD,
|
| 23 |
+
AnnotationFormat,
|
| 24 |
+
AnnotationType,
|
| 25 |
+
ChannelDimension,
|
| 26 |
+
ImageInput,
|
| 27 |
+
PILImageResampling,
|
| 28 |
+
get_image_size,
|
| 29 |
+
validate_annotations,
|
| 30 |
+
)
|
| 31 |
+
from ...processing_utils import Unpack
|
| 32 |
+
from ...utils import (
|
| 33 |
+
TensorType,
|
| 34 |
+
auto_docstring,
|
| 35 |
+
is_torch_available,
|
| 36 |
+
is_torchvision_available,
|
| 37 |
+
is_torchvision_v2_available,
|
| 38 |
+
requires_backends,
|
| 39 |
+
)
|
| 40 |
+
from ...utils.import_utils import requires
|
| 41 |
+
from .image_processing_rt_detr import get_size_with_aspect_ratio
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_torch_available():
|
| 45 |
+
import torch
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if is_torchvision_v2_available():
|
| 49 |
+
from torchvision.transforms.v2 import functional as F
|
| 50 |
+
elif is_torchvision_available():
|
| 51 |
+
from torchvision.transforms import functional as F
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class RTDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
| 55 |
+
r"""
|
| 56 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 57 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 58 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the
|
| 60 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 61 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 62 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 64 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 65 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 66 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 67 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 68 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 69 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 70 |
+
height and width in the batch.
|
| 71 |
+
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether to return segmentation masks.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
format: Optional[Union[str, AnnotationFormat]]
|
| 76 |
+
do_convert_annotations: Optional[bool]
|
| 77 |
+
do_pad: Optional[bool]
|
| 78 |
+
pad_size: Optional[dict[str, int]]
|
| 79 |
+
return_segmentation_masks: Optional[bool]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def prepare_coco_detection_annotation(
|
| 86 |
+
image,
|
| 87 |
+
target,
|
| 88 |
+
return_segmentation_masks: bool = False,
|
| 89 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 90 |
+
):
|
| 91 |
+
"""
|
| 92 |
+
Convert the target in COCO format into the format expected by RT-DETR.
|
| 93 |
+
"""
|
| 94 |
+
image_height, image_width = image.size()[-2:]
|
| 95 |
+
|
| 96 |
+
image_id = target["image_id"]
|
| 97 |
+
image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
|
| 98 |
+
|
| 99 |
+
# Get all COCO annotations for the given image.
|
| 100 |
+
annotations = target["annotations"]
|
| 101 |
+
classes = []
|
| 102 |
+
area = []
|
| 103 |
+
boxes = []
|
| 104 |
+
keypoints = []
|
| 105 |
+
for obj in annotations:
|
| 106 |
+
if "iscrowd" not in obj or obj["iscrowd"] == 0:
|
| 107 |
+
classes.append(obj["category_id"])
|
| 108 |
+
area.append(obj["area"])
|
| 109 |
+
boxes.append(obj["bbox"])
|
| 110 |
+
if "keypoints" in obj:
|
| 111 |
+
keypoints.append(obj["keypoints"])
|
| 112 |
+
|
| 113 |
+
classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
|
| 114 |
+
area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
|
| 115 |
+
iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
|
| 116 |
+
# guard against no boxes via resizing
|
| 117 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
|
| 118 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 119 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 120 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 121 |
+
|
| 122 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 123 |
+
|
| 124 |
+
new_target = {
|
| 125 |
+
"image_id": image_id,
|
| 126 |
+
"class_labels": classes[keep],
|
| 127 |
+
"boxes": boxes[keep],
|
| 128 |
+
"area": area[keep],
|
| 129 |
+
"iscrowd": iscrowd[keep],
|
| 130 |
+
"orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
if keypoints:
|
| 134 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
|
| 135 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 136 |
+
keypoints = keypoints[keep]
|
| 137 |
+
num_keypoints = keypoints.shape[0]
|
| 138 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 139 |
+
new_target["keypoints"] = keypoints
|
| 140 |
+
|
| 141 |
+
return new_target
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@auto_docstring
|
| 145 |
+
@requires(backends=("torchvision", "torch"))
|
| 146 |
+
class RTDetrImageProcessorFast(BaseImageProcessorFast):
|
| 147 |
+
resample = PILImageResampling.BILINEAR
|
| 148 |
+
image_mean = IMAGENET_DEFAULT_MEAN
|
| 149 |
+
image_std = IMAGENET_DEFAULT_STD
|
| 150 |
+
format = AnnotationFormat.COCO_DETECTION
|
| 151 |
+
do_resize = True
|
| 152 |
+
do_rescale = True
|
| 153 |
+
do_normalize = False
|
| 154 |
+
do_pad = False
|
| 155 |
+
size = {"height": 640, "width": 640}
|
| 156 |
+
default_to_square = False
|
| 157 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 158 |
+
valid_kwargs = RTDetrFastImageProcessorKwargs
|
| 159 |
+
do_convert_annotations = True
|
| 160 |
+
|
| 161 |
+
def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None:
|
| 162 |
+
# Backwards compatibility
|
| 163 |
+
do_convert_annotations = kwargs.get("do_convert_annotations")
|
| 164 |
+
do_normalize = kwargs.get("do_normalize")
|
| 165 |
+
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
|
| 166 |
+
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
|
| 167 |
+
|
| 168 |
+
super().__init__(**kwargs)
|
| 169 |
+
|
| 170 |
+
def prepare_annotation(
|
| 171 |
+
self,
|
| 172 |
+
image: torch.Tensor,
|
| 173 |
+
target: dict,
|
| 174 |
+
format: Optional[AnnotationFormat] = None,
|
| 175 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 176 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 177 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 178 |
+
) -> dict:
|
| 179 |
+
"""
|
| 180 |
+
Prepare an annotation for feeding into RT_DETR model.
|
| 181 |
+
"""
|
| 182 |
+
format = format if format is not None else self.format
|
| 183 |
+
|
| 184 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 185 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 186 |
+
target = prepare_coco_detection_annotation(
|
| 187 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 191 |
+
return target
|
| 192 |
+
|
| 193 |
+
def resize(
|
| 194 |
+
self,
|
| 195 |
+
image: torch.Tensor,
|
| 196 |
+
size: SizeDict,
|
| 197 |
+
interpolation: "F.InterpolationMode" = None,
|
| 198 |
+
**kwargs,
|
| 199 |
+
) -> torch.Tensor:
|
| 200 |
+
"""
|
| 201 |
+
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
| 202 |
+
int, smaller edge of the image will be matched to this number.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
image (`torch.Tensor`):
|
| 206 |
+
Image to resize.
|
| 207 |
+
size (`SizeDict`):
|
| 208 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 209 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 210 |
+
Do NOT keep the aspect ratio.
|
| 211 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 212 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 213 |
+
less or equal to `longest_edge`.
|
| 214 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 215 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 216 |
+
`max_width`.
|
| 217 |
+
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
| 218 |
+
Resampling filter to use if resizing the image.
|
| 219 |
+
"""
|
| 220 |
+
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
| 221 |
+
if size.shortest_edge and size.longest_edge:
|
| 222 |
+
# Resize the image so that the shortest edge or the longest edge is of the given size
|
| 223 |
+
# while maintaining the aspect ratio of the original image.
|
| 224 |
+
new_size = get_size_with_aspect_ratio(
|
| 225 |
+
image.size()[-2:],
|
| 226 |
+
size["shortest_edge"],
|
| 227 |
+
size["longest_edge"],
|
| 228 |
+
)
|
| 229 |
+
elif size.max_height and size.max_width:
|
| 230 |
+
new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
|
| 231 |
+
elif size.height and size.width:
|
| 232 |
+
new_size = (size["height"], size["width"])
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(
|
| 235 |
+
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
| 236 |
+
f" {size.keys()}."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
image = F.resize(
|
| 240 |
+
image,
|
| 241 |
+
size=new_size,
|
| 242 |
+
interpolation=interpolation,
|
| 243 |
+
**kwargs,
|
| 244 |
+
)
|
| 245 |
+
return image
|
| 246 |
+
|
| 247 |
+
def resize_annotation(
|
| 248 |
+
self,
|
| 249 |
+
annotation: dict[str, Any],
|
| 250 |
+
orig_size: tuple[int, int],
|
| 251 |
+
target_size: tuple[int, int],
|
| 252 |
+
threshold: float = 0.5,
|
| 253 |
+
interpolation: "F.InterpolationMode" = None,
|
| 254 |
+
):
|
| 255 |
+
"""
|
| 256 |
+
Resizes an annotation to a target size.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
annotation (`dict[str, Any]`):
|
| 260 |
+
The annotation dictionary.
|
| 261 |
+
orig_size (`tuple[int, int]`):
|
| 262 |
+
The original size of the input image.
|
| 263 |
+
target_size (`tuple[int, int]`):
|
| 264 |
+
The target size of the image, as returned by the preprocessing `resize` step.
|
| 265 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 266 |
+
The threshold used to binarize the segmentation masks.
|
| 267 |
+
resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`):
|
| 268 |
+
The resampling filter to use when resizing the masks.
|
| 269 |
+
"""
|
| 270 |
+
interpolation = (
|
| 271 |
+
interpolation
|
| 272 |
+
if interpolation is not None
|
| 273 |
+
else F.InterpolationMode.NEAREST_EXACT
|
| 274 |
+
if is_torchvision_v2_available()
|
| 275 |
+
else F.InterpolationMode.NEAREST
|
| 276 |
+
)
|
| 277 |
+
ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
|
| 278 |
+
|
| 279 |
+
new_annotation = {}
|
| 280 |
+
new_annotation["size"] = target_size
|
| 281 |
+
|
| 282 |
+
for key, value in annotation.items():
|
| 283 |
+
if key == "boxes":
|
| 284 |
+
boxes = value
|
| 285 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 286 |
+
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
|
| 287 |
+
)
|
| 288 |
+
new_annotation["boxes"] = scaled_boxes
|
| 289 |
+
elif key == "area":
|
| 290 |
+
area = value
|
| 291 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 292 |
+
new_annotation["area"] = scaled_area
|
| 293 |
+
elif key == "masks":
|
| 294 |
+
masks = value[:, None]
|
| 295 |
+
masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
|
| 296 |
+
masks = torch.stack(masks).to(torch.float32)
|
| 297 |
+
masks = masks[:, 0] > threshold
|
| 298 |
+
new_annotation["masks"] = masks
|
| 299 |
+
elif key == "size":
|
| 300 |
+
new_annotation["size"] = target_size
|
| 301 |
+
else:
|
| 302 |
+
new_annotation[key] = value
|
| 303 |
+
|
| 304 |
+
return new_annotation
|
| 305 |
+
|
| 306 |
+
def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
|
| 307 |
+
image_height, image_width = image_size
|
| 308 |
+
norm_annotation = {}
|
| 309 |
+
for key, value in annotation.items():
|
| 310 |
+
if key == "boxes":
|
| 311 |
+
boxes = value
|
| 312 |
+
boxes = corners_to_center_format(boxes)
|
| 313 |
+
boxes /= torch.as_tensor(
|
| 314 |
+
[image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
|
| 315 |
+
)
|
| 316 |
+
norm_annotation[key] = boxes
|
| 317 |
+
else:
|
| 318 |
+
norm_annotation[key] = value
|
| 319 |
+
return norm_annotation
|
| 320 |
+
|
| 321 |
+
def _update_annotation_for_padded_image(
|
| 322 |
+
self,
|
| 323 |
+
annotation: dict,
|
| 324 |
+
input_image_size: tuple[int, int],
|
| 325 |
+
output_image_size: tuple[int, int],
|
| 326 |
+
padding,
|
| 327 |
+
update_bboxes,
|
| 328 |
+
) -> dict:
|
| 329 |
+
"""
|
| 330 |
+
Update the annotation for a padded image.
|
| 331 |
+
"""
|
| 332 |
+
new_annotation = {}
|
| 333 |
+
new_annotation["size"] = output_image_size
|
| 334 |
+
ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
|
| 335 |
+
|
| 336 |
+
for key, value in annotation.items():
|
| 337 |
+
if key == "masks":
|
| 338 |
+
masks = value
|
| 339 |
+
masks = F.pad(
|
| 340 |
+
masks,
|
| 341 |
+
padding,
|
| 342 |
+
fill=0,
|
| 343 |
+
)
|
| 344 |
+
masks = safe_squeeze(masks, 1)
|
| 345 |
+
new_annotation["masks"] = masks
|
| 346 |
+
elif key == "boxes" and update_bboxes:
|
| 347 |
+
boxes = value
|
| 348 |
+
boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
|
| 349 |
+
new_annotation["boxes"] = boxes
|
| 350 |
+
elif key == "size":
|
| 351 |
+
new_annotation["size"] = output_image_size
|
| 352 |
+
else:
|
| 353 |
+
new_annotation[key] = value
|
| 354 |
+
return new_annotation
|
| 355 |
+
|
| 356 |
+
def pad(
|
| 357 |
+
self,
|
| 358 |
+
image: torch.Tensor,
|
| 359 |
+
padded_size: tuple[int, int],
|
| 360 |
+
annotation: Optional[dict[str, Any]] = None,
|
| 361 |
+
update_bboxes: bool = True,
|
| 362 |
+
fill: int = 0,
|
| 363 |
+
):
|
| 364 |
+
original_size = image.size()[-2:]
|
| 365 |
+
padding_bottom = padded_size[0] - original_size[0]
|
| 366 |
+
padding_right = padded_size[1] - original_size[1]
|
| 367 |
+
if padding_bottom < 0 or padding_right < 0:
|
| 368 |
+
raise ValueError(
|
| 369 |
+
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
| 370 |
+
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
| 371 |
+
)
|
| 372 |
+
if original_size != padded_size:
|
| 373 |
+
padding = [0, 0, padding_right, padding_bottom]
|
| 374 |
+
image = F.pad(image, padding, fill=fill)
|
| 375 |
+
if annotation is not None:
|
| 376 |
+
annotation = self._update_annotation_for_padded_image(
|
| 377 |
+
annotation, original_size, padded_size, padding, update_bboxes
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
| 381 |
+
pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
|
| 382 |
+
pixel_mask[: original_size[0], : original_size[1]] = 1
|
| 383 |
+
|
| 384 |
+
return image, pixel_mask, annotation
|
| 385 |
+
|
| 386 |
+
@auto_docstring
|
| 387 |
+
def preprocess(
|
| 388 |
+
self,
|
| 389 |
+
images: ImageInput,
|
| 390 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
|
| 391 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 392 |
+
**kwargs: Unpack[RTDetrFastImageProcessorKwargs],
|
| 393 |
+
) -> BatchFeature:
|
| 394 |
+
r"""
|
| 395 |
+
annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
|
| 396 |
+
List of annotations associated with the image or batch of images. If annotation is for object
|
| 397 |
+
detection, the annotations should be a dictionary with the following keys:
|
| 398 |
+
- "image_id" (`int`): The image id.
|
| 399 |
+
- "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
|
| 400 |
+
dictionary. An image can have no annotations, in which case the list should be empty.
|
| 401 |
+
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
| 402 |
+
- "image_id" (`int`): The image id.
|
| 403 |
+
- "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
| 404 |
+
An image can have no segments, in which case the list should be empty.
|
| 405 |
+
- "file_name" (`str`): The file name of the image.
|
| 406 |
+
masks_path (`str` or `pathlib.Path`, *optional*):
|
| 407 |
+
Path to the directory containing the segmentation masks.
|
| 408 |
+
"""
|
| 409 |
+
return super().preprocess(images, annotations, masks_path, **kwargs)
|
| 410 |
+
|
| 411 |
+
def _preprocess(
|
| 412 |
+
self,
|
| 413 |
+
images: list["torch.Tensor"],
|
| 414 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
|
| 415 |
+
masks_path: Optional[Union[str, pathlib.Path]],
|
| 416 |
+
return_segmentation_masks: bool,
|
| 417 |
+
do_resize: bool,
|
| 418 |
+
size: SizeDict,
|
| 419 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 420 |
+
do_rescale: bool,
|
| 421 |
+
rescale_factor: float,
|
| 422 |
+
do_normalize: bool,
|
| 423 |
+
do_convert_annotations: bool,
|
| 424 |
+
image_mean: Optional[Union[float, list[float]]],
|
| 425 |
+
image_std: Optional[Union[float, list[float]]],
|
| 426 |
+
do_pad: bool,
|
| 427 |
+
pad_size: Optional[dict[str, int]],
|
| 428 |
+
format: Optional[Union[str, AnnotationFormat]],
|
| 429 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 430 |
+
**kwargs,
|
| 431 |
+
) -> BatchFeature:
|
| 432 |
+
"""
|
| 433 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 437 |
+
annotations = [annotations]
|
| 438 |
+
|
| 439 |
+
if annotations is not None and len(images) != len(annotations):
|
| 440 |
+
raise ValueError(
|
| 441 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
format = AnnotationFormat(format)
|
| 445 |
+
if annotations is not None:
|
| 446 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 447 |
+
|
| 448 |
+
data = {}
|
| 449 |
+
processed_images = []
|
| 450 |
+
processed_annotations = []
|
| 451 |
+
pixel_masks = [] # Initialize pixel_masks here
|
| 452 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 453 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 454 |
+
if annotations is not None:
|
| 455 |
+
annotation = self.prepare_annotation(
|
| 456 |
+
image,
|
| 457 |
+
annotation,
|
| 458 |
+
format,
|
| 459 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 460 |
+
masks_path=masks_path,
|
| 461 |
+
input_data_format=ChannelDimension.FIRST,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
if do_resize:
|
| 465 |
+
resized_image = self.resize(image, size=size, interpolation=interpolation)
|
| 466 |
+
if annotations is not None:
|
| 467 |
+
annotation = self.resize_annotation(
|
| 468 |
+
annotation,
|
| 469 |
+
orig_size=image.size()[-2:],
|
| 470 |
+
target_size=resized_image.size()[-2:],
|
| 471 |
+
)
|
| 472 |
+
image = resized_image
|
| 473 |
+
# Fused rescale and normalize
|
| 474 |
+
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
| 475 |
+
if do_convert_annotations and annotations is not None:
|
| 476 |
+
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
| 477 |
+
|
| 478 |
+
processed_images.append(image)
|
| 479 |
+
processed_annotations.append(annotation)
|
| 480 |
+
images = processed_images
|
| 481 |
+
annotations = processed_annotations if annotations is not None else None
|
| 482 |
+
|
| 483 |
+
if do_pad:
|
| 484 |
+
# depends on all resized image shapes so we need another loop
|
| 485 |
+
if pad_size is not None:
|
| 486 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 487 |
+
else:
|
| 488 |
+
padded_size = get_max_height_width(images)
|
| 489 |
+
|
| 490 |
+
padded_images = []
|
| 491 |
+
padded_annotations = []
|
| 492 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 493 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 494 |
+
if padded_size == image.size()[-2:]:
|
| 495 |
+
padded_images.append(image)
|
| 496 |
+
pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
|
| 497 |
+
padded_annotations.append(annotation)
|
| 498 |
+
continue
|
| 499 |
+
image, pixel_mask, annotation = self.pad(
|
| 500 |
+
image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
|
| 501 |
+
)
|
| 502 |
+
padded_images.append(image)
|
| 503 |
+
padded_annotations.append(annotation)
|
| 504 |
+
pixel_masks.append(pixel_mask)
|
| 505 |
+
images = padded_images
|
| 506 |
+
annotations = padded_annotations if annotations is not None else None
|
| 507 |
+
data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
|
| 508 |
+
|
| 509 |
+
data.update({"pixel_values": torch.stack(images, dim=0)})
|
| 510 |
+
encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
|
| 511 |
+
if annotations is not None:
|
| 512 |
+
encoded_inputs["labels"] = [
|
| 513 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 514 |
+
]
|
| 515 |
+
return encoded_inputs
|
| 516 |
+
|
| 517 |
+
def post_process_object_detection(
|
| 518 |
+
self,
|
| 519 |
+
outputs,
|
| 520 |
+
threshold: float = 0.5,
|
| 521 |
+
target_sizes: Union[TensorType, list[tuple]] = None,
|
| 522 |
+
use_focal_loss: bool = True,
|
| 523 |
+
):
|
| 524 |
+
"""
|
| 525 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 526 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 530 |
+
Raw outputs of the model.
|
| 531 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 532 |
+
Score threshold to keep object detection predictions.
|
| 533 |
+
target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
|
| 534 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
|
| 535 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 536 |
+
use_focal_loss (`bool` defaults to `True`):
|
| 537 |
+
Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
|
| 538 |
+
to compute the scores of each detection, otherwise, a softmax function is used.
|
| 539 |
+
|
| 540 |
+
Returns:
|
| 541 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 542 |
+
in the batch as predicted by the model.
|
| 543 |
+
"""
|
| 544 |
+
requires_backends(self, ["torch"])
|
| 545 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 546 |
+
# convert from relative cxcywh to absolute xyxy
|
| 547 |
+
boxes = center_to_corners_format(out_bbox)
|
| 548 |
+
if target_sizes is not None:
|
| 549 |
+
if len(out_logits) != len(target_sizes):
|
| 550 |
+
raise ValueError(
|
| 551 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 552 |
+
)
|
| 553 |
+
if isinstance(target_sizes, list):
|
| 554 |
+
img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
|
| 555 |
+
else:
|
| 556 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 557 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 558 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 559 |
+
|
| 560 |
+
num_top_queries = out_logits.shape[1]
|
| 561 |
+
num_classes = out_logits.shape[2]
|
| 562 |
+
|
| 563 |
+
if use_focal_loss:
|
| 564 |
+
scores = torch.nn.functional.sigmoid(out_logits)
|
| 565 |
+
scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
|
| 566 |
+
labels = index % num_classes
|
| 567 |
+
index = index // num_classes
|
| 568 |
+
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
| 569 |
+
else:
|
| 570 |
+
scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
|
| 571 |
+
scores, labels = scores.max(dim=-1)
|
| 572 |
+
if scores.shape[1] > num_top_queries:
|
| 573 |
+
scores, index = torch.topk(scores, num_top_queries, dim=-1)
|
| 574 |
+
labels = torch.gather(labels, dim=1, index=index)
|
| 575 |
+
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
|
| 576 |
+
|
| 577 |
+
results = []
|
| 578 |
+
for score, label, box in zip(scores, labels, boxes):
|
| 579 |
+
results.append(
|
| 580 |
+
{
|
| 581 |
+
"scores": score[score > threshold],
|
| 582 |
+
"labels": label[score > threshold],
|
| 583 |
+
"boxes": box[score > threshold],
|
| 584 |
+
}
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
return results
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
__all__ = ["RTDetrImageProcessorFast"]
|
phivenv/Lib/site-packages/transformers/models/rt_detr/modeling_rt_detr.py
ADDED
|
@@ -0,0 +1,2013 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Baidu Inc and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch RT-DETR model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import warnings
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from functools import partial
|
| 21 |
+
from typing import Optional, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from torch import Tensor, nn
|
| 26 |
+
|
| 27 |
+
from ...activations import ACT2CLS, ACT2FN
|
| 28 |
+
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
| 29 |
+
from ...integrations import use_kernel_forward_from_hub
|
| 30 |
+
from ...modeling_outputs import BaseModelOutput
|
| 31 |
+
from ...modeling_utils import PreTrainedModel
|
| 32 |
+
from ...pytorch_utils import compile_compatible_method_lru_cache
|
| 33 |
+
from ...utils import (
|
| 34 |
+
ModelOutput,
|
| 35 |
+
auto_docstring,
|
| 36 |
+
logging,
|
| 37 |
+
torch_int,
|
| 38 |
+
)
|
| 39 |
+
from ...utils.backbone_utils import load_backbone
|
| 40 |
+
from .configuration_rt_detr import RTDetrConfig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# TODO: Replace all occurrences of the checkpoint with the final one
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
| 50 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
|
| 51 |
+
class MultiScaleDeformableAttention(nn.Module):
|
| 52 |
+
def forward(
|
| 53 |
+
self,
|
| 54 |
+
value: Tensor,
|
| 55 |
+
value_spatial_shapes: Tensor,
|
| 56 |
+
value_spatial_shapes_list: list[tuple],
|
| 57 |
+
level_start_index: Tensor,
|
| 58 |
+
sampling_locations: Tensor,
|
| 59 |
+
attention_weights: Tensor,
|
| 60 |
+
im2col_step: int,
|
| 61 |
+
):
|
| 62 |
+
batch_size, _, num_heads, hidden_dim = value.shape
|
| 63 |
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
| 64 |
+
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
|
| 65 |
+
sampling_grids = 2 * sampling_locations - 1
|
| 66 |
+
sampling_value_list = []
|
| 67 |
+
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
|
| 68 |
+
# batch_size, height*width, num_heads, hidden_dim
|
| 69 |
+
# -> batch_size, height*width, num_heads*hidden_dim
|
| 70 |
+
# -> batch_size, num_heads*hidden_dim, height*width
|
| 71 |
+
# -> batch_size*num_heads, hidden_dim, height, width
|
| 72 |
+
value_l_ = (
|
| 73 |
+
value_list[level_id]
|
| 74 |
+
.flatten(2)
|
| 75 |
+
.transpose(1, 2)
|
| 76 |
+
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
| 77 |
+
)
|
| 78 |
+
# batch_size, num_queries, num_heads, num_points, 2
|
| 79 |
+
# -> batch_size, num_heads, num_queries, num_points, 2
|
| 80 |
+
# -> batch_size*num_heads, num_queries, num_points, 2
|
| 81 |
+
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
| 82 |
+
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
| 83 |
+
sampling_value_l_ = nn.functional.grid_sample(
|
| 84 |
+
value_l_,
|
| 85 |
+
sampling_grid_l_,
|
| 86 |
+
mode="bilinear",
|
| 87 |
+
padding_mode="zeros",
|
| 88 |
+
align_corners=False,
|
| 89 |
+
)
|
| 90 |
+
sampling_value_list.append(sampling_value_l_)
|
| 91 |
+
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
| 92 |
+
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
| 93 |
+
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
| 94 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(
|
| 95 |
+
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
| 96 |
+
)
|
| 97 |
+
output = (
|
| 98 |
+
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
| 99 |
+
.sum(-1)
|
| 100 |
+
.view(batch_size, num_heads * hidden_dim, num_queries)
|
| 101 |
+
)
|
| 102 |
+
return output.transpose(1, 2).contiguous()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
@auto_docstring(
|
| 107 |
+
custom_intro="""
|
| 108 |
+
Base class for outputs of the RTDetrDecoder. This class adds two attributes to
|
| 109 |
+
BaseModelOutputWithCrossAttentions, namely:
|
| 110 |
+
- a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
|
| 111 |
+
- a stacked tensor of intermediate reference points.
|
| 112 |
+
"""
|
| 113 |
+
)
|
| 114 |
+
class RTDetrDecoderOutput(ModelOutput):
|
| 115 |
+
r"""
|
| 116 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
| 117 |
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
| 118 |
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
|
| 119 |
+
Stacked intermediate logits (logits of each layer of the decoder).
|
| 120 |
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
| 121 |
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
| 122 |
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 123 |
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
| 124 |
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 125 |
+
Stacked initial reference points (initial reference points of each layer of the decoder).
|
| 126 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
| 127 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 128 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 129 |
+
used to compute the weighted average in the cross-attention heads.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 133 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 134 |
+
intermediate_logits: Optional[torch.FloatTensor] = None
|
| 135 |
+
intermediate_reference_points: Optional[torch.FloatTensor] = None
|
| 136 |
+
intermediate_predicted_corners: Optional[torch.FloatTensor] = None
|
| 137 |
+
initial_reference_points: Optional[torch.FloatTensor] = None
|
| 138 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 139 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 140 |
+
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@dataclass
|
| 144 |
+
@auto_docstring(
|
| 145 |
+
custom_intro="""
|
| 146 |
+
Base class for outputs of the RT-DETR encoder-decoder model.
|
| 147 |
+
"""
|
| 148 |
+
)
|
| 149 |
+
class RTDetrModelOutput(ModelOutput):
|
| 150 |
+
r"""
|
| 151 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
| 152 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 153 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
| 154 |
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
| 155 |
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
|
| 156 |
+
Stacked intermediate logits (logits of each layer of the decoder).
|
| 157 |
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 158 |
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
| 159 |
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 160 |
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
| 161 |
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 162 |
+
Initial reference points used for the first decoder layer.
|
| 163 |
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 164 |
+
Initial reference points sent through the Transformer decoder.
|
| 165 |
+
enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
|
| 166 |
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
| 167 |
+
picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e.
|
| 168 |
+
foreground and background).
|
| 169 |
+
enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`):
|
| 170 |
+
Logits of predicted bounding boxes coordinates in the encoder stage.
|
| 171 |
+
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 172 |
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
| 173 |
+
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
|
| 174 |
+
foreground and background).
|
| 175 |
+
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 176 |
+
Logits of predicted bounding boxes coordinates in the first stage.
|
| 177 |
+
denoising_meta_values (`dict`):
|
| 178 |
+
Extra dictionary for the denoising related values.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 182 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 183 |
+
intermediate_logits: Optional[torch.FloatTensor] = None
|
| 184 |
+
intermediate_reference_points: Optional[torch.FloatTensor] = None
|
| 185 |
+
intermediate_predicted_corners: Optional[torch.FloatTensor] = None
|
| 186 |
+
initial_reference_points: Optional[torch.FloatTensor] = None
|
| 187 |
+
decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 188 |
+
decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 189 |
+
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 190 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 191 |
+
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 192 |
+
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 193 |
+
init_reference_points: Optional[torch.FloatTensor] = None
|
| 194 |
+
enc_topk_logits: Optional[torch.FloatTensor] = None
|
| 195 |
+
enc_topk_bboxes: Optional[torch.FloatTensor] = None
|
| 196 |
+
enc_outputs_class: Optional[torch.FloatTensor] = None
|
| 197 |
+
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
| 198 |
+
denoising_meta_values: Optional[dict] = None
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dataclass
|
| 202 |
+
@auto_docstring(
|
| 203 |
+
custom_intro="""
|
| 204 |
+
Output type of [`RTDetrForObjectDetection`].
|
| 205 |
+
"""
|
| 206 |
+
)
|
| 207 |
+
class RTDetrObjectDetectionOutput(ModelOutput):
|
| 208 |
+
r"""
|
| 209 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
| 210 |
+
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
| 211 |
+
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
| 212 |
+
scale-invariant IoU loss.
|
| 213 |
+
loss_dict (`Dict`, *optional*):
|
| 214 |
+
A dictionary containing the individual losses. Useful for logging.
|
| 215 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
| 216 |
+
Classification logits (including no-object) for all queries.
|
| 217 |
+
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 218 |
+
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
| 219 |
+
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
| 220 |
+
possible padding). You can use [`~RTDetrImageProcessor.post_process_object_detection`] to retrieve the
|
| 221 |
+
unnormalized (absolute) bounding boxes.
|
| 222 |
+
auxiliary_outputs (`list[Dict]`, *optional*):
|
| 223 |
+
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
| 224 |
+
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
| 225 |
+
`pred_boxes`) for each decoder layer.
|
| 226 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
| 227 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 228 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
| 229 |
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
| 230 |
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`):
|
| 231 |
+
Stacked intermediate logits (logits of each layer of the decoder).
|
| 232 |
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 233 |
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
| 234 |
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 235 |
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
| 236 |
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 237 |
+
Stacked initial reference points (initial reference points of each layer of the decoder).
|
| 238 |
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 239 |
+
Initial reference points sent through the Transformer decoder.
|
| 240 |
+
enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 241 |
+
Logits of predicted bounding boxes coordinates in the encoder.
|
| 242 |
+
enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 243 |
+
Logits of predicted bounding boxes coordinates in the encoder.
|
| 244 |
+
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 245 |
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
| 246 |
+
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
|
| 247 |
+
foreground and background).
|
| 248 |
+
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 249 |
+
Logits of predicted bounding boxes coordinates in the first stage.
|
| 250 |
+
denoising_meta_values (`dict`):
|
| 251 |
+
Extra dictionary for the denoising related values
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
loss: Optional[torch.FloatTensor] = None
|
| 255 |
+
loss_dict: Optional[dict] = None
|
| 256 |
+
logits: Optional[torch.FloatTensor] = None
|
| 257 |
+
pred_boxes: Optional[torch.FloatTensor] = None
|
| 258 |
+
auxiliary_outputs: Optional[list[dict]] = None
|
| 259 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 260 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 261 |
+
intermediate_logits: Optional[torch.FloatTensor] = None
|
| 262 |
+
intermediate_reference_points: Optional[torch.FloatTensor] = None
|
| 263 |
+
intermediate_predicted_corners: Optional[torch.FloatTensor] = None
|
| 264 |
+
initial_reference_points: Optional[torch.FloatTensor] = None
|
| 265 |
+
decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 266 |
+
decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 267 |
+
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 268 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 269 |
+
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 270 |
+
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 271 |
+
init_reference_points: Optional[tuple[torch.FloatTensor]] = None
|
| 272 |
+
enc_topk_logits: Optional[torch.FloatTensor] = None
|
| 273 |
+
enc_topk_bboxes: Optional[torch.FloatTensor] = None
|
| 274 |
+
enc_outputs_class: Optional[torch.FloatTensor] = None
|
| 275 |
+
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
| 276 |
+
denoising_meta_values: Optional[dict] = None
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _get_clones(partial_module, N):
|
| 280 |
+
return nn.ModuleList([partial_module() for i in range(N)])
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# Copied from transformers.models.conditional_detr.modeling_conditional_detr.inverse_sigmoid
|
| 284 |
+
def inverse_sigmoid(x, eps=1e-5):
|
| 285 |
+
x = x.clamp(min=0, max=1)
|
| 286 |
+
x1 = x.clamp(min=eps)
|
| 287 |
+
x2 = (1 - x).clamp(min=eps)
|
| 288 |
+
return torch.log(x1 / x2)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->RTDetr
|
| 292 |
+
class RTDetrFrozenBatchNorm2d(nn.Module):
|
| 293 |
+
"""
|
| 294 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 295 |
+
|
| 296 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
| 297 |
+
torchvision.models.resnet[18,34,50,101] produce nans.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def __init__(self, n):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.register_buffer("weight", torch.ones(n))
|
| 303 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 304 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 305 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 306 |
+
|
| 307 |
+
def _load_from_state_dict(
|
| 308 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 309 |
+
):
|
| 310 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 311 |
+
if num_batches_tracked_key in state_dict:
|
| 312 |
+
del state_dict[num_batches_tracked_key]
|
| 313 |
+
|
| 314 |
+
super()._load_from_state_dict(
|
| 315 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def forward(self, x):
|
| 319 |
+
# move reshapes to the beginning
|
| 320 |
+
# to make it user-friendly
|
| 321 |
+
weight = self.weight.reshape(1, -1, 1, 1)
|
| 322 |
+
bias = self.bias.reshape(1, -1, 1, 1)
|
| 323 |
+
running_var = self.running_var.reshape(1, -1, 1, 1)
|
| 324 |
+
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
| 325 |
+
epsilon = 1e-5
|
| 326 |
+
scale = weight * (running_var + epsilon).rsqrt()
|
| 327 |
+
bias = bias - running_mean * scale
|
| 328 |
+
return x * scale + bias
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->RTDetr
|
| 332 |
+
def replace_batch_norm(model):
|
| 333 |
+
r"""
|
| 334 |
+
Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
model (torch.nn.Module):
|
| 338 |
+
input model
|
| 339 |
+
"""
|
| 340 |
+
for name, module in model.named_children():
|
| 341 |
+
if isinstance(module, nn.BatchNorm2d):
|
| 342 |
+
new_module = RTDetrFrozenBatchNorm2d(module.num_features)
|
| 343 |
+
|
| 344 |
+
if module.weight.device != torch.device("meta"):
|
| 345 |
+
new_module.weight.data.copy_(module.weight)
|
| 346 |
+
new_module.bias.data.copy_(module.bias)
|
| 347 |
+
new_module.running_mean.data.copy_(module.running_mean)
|
| 348 |
+
new_module.running_var.data.copy_(module.running_var)
|
| 349 |
+
|
| 350 |
+
model._modules[name] = new_module
|
| 351 |
+
|
| 352 |
+
if len(list(module.children())) > 0:
|
| 353 |
+
replace_batch_norm(module)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def get_contrastive_denoising_training_group(
|
| 357 |
+
targets,
|
| 358 |
+
num_classes,
|
| 359 |
+
num_queries,
|
| 360 |
+
class_embed,
|
| 361 |
+
num_denoising_queries=100,
|
| 362 |
+
label_noise_ratio=0.5,
|
| 363 |
+
box_noise_scale=1.0,
|
| 364 |
+
):
|
| 365 |
+
"""
|
| 366 |
+
Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
targets (`list[dict]`):
|
| 370 |
+
The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
|
| 371 |
+
num_classes (`int`):
|
| 372 |
+
Total number of classes in the dataset.
|
| 373 |
+
num_queries (`int`):
|
| 374 |
+
Number of query slots in the transformer.
|
| 375 |
+
class_embed (`callable`):
|
| 376 |
+
A function or a model layer to embed class labels.
|
| 377 |
+
num_denoising_queries (`int`, *optional*, defaults to 100):
|
| 378 |
+
Number of denoising queries.
|
| 379 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 380 |
+
Ratio of noise applied to labels.
|
| 381 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 382 |
+
Scale of noise applied to bounding boxes.
|
| 383 |
+
Returns:
|
| 384 |
+
`tuple` comprising various elements:
|
| 385 |
+
- **input_query_class** (`torch.FloatTensor`) --
|
| 386 |
+
Class queries with applied label noise.
|
| 387 |
+
- **input_query_bbox** (`torch.FloatTensor`) --
|
| 388 |
+
Bounding box queries with applied box noise.
|
| 389 |
+
- **attn_mask** (`torch.FloatTensor`) --
|
| 390 |
+
Attention mask for separating denoising and reconstruction queries.
|
| 391 |
+
- **denoising_meta_values** (`dict`) --
|
| 392 |
+
Metadata including denoising positive indices, number of groups, and split sizes.
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
if num_denoising_queries <= 0:
|
| 396 |
+
return None, None, None, None
|
| 397 |
+
|
| 398 |
+
num_ground_truths = [len(t["class_labels"]) for t in targets]
|
| 399 |
+
device = targets[0]["class_labels"].device
|
| 400 |
+
|
| 401 |
+
max_gt_num = max(num_ground_truths)
|
| 402 |
+
if max_gt_num == 0:
|
| 403 |
+
return None, None, None, None
|
| 404 |
+
|
| 405 |
+
num_groups_denoising_queries = num_denoising_queries // max_gt_num
|
| 406 |
+
num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
|
| 407 |
+
# pad gt to max_num of a batch
|
| 408 |
+
batch_size = len(num_ground_truths)
|
| 409 |
+
|
| 410 |
+
input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
|
| 411 |
+
input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
|
| 412 |
+
pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
|
| 413 |
+
|
| 414 |
+
for i in range(batch_size):
|
| 415 |
+
num_gt = num_ground_truths[i]
|
| 416 |
+
if num_gt > 0:
|
| 417 |
+
input_query_class[i, :num_gt] = targets[i]["class_labels"]
|
| 418 |
+
input_query_bbox[i, :num_gt] = targets[i]["boxes"]
|
| 419 |
+
pad_gt_mask[i, :num_gt] = 1
|
| 420 |
+
# each group has positive and negative queries.
|
| 421 |
+
input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
|
| 422 |
+
input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
|
| 423 |
+
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
|
| 424 |
+
# positive and negative mask
|
| 425 |
+
negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
|
| 426 |
+
negative_gt_mask[:, max_gt_num:] = 1
|
| 427 |
+
negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
|
| 428 |
+
positive_gt_mask = 1 - negative_gt_mask
|
| 429 |
+
# contrastive denoising training positive index
|
| 430 |
+
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
|
| 431 |
+
denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
|
| 432 |
+
denoise_positive_idx = torch.split(
|
| 433 |
+
denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
|
| 434 |
+
)
|
| 435 |
+
# total denoising queries
|
| 436 |
+
num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
|
| 437 |
+
|
| 438 |
+
if label_noise_ratio > 0:
|
| 439 |
+
mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
|
| 440 |
+
# randomly put a new one here
|
| 441 |
+
new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
|
| 442 |
+
input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
|
| 443 |
+
|
| 444 |
+
if box_noise_scale > 0:
|
| 445 |
+
known_bbox = center_to_corners_format(input_query_bbox)
|
| 446 |
+
diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
|
| 447 |
+
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
|
| 448 |
+
rand_part = torch.rand_like(input_query_bbox)
|
| 449 |
+
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
|
| 450 |
+
rand_part *= rand_sign
|
| 451 |
+
known_bbox += rand_part * diff
|
| 452 |
+
known_bbox.clip_(min=0.0, max=1.0)
|
| 453 |
+
input_query_bbox = corners_to_center_format(known_bbox)
|
| 454 |
+
input_query_bbox = inverse_sigmoid(input_query_bbox)
|
| 455 |
+
|
| 456 |
+
input_query_class = class_embed(input_query_class)
|
| 457 |
+
|
| 458 |
+
target_size = num_denoising_queries + num_queries
|
| 459 |
+
attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
|
| 460 |
+
# match query cannot see the reconstruction
|
| 461 |
+
attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
|
| 462 |
+
|
| 463 |
+
# reconstructions cannot see each other
|
| 464 |
+
for i in range(num_groups_denoising_queries):
|
| 465 |
+
idx_block_start = max_gt_num * 2 * i
|
| 466 |
+
idx_block_end = max_gt_num * 2 * (i + 1)
|
| 467 |
+
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
|
| 468 |
+
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
|
| 469 |
+
|
| 470 |
+
denoising_meta_values = {
|
| 471 |
+
"dn_positive_idx": denoise_positive_idx,
|
| 472 |
+
"dn_num_group": num_groups_denoising_queries,
|
| 473 |
+
"dn_num_split": [num_denoising_queries, num_queries],
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class RTDetrConvEncoder(nn.Module):
|
| 480 |
+
"""
|
| 481 |
+
Convolutional backbone using the modeling_rt_detr_resnet.py.
|
| 482 |
+
|
| 483 |
+
nn.BatchNorm2d layers are replaced by RTDetrFrozenBatchNorm2d as defined above.
|
| 484 |
+
https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/nn/backbone/presnet.py#L142
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
def __init__(self, config):
|
| 488 |
+
super().__init__()
|
| 489 |
+
|
| 490 |
+
backbone = load_backbone(config)
|
| 491 |
+
|
| 492 |
+
if config.freeze_backbone_batch_norms:
|
| 493 |
+
# replace batch norm by frozen batch norm
|
| 494 |
+
with torch.no_grad():
|
| 495 |
+
replace_batch_norm(backbone)
|
| 496 |
+
self.model = backbone
|
| 497 |
+
self.intermediate_channel_sizes = self.model.channels
|
| 498 |
+
|
| 499 |
+
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
| 500 |
+
# send pixel_values through the model to get list of feature maps
|
| 501 |
+
features = self.model(pixel_values).feature_maps
|
| 502 |
+
|
| 503 |
+
out = []
|
| 504 |
+
for feature_map in features:
|
| 505 |
+
# downsample pixel_mask to match shape of corresponding feature_map
|
| 506 |
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
| 507 |
+
out.append((feature_map, mask))
|
| 508 |
+
return out
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class RTDetrConvNormLayer(nn.Module):
|
| 512 |
+
def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
|
| 513 |
+
super().__init__()
|
| 514 |
+
self.conv = nn.Conv2d(
|
| 515 |
+
in_channels,
|
| 516 |
+
out_channels,
|
| 517 |
+
kernel_size,
|
| 518 |
+
stride,
|
| 519 |
+
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
| 520 |
+
bias=False,
|
| 521 |
+
)
|
| 522 |
+
self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
|
| 523 |
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
| 524 |
+
|
| 525 |
+
def forward(self, hidden_state):
|
| 526 |
+
hidden_state = self.conv(hidden_state)
|
| 527 |
+
hidden_state = self.norm(hidden_state)
|
| 528 |
+
hidden_state = self.activation(hidden_state)
|
| 529 |
+
return hidden_state
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class RTDetrEncoderLayer(nn.Module):
|
| 533 |
+
def __init__(self, config: RTDetrConfig):
|
| 534 |
+
super().__init__()
|
| 535 |
+
self.normalize_before = config.normalize_before
|
| 536 |
+
|
| 537 |
+
# self-attention
|
| 538 |
+
self.self_attn = RTDetrMultiheadAttention(
|
| 539 |
+
embed_dim=config.encoder_hidden_dim,
|
| 540 |
+
num_heads=config.num_attention_heads,
|
| 541 |
+
dropout=config.dropout,
|
| 542 |
+
)
|
| 543 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
| 544 |
+
self.dropout = config.dropout
|
| 545 |
+
self.activation_fn = ACT2FN[config.encoder_activation_function]
|
| 546 |
+
self.activation_dropout = config.activation_dropout
|
| 547 |
+
self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
|
| 548 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
|
| 549 |
+
self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
| 550 |
+
|
| 551 |
+
def forward(
|
| 552 |
+
self,
|
| 553 |
+
hidden_states: torch.Tensor,
|
| 554 |
+
attention_mask: torch.Tensor,
|
| 555 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 556 |
+
output_attentions: bool = False,
|
| 557 |
+
**kwargs,
|
| 558 |
+
):
|
| 559 |
+
"""
|
| 560 |
+
Args:
|
| 561 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 562 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 563 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 564 |
+
values.
|
| 565 |
+
position_embeddings (`torch.FloatTensor`, *optional*):
|
| 566 |
+
Object queries (also called content embeddings), to be added to the hidden states.
|
| 567 |
+
output_attentions (`bool`, *optional*):
|
| 568 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 569 |
+
returned tensors for more detail.
|
| 570 |
+
"""
|
| 571 |
+
residual = hidden_states
|
| 572 |
+
if self.normalize_before:
|
| 573 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 574 |
+
|
| 575 |
+
hidden_states, attn_weights = self.self_attn(
|
| 576 |
+
hidden_states=hidden_states,
|
| 577 |
+
attention_mask=attention_mask,
|
| 578 |
+
position_embeddings=position_embeddings,
|
| 579 |
+
output_attentions=output_attentions,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 583 |
+
hidden_states = residual + hidden_states
|
| 584 |
+
if not self.normalize_before:
|
| 585 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 586 |
+
|
| 587 |
+
if self.normalize_before:
|
| 588 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 589 |
+
residual = hidden_states
|
| 590 |
+
|
| 591 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 592 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 593 |
+
|
| 594 |
+
hidden_states = self.fc2(hidden_states)
|
| 595 |
+
|
| 596 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 597 |
+
|
| 598 |
+
hidden_states = residual + hidden_states
|
| 599 |
+
if not self.normalize_before:
|
| 600 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 601 |
+
|
| 602 |
+
if self.training:
|
| 603 |
+
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
| 604 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 605 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 606 |
+
|
| 607 |
+
outputs = (hidden_states,)
|
| 608 |
+
|
| 609 |
+
if output_attentions:
|
| 610 |
+
outputs += (attn_weights,)
|
| 611 |
+
|
| 612 |
+
return outputs
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
class RTDetrRepVggBlock(nn.Module):
|
| 616 |
+
"""
|
| 617 |
+
RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
|
| 618 |
+
"""
|
| 619 |
+
|
| 620 |
+
def __init__(self, config: RTDetrConfig):
|
| 621 |
+
super().__init__()
|
| 622 |
+
|
| 623 |
+
activation = config.activation_function
|
| 624 |
+
hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
|
| 625 |
+
self.conv1 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
|
| 626 |
+
self.conv2 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
|
| 627 |
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
| 628 |
+
|
| 629 |
+
def forward(self, x):
|
| 630 |
+
y = self.conv1(x) + self.conv2(x)
|
| 631 |
+
return self.activation(y)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class RTDetrCSPRepLayer(nn.Module):
|
| 635 |
+
"""
|
| 636 |
+
Cross Stage Partial (CSP) network layer with RepVGG blocks.
|
| 637 |
+
"""
|
| 638 |
+
|
| 639 |
+
def __init__(self, config: RTDetrConfig):
|
| 640 |
+
super().__init__()
|
| 641 |
+
|
| 642 |
+
in_channels = config.encoder_hidden_dim * 2
|
| 643 |
+
out_channels = config.encoder_hidden_dim
|
| 644 |
+
num_blocks = 3
|
| 645 |
+
activation = config.activation_function
|
| 646 |
+
|
| 647 |
+
hidden_channels = int(out_channels * config.hidden_expansion)
|
| 648 |
+
self.conv1 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
| 649 |
+
self.conv2 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
| 650 |
+
self.bottlenecks = nn.Sequential(*[RTDetrRepVggBlock(config) for _ in range(num_blocks)])
|
| 651 |
+
if hidden_channels != out_channels:
|
| 652 |
+
self.conv3 = RTDetrConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
|
| 653 |
+
else:
|
| 654 |
+
self.conv3 = nn.Identity()
|
| 655 |
+
|
| 656 |
+
def forward(self, hidden_state):
|
| 657 |
+
hidden_state_1 = self.conv1(hidden_state)
|
| 658 |
+
hidden_state_1 = self.bottlenecks(hidden_state_1)
|
| 659 |
+
hidden_state_2 = self.conv2(hidden_state)
|
| 660 |
+
return self.conv3(hidden_state_1 + hidden_state_2)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
|
| 664 |
+
class RTDetrMultiscaleDeformableAttention(nn.Module):
|
| 665 |
+
"""
|
| 666 |
+
Multiscale deformable attention as proposed in Deformable DETR.
|
| 667 |
+
"""
|
| 668 |
+
|
| 669 |
+
def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int):
|
| 670 |
+
super().__init__()
|
| 671 |
+
|
| 672 |
+
self.attn = MultiScaleDeformableAttention()
|
| 673 |
+
|
| 674 |
+
if config.d_model % num_heads != 0:
|
| 675 |
+
raise ValueError(
|
| 676 |
+
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
|
| 677 |
+
)
|
| 678 |
+
dim_per_head = config.d_model // num_heads
|
| 679 |
+
# check if dim_per_head is power of 2
|
| 680 |
+
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
|
| 681 |
+
warnings.warn(
|
| 682 |
+
"You'd better set embed_dim (d_model) in RTDetrMultiscaleDeformableAttention to make the"
|
| 683 |
+
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
|
| 684 |
+
" implementation."
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
self.im2col_step = 64
|
| 688 |
+
|
| 689 |
+
self.d_model = config.d_model
|
| 690 |
+
self.n_levels = config.num_feature_levels
|
| 691 |
+
self.n_heads = num_heads
|
| 692 |
+
self.n_points = n_points
|
| 693 |
+
|
| 694 |
+
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
|
| 695 |
+
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
|
| 696 |
+
self.value_proj = nn.Linear(config.d_model, config.d_model)
|
| 697 |
+
self.output_proj = nn.Linear(config.d_model, config.d_model)
|
| 698 |
+
|
| 699 |
+
self.disable_custom_kernels = config.disable_custom_kernels
|
| 700 |
+
|
| 701 |
+
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
| 702 |
+
return tensor if position_embeddings is None else tensor + position_embeddings
|
| 703 |
+
|
| 704 |
+
def forward(
|
| 705 |
+
self,
|
| 706 |
+
hidden_states: torch.Tensor,
|
| 707 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 708 |
+
encoder_hidden_states=None,
|
| 709 |
+
encoder_attention_mask=None,
|
| 710 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 711 |
+
reference_points=None,
|
| 712 |
+
spatial_shapes=None,
|
| 713 |
+
spatial_shapes_list=None,
|
| 714 |
+
level_start_index=None,
|
| 715 |
+
output_attentions: bool = False,
|
| 716 |
+
):
|
| 717 |
+
# add position embeddings to the hidden states before projecting to queries and keys
|
| 718 |
+
if position_embeddings is not None:
|
| 719 |
+
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
| 720 |
+
|
| 721 |
+
batch_size, num_queries, _ = hidden_states.shape
|
| 722 |
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
| 723 |
+
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
| 724 |
+
if total_elements != sequence_length:
|
| 725 |
+
raise ValueError(
|
| 726 |
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
value = self.value_proj(encoder_hidden_states)
|
| 730 |
+
if attention_mask is not None:
|
| 731 |
+
# we invert the attention_mask
|
| 732 |
+
value = value.masked_fill(~attention_mask[..., None], float(0))
|
| 733 |
+
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
| 734 |
+
sampling_offsets = self.sampling_offsets(hidden_states).view(
|
| 735 |
+
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
|
| 736 |
+
)
|
| 737 |
+
attention_weights = self.attention_weights(hidden_states).view(
|
| 738 |
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
|
| 739 |
+
)
|
| 740 |
+
attention_weights = F.softmax(attention_weights, -1).view(
|
| 741 |
+
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
|
| 742 |
+
)
|
| 743 |
+
# batch_size, num_queries, n_heads, n_levels, n_points, 2
|
| 744 |
+
num_coordinates = reference_points.shape[-1]
|
| 745 |
+
if num_coordinates == 2:
|
| 746 |
+
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
| 747 |
+
sampling_locations = (
|
| 748 |
+
reference_points[:, :, None, :, None, :]
|
| 749 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
| 750 |
+
)
|
| 751 |
+
elif num_coordinates == 4:
|
| 752 |
+
sampling_locations = (
|
| 753 |
+
reference_points[:, :, None, :, None, :2]
|
| 754 |
+
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
| 755 |
+
)
|
| 756 |
+
else:
|
| 757 |
+
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
| 758 |
+
|
| 759 |
+
output = self.attn(
|
| 760 |
+
value,
|
| 761 |
+
spatial_shapes,
|
| 762 |
+
spatial_shapes_list,
|
| 763 |
+
level_start_index,
|
| 764 |
+
sampling_locations,
|
| 765 |
+
attention_weights,
|
| 766 |
+
self.im2col_step,
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
output = self.output_proj(output)
|
| 770 |
+
|
| 771 |
+
return output, attention_weights
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
class RTDetrMultiheadAttention(nn.Module):
|
| 775 |
+
"""
|
| 776 |
+
Multi-headed attention from 'Attention Is All You Need' paper.
|
| 777 |
+
|
| 778 |
+
Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
|
| 779 |
+
"""
|
| 780 |
+
|
| 781 |
+
def __init__(
|
| 782 |
+
self,
|
| 783 |
+
embed_dim: int,
|
| 784 |
+
num_heads: int,
|
| 785 |
+
dropout: float = 0.0,
|
| 786 |
+
bias: bool = True,
|
| 787 |
+
):
|
| 788 |
+
super().__init__()
|
| 789 |
+
self.embed_dim = embed_dim
|
| 790 |
+
self.num_heads = num_heads
|
| 791 |
+
self.dropout = dropout
|
| 792 |
+
self.head_dim = embed_dim // num_heads
|
| 793 |
+
if self.head_dim * num_heads != self.embed_dim:
|
| 794 |
+
raise ValueError(
|
| 795 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 796 |
+
f" {num_heads})."
|
| 797 |
+
)
|
| 798 |
+
self.scaling = self.head_dim**-0.5
|
| 799 |
+
|
| 800 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 801 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 802 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 803 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 804 |
+
|
| 805 |
+
def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
| 806 |
+
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 807 |
+
|
| 808 |
+
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
| 809 |
+
return tensor if position_embeddings is None else tensor + position_embeddings
|
| 810 |
+
|
| 811 |
+
def forward(
|
| 812 |
+
self,
|
| 813 |
+
hidden_states: torch.Tensor,
|
| 814 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 815 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 816 |
+
output_attentions: bool = False,
|
| 817 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 818 |
+
"""Input shape: Batch x Time x Channel"""
|
| 819 |
+
|
| 820 |
+
batch_size, target_len, embed_dim = hidden_states.size()
|
| 821 |
+
# add position embeddings to the hidden states before projecting to queries and keys
|
| 822 |
+
if position_embeddings is not None:
|
| 823 |
+
hidden_states_original = hidden_states
|
| 824 |
+
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
| 825 |
+
|
| 826 |
+
# get queries, keys and values
|
| 827 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 828 |
+
key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
|
| 829 |
+
value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
|
| 830 |
+
|
| 831 |
+
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
| 832 |
+
query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
|
| 833 |
+
key_states = key_states.view(*proj_shape)
|
| 834 |
+
value_states = value_states.view(*proj_shape)
|
| 835 |
+
|
| 836 |
+
source_len = key_states.size(1)
|
| 837 |
+
|
| 838 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 839 |
+
|
| 840 |
+
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
|
| 841 |
+
raise ValueError(
|
| 842 |
+
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
| 843 |
+
f" {attn_weights.size()}"
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
# expand attention_mask
|
| 847 |
+
if attention_mask is not None:
|
| 848 |
+
# [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 849 |
+
attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
|
| 850 |
+
|
| 851 |
+
if attention_mask is not None:
|
| 852 |
+
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
| 853 |
+
raise ValueError(
|
| 854 |
+
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
| 855 |
+
f" {attention_mask.size()}"
|
| 856 |
+
)
|
| 857 |
+
if attention_mask.dtype == torch.bool:
|
| 858 |
+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
| 859 |
+
attention_mask, -torch.inf
|
| 860 |
+
)
|
| 861 |
+
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
| 862 |
+
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
| 863 |
+
|
| 864 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 865 |
+
|
| 866 |
+
if output_attentions:
|
| 867 |
+
# this operation is a bit awkward, but it's required to
|
| 868 |
+
# make sure that attn_weights keeps its gradient.
|
| 869 |
+
# In order to do so, attn_weights have to reshaped
|
| 870 |
+
# twice and have to be reused in the following
|
| 871 |
+
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
| 872 |
+
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
| 873 |
+
else:
|
| 874 |
+
attn_weights_reshaped = None
|
| 875 |
+
|
| 876 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 877 |
+
|
| 878 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 879 |
+
|
| 880 |
+
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
|
| 881 |
+
raise ValueError(
|
| 882 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
| 883 |
+
f" {attn_output.size()}"
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
|
| 887 |
+
attn_output = attn_output.transpose(1, 2)
|
| 888 |
+
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
|
| 889 |
+
|
| 890 |
+
attn_output = self.out_proj(attn_output)
|
| 891 |
+
|
| 892 |
+
return attn_output, attn_weights_reshaped
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
class RTDetrDecoderLayer(nn.Module):
|
| 896 |
+
def __init__(self, config: RTDetrConfig):
|
| 897 |
+
super().__init__()
|
| 898 |
+
# self-attention
|
| 899 |
+
self.self_attn = RTDetrMultiheadAttention(
|
| 900 |
+
embed_dim=config.d_model,
|
| 901 |
+
num_heads=config.decoder_attention_heads,
|
| 902 |
+
dropout=config.attention_dropout,
|
| 903 |
+
)
|
| 904 |
+
self.dropout = config.dropout
|
| 905 |
+
self.activation_fn = ACT2FN[config.decoder_activation_function]
|
| 906 |
+
self.activation_dropout = config.activation_dropout
|
| 907 |
+
|
| 908 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
| 909 |
+
# cross-attention
|
| 910 |
+
self.encoder_attn = RTDetrMultiscaleDeformableAttention(
|
| 911 |
+
config,
|
| 912 |
+
num_heads=config.decoder_attention_heads,
|
| 913 |
+
n_points=config.decoder_n_points,
|
| 914 |
+
)
|
| 915 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
| 916 |
+
# feedforward neural networks
|
| 917 |
+
self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
|
| 918 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
|
| 919 |
+
self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
| 920 |
+
|
| 921 |
+
def forward(
|
| 922 |
+
self,
|
| 923 |
+
hidden_states: torch.Tensor,
|
| 924 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 925 |
+
reference_points=None,
|
| 926 |
+
spatial_shapes=None,
|
| 927 |
+
spatial_shapes_list=None,
|
| 928 |
+
level_start_index=None,
|
| 929 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 930 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 931 |
+
output_attentions: Optional[bool] = False,
|
| 932 |
+
):
|
| 933 |
+
"""
|
| 934 |
+
Args:
|
| 935 |
+
hidden_states (`torch.FloatTensor`):
|
| 936 |
+
Input to the layer of shape `(seq_len, batch, embed_dim)`.
|
| 937 |
+
position_embeddings (`torch.FloatTensor`, *optional*):
|
| 938 |
+
Position embeddings that are added to the queries and keys in the self-attention layer.
|
| 939 |
+
reference_points (`torch.FloatTensor`, *optional*):
|
| 940 |
+
Reference points.
|
| 941 |
+
spatial_shapes (`torch.LongTensor`, *optional*):
|
| 942 |
+
Spatial shapes.
|
| 943 |
+
level_start_index (`torch.LongTensor`, *optional*):
|
| 944 |
+
Level start index.
|
| 945 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 946 |
+
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 947 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
| 948 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 949 |
+
values.
|
| 950 |
+
output_attentions (`bool`, *optional*):
|
| 951 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 952 |
+
returned tensors for more detail.
|
| 953 |
+
"""
|
| 954 |
+
residual = hidden_states
|
| 955 |
+
|
| 956 |
+
# Self Attention
|
| 957 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 958 |
+
hidden_states=hidden_states,
|
| 959 |
+
attention_mask=encoder_attention_mask,
|
| 960 |
+
position_embeddings=position_embeddings,
|
| 961 |
+
output_attentions=output_attentions,
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 965 |
+
hidden_states = residual + hidden_states
|
| 966 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 967 |
+
|
| 968 |
+
second_residual = hidden_states
|
| 969 |
+
|
| 970 |
+
# Cross-Attention
|
| 971 |
+
cross_attn_weights = None
|
| 972 |
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
| 973 |
+
hidden_states=hidden_states,
|
| 974 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 975 |
+
position_embeddings=position_embeddings,
|
| 976 |
+
reference_points=reference_points,
|
| 977 |
+
spatial_shapes=spatial_shapes,
|
| 978 |
+
spatial_shapes_list=spatial_shapes_list,
|
| 979 |
+
level_start_index=level_start_index,
|
| 980 |
+
output_attentions=output_attentions,
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 984 |
+
hidden_states = second_residual + hidden_states
|
| 985 |
+
|
| 986 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 987 |
+
|
| 988 |
+
# Fully Connected
|
| 989 |
+
residual = hidden_states
|
| 990 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 991 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 992 |
+
hidden_states = self.fc2(hidden_states)
|
| 993 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 994 |
+
hidden_states = residual + hidden_states
|
| 995 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 996 |
+
|
| 997 |
+
outputs = (hidden_states,)
|
| 998 |
+
|
| 999 |
+
if output_attentions:
|
| 1000 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
| 1001 |
+
|
| 1002 |
+
return outputs
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
@auto_docstring
|
| 1006 |
+
class RTDetrPreTrainedModel(PreTrainedModel):
|
| 1007 |
+
config: RTDetrConfig
|
| 1008 |
+
base_model_prefix = "rt_detr"
|
| 1009 |
+
main_input_name = "pixel_values"
|
| 1010 |
+
_no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"]
|
| 1011 |
+
|
| 1012 |
+
def _init_weights(self, module):
|
| 1013 |
+
"""Initialize the weights"""
|
| 1014 |
+
if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)):
|
| 1015 |
+
if module.class_embed is not None:
|
| 1016 |
+
for layer in module.class_embed:
|
| 1017 |
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
| 1018 |
+
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
| 1019 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 1020 |
+
nn.init.constant_(layer.bias, bias)
|
| 1021 |
+
|
| 1022 |
+
if module.bbox_embed is not None:
|
| 1023 |
+
for layer in module.bbox_embed:
|
| 1024 |
+
nn.init.constant_(layer.layers[-1].weight, 0)
|
| 1025 |
+
nn.init.constant_(layer.layers[-1].bias, 0)
|
| 1026 |
+
|
| 1027 |
+
elif isinstance(module, RTDetrMultiscaleDeformableAttention):
|
| 1028 |
+
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
| 1029 |
+
default_dtype = torch.get_default_dtype()
|
| 1030 |
+
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
| 1031 |
+
2.0 * math.pi / module.n_heads
|
| 1032 |
+
)
|
| 1033 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
| 1034 |
+
grid_init = (
|
| 1035 |
+
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
| 1036 |
+
.view(module.n_heads, 1, 1, 2)
|
| 1037 |
+
.repeat(1, module.n_levels, module.n_points, 1)
|
| 1038 |
+
)
|
| 1039 |
+
for i in range(module.n_points):
|
| 1040 |
+
grid_init[:, :, i, :] *= i + 1
|
| 1041 |
+
with torch.no_grad():
|
| 1042 |
+
module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
| 1043 |
+
nn.init.constant_(module.attention_weights.weight.data, 0.0)
|
| 1044 |
+
nn.init.constant_(module.attention_weights.bias.data, 0.0)
|
| 1045 |
+
nn.init.xavier_uniform_(module.value_proj.weight.data)
|
| 1046 |
+
nn.init.constant_(module.value_proj.bias.data, 0.0)
|
| 1047 |
+
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
| 1048 |
+
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
| 1049 |
+
|
| 1050 |
+
elif isinstance(module, RTDetrModel):
|
| 1051 |
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
| 1052 |
+
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
| 1053 |
+
nn.init.xavier_uniform_(module.enc_score_head.weight)
|
| 1054 |
+
nn.init.constant_(module.enc_score_head.bias, bias)
|
| 1055 |
+
|
| 1056 |
+
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
| 1057 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1058 |
+
if module.bias is not None:
|
| 1059 |
+
module.bias.data.zero_()
|
| 1060 |
+
|
| 1061 |
+
elif isinstance(module, nn.LayerNorm):
|
| 1062 |
+
module.weight.data.fill_(1.0)
|
| 1063 |
+
module.bias.data.zero_()
|
| 1064 |
+
|
| 1065 |
+
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
|
| 1066 |
+
nn.init.xavier_uniform_(module.weight_embedding.weight)
|
| 1067 |
+
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
|
| 1068 |
+
nn.init.xavier_uniform_(module.denoising_class_embed.weight)
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
class RTDetrEncoder(nn.Module):
|
| 1072 |
+
def __init__(self, config: RTDetrConfig):
|
| 1073 |
+
super().__init__()
|
| 1074 |
+
|
| 1075 |
+
self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 1076 |
+
|
| 1077 |
+
def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
|
| 1078 |
+
hidden_states = src
|
| 1079 |
+
for layer in self.layers:
|
| 1080 |
+
hidden_states = layer(
|
| 1081 |
+
hidden_states,
|
| 1082 |
+
attention_mask=src_mask,
|
| 1083 |
+
position_embeddings=pos_embed,
|
| 1084 |
+
output_attentions=output_attentions,
|
| 1085 |
+
)
|
| 1086 |
+
return hidden_states
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
class RTDetrHybridEncoder(nn.Module):
|
| 1090 |
+
"""
|
| 1091 |
+
Decoder consisting of a projection layer, a set of `RTDetrEncoder`, a top-down Feature Pyramid Network
|
| 1092 |
+
(FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
|
| 1093 |
+
|
| 1094 |
+
Args:
|
| 1095 |
+
config: RTDetrConfig
|
| 1096 |
+
"""
|
| 1097 |
+
|
| 1098 |
+
def __init__(self, config: RTDetrConfig):
|
| 1099 |
+
super().__init__()
|
| 1100 |
+
self.config = config
|
| 1101 |
+
self.in_channels = config.encoder_in_channels
|
| 1102 |
+
self.feat_strides = config.feat_strides
|
| 1103 |
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
| 1104 |
+
self.encode_proj_layers = config.encode_proj_layers
|
| 1105 |
+
self.positional_encoding_temperature = config.positional_encoding_temperature
|
| 1106 |
+
self.eval_size = config.eval_size
|
| 1107 |
+
self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
|
| 1108 |
+
self.out_strides = self.feat_strides
|
| 1109 |
+
self.num_fpn_stages = len(self.in_channels) - 1
|
| 1110 |
+
self.num_pan_stages = len(self.in_channels) - 1
|
| 1111 |
+
activation = config.activation_function
|
| 1112 |
+
|
| 1113 |
+
# encoder transformer
|
| 1114 |
+
self.encoder = nn.ModuleList([RTDetrEncoder(config) for _ in range(len(self.encode_proj_layers))])
|
| 1115 |
+
|
| 1116 |
+
# top-down FPN
|
| 1117 |
+
self.lateral_convs = nn.ModuleList()
|
| 1118 |
+
self.fpn_blocks = nn.ModuleList()
|
| 1119 |
+
for _ in range(self.num_fpn_stages):
|
| 1120 |
+
lateral_conv = RTDetrConvNormLayer(
|
| 1121 |
+
config,
|
| 1122 |
+
in_channels=self.encoder_hidden_dim,
|
| 1123 |
+
out_channels=self.encoder_hidden_dim,
|
| 1124 |
+
kernel_size=1,
|
| 1125 |
+
stride=1,
|
| 1126 |
+
activation=activation,
|
| 1127 |
+
)
|
| 1128 |
+
fpn_block = RTDetrCSPRepLayer(config)
|
| 1129 |
+
self.lateral_convs.append(lateral_conv)
|
| 1130 |
+
self.fpn_blocks.append(fpn_block)
|
| 1131 |
+
|
| 1132 |
+
# bottom-up PAN
|
| 1133 |
+
self.downsample_convs = nn.ModuleList()
|
| 1134 |
+
self.pan_blocks = nn.ModuleList()
|
| 1135 |
+
for _ in range(self.num_pan_stages):
|
| 1136 |
+
downsample_conv = RTDetrConvNormLayer(
|
| 1137 |
+
config,
|
| 1138 |
+
in_channels=self.encoder_hidden_dim,
|
| 1139 |
+
out_channels=self.encoder_hidden_dim,
|
| 1140 |
+
kernel_size=3,
|
| 1141 |
+
stride=2,
|
| 1142 |
+
activation=activation,
|
| 1143 |
+
)
|
| 1144 |
+
pan_block = RTDetrCSPRepLayer(config)
|
| 1145 |
+
self.downsample_convs.append(downsample_conv)
|
| 1146 |
+
self.pan_blocks.append(pan_block)
|
| 1147 |
+
|
| 1148 |
+
@staticmethod
|
| 1149 |
+
def build_2d_sincos_position_embedding(
|
| 1150 |
+
width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
|
| 1151 |
+
):
|
| 1152 |
+
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
| 1153 |
+
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
| 1154 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
|
| 1155 |
+
if embed_dim % 4 != 0:
|
| 1156 |
+
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
| 1157 |
+
pos_dim = embed_dim // 4
|
| 1158 |
+
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
| 1159 |
+
omega = 1.0 / (temperature**omega)
|
| 1160 |
+
|
| 1161 |
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
| 1162 |
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
| 1163 |
+
|
| 1164 |
+
return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
|
| 1165 |
+
|
| 1166 |
+
def forward(
|
| 1167 |
+
self,
|
| 1168 |
+
inputs_embeds=None,
|
| 1169 |
+
attention_mask=None,
|
| 1170 |
+
position_embeddings=None,
|
| 1171 |
+
spatial_shapes=None,
|
| 1172 |
+
level_start_index=None,
|
| 1173 |
+
valid_ratios=None,
|
| 1174 |
+
output_attentions=None,
|
| 1175 |
+
output_hidden_states=None,
|
| 1176 |
+
return_dict=None,
|
| 1177 |
+
):
|
| 1178 |
+
r"""
|
| 1179 |
+
Args:
|
| 1180 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 1181 |
+
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
| 1182 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1183 |
+
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
| 1184 |
+
- 1 for pixel features that are real (i.e. **not masked**),
|
| 1185 |
+
- 0 for pixel features that are padding (i.e. **masked**).
|
| 1186 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1187 |
+
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 1188 |
+
Position embeddings that are added to the queries and keys in each self-attention layer.
|
| 1189 |
+
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
| 1190 |
+
Spatial shapes of each feature map.
|
| 1191 |
+
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
| 1192 |
+
Starting index of each feature map.
|
| 1193 |
+
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
| 1194 |
+
Ratio of valid area in each feature level.
|
| 1195 |
+
output_attentions (`bool`, *optional*):
|
| 1196 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1197 |
+
returned tensors for more detail.
|
| 1198 |
+
output_hidden_states (`bool`, *optional*):
|
| 1199 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1200 |
+
for more detail.
|
| 1201 |
+
return_dict (`bool`, *optional*):
|
| 1202 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 1203 |
+
"""
|
| 1204 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1205 |
+
output_hidden_states = (
|
| 1206 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1207 |
+
)
|
| 1208 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1209 |
+
|
| 1210 |
+
hidden_states = inputs_embeds
|
| 1211 |
+
|
| 1212 |
+
encoder_states = () if output_hidden_states else None
|
| 1213 |
+
all_attentions = () if output_attentions else None
|
| 1214 |
+
|
| 1215 |
+
# encoder
|
| 1216 |
+
if self.config.encoder_layers > 0:
|
| 1217 |
+
for i, enc_ind in enumerate(self.encode_proj_layers):
|
| 1218 |
+
if output_hidden_states:
|
| 1219 |
+
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
| 1220 |
+
height, width = hidden_states[enc_ind].shape[2:]
|
| 1221 |
+
# flatten [batch, channel, height, width] to [batch, height*width, channel]
|
| 1222 |
+
src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
|
| 1223 |
+
if self.training or self.eval_size is None:
|
| 1224 |
+
pos_embed = self.build_2d_sincos_position_embedding(
|
| 1225 |
+
width,
|
| 1226 |
+
height,
|
| 1227 |
+
self.encoder_hidden_dim,
|
| 1228 |
+
self.positional_encoding_temperature,
|
| 1229 |
+
device=src_flatten.device,
|
| 1230 |
+
dtype=src_flatten.dtype,
|
| 1231 |
+
)
|
| 1232 |
+
else:
|
| 1233 |
+
pos_embed = None
|
| 1234 |
+
|
| 1235 |
+
layer_outputs = self.encoder[i](
|
| 1236 |
+
src_flatten,
|
| 1237 |
+
pos_embed=pos_embed,
|
| 1238 |
+
output_attentions=output_attentions,
|
| 1239 |
+
)
|
| 1240 |
+
hidden_states[enc_ind] = (
|
| 1241 |
+
layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
if output_attentions:
|
| 1245 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 1246 |
+
|
| 1247 |
+
if output_hidden_states:
|
| 1248 |
+
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
| 1249 |
+
|
| 1250 |
+
# top-down FPN
|
| 1251 |
+
fpn_feature_maps = [hidden_states[-1]]
|
| 1252 |
+
for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
|
| 1253 |
+
backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
|
| 1254 |
+
top_fpn_feature_map = fpn_feature_maps[-1]
|
| 1255 |
+
# apply lateral block
|
| 1256 |
+
top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
|
| 1257 |
+
fpn_feature_maps[-1] = top_fpn_feature_map
|
| 1258 |
+
# apply fpn block
|
| 1259 |
+
top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
|
| 1260 |
+
fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
|
| 1261 |
+
new_fpn_feature_map = fpn_block(fused_feature_map)
|
| 1262 |
+
fpn_feature_maps.append(new_fpn_feature_map)
|
| 1263 |
+
|
| 1264 |
+
fpn_feature_maps = fpn_feature_maps[::-1]
|
| 1265 |
+
|
| 1266 |
+
# bottom-up PAN
|
| 1267 |
+
pan_feature_maps = [fpn_feature_maps[0]]
|
| 1268 |
+
for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
|
| 1269 |
+
top_pan_feature_map = pan_feature_maps[-1]
|
| 1270 |
+
fpn_feature_map = fpn_feature_maps[idx + 1]
|
| 1271 |
+
downsampled_feature_map = downsample_conv(top_pan_feature_map)
|
| 1272 |
+
fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
|
| 1273 |
+
new_pan_feature_map = pan_block(fused_feature_map)
|
| 1274 |
+
pan_feature_maps.append(new_pan_feature_map)
|
| 1275 |
+
|
| 1276 |
+
if not return_dict:
|
| 1277 |
+
return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
|
| 1278 |
+
return BaseModelOutput(
|
| 1279 |
+
last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
|
| 1280 |
+
)
|
| 1281 |
+
|
| 1282 |
+
|
| 1283 |
+
class RTDetrDecoder(RTDetrPreTrainedModel):
|
| 1284 |
+
def __init__(self, config: RTDetrConfig):
|
| 1285 |
+
super().__init__(config)
|
| 1286 |
+
|
| 1287 |
+
self.dropout = config.dropout
|
| 1288 |
+
self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
| 1289 |
+
self.query_pos_head = RTDetrMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
|
| 1290 |
+
|
| 1291 |
+
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
| 1292 |
+
self.bbox_embed = None
|
| 1293 |
+
self.class_embed = None
|
| 1294 |
+
|
| 1295 |
+
# Initialize weights and apply final processing
|
| 1296 |
+
self.post_init()
|
| 1297 |
+
|
| 1298 |
+
def forward(
|
| 1299 |
+
self,
|
| 1300 |
+
inputs_embeds=None,
|
| 1301 |
+
encoder_hidden_states=None,
|
| 1302 |
+
encoder_attention_mask=None,
|
| 1303 |
+
position_embeddings=None,
|
| 1304 |
+
reference_points=None,
|
| 1305 |
+
spatial_shapes=None,
|
| 1306 |
+
spatial_shapes_list=None,
|
| 1307 |
+
level_start_index=None,
|
| 1308 |
+
valid_ratios=None,
|
| 1309 |
+
output_attentions=None,
|
| 1310 |
+
output_hidden_states=None,
|
| 1311 |
+
return_dict=None,
|
| 1312 |
+
):
|
| 1313 |
+
r"""
|
| 1314 |
+
Args:
|
| 1315 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
| 1316 |
+
The query embeddings that are passed into the decoder.
|
| 1317 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1318 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 1319 |
+
of the decoder.
|
| 1320 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1321 |
+
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
|
| 1322 |
+
in `[0, 1]`:
|
| 1323 |
+
- 1 for pixels that are real (i.e. **not masked**),
|
| 1324 |
+
- 0 for pixels that are padding (i.e. **masked**).
|
| 1325 |
+
position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 1326 |
+
Position embeddings that are added to the queries and keys in each self-attention layer.
|
| 1327 |
+
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
|
| 1328 |
+
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
|
| 1329 |
+
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
|
| 1330 |
+
Spatial shapes of the feature maps.
|
| 1331 |
+
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
|
| 1332 |
+
Indexes for the start of each feature level. In range `[0, sequence_length]`.
|
| 1333 |
+
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
|
| 1334 |
+
Ratio of valid area in each feature level.
|
| 1335 |
+
|
| 1336 |
+
output_attentions (`bool`, *optional*):
|
| 1337 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1338 |
+
returned tensors for more detail.
|
| 1339 |
+
output_hidden_states (`bool`, *optional*):
|
| 1340 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1341 |
+
for more detail.
|
| 1342 |
+
return_dict (`bool`, *optional*):
|
| 1343 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 1344 |
+
"""
|
| 1345 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1346 |
+
output_hidden_states = (
|
| 1347 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1348 |
+
)
|
| 1349 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1350 |
+
|
| 1351 |
+
if inputs_embeds is not None:
|
| 1352 |
+
hidden_states = inputs_embeds
|
| 1353 |
+
|
| 1354 |
+
# decoder layers
|
| 1355 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1356 |
+
all_self_attns = () if output_attentions else None
|
| 1357 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 1358 |
+
intermediate = ()
|
| 1359 |
+
intermediate_reference_points = ()
|
| 1360 |
+
intermediate_logits = ()
|
| 1361 |
+
|
| 1362 |
+
reference_points = F.sigmoid(reference_points)
|
| 1363 |
+
|
| 1364 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252
|
| 1365 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 1366 |
+
reference_points_input = reference_points.unsqueeze(2)
|
| 1367 |
+
position_embeddings = self.query_pos_head(reference_points)
|
| 1368 |
+
|
| 1369 |
+
if output_hidden_states:
|
| 1370 |
+
all_hidden_states += (hidden_states,)
|
| 1371 |
+
|
| 1372 |
+
layer_outputs = decoder_layer(
|
| 1373 |
+
hidden_states,
|
| 1374 |
+
position_embeddings=position_embeddings,
|
| 1375 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1376 |
+
reference_points=reference_points_input,
|
| 1377 |
+
spatial_shapes=spatial_shapes,
|
| 1378 |
+
spatial_shapes_list=spatial_shapes_list,
|
| 1379 |
+
level_start_index=level_start_index,
|
| 1380 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1381 |
+
output_attentions=output_attentions,
|
| 1382 |
+
)
|
| 1383 |
+
|
| 1384 |
+
hidden_states = layer_outputs[0]
|
| 1385 |
+
|
| 1386 |
+
# hack implementation for iterative bounding box refinement
|
| 1387 |
+
if self.bbox_embed is not None:
|
| 1388 |
+
predicted_corners = self.bbox_embed[idx](hidden_states)
|
| 1389 |
+
new_reference_points = F.sigmoid(predicted_corners + inverse_sigmoid(reference_points))
|
| 1390 |
+
reference_points = new_reference_points.detach()
|
| 1391 |
+
|
| 1392 |
+
intermediate += (hidden_states,)
|
| 1393 |
+
intermediate_reference_points += (
|
| 1394 |
+
(new_reference_points,) if self.bbox_embed is not None else (reference_points,)
|
| 1395 |
+
)
|
| 1396 |
+
|
| 1397 |
+
if self.class_embed is not None:
|
| 1398 |
+
logits = self.class_embed[idx](hidden_states)
|
| 1399 |
+
intermediate_logits += (logits,)
|
| 1400 |
+
|
| 1401 |
+
if output_attentions:
|
| 1402 |
+
all_self_attns += (layer_outputs[1],)
|
| 1403 |
+
|
| 1404 |
+
if encoder_hidden_states is not None:
|
| 1405 |
+
all_cross_attentions += (layer_outputs[2],)
|
| 1406 |
+
|
| 1407 |
+
# Keep batch_size as first dimension
|
| 1408 |
+
intermediate = torch.stack(intermediate, dim=1)
|
| 1409 |
+
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
| 1410 |
+
if self.class_embed is not None:
|
| 1411 |
+
intermediate_logits = torch.stack(intermediate_logits, dim=1)
|
| 1412 |
+
|
| 1413 |
+
# add hidden states from the last decoder layer
|
| 1414 |
+
if output_hidden_states:
|
| 1415 |
+
all_hidden_states += (hidden_states,)
|
| 1416 |
+
|
| 1417 |
+
if not return_dict:
|
| 1418 |
+
return tuple(
|
| 1419 |
+
v
|
| 1420 |
+
for v in [
|
| 1421 |
+
hidden_states,
|
| 1422 |
+
intermediate,
|
| 1423 |
+
intermediate_logits,
|
| 1424 |
+
intermediate_reference_points,
|
| 1425 |
+
all_hidden_states,
|
| 1426 |
+
all_self_attns,
|
| 1427 |
+
all_cross_attentions,
|
| 1428 |
+
]
|
| 1429 |
+
if v is not None
|
| 1430 |
+
)
|
| 1431 |
+
return RTDetrDecoderOutput(
|
| 1432 |
+
last_hidden_state=hidden_states,
|
| 1433 |
+
intermediate_hidden_states=intermediate,
|
| 1434 |
+
intermediate_logits=intermediate_logits,
|
| 1435 |
+
intermediate_reference_points=intermediate_reference_points,
|
| 1436 |
+
hidden_states=all_hidden_states,
|
| 1437 |
+
attentions=all_self_attns,
|
| 1438 |
+
cross_attentions=all_cross_attentions,
|
| 1439 |
+
)
|
| 1440 |
+
|
| 1441 |
+
|
| 1442 |
+
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 1443 |
+
class RTDetrMLPPredictionHead(nn.Module):
|
| 1444 |
+
"""
|
| 1445 |
+
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
| 1446 |
+
height and width of a bounding box w.r.t. an image.
|
| 1447 |
+
|
| 1448 |
+
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 1449 |
+
Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453
|
| 1450 |
+
|
| 1451 |
+
"""
|
| 1452 |
+
|
| 1453 |
+
def __init__(self, config, input_dim, d_model, output_dim, num_layers):
|
| 1454 |
+
super().__init__()
|
| 1455 |
+
self.num_layers = num_layers
|
| 1456 |
+
h = [d_model] * (num_layers - 1)
|
| 1457 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 1458 |
+
|
| 1459 |
+
def forward(self, x):
|
| 1460 |
+
for i, layer in enumerate(self.layers):
|
| 1461 |
+
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 1462 |
+
return x
|
| 1463 |
+
|
| 1464 |
+
|
| 1465 |
+
@auto_docstring(
|
| 1466 |
+
custom_intro="""
|
| 1467 |
+
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
|
| 1468 |
+
"""
|
| 1469 |
+
)
|
| 1470 |
+
class RTDetrModel(RTDetrPreTrainedModel):
|
| 1471 |
+
def __init__(self, config: RTDetrConfig):
|
| 1472 |
+
super().__init__(config)
|
| 1473 |
+
|
| 1474 |
+
# Create backbone
|
| 1475 |
+
self.backbone = RTDetrConvEncoder(config)
|
| 1476 |
+
intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
|
| 1477 |
+
|
| 1478 |
+
# Create encoder input projection layers
|
| 1479 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212
|
| 1480 |
+
num_backbone_outs = len(intermediate_channel_sizes)
|
| 1481 |
+
encoder_input_proj_list = []
|
| 1482 |
+
for _ in range(num_backbone_outs):
|
| 1483 |
+
in_channels = intermediate_channel_sizes[_]
|
| 1484 |
+
encoder_input_proj_list.append(
|
| 1485 |
+
nn.Sequential(
|
| 1486 |
+
nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
|
| 1487 |
+
nn.BatchNorm2d(config.encoder_hidden_dim),
|
| 1488 |
+
)
|
| 1489 |
+
)
|
| 1490 |
+
self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list)
|
| 1491 |
+
|
| 1492 |
+
# Create encoder
|
| 1493 |
+
self.encoder = RTDetrHybridEncoder(config)
|
| 1494 |
+
|
| 1495 |
+
# denoising part
|
| 1496 |
+
if config.num_denoising > 0:
|
| 1497 |
+
self.denoising_class_embed = nn.Embedding(
|
| 1498 |
+
config.num_labels + 1, config.d_model, padding_idx=config.num_labels
|
| 1499 |
+
)
|
| 1500 |
+
|
| 1501 |
+
# decoder embedding
|
| 1502 |
+
if config.learn_initial_query:
|
| 1503 |
+
self.weight_embedding = nn.Embedding(config.num_queries, config.d_model)
|
| 1504 |
+
|
| 1505 |
+
# encoder head
|
| 1506 |
+
self.enc_output = nn.Sequential(
|
| 1507 |
+
nn.Linear(config.d_model, config.d_model),
|
| 1508 |
+
nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
|
| 1509 |
+
)
|
| 1510 |
+
self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
|
| 1511 |
+
self.enc_bbox_head = RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
|
| 1512 |
+
|
| 1513 |
+
# init encoder output anchors and valid_mask
|
| 1514 |
+
if config.anchor_image_size:
|
| 1515 |
+
self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
|
| 1516 |
+
|
| 1517 |
+
# Create decoder input projection layers
|
| 1518 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
|
| 1519 |
+
num_backbone_outs = len(config.decoder_in_channels)
|
| 1520 |
+
decoder_input_proj_list = []
|
| 1521 |
+
for _ in range(num_backbone_outs):
|
| 1522 |
+
in_channels = config.decoder_in_channels[_]
|
| 1523 |
+
decoder_input_proj_list.append(
|
| 1524 |
+
nn.Sequential(
|
| 1525 |
+
nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
|
| 1526 |
+
nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
|
| 1527 |
+
)
|
| 1528 |
+
)
|
| 1529 |
+
for _ in range(config.num_feature_levels - num_backbone_outs):
|
| 1530 |
+
decoder_input_proj_list.append(
|
| 1531 |
+
nn.Sequential(
|
| 1532 |
+
nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False),
|
| 1533 |
+
nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
|
| 1534 |
+
)
|
| 1535 |
+
)
|
| 1536 |
+
in_channels = config.d_model
|
| 1537 |
+
self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list)
|
| 1538 |
+
|
| 1539 |
+
# decoder
|
| 1540 |
+
self.decoder = RTDetrDecoder(config)
|
| 1541 |
+
|
| 1542 |
+
self.post_init()
|
| 1543 |
+
|
| 1544 |
+
def get_encoder(self):
|
| 1545 |
+
return self.encoder
|
| 1546 |
+
|
| 1547 |
+
def freeze_backbone(self):
|
| 1548 |
+
for param in self.backbone.parameters():
|
| 1549 |
+
param.requires_grad_(False)
|
| 1550 |
+
|
| 1551 |
+
def unfreeze_backbone(self):
|
| 1552 |
+
for param in self.backbone.parameters():
|
| 1553 |
+
param.requires_grad_(True)
|
| 1554 |
+
|
| 1555 |
+
@compile_compatible_method_lru_cache(maxsize=32)
|
| 1556 |
+
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
|
| 1557 |
+
if spatial_shapes is None:
|
| 1558 |
+
spatial_shapes = [
|
| 1559 |
+
[int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
|
| 1560 |
+
for s in self.config.feat_strides
|
| 1561 |
+
]
|
| 1562 |
+
anchors = []
|
| 1563 |
+
for level, (height, width) in enumerate(spatial_shapes):
|
| 1564 |
+
grid_y, grid_x = torch.meshgrid(
|
| 1565 |
+
torch.arange(end=height, device=device).to(dtype),
|
| 1566 |
+
torch.arange(end=width, device=device).to(dtype),
|
| 1567 |
+
indexing="ij",
|
| 1568 |
+
)
|
| 1569 |
+
grid_xy = torch.stack([grid_x, grid_y], -1)
|
| 1570 |
+
grid_xy = grid_xy.unsqueeze(0) + 0.5
|
| 1571 |
+
grid_xy[..., 0] /= width
|
| 1572 |
+
grid_xy[..., 1] /= height
|
| 1573 |
+
wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
|
| 1574 |
+
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
|
| 1575 |
+
# define the valid range for anchor coordinates
|
| 1576 |
+
eps = 1e-2
|
| 1577 |
+
anchors = torch.concat(anchors, 1)
|
| 1578 |
+
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
|
| 1579 |
+
anchors = torch.log(anchors / (1 - anchors))
|
| 1580 |
+
anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
|
| 1581 |
+
|
| 1582 |
+
return anchors, valid_mask
|
| 1583 |
+
|
| 1584 |
+
@auto_docstring
|
| 1585 |
+
def forward(
|
| 1586 |
+
self,
|
| 1587 |
+
pixel_values: torch.FloatTensor,
|
| 1588 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1589 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1590 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1591 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1592 |
+
labels: Optional[list[dict]] = None,
|
| 1593 |
+
output_attentions: Optional[bool] = None,
|
| 1594 |
+
output_hidden_states: Optional[bool] = None,
|
| 1595 |
+
return_dict: Optional[bool] = None,
|
| 1596 |
+
) -> Union[tuple[torch.FloatTensor], RTDetrModelOutput]:
|
| 1597 |
+
r"""
|
| 1598 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1599 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 1600 |
+
can choose to directly pass a flattened representation of an image.
|
| 1601 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 1602 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 1603 |
+
embedded representation.
|
| 1604 |
+
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
| 1605 |
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
| 1606 |
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
| 1607 |
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
| 1608 |
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
| 1609 |
+
|
| 1610 |
+
Examples:
|
| 1611 |
+
|
| 1612 |
+
```python
|
| 1613 |
+
>>> from transformers import AutoImageProcessor, RTDetrModel
|
| 1614 |
+
>>> from PIL import Image
|
| 1615 |
+
>>> import requests
|
| 1616 |
+
|
| 1617 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1618 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1619 |
+
|
| 1620 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
|
| 1621 |
+
>>> model = RTDetrModel.from_pretrained("PekingU/rtdetr_r50vd")
|
| 1622 |
+
|
| 1623 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1624 |
+
|
| 1625 |
+
>>> outputs = model(**inputs)
|
| 1626 |
+
|
| 1627 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 1628 |
+
>>> list(last_hidden_states.shape)
|
| 1629 |
+
[1, 300, 256]
|
| 1630 |
+
```"""
|
| 1631 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1632 |
+
output_hidden_states = (
|
| 1633 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1634 |
+
)
|
| 1635 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1636 |
+
|
| 1637 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 1638 |
+
device = pixel_values.device
|
| 1639 |
+
|
| 1640 |
+
if pixel_mask is None:
|
| 1641 |
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
| 1642 |
+
|
| 1643 |
+
features = self.backbone(pixel_values, pixel_mask)
|
| 1644 |
+
|
| 1645 |
+
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
| 1646 |
+
|
| 1647 |
+
if encoder_outputs is None:
|
| 1648 |
+
encoder_outputs = self.encoder(
|
| 1649 |
+
proj_feats,
|
| 1650 |
+
output_attentions=output_attentions,
|
| 1651 |
+
output_hidden_states=output_hidden_states,
|
| 1652 |
+
return_dict=return_dict,
|
| 1653 |
+
)
|
| 1654 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
| 1655 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 1656 |
+
encoder_outputs = BaseModelOutput(
|
| 1657 |
+
last_hidden_state=encoder_outputs[0],
|
| 1658 |
+
hidden_states=encoder_outputs[1] if output_hidden_states else None,
|
| 1659 |
+
attentions=encoder_outputs[2]
|
| 1660 |
+
if len(encoder_outputs) > 2
|
| 1661 |
+
else encoder_outputs[1]
|
| 1662 |
+
if output_attentions
|
| 1663 |
+
else None,
|
| 1664 |
+
)
|
| 1665 |
+
|
| 1666 |
+
# Equivalent to def _get_encoder_input
|
| 1667 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
|
| 1668 |
+
sources = []
|
| 1669 |
+
for level, source in enumerate(encoder_outputs[0]):
|
| 1670 |
+
sources.append(self.decoder_input_proj[level](source))
|
| 1671 |
+
|
| 1672 |
+
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
| 1673 |
+
if self.config.num_feature_levels > len(sources):
|
| 1674 |
+
_len_sources = len(sources)
|
| 1675 |
+
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
|
| 1676 |
+
for i in range(_len_sources + 1, self.config.num_feature_levels):
|
| 1677 |
+
sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
|
| 1678 |
+
|
| 1679 |
+
# Prepare encoder inputs (by flattening)
|
| 1680 |
+
source_flatten = []
|
| 1681 |
+
spatial_shapes_list = []
|
| 1682 |
+
spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long)
|
| 1683 |
+
for level, source in enumerate(sources):
|
| 1684 |
+
height, width = source.shape[-2:]
|
| 1685 |
+
spatial_shapes[level, 0] = height
|
| 1686 |
+
spatial_shapes[level, 1] = width
|
| 1687 |
+
spatial_shapes_list.append((height, width))
|
| 1688 |
+
source = source.flatten(2).transpose(1, 2)
|
| 1689 |
+
source_flatten.append(source)
|
| 1690 |
+
source_flatten = torch.cat(source_flatten, 1)
|
| 1691 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
| 1692 |
+
|
| 1693 |
+
# prepare denoising training
|
| 1694 |
+
if self.training and self.config.num_denoising > 0 and labels is not None:
|
| 1695 |
+
(
|
| 1696 |
+
denoising_class,
|
| 1697 |
+
denoising_bbox_unact,
|
| 1698 |
+
attention_mask,
|
| 1699 |
+
denoising_meta_values,
|
| 1700 |
+
) = get_contrastive_denoising_training_group(
|
| 1701 |
+
targets=labels,
|
| 1702 |
+
num_classes=self.config.num_labels,
|
| 1703 |
+
num_queries=self.config.num_queries,
|
| 1704 |
+
class_embed=self.denoising_class_embed,
|
| 1705 |
+
num_denoising_queries=self.config.num_denoising,
|
| 1706 |
+
label_noise_ratio=self.config.label_noise_ratio,
|
| 1707 |
+
box_noise_scale=self.config.box_noise_scale,
|
| 1708 |
+
)
|
| 1709 |
+
else:
|
| 1710 |
+
denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None
|
| 1711 |
+
|
| 1712 |
+
batch_size = len(source_flatten)
|
| 1713 |
+
device = source_flatten.device
|
| 1714 |
+
dtype = source_flatten.dtype
|
| 1715 |
+
|
| 1716 |
+
# prepare input for decoder
|
| 1717 |
+
if self.training or self.config.anchor_image_size is None:
|
| 1718 |
+
# Pass spatial_shapes as tuple to make it hashable and make sure
|
| 1719 |
+
# lru_cache is working for generate_anchors()
|
| 1720 |
+
spatial_shapes_tuple = tuple(spatial_shapes_list)
|
| 1721 |
+
anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
|
| 1722 |
+
else:
|
| 1723 |
+
anchors, valid_mask = self.anchors, self.valid_mask
|
| 1724 |
+
anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
|
| 1725 |
+
|
| 1726 |
+
# use the valid_mask to selectively retain values in the feature map where the mask is `True`
|
| 1727 |
+
memory = valid_mask.to(source_flatten.dtype) * source_flatten
|
| 1728 |
+
|
| 1729 |
+
output_memory = self.enc_output(memory)
|
| 1730 |
+
|
| 1731 |
+
enc_outputs_class = self.enc_score_head(output_memory)
|
| 1732 |
+
enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors
|
| 1733 |
+
|
| 1734 |
+
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1)
|
| 1735 |
+
|
| 1736 |
+
reference_points_unact = enc_outputs_coord_logits.gather(
|
| 1737 |
+
dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1])
|
| 1738 |
+
)
|
| 1739 |
+
|
| 1740 |
+
enc_topk_bboxes = F.sigmoid(reference_points_unact)
|
| 1741 |
+
if denoising_bbox_unact is not None:
|
| 1742 |
+
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
|
| 1743 |
+
|
| 1744 |
+
enc_topk_logits = enc_outputs_class.gather(
|
| 1745 |
+
dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
|
| 1746 |
+
)
|
| 1747 |
+
|
| 1748 |
+
# extract region features
|
| 1749 |
+
if self.config.learn_initial_query:
|
| 1750 |
+
target = self.weight_embedding.tile([batch_size, 1, 1])
|
| 1751 |
+
else:
|
| 1752 |
+
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
|
| 1753 |
+
target = target.detach()
|
| 1754 |
+
|
| 1755 |
+
if denoising_class is not None:
|
| 1756 |
+
target = torch.concat([denoising_class, target], 1)
|
| 1757 |
+
|
| 1758 |
+
init_reference_points = reference_points_unact.detach()
|
| 1759 |
+
|
| 1760 |
+
# decoder
|
| 1761 |
+
decoder_outputs = self.decoder(
|
| 1762 |
+
inputs_embeds=target,
|
| 1763 |
+
encoder_hidden_states=source_flatten,
|
| 1764 |
+
encoder_attention_mask=attention_mask,
|
| 1765 |
+
reference_points=init_reference_points,
|
| 1766 |
+
spatial_shapes=spatial_shapes,
|
| 1767 |
+
spatial_shapes_list=spatial_shapes_list,
|
| 1768 |
+
level_start_index=level_start_index,
|
| 1769 |
+
output_attentions=output_attentions,
|
| 1770 |
+
output_hidden_states=output_hidden_states,
|
| 1771 |
+
return_dict=return_dict,
|
| 1772 |
+
)
|
| 1773 |
+
|
| 1774 |
+
if not return_dict:
|
| 1775 |
+
enc_outputs = tuple(
|
| 1776 |
+
value
|
| 1777 |
+
for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
|
| 1778 |
+
if value is not None
|
| 1779 |
+
)
|
| 1780 |
+
dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
|
| 1781 |
+
tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
|
| 1782 |
+
|
| 1783 |
+
return tuple_outputs
|
| 1784 |
+
|
| 1785 |
+
return RTDetrModelOutput(
|
| 1786 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1787 |
+
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
| 1788 |
+
intermediate_logits=decoder_outputs.intermediate_logits,
|
| 1789 |
+
intermediate_reference_points=decoder_outputs.intermediate_reference_points,
|
| 1790 |
+
intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners,
|
| 1791 |
+
initial_reference_points=decoder_outputs.initial_reference_points,
|
| 1792 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1793 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1794 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 1795 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1796 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1797 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1798 |
+
init_reference_points=init_reference_points,
|
| 1799 |
+
enc_topk_logits=enc_topk_logits,
|
| 1800 |
+
enc_topk_bboxes=enc_topk_bboxes,
|
| 1801 |
+
enc_outputs_class=enc_outputs_class,
|
| 1802 |
+
enc_outputs_coord_logits=enc_outputs_coord_logits,
|
| 1803 |
+
denoising_meta_values=denoising_meta_values,
|
| 1804 |
+
)
|
| 1805 |
+
|
| 1806 |
+
|
| 1807 |
+
@auto_docstring(
|
| 1808 |
+
custom_intro="""
|
| 1809 |
+
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further
|
| 1810 |
+
decoded into scores and classes.
|
| 1811 |
+
"""
|
| 1812 |
+
)
|
| 1813 |
+
class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
| 1814 |
+
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
| 1815 |
+
_tied_weights_keys = ["bbox_embed", "class_embed"]
|
| 1816 |
+
# We can't initialize the model on meta device as some weights are modified during the initialization
|
| 1817 |
+
_no_split_modules = None
|
| 1818 |
+
|
| 1819 |
+
def __init__(self, config: RTDetrConfig):
|
| 1820 |
+
super().__init__(config)
|
| 1821 |
+
|
| 1822 |
+
# RTDETR encoder-decoder model
|
| 1823 |
+
self.model = RTDetrModel(config)
|
| 1824 |
+
|
| 1825 |
+
# Detection heads on top
|
| 1826 |
+
self.class_embed = partial(nn.Linear, config.d_model, config.num_labels)
|
| 1827 |
+
self.bbox_embed = partial(RTDetrMLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3)
|
| 1828 |
+
|
| 1829 |
+
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
|
| 1830 |
+
num_pred = config.decoder_layers
|
| 1831 |
+
if config.with_box_refine:
|
| 1832 |
+
self.class_embed = _get_clones(self.class_embed, num_pred)
|
| 1833 |
+
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
|
| 1834 |
+
else:
|
| 1835 |
+
self.class_embed = nn.ModuleList([self.class_embed() for _ in range(num_pred)])
|
| 1836 |
+
self.bbox_embed = nn.ModuleList([self.bbox_embed() for _ in range(num_pred)])
|
| 1837 |
+
|
| 1838 |
+
# hack implementation for iterative bounding box refinement
|
| 1839 |
+
self.model.decoder.class_embed = self.class_embed
|
| 1840 |
+
self.model.decoder.bbox_embed = self.bbox_embed
|
| 1841 |
+
|
| 1842 |
+
# Initialize weights and apply final processing
|
| 1843 |
+
self.post_init()
|
| 1844 |
+
|
| 1845 |
+
@torch.jit.unused
|
| 1846 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
| 1847 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 1848 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 1849 |
+
# as a dict having both a Tensor and a list.
|
| 1850 |
+
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
|
| 1851 |
+
|
| 1852 |
+
@auto_docstring
|
| 1853 |
+
def forward(
|
| 1854 |
+
self,
|
| 1855 |
+
pixel_values: torch.FloatTensor,
|
| 1856 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1857 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1858 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1859 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1860 |
+
labels: Optional[list[dict]] = None,
|
| 1861 |
+
output_attentions: Optional[bool] = None,
|
| 1862 |
+
output_hidden_states: Optional[bool] = None,
|
| 1863 |
+
return_dict: Optional[bool] = None,
|
| 1864 |
+
**kwargs,
|
| 1865 |
+
) -> Union[tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
|
| 1866 |
+
r"""
|
| 1867 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1868 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 1869 |
+
can choose to directly pass a flattened representation of an image.
|
| 1870 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 1871 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 1872 |
+
embedded representation.
|
| 1873 |
+
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
| 1874 |
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
| 1875 |
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
| 1876 |
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
| 1877 |
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
| 1878 |
+
|
| 1879 |
+
Examples:
|
| 1880 |
+
|
| 1881 |
+
```python
|
| 1882 |
+
>>> from transformers import RTDetrImageProcessor, RTDetrForObjectDetection
|
| 1883 |
+
>>> from PIL import Image
|
| 1884 |
+
>>> import requests
|
| 1885 |
+
>>> import torch
|
| 1886 |
+
|
| 1887 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1888 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1889 |
+
|
| 1890 |
+
>>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
|
| 1891 |
+
>>> model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
|
| 1892 |
+
|
| 1893 |
+
>>> # prepare image for the model
|
| 1894 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1895 |
+
|
| 1896 |
+
>>> # forward pass
|
| 1897 |
+
>>> outputs = model(**inputs)
|
| 1898 |
+
|
| 1899 |
+
>>> logits = outputs.logits
|
| 1900 |
+
>>> list(logits.shape)
|
| 1901 |
+
[1, 300, 80]
|
| 1902 |
+
|
| 1903 |
+
>>> boxes = outputs.pred_boxes
|
| 1904 |
+
>>> list(boxes.shape)
|
| 1905 |
+
[1, 300, 4]
|
| 1906 |
+
|
| 1907 |
+
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
| 1908 |
+
>>> target_sizes = torch.tensor([image.size[::-1]])
|
| 1909 |
+
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
|
| 1910 |
+
... 0
|
| 1911 |
+
... ]
|
| 1912 |
+
|
| 1913 |
+
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 1914 |
+
... box = [round(i, 2) for i in box.tolist()]
|
| 1915 |
+
... print(
|
| 1916 |
+
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
| 1917 |
+
... f"{round(score.item(), 3)} at location {box}"
|
| 1918 |
+
... )
|
| 1919 |
+
Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21]
|
| 1920 |
+
Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5]
|
| 1921 |
+
Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22]
|
| 1922 |
+
Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
|
| 1923 |
+
Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
|
| 1924 |
+
```"""
|
| 1925 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1926 |
+
output_hidden_states = (
|
| 1927 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1928 |
+
)
|
| 1929 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1930 |
+
|
| 1931 |
+
outputs = self.model(
|
| 1932 |
+
pixel_values,
|
| 1933 |
+
pixel_mask=pixel_mask,
|
| 1934 |
+
encoder_outputs=encoder_outputs,
|
| 1935 |
+
inputs_embeds=inputs_embeds,
|
| 1936 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1937 |
+
labels=labels,
|
| 1938 |
+
output_attentions=output_attentions,
|
| 1939 |
+
output_hidden_states=output_hidden_states,
|
| 1940 |
+
return_dict=return_dict,
|
| 1941 |
+
)
|
| 1942 |
+
|
| 1943 |
+
denoising_meta_values = (
|
| 1944 |
+
outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
|
| 1945 |
+
)
|
| 1946 |
+
|
| 1947 |
+
outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
|
| 1948 |
+
outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
|
| 1949 |
+
predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
|
| 1950 |
+
initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
|
| 1951 |
+
|
| 1952 |
+
logits = outputs_class[:, -1]
|
| 1953 |
+
pred_boxes = outputs_coord[:, -1]
|
| 1954 |
+
|
| 1955 |
+
loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
|
| 1956 |
+
if labels is not None:
|
| 1957 |
+
enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
|
| 1958 |
+
enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
|
| 1959 |
+
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
| 1960 |
+
logits,
|
| 1961 |
+
labels,
|
| 1962 |
+
self.device,
|
| 1963 |
+
pred_boxes,
|
| 1964 |
+
self.config,
|
| 1965 |
+
outputs_class,
|
| 1966 |
+
outputs_coord,
|
| 1967 |
+
enc_topk_logits=enc_topk_logits,
|
| 1968 |
+
enc_topk_bboxes=enc_topk_bboxes,
|
| 1969 |
+
denoising_meta_values=denoising_meta_values,
|
| 1970 |
+
predicted_corners=predicted_corners,
|
| 1971 |
+
initial_reference_points=initial_reference_points,
|
| 1972 |
+
**kwargs,
|
| 1973 |
+
)
|
| 1974 |
+
|
| 1975 |
+
if not return_dict:
|
| 1976 |
+
if auxiliary_outputs is not None:
|
| 1977 |
+
output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
|
| 1978 |
+
else:
|
| 1979 |
+
output = (logits, pred_boxes) + outputs
|
| 1980 |
+
return ((loss, loss_dict) + output) if loss is not None else output
|
| 1981 |
+
|
| 1982 |
+
return RTDetrObjectDetectionOutput(
|
| 1983 |
+
loss=loss,
|
| 1984 |
+
loss_dict=loss_dict,
|
| 1985 |
+
logits=logits,
|
| 1986 |
+
pred_boxes=pred_boxes,
|
| 1987 |
+
auxiliary_outputs=auxiliary_outputs,
|
| 1988 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1989 |
+
intermediate_hidden_states=outputs.intermediate_hidden_states,
|
| 1990 |
+
intermediate_logits=outputs.intermediate_logits,
|
| 1991 |
+
intermediate_reference_points=outputs.intermediate_reference_points,
|
| 1992 |
+
intermediate_predicted_corners=outputs.intermediate_predicted_corners,
|
| 1993 |
+
initial_reference_points=outputs.initial_reference_points,
|
| 1994 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1995 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 1996 |
+
cross_attentions=outputs.cross_attentions,
|
| 1997 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1998 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1999 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 2000 |
+
init_reference_points=outputs.init_reference_points,
|
| 2001 |
+
enc_topk_logits=outputs.enc_topk_logits,
|
| 2002 |
+
enc_topk_bboxes=outputs.enc_topk_bboxes,
|
| 2003 |
+
enc_outputs_class=outputs.enc_outputs_class,
|
| 2004 |
+
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
| 2005 |
+
denoising_meta_values=outputs.denoising_meta_values,
|
| 2006 |
+
)
|
| 2007 |
+
|
| 2008 |
+
|
| 2009 |
+
__all__ = [
|
| 2010 |
+
"RTDetrForObjectDetection",
|
| 2011 |
+
"RTDetrModel",
|
| 2012 |
+
"RTDetrPreTrainedModel",
|
| 2013 |
+
]
|
phivenv/Lib/site-packages/transformers/models/rt_detr/modeling_rt_detr_resnet.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
PyTorch RTDetr specific ResNet model. The main difference between hugginface ResNet model is that this RTDetrResNet model forces to use shortcut at the first layer in the resnet-18/34 models.
|
| 17 |
+
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from torch import Tensor, nn
|
| 24 |
+
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention
|
| 27 |
+
from ...modeling_utils import PreTrainedModel
|
| 28 |
+
from ...utils import auto_docstring, logging
|
| 29 |
+
from ...utils.backbone_utils import BackboneMixin
|
| 30 |
+
from .configuration_rt_detr_resnet import RTDetrResNetConfig
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Copied from transformers.models.resnet.modeling_resnet.ResNetConvLayer -> RTDetrResNetConvLayer
|
| 37 |
+
class RTDetrResNetConvLayer(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.convolution = nn.Conv2d(
|
| 43 |
+
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
|
| 44 |
+
)
|
| 45 |
+
self.normalization = nn.BatchNorm2d(out_channels)
|
| 46 |
+
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
|
| 47 |
+
|
| 48 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 49 |
+
hidden_state = self.convolution(input)
|
| 50 |
+
hidden_state = self.normalization(hidden_state)
|
| 51 |
+
hidden_state = self.activation(hidden_state)
|
| 52 |
+
return hidden_state
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class RTDetrResNetEmbeddings(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
ResNet Embeddings (stem) composed of a deep aggressive convolution.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, config: RTDetrResNetConfig):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.embedder = nn.Sequential(
|
| 63 |
+
*[
|
| 64 |
+
RTDetrResNetConvLayer(
|
| 65 |
+
config.num_channels,
|
| 66 |
+
config.embedding_size // 2,
|
| 67 |
+
kernel_size=3,
|
| 68 |
+
stride=2,
|
| 69 |
+
activation=config.hidden_act,
|
| 70 |
+
),
|
| 71 |
+
RTDetrResNetConvLayer(
|
| 72 |
+
config.embedding_size // 2,
|
| 73 |
+
config.embedding_size // 2,
|
| 74 |
+
kernel_size=3,
|
| 75 |
+
stride=1,
|
| 76 |
+
activation=config.hidden_act,
|
| 77 |
+
),
|
| 78 |
+
RTDetrResNetConvLayer(
|
| 79 |
+
config.embedding_size // 2,
|
| 80 |
+
config.embedding_size,
|
| 81 |
+
kernel_size=3,
|
| 82 |
+
stride=1,
|
| 83 |
+
activation=config.hidden_act,
|
| 84 |
+
),
|
| 85 |
+
]
|
| 86 |
+
)
|
| 87 |
+
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 88 |
+
self.num_channels = config.num_channels
|
| 89 |
+
|
| 90 |
+
def forward(self, pixel_values: Tensor) -> Tensor:
|
| 91 |
+
num_channels = pixel_values.shape[1]
|
| 92 |
+
if num_channels != self.num_channels:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 95 |
+
)
|
| 96 |
+
embedding = self.embedder(pixel_values)
|
| 97 |
+
embedding = self.pooler(embedding)
|
| 98 |
+
return embedding
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut -> RTDetrResNetChortCut
|
| 102 |
+
class RTDetrResNetShortCut(nn.Module):
|
| 103 |
+
"""
|
| 104 |
+
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
|
| 105 |
+
downsample the input using `stride=2`.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
| 111 |
+
self.normalization = nn.BatchNorm2d(out_channels)
|
| 112 |
+
|
| 113 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 114 |
+
hidden_state = self.convolution(input)
|
| 115 |
+
hidden_state = self.normalization(hidden_state)
|
| 116 |
+
return hidden_state
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class RTDetrResNetBasicLayer(nn.Module):
|
| 120 |
+
"""
|
| 121 |
+
A classic ResNet's residual layer composed by two `3x3` convolutions.
|
| 122 |
+
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L34.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
config: RTDetrResNetConfig,
|
| 128 |
+
in_channels: int,
|
| 129 |
+
out_channels: int,
|
| 130 |
+
stride: int = 1,
|
| 131 |
+
should_apply_shortcut: bool = False,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
if in_channels != out_channels:
|
| 135 |
+
self.shortcut = (
|
| 136 |
+
nn.Sequential(
|
| 137 |
+
*[nn.AvgPool2d(2, 2, 0, ceil_mode=True), RTDetrResNetShortCut(in_channels, out_channels, stride=1)]
|
| 138 |
+
)
|
| 139 |
+
if should_apply_shortcut
|
| 140 |
+
else nn.Identity()
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
self.shortcut = (
|
| 144 |
+
RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
|
| 145 |
+
if should_apply_shortcut
|
| 146 |
+
else nn.Identity()
|
| 147 |
+
)
|
| 148 |
+
self.layer = nn.Sequential(
|
| 149 |
+
RTDetrResNetConvLayer(in_channels, out_channels, stride=stride),
|
| 150 |
+
RTDetrResNetConvLayer(out_channels, out_channels, activation=None),
|
| 151 |
+
)
|
| 152 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 153 |
+
|
| 154 |
+
def forward(self, hidden_state):
|
| 155 |
+
residual = hidden_state
|
| 156 |
+
hidden_state = self.layer(hidden_state)
|
| 157 |
+
residual = self.shortcut(residual)
|
| 158 |
+
hidden_state += residual
|
| 159 |
+
hidden_state = self.activation(hidden_state)
|
| 160 |
+
return hidden_state
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class RTDetrResNetBottleNeckLayer(nn.Module):
|
| 164 |
+
"""
|
| 165 |
+
A classic RTDetrResNet's bottleneck layer composed by three `3x3` convolutions.
|
| 166 |
+
|
| 167 |
+
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
|
| 168 |
+
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
|
| 169 |
+
`downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
config: RTDetrResNetConfig,
|
| 175 |
+
in_channels: int,
|
| 176 |
+
out_channels: int,
|
| 177 |
+
stride: int = 1,
|
| 178 |
+
):
|
| 179 |
+
super().__init__()
|
| 180 |
+
reduction = 4
|
| 181 |
+
should_apply_shortcut = in_channels != out_channels or stride != 1
|
| 182 |
+
reduces_channels = out_channels // reduction
|
| 183 |
+
if stride == 2:
|
| 184 |
+
self.shortcut = nn.Sequential(
|
| 185 |
+
*[
|
| 186 |
+
nn.AvgPool2d(2, 2, 0, ceil_mode=True),
|
| 187 |
+
RTDetrResNetShortCut(in_channels, out_channels, stride=1)
|
| 188 |
+
if should_apply_shortcut
|
| 189 |
+
else nn.Identity(),
|
| 190 |
+
]
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
self.shortcut = (
|
| 194 |
+
RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
|
| 195 |
+
if should_apply_shortcut
|
| 196 |
+
else nn.Identity()
|
| 197 |
+
)
|
| 198 |
+
self.layer = nn.Sequential(
|
| 199 |
+
RTDetrResNetConvLayer(
|
| 200 |
+
in_channels, reduces_channels, kernel_size=1, stride=stride if config.downsample_in_bottleneck else 1
|
| 201 |
+
),
|
| 202 |
+
RTDetrResNetConvLayer(
|
| 203 |
+
reduces_channels, reduces_channels, stride=stride if not config.downsample_in_bottleneck else 1
|
| 204 |
+
),
|
| 205 |
+
RTDetrResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
|
| 206 |
+
)
|
| 207 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 208 |
+
|
| 209 |
+
def forward(self, hidden_state):
|
| 210 |
+
residual = hidden_state
|
| 211 |
+
hidden_state = self.layer(hidden_state)
|
| 212 |
+
residual = self.shortcut(residual)
|
| 213 |
+
hidden_state += residual
|
| 214 |
+
hidden_state = self.activation(hidden_state)
|
| 215 |
+
return hidden_state
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class RTDetrResNetStage(nn.Module):
|
| 219 |
+
"""
|
| 220 |
+
A RTDetrResNet stage composed by stacked layers.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
config: RTDetrResNetConfig,
|
| 226 |
+
in_channels: int,
|
| 227 |
+
out_channels: int,
|
| 228 |
+
stride: int = 2,
|
| 229 |
+
depth: int = 2,
|
| 230 |
+
):
|
| 231 |
+
super().__init__()
|
| 232 |
+
|
| 233 |
+
layer = RTDetrResNetBottleNeckLayer if config.layer_type == "bottleneck" else RTDetrResNetBasicLayer
|
| 234 |
+
|
| 235 |
+
if config.layer_type == "bottleneck":
|
| 236 |
+
first_layer = layer(
|
| 237 |
+
config,
|
| 238 |
+
in_channels,
|
| 239 |
+
out_channels,
|
| 240 |
+
stride=stride,
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
first_layer = layer(config, in_channels, out_channels, stride=stride, should_apply_shortcut=True)
|
| 244 |
+
self.layers = nn.Sequential(
|
| 245 |
+
first_layer, *[layer(config, out_channels, out_channels) for _ in range(depth - 1)]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 249 |
+
hidden_state = input
|
| 250 |
+
for layer in self.layers:
|
| 251 |
+
hidden_state = layer(hidden_state)
|
| 252 |
+
return hidden_state
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# Copied from transformers.models.resnet.modeling_resnet.ResNetEncoder with ResNet->RTDetrResNet
|
| 256 |
+
class RTDetrResNetEncoder(nn.Module):
|
| 257 |
+
def __init__(self, config: RTDetrResNetConfig):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.stages = nn.ModuleList([])
|
| 260 |
+
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
|
| 261 |
+
self.stages.append(
|
| 262 |
+
RTDetrResNetStage(
|
| 263 |
+
config,
|
| 264 |
+
config.embedding_size,
|
| 265 |
+
config.hidden_sizes[0],
|
| 266 |
+
stride=2 if config.downsample_in_first_stage else 1,
|
| 267 |
+
depth=config.depths[0],
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
|
| 271 |
+
for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
|
| 272 |
+
self.stages.append(RTDetrResNetStage(config, in_channels, out_channels, depth=depth))
|
| 273 |
+
|
| 274 |
+
def forward(
|
| 275 |
+
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
|
| 276 |
+
) -> BaseModelOutputWithNoAttention:
|
| 277 |
+
hidden_states = () if output_hidden_states else None
|
| 278 |
+
|
| 279 |
+
for stage_module in self.stages:
|
| 280 |
+
if output_hidden_states:
|
| 281 |
+
hidden_states = hidden_states + (hidden_state,)
|
| 282 |
+
|
| 283 |
+
hidden_state = stage_module(hidden_state)
|
| 284 |
+
|
| 285 |
+
if output_hidden_states:
|
| 286 |
+
hidden_states = hidden_states + (hidden_state,)
|
| 287 |
+
|
| 288 |
+
if not return_dict:
|
| 289 |
+
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
|
| 290 |
+
|
| 291 |
+
return BaseModelOutputWithNoAttention(
|
| 292 |
+
last_hidden_state=hidden_state,
|
| 293 |
+
hidden_states=hidden_states,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@auto_docstring
|
| 298 |
+
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RTDetrResNet
|
| 299 |
+
class RTDetrResNetPreTrainedModel(PreTrainedModel):
|
| 300 |
+
config: RTDetrResNetConfig
|
| 301 |
+
base_model_prefix = "resnet"
|
| 302 |
+
main_input_name = "pixel_values"
|
| 303 |
+
_no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"]
|
| 304 |
+
|
| 305 |
+
def _init_weights(self, module):
|
| 306 |
+
if isinstance(module, nn.Conv2d):
|
| 307 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 308 |
+
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
| 309 |
+
elif isinstance(module, nn.Linear):
|
| 310 |
+
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| 311 |
+
if module.bias is not None:
|
| 312 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| 313 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 314 |
+
nn.init.uniform_(module.bias, -bound, bound)
|
| 315 |
+
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 316 |
+
nn.init.constant_(module.weight, 1)
|
| 317 |
+
nn.init.constant_(module.bias, 0)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@auto_docstring(
|
| 321 |
+
custom_intro="""
|
| 322 |
+
ResNet backbone, to be used with frameworks like RTDETR.
|
| 323 |
+
"""
|
| 324 |
+
)
|
| 325 |
+
class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin):
|
| 326 |
+
has_attentions = False
|
| 327 |
+
|
| 328 |
+
def __init__(self, config):
|
| 329 |
+
super().__init__(config)
|
| 330 |
+
super()._init_backbone(config)
|
| 331 |
+
|
| 332 |
+
self.num_features = [config.embedding_size] + config.hidden_sizes
|
| 333 |
+
self.embedder = RTDetrResNetEmbeddings(config)
|
| 334 |
+
self.encoder = RTDetrResNetEncoder(config)
|
| 335 |
+
|
| 336 |
+
# initialize weights and apply final processing
|
| 337 |
+
self.post_init()
|
| 338 |
+
|
| 339 |
+
@auto_docstring
|
| 340 |
+
def forward(
|
| 341 |
+
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
|
| 342 |
+
) -> BackboneOutput:
|
| 343 |
+
r"""
|
| 344 |
+
Examples:
|
| 345 |
+
|
| 346 |
+
```python
|
| 347 |
+
>>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone
|
| 348 |
+
>>> import torch
|
| 349 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 350 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 351 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 352 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 353 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 354 |
+
|
| 355 |
+
>>> config = RTDetrResNetConfig()
|
| 356 |
+
>>> model = RTDetrResNetBackbone(config)
|
| 357 |
+
|
| 358 |
+
>>> pixel_values = torch.randn(1, 3, 224, 224)
|
| 359 |
+
|
| 360 |
+
>>> with torch.no_grad():
|
| 361 |
+
... outputs = model(pixel_values)
|
| 362 |
+
|
| 363 |
+
>>> feature_maps = outputs.feature_maps
|
| 364 |
+
>>> list(feature_maps[-1].shape)
|
| 365 |
+
[1, 2048, 7, 7]
|
| 366 |
+
```"""
|
| 367 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 368 |
+
output_hidden_states = (
|
| 369 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
embedding_output = self.embedder(pixel_values)
|
| 373 |
+
|
| 374 |
+
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
|
| 375 |
+
|
| 376 |
+
hidden_states = outputs.hidden_states
|
| 377 |
+
|
| 378 |
+
feature_maps = ()
|
| 379 |
+
for idx, stage in enumerate(self.stage_names):
|
| 380 |
+
if stage in self.out_features:
|
| 381 |
+
feature_maps += (hidden_states[idx],)
|
| 382 |
+
|
| 383 |
+
if not return_dict:
|
| 384 |
+
output = (feature_maps,)
|
| 385 |
+
if output_hidden_states:
|
| 386 |
+
output += (outputs.hidden_states,)
|
| 387 |
+
return output
|
| 388 |
+
|
| 389 |
+
return BackboneOutput(
|
| 390 |
+
feature_maps=feature_maps,
|
| 391 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 392 |
+
attentions=None,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
__all__ = [
|
| 397 |
+
"RTDetrResNetBackbone",
|
| 398 |
+
"RTDetrResNetPreTrainedModel",
|
| 399 |
+
]
|
phivenv/Lib/site-packages/transformers/models/rt_detr/modular_rt_detr.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
from transformers.models.detr.image_processing_detr_fast import DetrFastImageProcessorKwargs, DetrImageProcessorFast
|
| 5 |
+
|
| 6 |
+
from ...image_processing_utils import BatchFeature
|
| 7 |
+
from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, get_max_height_width
|
| 8 |
+
from ...image_transforms import center_to_corners_format
|
| 9 |
+
from ...image_utils import (
|
| 10 |
+
IMAGENET_DEFAULT_MEAN,
|
| 11 |
+
IMAGENET_DEFAULT_STD,
|
| 12 |
+
AnnotationFormat,
|
| 13 |
+
AnnotationType,
|
| 14 |
+
ChannelDimension,
|
| 15 |
+
ImageInput,
|
| 16 |
+
PILImageResampling,
|
| 17 |
+
get_image_size,
|
| 18 |
+
validate_annotations,
|
| 19 |
+
)
|
| 20 |
+
from ...processing_utils import Unpack
|
| 21 |
+
from ...utils import (
|
| 22 |
+
TensorType,
|
| 23 |
+
is_torch_available,
|
| 24 |
+
is_torchvision_available,
|
| 25 |
+
is_torchvision_v2_available,
|
| 26 |
+
logging,
|
| 27 |
+
requires_backends,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_torch_available():
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_torchvision_v2_available():
|
| 36 |
+
from torchvision.transforms.v2 import functional as F
|
| 37 |
+
elif is_torchvision_available():
|
| 38 |
+
from torchvision.transforms import functional as F
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def prepare_coco_detection_annotation(
|
| 47 |
+
image,
|
| 48 |
+
target,
|
| 49 |
+
return_segmentation_masks: bool = False,
|
| 50 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Convert the target in COCO format into the format expected by RT-DETR.
|
| 54 |
+
"""
|
| 55 |
+
image_height, image_width = image.size()[-2:]
|
| 56 |
+
|
| 57 |
+
image_id = target["image_id"]
|
| 58 |
+
image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
|
| 59 |
+
|
| 60 |
+
# Get all COCO annotations for the given image.
|
| 61 |
+
annotations = target["annotations"]
|
| 62 |
+
classes = []
|
| 63 |
+
area = []
|
| 64 |
+
boxes = []
|
| 65 |
+
keypoints = []
|
| 66 |
+
for obj in annotations:
|
| 67 |
+
if "iscrowd" not in obj or obj["iscrowd"] == 0:
|
| 68 |
+
classes.append(obj["category_id"])
|
| 69 |
+
area.append(obj["area"])
|
| 70 |
+
boxes.append(obj["bbox"])
|
| 71 |
+
if "keypoints" in obj:
|
| 72 |
+
keypoints.append(obj["keypoints"])
|
| 73 |
+
|
| 74 |
+
classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
|
| 75 |
+
area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
|
| 76 |
+
iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
|
| 77 |
+
# guard against no boxes via resizing
|
| 78 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
|
| 79 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 80 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 81 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 82 |
+
|
| 83 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 84 |
+
|
| 85 |
+
new_target = {
|
| 86 |
+
"image_id": image_id,
|
| 87 |
+
"class_labels": classes[keep],
|
| 88 |
+
"boxes": boxes[keep],
|
| 89 |
+
"area": area[keep],
|
| 90 |
+
"iscrowd": iscrowd[keep],
|
| 91 |
+
"orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
if keypoints:
|
| 95 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
|
| 96 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 97 |
+
keypoints = keypoints[keep]
|
| 98 |
+
num_keypoints = keypoints.shape[0]
|
| 99 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 100 |
+
new_target["keypoints"] = keypoints
|
| 101 |
+
|
| 102 |
+
return new_target
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class RTDetrFastImageProcessorKwargs(DetrFastImageProcessorKwargs):
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class RTDetrImageProcessorFast(DetrImageProcessorFast):
|
| 110 |
+
resample = PILImageResampling.BILINEAR
|
| 111 |
+
image_mean = IMAGENET_DEFAULT_MEAN
|
| 112 |
+
image_std = IMAGENET_DEFAULT_STD
|
| 113 |
+
format = AnnotationFormat.COCO_DETECTION
|
| 114 |
+
do_convert_annotations = True
|
| 115 |
+
do_resize = True
|
| 116 |
+
do_rescale = True
|
| 117 |
+
do_normalize = False
|
| 118 |
+
do_pad = False
|
| 119 |
+
size = {"height": 640, "width": 640}
|
| 120 |
+
default_to_square = False
|
| 121 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 122 |
+
valid_kwargs = RTDetrFastImageProcessorKwargs
|
| 123 |
+
|
| 124 |
+
def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None:
|
| 125 |
+
# Backwards compatibility
|
| 126 |
+
do_convert_annotations = kwargs.get("do_convert_annotations")
|
| 127 |
+
do_normalize = kwargs.get("do_normalize")
|
| 128 |
+
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
|
| 129 |
+
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
|
| 130 |
+
|
| 131 |
+
BaseImageProcessorFast.__init__(self, **kwargs)
|
| 132 |
+
|
| 133 |
+
def preprocess(
|
| 134 |
+
self,
|
| 135 |
+
images: ImageInput,
|
| 136 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
|
| 137 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 138 |
+
**kwargs: Unpack[RTDetrFastImageProcessorKwargs],
|
| 139 |
+
) -> BatchFeature:
|
| 140 |
+
return BaseImageProcessorFast.preprocess(self, images, annotations, masks_path, **kwargs)
|
| 141 |
+
|
| 142 |
+
def prepare_annotation(
|
| 143 |
+
self,
|
| 144 |
+
image: torch.Tensor,
|
| 145 |
+
target: dict,
|
| 146 |
+
format: Optional[AnnotationFormat] = None,
|
| 147 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 148 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 149 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 150 |
+
) -> dict:
|
| 151 |
+
format = format if format is not None else self.format
|
| 152 |
+
|
| 153 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 154 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 155 |
+
target = prepare_coco_detection_annotation(
|
| 156 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 160 |
+
return target
|
| 161 |
+
|
| 162 |
+
def _preprocess(
|
| 163 |
+
self,
|
| 164 |
+
images: list["torch.Tensor"],
|
| 165 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
|
| 166 |
+
masks_path: Optional[Union[str, pathlib.Path]],
|
| 167 |
+
return_segmentation_masks: bool,
|
| 168 |
+
do_resize: bool,
|
| 169 |
+
size: SizeDict,
|
| 170 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 171 |
+
do_rescale: bool,
|
| 172 |
+
rescale_factor: float,
|
| 173 |
+
do_normalize: bool,
|
| 174 |
+
do_convert_annotations: bool,
|
| 175 |
+
image_mean: Optional[Union[float, list[float]]],
|
| 176 |
+
image_std: Optional[Union[float, list[float]]],
|
| 177 |
+
do_pad: bool,
|
| 178 |
+
pad_size: Optional[dict[str, int]],
|
| 179 |
+
format: Optional[Union[str, AnnotationFormat]],
|
| 180 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 181 |
+
**kwargs,
|
| 182 |
+
) -> BatchFeature:
|
| 183 |
+
"""
|
| 184 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 188 |
+
annotations = [annotations]
|
| 189 |
+
|
| 190 |
+
if annotations is not None and len(images) != len(annotations):
|
| 191 |
+
raise ValueError(
|
| 192 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
format = AnnotationFormat(format)
|
| 196 |
+
if annotations is not None:
|
| 197 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 198 |
+
|
| 199 |
+
data = {}
|
| 200 |
+
processed_images = []
|
| 201 |
+
processed_annotations = []
|
| 202 |
+
pixel_masks = [] # Initialize pixel_masks here
|
| 203 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 204 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 205 |
+
if annotations is not None:
|
| 206 |
+
annotation = self.prepare_annotation(
|
| 207 |
+
image,
|
| 208 |
+
annotation,
|
| 209 |
+
format,
|
| 210 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 211 |
+
masks_path=masks_path,
|
| 212 |
+
input_data_format=ChannelDimension.FIRST,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if do_resize:
|
| 216 |
+
resized_image = self.resize(image, size=size, interpolation=interpolation)
|
| 217 |
+
if annotations is not None:
|
| 218 |
+
annotation = self.resize_annotation(
|
| 219 |
+
annotation,
|
| 220 |
+
orig_size=image.size()[-2:],
|
| 221 |
+
target_size=resized_image.size()[-2:],
|
| 222 |
+
)
|
| 223 |
+
image = resized_image
|
| 224 |
+
# Fused rescale and normalize
|
| 225 |
+
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
| 226 |
+
if do_convert_annotations and annotations is not None:
|
| 227 |
+
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
| 228 |
+
|
| 229 |
+
processed_images.append(image)
|
| 230 |
+
processed_annotations.append(annotation)
|
| 231 |
+
images = processed_images
|
| 232 |
+
annotations = processed_annotations if annotations is not None else None
|
| 233 |
+
|
| 234 |
+
if do_pad:
|
| 235 |
+
# depends on all resized image shapes so we need another loop
|
| 236 |
+
if pad_size is not None:
|
| 237 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 238 |
+
else:
|
| 239 |
+
padded_size = get_max_height_width(images)
|
| 240 |
+
|
| 241 |
+
padded_images = []
|
| 242 |
+
padded_annotations = []
|
| 243 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 244 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 245 |
+
if padded_size == image.size()[-2:]:
|
| 246 |
+
padded_images.append(image)
|
| 247 |
+
pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
|
| 248 |
+
padded_annotations.append(annotation)
|
| 249 |
+
continue
|
| 250 |
+
image, pixel_mask, annotation = self.pad(
|
| 251 |
+
image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
|
| 252 |
+
)
|
| 253 |
+
padded_images.append(image)
|
| 254 |
+
padded_annotations.append(annotation)
|
| 255 |
+
pixel_masks.append(pixel_mask)
|
| 256 |
+
images = padded_images
|
| 257 |
+
annotations = padded_annotations if annotations is not None else None
|
| 258 |
+
data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
|
| 259 |
+
|
| 260 |
+
data.update({"pixel_values": torch.stack(images, dim=0)})
|
| 261 |
+
encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
|
| 262 |
+
if annotations is not None:
|
| 263 |
+
encoded_inputs["labels"] = [
|
| 264 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 265 |
+
]
|
| 266 |
+
return encoded_inputs
|
| 267 |
+
|
| 268 |
+
def post_process_object_detection(
|
| 269 |
+
self,
|
| 270 |
+
outputs,
|
| 271 |
+
threshold: float = 0.5,
|
| 272 |
+
target_sizes: Union[TensorType, list[tuple]] = None,
|
| 273 |
+
use_focal_loss: bool = True,
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 277 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 281 |
+
Raw outputs of the model.
|
| 282 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 283 |
+
Score threshold to keep object detection predictions.
|
| 284 |
+
target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
|
| 285 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
|
| 286 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 287 |
+
use_focal_loss (`bool` defaults to `True`):
|
| 288 |
+
Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
|
| 289 |
+
to compute the scores of each detection, otherwise, a softmax function is used.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 293 |
+
in the batch as predicted by the model.
|
| 294 |
+
"""
|
| 295 |
+
requires_backends(self, ["torch"])
|
| 296 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 297 |
+
# convert from relative cxcywh to absolute xyxy
|
| 298 |
+
boxes = center_to_corners_format(out_bbox)
|
| 299 |
+
if target_sizes is not None:
|
| 300 |
+
if len(out_logits) != len(target_sizes):
|
| 301 |
+
raise ValueError(
|
| 302 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 303 |
+
)
|
| 304 |
+
if isinstance(target_sizes, list):
|
| 305 |
+
img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
|
| 306 |
+
else:
|
| 307 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 308 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 309 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 310 |
+
|
| 311 |
+
num_top_queries = out_logits.shape[1]
|
| 312 |
+
num_classes = out_logits.shape[2]
|
| 313 |
+
|
| 314 |
+
if use_focal_loss:
|
| 315 |
+
scores = torch.nn.functional.sigmoid(out_logits)
|
| 316 |
+
scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
|
| 317 |
+
labels = index % num_classes
|
| 318 |
+
index = index // num_classes
|
| 319 |
+
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
| 320 |
+
else:
|
| 321 |
+
scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
|
| 322 |
+
scores, labels = scores.max(dim=-1)
|
| 323 |
+
if scores.shape[1] > num_top_queries:
|
| 324 |
+
scores, index = torch.topk(scores, num_top_queries, dim=-1)
|
| 325 |
+
labels = torch.gather(labels, dim=1, index=index)
|
| 326 |
+
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
|
| 327 |
+
|
| 328 |
+
results = []
|
| 329 |
+
for score, label, box in zip(scores, labels, boxes):
|
| 330 |
+
results.append(
|
| 331 |
+
{
|
| 332 |
+
"scores": score[score > threshold],
|
| 333 |
+
"labels": label[score > threshold],
|
| 334 |
+
"boxes": box[score > threshold],
|
| 335 |
+
}
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return results
|
| 339 |
+
|
| 340 |
+
def from_dict():
|
| 341 |
+
raise NotImplementedError("No need to override this method for RT-DETR yet.")
|
| 342 |
+
|
| 343 |
+
def post_process():
|
| 344 |
+
raise NotImplementedError("Post-processing is not implemented for RT-DETR yet.")
|
| 345 |
+
|
| 346 |
+
def post_process_segmentation():
|
| 347 |
+
raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.")
|
| 348 |
+
|
| 349 |
+
def post_process_instance():
|
| 350 |
+
raise NotImplementedError("Instance post-processing is not implemented for RT-DETR yet.")
|
| 351 |
+
|
| 352 |
+
def post_process_panoptic():
|
| 353 |
+
raise NotImplementedError("Panoptic post-processing is not implemented for RT-DETR yet.")
|
| 354 |
+
|
| 355 |
+
def post_process_instance_segmentation():
|
| 356 |
+
raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.")
|
| 357 |
+
|
| 358 |
+
def post_process_semantic_segmentation():
|
| 359 |
+
raise NotImplementedError("Semantic segmentation post-processing is not implemented for RT-DETR yet.")
|
| 360 |
+
|
| 361 |
+
def post_process_panoptic_segmentation():
|
| 362 |
+
raise NotImplementedError("Panoptic segmentation post-processing is not implemented for RT-DETR yet.")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
__all__ = ["RTDetrImageProcessorFast"]
|
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
from ...utils import _LazyModule
|
| 19 |
+
from ...utils.import_utils import define_import_structure
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from .configuration_rt_detr_v2 import *
|
| 24 |
+
from .modeling_rt_detr_v2 import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (540 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/configuration_rt_detr_v2.cpython-39.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/modeling_rt_detr_v2.cpython-39.pyc
ADDED
|
Binary file (64.3 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/modular_rt_detr_v2.cpython-39.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_rt_detr_v2.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import PretrainedConfig
|
| 23 |
+
from ...utils import logging
|
| 24 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 25 |
+
from ..auto import CONFIG_MAPPING
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RTDetrV2Config(PretrainedConfig):
|
| 32 |
+
r"""
|
| 33 |
+
This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a
|
| 34 |
+
RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 35 |
+
with the defaults will yield a similar configuration to that of the RT-DETR architecture.
|
| 36 |
+
|
| 37 |
+
e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd)
|
| 38 |
+
|
| 39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 40 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
initializer_range (`float`, *optional*, defaults to 0.01):
|
| 44 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 45 |
+
initializer_bias_prior_prob (`float`, *optional*):
|
| 46 |
+
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
| 47 |
+
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
| 48 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 49 |
+
The epsilon used by the layer normalization layers.
|
| 50 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 51 |
+
The epsilon used by the batch normalization layers.
|
| 52 |
+
backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
|
| 53 |
+
The configuration of the backbone model.
|
| 54 |
+
backbone (`str`, *optional*):
|
| 55 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 56 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 57 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 58 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 59 |
+
Whether to use pretrained weights for the backbone.
|
| 60 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 61 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 62 |
+
library.
|
| 63 |
+
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
|
| 64 |
+
Whether to freeze the batch normalization layers in the backbone.
|
| 65 |
+
backbone_kwargs (`dict`, *optional*):
|
| 66 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 67 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 68 |
+
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 69 |
+
Dimension of the layers in hybrid encoder.
|
| 70 |
+
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
|
| 71 |
+
Multi level features input for encoder.
|
| 72 |
+
feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
|
| 73 |
+
Strides used in each feature map.
|
| 74 |
+
encoder_layers (`int`, *optional*, defaults to 1):
|
| 75 |
+
Total of layers to be used by the encoder.
|
| 76 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 77 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 78 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 79 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 80 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 81 |
+
The ratio for all dropout layers.
|
| 82 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 83 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 84 |
+
encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
|
| 85 |
+
Indexes of the projected layers to be used in the encoder.
|
| 86 |
+
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
| 87 |
+
The temperature parameter used to create the positional encodings.
|
| 88 |
+
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
|
| 89 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 90 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 91 |
+
activation_function (`str`, *optional*, defaults to `"silu"`):
|
| 92 |
+
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
|
| 93 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 94 |
+
eval_size (`tuple[int, int]`, *optional*):
|
| 95 |
+
Height and width used to compute the effective height and width of the position embeddings after taking
|
| 96 |
+
into account the stride.
|
| 97 |
+
normalize_before (`bool`, *optional*, defaults to `False`):
|
| 98 |
+
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
|
| 99 |
+
feed-forward modules.
|
| 100 |
+
hidden_expansion (`float`, *optional*, defaults to 1.0):
|
| 101 |
+
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
|
| 102 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 103 |
+
Dimension of the layers exclude hybrid encoder.
|
| 104 |
+
num_queries (`int`, *optional*, defaults to 300):
|
| 105 |
+
Number of object queries.
|
| 106 |
+
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
|
| 107 |
+
Multi level features dimension for decoder
|
| 108 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 109 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 110 |
+
num_feature_levels (`int`, *optional*, defaults to 3):
|
| 111 |
+
The number of input feature levels.
|
| 112 |
+
decoder_n_points (`int`, *optional*, defaults to 4):
|
| 113 |
+
The number of sampled keys in each feature level for each attention head in the decoder.
|
| 114 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 115 |
+
Number of decoder layers.
|
| 116 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 117 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 118 |
+
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
|
| 119 |
+
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
|
| 120 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 121 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 122 |
+
The dropout ratio for the attention probabilities.
|
| 123 |
+
num_denoising (`int`, *optional*, defaults to 100):
|
| 124 |
+
The total number of denoising tasks or queries to be used for contrastive denoising.
|
| 125 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 126 |
+
The fraction of denoising labels to which random noise should be added.
|
| 127 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 128 |
+
Scale or magnitude of noise to be added to the bounding boxes.
|
| 129 |
+
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
| 130 |
+
Indicates whether the initial query embeddings for the decoder should be learned during training
|
| 131 |
+
anchor_image_size (`tuple[int, int]`, *optional*):
|
| 132 |
+
Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
|
| 133 |
+
with_box_refine (`bool`, *optional*, defaults to `True`):
|
| 134 |
+
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
| 135 |
+
based on the predictions from the previous layer.
|
| 136 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 137 |
+
Whether the architecture has an encoder decoder structure.
|
| 138 |
+
matcher_alpha (`float`, *optional*, defaults to 0.25):
|
| 139 |
+
Parameter alpha used by the Hungarian Matcher.
|
| 140 |
+
matcher_gamma (`float`, *optional*, defaults to 2.0):
|
| 141 |
+
Parameter gamma used by the Hungarian Matcher.
|
| 142 |
+
matcher_class_cost (`float`, *optional*, defaults to 2.0):
|
| 143 |
+
The relative weight of the class loss used by the Hungarian Matcher.
|
| 144 |
+
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
|
| 145 |
+
The relative weight of the bounding box loss used by the Hungarian Matcher.
|
| 146 |
+
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
|
| 147 |
+
The relative weight of the giou loss of used by the Hungarian Matcher.
|
| 148 |
+
use_focal_loss (`bool`, *optional*, defaults to `True`):
|
| 149 |
+
Parameter informing if focal loss should be used.
|
| 150 |
+
auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
| 151 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 152 |
+
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
|
| 153 |
+
Parameter alpha used to compute the focal loss.
|
| 154 |
+
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
|
| 155 |
+
Parameter gamma used to compute the focal loss.
|
| 156 |
+
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
|
| 157 |
+
Relative weight of the varifocal loss in the object detection loss.
|
| 158 |
+
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
|
| 159 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 160 |
+
weight_loss_giou (`float`, *optional*, defaults to 2.0):
|
| 161 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 162 |
+
eos_coefficient (`float`, *optional*, defaults to 0.0001):
|
| 163 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 164 |
+
decoder_n_levels (`int`, *optional*, defaults to 3):
|
| 165 |
+
The number of feature levels used by the decoder.
|
| 166 |
+
decoder_offset_scale (`float`, *optional*, defaults to 0.5):
|
| 167 |
+
Scaling factor applied to the attention offsets in the decoder.
|
| 168 |
+
decoder_method (`str`, *optional*, defaults to `"default"`):
|
| 169 |
+
The method to use for the decoder: `"default"` or `"discrete"`.
|
| 170 |
+
|
| 171 |
+
Examples:
|
| 172 |
+
|
| 173 |
+
```python
|
| 174 |
+
>>> from transformers import RTDetrV2Config, RTDetrV2Model
|
| 175 |
+
|
| 176 |
+
>>> # Initializing a RT-DETR configuration
|
| 177 |
+
>>> configuration = RTDetrV2Config()
|
| 178 |
+
|
| 179 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 180 |
+
>>> model = RTDetrV2Model(configuration)
|
| 181 |
+
|
| 182 |
+
>>> # Accessing the model configuration
|
| 183 |
+
>>> configuration = model.config
|
| 184 |
+
```
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
model_type = "rt_detr_v2"
|
| 188 |
+
layer_types = ["basic", "bottleneck"]
|
| 189 |
+
attribute_map = {
|
| 190 |
+
"hidden_size": "d_model",
|
| 191 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
initializer_range=0.01,
|
| 197 |
+
initializer_bias_prior_prob=None,
|
| 198 |
+
layer_norm_eps=1e-5,
|
| 199 |
+
batch_norm_eps=1e-5,
|
| 200 |
+
# backbone
|
| 201 |
+
backbone_config=None,
|
| 202 |
+
backbone=None,
|
| 203 |
+
use_pretrained_backbone=False,
|
| 204 |
+
use_timm_backbone=False,
|
| 205 |
+
freeze_backbone_batch_norms=True,
|
| 206 |
+
backbone_kwargs=None,
|
| 207 |
+
# encoder HybridEncoder
|
| 208 |
+
encoder_hidden_dim=256,
|
| 209 |
+
encoder_in_channels=[512, 1024, 2048],
|
| 210 |
+
feat_strides=[8, 16, 32],
|
| 211 |
+
encoder_layers=1,
|
| 212 |
+
encoder_ffn_dim=1024,
|
| 213 |
+
encoder_attention_heads=8,
|
| 214 |
+
dropout=0.0,
|
| 215 |
+
activation_dropout=0.0,
|
| 216 |
+
encode_proj_layers=[2],
|
| 217 |
+
positional_encoding_temperature=10000,
|
| 218 |
+
encoder_activation_function="gelu",
|
| 219 |
+
activation_function="silu",
|
| 220 |
+
eval_size=None,
|
| 221 |
+
normalize_before=False,
|
| 222 |
+
hidden_expansion=1.0,
|
| 223 |
+
# decoder RTDetrV2Transformer
|
| 224 |
+
d_model=256,
|
| 225 |
+
num_queries=300,
|
| 226 |
+
decoder_in_channels=[256, 256, 256],
|
| 227 |
+
decoder_ffn_dim=1024,
|
| 228 |
+
num_feature_levels=3,
|
| 229 |
+
decoder_n_points=4,
|
| 230 |
+
decoder_layers=6,
|
| 231 |
+
decoder_attention_heads=8,
|
| 232 |
+
decoder_activation_function="relu",
|
| 233 |
+
attention_dropout=0.0,
|
| 234 |
+
num_denoising=100,
|
| 235 |
+
label_noise_ratio=0.5,
|
| 236 |
+
box_noise_scale=1.0,
|
| 237 |
+
learn_initial_query=False,
|
| 238 |
+
anchor_image_size=None,
|
| 239 |
+
with_box_refine=True,
|
| 240 |
+
is_encoder_decoder=True,
|
| 241 |
+
# Loss
|
| 242 |
+
matcher_alpha=0.25,
|
| 243 |
+
matcher_gamma=2.0,
|
| 244 |
+
matcher_class_cost=2.0,
|
| 245 |
+
matcher_bbox_cost=5.0,
|
| 246 |
+
matcher_giou_cost=2.0,
|
| 247 |
+
use_focal_loss=True,
|
| 248 |
+
auxiliary_loss=True,
|
| 249 |
+
focal_loss_alpha=0.75,
|
| 250 |
+
focal_loss_gamma=2.0,
|
| 251 |
+
weight_loss_vfl=1.0,
|
| 252 |
+
weight_loss_bbox=5.0,
|
| 253 |
+
weight_loss_giou=2.0,
|
| 254 |
+
eos_coefficient=1e-4,
|
| 255 |
+
decoder_n_levels=3, # default value
|
| 256 |
+
decoder_offset_scale=0.5, # default value
|
| 257 |
+
decoder_method="default",
|
| 258 |
+
**kwargs,
|
| 259 |
+
):
|
| 260 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 261 |
+
self.initializer_range = initializer_range
|
| 262 |
+
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
| 263 |
+
self.layer_norm_eps = layer_norm_eps
|
| 264 |
+
self.batch_norm_eps = batch_norm_eps
|
| 265 |
+
# backbone
|
| 266 |
+
if backbone_config is None and backbone is None:
|
| 267 |
+
logger.info(
|
| 268 |
+
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone."
|
| 269 |
+
)
|
| 270 |
+
backbone_model_type = "rt_detr_resnet"
|
| 271 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 272 |
+
# this will map it to RTDetrResNetConfig
|
| 273 |
+
# note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1
|
| 274 |
+
# and we would need to create RTDetrV2ResNetModel
|
| 275 |
+
backbone_config = config_class(
|
| 276 |
+
num_channels=3,
|
| 277 |
+
embedding_size=64,
|
| 278 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 279 |
+
depths=[3, 4, 6, 3],
|
| 280 |
+
layer_type="bottleneck",
|
| 281 |
+
hidden_act="relu",
|
| 282 |
+
downsample_in_first_stage=False,
|
| 283 |
+
downsample_in_bottleneck=False,
|
| 284 |
+
out_features=None,
|
| 285 |
+
out_indices=[2, 3, 4],
|
| 286 |
+
)
|
| 287 |
+
elif isinstance(backbone_config, dict):
|
| 288 |
+
backbone_model_type = backbone_config.pop("model_type")
|
| 289 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 290 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 291 |
+
|
| 292 |
+
verify_backbone_config_arguments(
|
| 293 |
+
use_timm_backbone=use_timm_backbone,
|
| 294 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 295 |
+
backbone=backbone,
|
| 296 |
+
backbone_config=backbone_config,
|
| 297 |
+
backbone_kwargs=backbone_kwargs,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.backbone_config = backbone_config
|
| 301 |
+
self.backbone = backbone
|
| 302 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 303 |
+
self.use_timm_backbone = use_timm_backbone
|
| 304 |
+
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
|
| 305 |
+
self.backbone_kwargs = backbone_kwargs
|
| 306 |
+
# encoder
|
| 307 |
+
self.encoder_hidden_dim = encoder_hidden_dim
|
| 308 |
+
self.encoder_in_channels = encoder_in_channels
|
| 309 |
+
self.feat_strides = feat_strides
|
| 310 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 311 |
+
self.dropout = dropout
|
| 312 |
+
self.activation_dropout = activation_dropout
|
| 313 |
+
self.encode_proj_layers = encode_proj_layers
|
| 314 |
+
self.encoder_layers = encoder_layers
|
| 315 |
+
self.positional_encoding_temperature = positional_encoding_temperature
|
| 316 |
+
self.eval_size = eval_size
|
| 317 |
+
self.normalize_before = normalize_before
|
| 318 |
+
self.encoder_activation_function = encoder_activation_function
|
| 319 |
+
self.activation_function = activation_function
|
| 320 |
+
self.hidden_expansion = hidden_expansion
|
| 321 |
+
self.num_queries = num_queries
|
| 322 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 323 |
+
self.decoder_in_channels = decoder_in_channels
|
| 324 |
+
self.num_feature_levels = num_feature_levels
|
| 325 |
+
self.decoder_n_points = decoder_n_points
|
| 326 |
+
self.decoder_layers = decoder_layers
|
| 327 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 328 |
+
self.decoder_activation_function = decoder_activation_function
|
| 329 |
+
self.attention_dropout = attention_dropout
|
| 330 |
+
self.num_denoising = num_denoising
|
| 331 |
+
self.label_noise_ratio = label_noise_ratio
|
| 332 |
+
self.box_noise_scale = box_noise_scale
|
| 333 |
+
self.learn_initial_query = learn_initial_query
|
| 334 |
+
self.anchor_image_size = anchor_image_size
|
| 335 |
+
self.auxiliary_loss = auxiliary_loss
|
| 336 |
+
self.with_box_refine = with_box_refine
|
| 337 |
+
# Loss
|
| 338 |
+
self.matcher_alpha = matcher_alpha
|
| 339 |
+
self.matcher_gamma = matcher_gamma
|
| 340 |
+
self.matcher_class_cost = matcher_class_cost
|
| 341 |
+
self.matcher_bbox_cost = matcher_bbox_cost
|
| 342 |
+
self.matcher_giou_cost = matcher_giou_cost
|
| 343 |
+
self.use_focal_loss = use_focal_loss
|
| 344 |
+
self.focal_loss_alpha = focal_loss_alpha
|
| 345 |
+
self.focal_loss_gamma = focal_loss_gamma
|
| 346 |
+
self.weight_loss_vfl = weight_loss_vfl
|
| 347 |
+
self.weight_loss_bbox = weight_loss_bbox
|
| 348 |
+
self.weight_loss_giou = weight_loss_giou
|
| 349 |
+
self.eos_coefficient = eos_coefficient
|
| 350 |
+
|
| 351 |
+
if not hasattr(self, "d_model"):
|
| 352 |
+
self.d_model = d_model
|
| 353 |
+
|
| 354 |
+
if not hasattr(self, "encoder_attention_heads"):
|
| 355 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 356 |
+
# add the new attributes with the given values or defaults
|
| 357 |
+
self.decoder_n_levels = decoder_n_levels
|
| 358 |
+
self.decoder_offset_scale = decoder_offset_scale
|
| 359 |
+
self.decoder_method = decoder_method
|
| 360 |
+
|
| 361 |
+
@property
|
| 362 |
+
def sub_configs(self):
|
| 363 |
+
return (
|
| 364 |
+
{"backbone_config": type(self.backbone_config)}
|
| 365 |
+
if getattr(self, "backbone_config", None) is not None
|
| 366 |
+
else {}
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
@classmethod
|
| 370 |
+
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 371 |
+
"""Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
| 372 |
+
configuration.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
backbone_config ([`PretrainedConfig`]):
|
| 376 |
+
The backbone configuration.
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
[`RTDetrV2Config`]: An instance of a configuration object
|
| 380 |
+
"""
|
| 381 |
+
return cls(
|
| 382 |
+
backbone_config=backbone_config,
|
| 383 |
+
**kwargs,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
__all__ = ["RTDetrV2Config"]
|
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py
ADDED
|
@@ -0,0 +1,1998 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_rt_detr_v2.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
import math
|
| 22 |
+
import warnings
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from functools import partial
|
| 25 |
+
from typing import Optional, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from torch import Tensor, nn
|
| 30 |
+
|
| 31 |
+
from ...activations import ACT2CLS, ACT2FN
|
| 32 |
+
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
| 33 |
+
from ...modeling_outputs import BaseModelOutput
|
| 34 |
+
from ...modeling_utils import PreTrainedModel
|
| 35 |
+
from ...pytorch_utils import compile_compatible_method_lru_cache
|
| 36 |
+
from ...utils import ModelOutput, auto_docstring, is_torchdynamo_compiling, torch_int
|
| 37 |
+
from ...utils.backbone_utils import load_backbone
|
| 38 |
+
from .configuration_rt_detr_v2 import RTDetrV2Config
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def multi_scale_deformable_attention_v2(
|
| 42 |
+
value: Tensor,
|
| 43 |
+
value_spatial_shapes: Tensor,
|
| 44 |
+
sampling_locations: Tensor,
|
| 45 |
+
attention_weights: Tensor,
|
| 46 |
+
num_points_list: list[int],
|
| 47 |
+
method="default",
|
| 48 |
+
) -> Tensor:
|
| 49 |
+
batch_size, _, num_heads, hidden_dim = value.shape
|
| 50 |
+
_, num_queries, num_heads, num_levels, num_points = sampling_locations.shape
|
| 51 |
+
value_list = (
|
| 52 |
+
value.permute(0, 2, 3, 1)
|
| 53 |
+
.flatten(0, 1)
|
| 54 |
+
.split([height * width for height, width in value_spatial_shapes], dim=-1)
|
| 55 |
+
)
|
| 56 |
+
# sampling_offsets [8, 480, 8, 12, 2]
|
| 57 |
+
if method == "default":
|
| 58 |
+
sampling_grids = 2 * sampling_locations - 1
|
| 59 |
+
elif method == "discrete":
|
| 60 |
+
sampling_grids = sampling_locations
|
| 61 |
+
sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 62 |
+
sampling_grids = sampling_grids.split(num_points_list, dim=-2)
|
| 63 |
+
sampling_value_list = []
|
| 64 |
+
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
| 65 |
+
# batch_size, height*width, num_heads, hidden_dim
|
| 66 |
+
# -> batch_size, height*width, num_heads*hidden_dim
|
| 67 |
+
# -> batch_size, num_heads*hidden_dim, height*width
|
| 68 |
+
# -> batch_size*num_heads, hidden_dim, height, width
|
| 69 |
+
value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width)
|
| 70 |
+
# batch_size, num_queries, num_heads, num_points, 2
|
| 71 |
+
# -> batch_size, num_heads, num_queries, num_points, 2
|
| 72 |
+
# -> batch_size*num_heads, num_queries, num_points, 2
|
| 73 |
+
sampling_grid_l_ = sampling_grids[level_id]
|
| 74 |
+
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
| 75 |
+
if method == "default":
|
| 76 |
+
sampling_value_l_ = nn.functional.grid_sample(
|
| 77 |
+
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
| 78 |
+
)
|
| 79 |
+
elif method == "discrete":
|
| 80 |
+
sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to(
|
| 81 |
+
torch.int64
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Separate clamping for x and y coordinates
|
| 85 |
+
sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1)
|
| 86 |
+
sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1)
|
| 87 |
+
|
| 88 |
+
# Combine the clamped coordinates
|
| 89 |
+
sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1)
|
| 90 |
+
sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2)
|
| 91 |
+
sampling_idx = (
|
| 92 |
+
torch.arange(sampling_coord.shape[0], device=value.device)
|
| 93 |
+
.unsqueeze(-1)
|
| 94 |
+
.repeat(1, sampling_coord.shape[1])
|
| 95 |
+
)
|
| 96 |
+
sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]]
|
| 97 |
+
sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape(
|
| 98 |
+
batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id]
|
| 99 |
+
)
|
| 100 |
+
sampling_value_list.append(sampling_value_l_)
|
| 101 |
+
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
| 102 |
+
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
| 103 |
+
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
| 104 |
+
attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(
|
| 105 |
+
batch_size * num_heads, 1, num_queries, sum(num_points_list)
|
| 106 |
+
)
|
| 107 |
+
output = (
|
| 108 |
+
(torch.concat(sampling_value_list, dim=-1) * attention_weights)
|
| 109 |
+
.sum(-1)
|
| 110 |
+
.view(batch_size, num_heads * hidden_dim, num_queries)
|
| 111 |
+
)
|
| 112 |
+
return output.transpose(1, 2).contiguous()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# the main change
|
| 116 |
+
class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
| 117 |
+
"""
|
| 118 |
+
RTDetrV2 version of multiscale deformable attention, extending the base implementation
|
| 119 |
+
with improved offset handling and initialization.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, config: RTDetrV2Config):
|
| 123 |
+
super().__init__()
|
| 124 |
+
num_heads = config.decoder_attention_heads
|
| 125 |
+
n_points = config.decoder_n_points
|
| 126 |
+
|
| 127 |
+
if config.d_model % num_heads != 0:
|
| 128 |
+
raise ValueError(
|
| 129 |
+
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
|
| 130 |
+
)
|
| 131 |
+
dim_per_head = config.d_model // num_heads
|
| 132 |
+
# check if dim_per_head is power of 2
|
| 133 |
+
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
|
| 134 |
+
warnings.warn(
|
| 135 |
+
"You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the"
|
| 136 |
+
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
|
| 137 |
+
" implementation."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.im2col_step = 64
|
| 141 |
+
|
| 142 |
+
self.d_model = config.d_model
|
| 143 |
+
|
| 144 |
+
# V2-specific attributes
|
| 145 |
+
self.n_levels = config.decoder_n_levels
|
| 146 |
+
self.n_heads = num_heads
|
| 147 |
+
self.n_points = n_points
|
| 148 |
+
|
| 149 |
+
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
|
| 150 |
+
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
|
| 151 |
+
self.value_proj = nn.Linear(config.d_model, config.d_model)
|
| 152 |
+
self.output_proj = nn.Linear(config.d_model, config.d_model)
|
| 153 |
+
|
| 154 |
+
self.offset_scale = config.decoder_offset_scale
|
| 155 |
+
self.method = config.decoder_method
|
| 156 |
+
|
| 157 |
+
# Initialize n_points list and scale
|
| 158 |
+
n_points_list = [self.n_points for _ in range(self.n_levels)]
|
| 159 |
+
self.n_points_list = n_points_list
|
| 160 |
+
n_points_scale = [1 / n for n in n_points_list for _ in range(n)]
|
| 161 |
+
self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32))
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self,
|
| 165 |
+
hidden_states: torch.Tensor,
|
| 166 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 167 |
+
encoder_hidden_states=None,
|
| 168 |
+
encoder_attention_mask=None,
|
| 169 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 170 |
+
reference_points=None,
|
| 171 |
+
spatial_shapes=None,
|
| 172 |
+
spatial_shapes_list=None,
|
| 173 |
+
level_start_index=None,
|
| 174 |
+
output_attentions: bool = False,
|
| 175 |
+
):
|
| 176 |
+
# Process inputs up to sampling locations calculation using parent class logic
|
| 177 |
+
if position_embeddings is not None:
|
| 178 |
+
hidden_states = hidden_states + position_embeddings
|
| 179 |
+
|
| 180 |
+
batch_size, num_queries, _ = hidden_states.shape
|
| 181 |
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
| 182 |
+
if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
| 183 |
+
raise ValueError(
|
| 184 |
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
value = self.value_proj(encoder_hidden_states)
|
| 188 |
+
if attention_mask is not None:
|
| 189 |
+
value = value.masked_fill(~attention_mask[..., None], float(0))
|
| 190 |
+
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
| 191 |
+
|
| 192 |
+
# V2-specific sampling offsets shape
|
| 193 |
+
sampling_offsets = self.sampling_offsets(hidden_states).view(
|
| 194 |
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
attention_weights = self.attention_weights(hidden_states).view(
|
| 198 |
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
|
| 199 |
+
)
|
| 200 |
+
attention_weights = F.softmax(attention_weights, -1)
|
| 201 |
+
|
| 202 |
+
# V2-specific sampling locations calculation
|
| 203 |
+
if reference_points.shape[-1] == 2:
|
| 204 |
+
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
| 205 |
+
sampling_locations = (
|
| 206 |
+
reference_points[:, :, None, :, None, :]
|
| 207 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
| 208 |
+
)
|
| 209 |
+
elif reference_points.shape[-1] == 4:
|
| 210 |
+
n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
|
| 211 |
+
offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
|
| 212 |
+
sampling_locations = reference_points[:, :, None, :, :2] + offset
|
| 213 |
+
else:
|
| 214 |
+
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
| 215 |
+
|
| 216 |
+
# V2-specific attention implementation choice
|
| 217 |
+
output = multi_scale_deformable_attention_v2(
|
| 218 |
+
value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
output = self.output_proj(output)
|
| 222 |
+
return output, attention_weights
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class RTDetrV2MultiheadAttention(nn.Module):
|
| 226 |
+
"""
|
| 227 |
+
Multi-headed attention from 'Attention Is All You Need' paper.
|
| 228 |
+
|
| 229 |
+
Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
embed_dim: int,
|
| 235 |
+
num_heads: int,
|
| 236 |
+
dropout: float = 0.0,
|
| 237 |
+
bias: bool = True,
|
| 238 |
+
):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.embed_dim = embed_dim
|
| 241 |
+
self.num_heads = num_heads
|
| 242 |
+
self.dropout = dropout
|
| 243 |
+
self.head_dim = embed_dim // num_heads
|
| 244 |
+
if self.head_dim * num_heads != self.embed_dim:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 247 |
+
f" {num_heads})."
|
| 248 |
+
)
|
| 249 |
+
self.scaling = self.head_dim**-0.5
|
| 250 |
+
|
| 251 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 252 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 253 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 254 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 255 |
+
|
| 256 |
+
def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
| 257 |
+
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 258 |
+
|
| 259 |
+
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
| 260 |
+
return tensor if position_embeddings is None else tensor + position_embeddings
|
| 261 |
+
|
| 262 |
+
def forward(
|
| 263 |
+
self,
|
| 264 |
+
hidden_states: torch.Tensor,
|
| 265 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 266 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 267 |
+
output_attentions: bool = False,
|
| 268 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 269 |
+
"""Input shape: Batch x Time x Channel"""
|
| 270 |
+
|
| 271 |
+
batch_size, target_len, embed_dim = hidden_states.size()
|
| 272 |
+
# add position embeddings to the hidden states before projecting to queries and keys
|
| 273 |
+
if position_embeddings is not None:
|
| 274 |
+
hidden_states_original = hidden_states
|
| 275 |
+
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
| 276 |
+
|
| 277 |
+
# get queries, keys and values
|
| 278 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 279 |
+
key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
|
| 280 |
+
value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
|
| 281 |
+
|
| 282 |
+
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
| 283 |
+
query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
|
| 284 |
+
key_states = key_states.view(*proj_shape)
|
| 285 |
+
value_states = value_states.view(*proj_shape)
|
| 286 |
+
|
| 287 |
+
source_len = key_states.size(1)
|
| 288 |
+
|
| 289 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 290 |
+
|
| 291 |
+
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
|
| 292 |
+
raise ValueError(
|
| 293 |
+
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
| 294 |
+
f" {attn_weights.size()}"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# expand attention_mask
|
| 298 |
+
if attention_mask is not None:
|
| 299 |
+
# [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 300 |
+
attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
|
| 301 |
+
|
| 302 |
+
if attention_mask is not None:
|
| 303 |
+
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
| 304 |
+
raise ValueError(
|
| 305 |
+
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
| 306 |
+
f" {attention_mask.size()}"
|
| 307 |
+
)
|
| 308 |
+
if attention_mask.dtype == torch.bool:
|
| 309 |
+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
| 310 |
+
attention_mask, -torch.inf
|
| 311 |
+
)
|
| 312 |
+
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
| 313 |
+
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
| 314 |
+
|
| 315 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 316 |
+
|
| 317 |
+
if output_attentions:
|
| 318 |
+
# this operation is a bit awkward, but it's required to
|
| 319 |
+
# make sure that attn_weights keeps its gradient.
|
| 320 |
+
# In order to do so, attn_weights have to reshaped
|
| 321 |
+
# twice and have to be reused in the following
|
| 322 |
+
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
| 323 |
+
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
| 324 |
+
else:
|
| 325 |
+
attn_weights_reshaped = None
|
| 326 |
+
|
| 327 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 328 |
+
|
| 329 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 330 |
+
|
| 331 |
+
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
|
| 332 |
+
raise ValueError(
|
| 333 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
| 334 |
+
f" {attn_output.size()}"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
|
| 338 |
+
attn_output = attn_output.transpose(1, 2)
|
| 339 |
+
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
|
| 340 |
+
|
| 341 |
+
attn_output = self.out_proj(attn_output)
|
| 342 |
+
|
| 343 |
+
return attn_output, attn_weights_reshaped
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class RTDetrV2DecoderLayer(nn.Module):
|
| 347 |
+
def __init__(self, config: RTDetrV2Config):
|
| 348 |
+
super().__init__()
|
| 349 |
+
# self-attention
|
| 350 |
+
self.self_attn = RTDetrV2MultiheadAttention(
|
| 351 |
+
embed_dim=config.d_model,
|
| 352 |
+
num_heads=config.decoder_attention_heads,
|
| 353 |
+
dropout=config.attention_dropout,
|
| 354 |
+
)
|
| 355 |
+
self.dropout = config.dropout
|
| 356 |
+
self.activation_fn = ACT2FN[config.decoder_activation_function]
|
| 357 |
+
self.activation_dropout = config.activation_dropout
|
| 358 |
+
|
| 359 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
| 360 |
+
# override only the encoder attention module with v2 version
|
| 361 |
+
self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config)
|
| 362 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
| 363 |
+
# feedforward neural networks
|
| 364 |
+
self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
|
| 365 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
|
| 366 |
+
self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
|
| 367 |
+
|
| 368 |
+
def forward(
|
| 369 |
+
self,
|
| 370 |
+
hidden_states: torch.Tensor,
|
| 371 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 372 |
+
reference_points=None,
|
| 373 |
+
spatial_shapes=None,
|
| 374 |
+
spatial_shapes_list=None,
|
| 375 |
+
level_start_index=None,
|
| 376 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 377 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 378 |
+
output_attentions: Optional[bool] = False,
|
| 379 |
+
):
|
| 380 |
+
"""
|
| 381 |
+
Args:
|
| 382 |
+
hidden_states (`torch.FloatTensor`):
|
| 383 |
+
Input to the layer of shape `(seq_len, batch, embed_dim)`.
|
| 384 |
+
position_embeddings (`torch.FloatTensor`, *optional*):
|
| 385 |
+
Position embeddings that are added to the queries and keys in the self-attention layer.
|
| 386 |
+
reference_points (`torch.FloatTensor`, *optional*):
|
| 387 |
+
Reference points.
|
| 388 |
+
spatial_shapes (`torch.LongTensor`, *optional*):
|
| 389 |
+
Spatial shapes.
|
| 390 |
+
level_start_index (`torch.LongTensor`, *optional*):
|
| 391 |
+
Level start index.
|
| 392 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 393 |
+
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 394 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
| 395 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 396 |
+
values.
|
| 397 |
+
output_attentions (`bool`, *optional*):
|
| 398 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 399 |
+
returned tensors for more detail.
|
| 400 |
+
"""
|
| 401 |
+
residual = hidden_states
|
| 402 |
+
|
| 403 |
+
# Self Attention
|
| 404 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 405 |
+
hidden_states=hidden_states,
|
| 406 |
+
attention_mask=encoder_attention_mask,
|
| 407 |
+
position_embeddings=position_embeddings,
|
| 408 |
+
output_attentions=output_attentions,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 412 |
+
hidden_states = residual + hidden_states
|
| 413 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 414 |
+
|
| 415 |
+
second_residual = hidden_states
|
| 416 |
+
|
| 417 |
+
# Cross-Attention
|
| 418 |
+
cross_attn_weights = None
|
| 419 |
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
| 420 |
+
hidden_states=hidden_states,
|
| 421 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 422 |
+
position_embeddings=position_embeddings,
|
| 423 |
+
reference_points=reference_points,
|
| 424 |
+
spatial_shapes=spatial_shapes,
|
| 425 |
+
spatial_shapes_list=spatial_shapes_list,
|
| 426 |
+
level_start_index=level_start_index,
|
| 427 |
+
output_attentions=output_attentions,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 431 |
+
hidden_states = second_residual + hidden_states
|
| 432 |
+
|
| 433 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 434 |
+
|
| 435 |
+
# Fully Connected
|
| 436 |
+
residual = hidden_states
|
| 437 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 438 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 439 |
+
hidden_states = self.fc2(hidden_states)
|
| 440 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 441 |
+
hidden_states = residual + hidden_states
|
| 442 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 443 |
+
|
| 444 |
+
outputs = (hidden_states,)
|
| 445 |
+
|
| 446 |
+
if output_attentions:
|
| 447 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
| 448 |
+
|
| 449 |
+
return outputs
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
@auto_docstring
|
| 453 |
+
class RTDetrV2PreTrainedModel(PreTrainedModel):
|
| 454 |
+
config: RTDetrV2Config
|
| 455 |
+
base_model_prefix = "rt_detr_v2"
|
| 456 |
+
main_input_name = "pixel_values"
|
| 457 |
+
_no_split_modules = [r"RTDetrV2HybridEncoder", r"RTDetrV2DecoderLayer"]
|
| 458 |
+
|
| 459 |
+
def _init_weights(self, module):
|
| 460 |
+
"""Initialize the weights"""
|
| 461 |
+
if isinstance(module, (RTDetrV2ForObjectDetection, RTDetrV2Decoder)):
|
| 462 |
+
if module.class_embed is not None:
|
| 463 |
+
for layer in module.class_embed:
|
| 464 |
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
| 465 |
+
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
| 466 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 467 |
+
nn.init.constant_(layer.bias, bias)
|
| 468 |
+
|
| 469 |
+
if module.bbox_embed is not None:
|
| 470 |
+
for layer in module.bbox_embed:
|
| 471 |
+
nn.init.constant_(layer.layers[-1].weight, 0)
|
| 472 |
+
nn.init.constant_(layer.layers[-1].bias, 0)
|
| 473 |
+
|
| 474 |
+
elif isinstance(module, RTDetrV2MultiscaleDeformableAttention):
|
| 475 |
+
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
| 476 |
+
default_dtype = torch.get_default_dtype()
|
| 477 |
+
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
| 478 |
+
2.0 * math.pi / module.n_heads
|
| 479 |
+
)
|
| 480 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
| 481 |
+
grid_init = (
|
| 482 |
+
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
| 483 |
+
.view(module.n_heads, 1, 1, 2)
|
| 484 |
+
.repeat(1, module.n_levels, module.n_points, 1)
|
| 485 |
+
)
|
| 486 |
+
for i in range(module.n_points):
|
| 487 |
+
grid_init[:, :, i, :] *= i + 1
|
| 488 |
+
with torch.no_grad():
|
| 489 |
+
module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
| 490 |
+
nn.init.constant_(module.attention_weights.weight.data, 0.0)
|
| 491 |
+
nn.init.constant_(module.attention_weights.bias.data, 0.0)
|
| 492 |
+
nn.init.xavier_uniform_(module.value_proj.weight.data)
|
| 493 |
+
nn.init.constant_(module.value_proj.bias.data, 0.0)
|
| 494 |
+
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
| 495 |
+
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
| 496 |
+
|
| 497 |
+
elif isinstance(module, RTDetrV2Model):
|
| 498 |
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
| 499 |
+
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
| 500 |
+
nn.init.xavier_uniform_(module.enc_score_head.weight)
|
| 501 |
+
nn.init.constant_(module.enc_score_head.bias, bias)
|
| 502 |
+
|
| 503 |
+
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
| 504 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 505 |
+
if module.bias is not None:
|
| 506 |
+
module.bias.data.zero_()
|
| 507 |
+
|
| 508 |
+
elif isinstance(module, nn.LayerNorm):
|
| 509 |
+
module.weight.data.fill_(1.0)
|
| 510 |
+
module.bias.data.zero_()
|
| 511 |
+
|
| 512 |
+
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
|
| 513 |
+
nn.init.xavier_uniform_(module.weight_embedding.weight)
|
| 514 |
+
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
|
| 515 |
+
nn.init.xavier_uniform_(module.denoising_class_embed.weight)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
@dataclass
|
| 519 |
+
@auto_docstring(
|
| 520 |
+
custom_intro="""
|
| 521 |
+
Base class for outputs of the RTDetrV2Decoder. This class adds two attributes to
|
| 522 |
+
BaseModelOutputWithCrossAttentions, namely:
|
| 523 |
+
- a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
|
| 524 |
+
- a stacked tensor of intermediate reference points.
|
| 525 |
+
"""
|
| 526 |
+
)
|
| 527 |
+
class RTDetrV2DecoderOutput(ModelOutput):
|
| 528 |
+
r"""
|
| 529 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
| 530 |
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
| 531 |
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
|
| 532 |
+
Stacked intermediate logits (logits of each layer of the decoder).
|
| 533 |
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
| 534 |
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
| 535 |
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 536 |
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
| 537 |
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 538 |
+
Stacked initial reference points (initial reference points of each layer of the decoder).
|
| 539 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
| 540 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 541 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 542 |
+
used to compute the weighted average in the cross-attention heads.
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 546 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 547 |
+
intermediate_logits: Optional[torch.FloatTensor] = None
|
| 548 |
+
intermediate_reference_points: Optional[torch.FloatTensor] = None
|
| 549 |
+
intermediate_predicted_corners: Optional[torch.FloatTensor] = None
|
| 550 |
+
initial_reference_points: Optional[torch.FloatTensor] = None
|
| 551 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 552 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 553 |
+
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def inverse_sigmoid(x, eps=1e-5):
|
| 557 |
+
x = x.clamp(min=0, max=1)
|
| 558 |
+
x1 = x.clamp(min=eps)
|
| 559 |
+
x2 = (1 - x).clamp(min=eps)
|
| 560 |
+
return torch.log(x1 / x2)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class RTDetrV2Decoder(RTDetrV2PreTrainedModel):
|
| 564 |
+
def __init__(self, config: RTDetrV2Config):
|
| 565 |
+
super().__init__(config)
|
| 566 |
+
|
| 567 |
+
self.dropout = config.dropout
|
| 568 |
+
self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)])
|
| 569 |
+
self.query_pos_head = RTDetrV2MLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
|
| 570 |
+
|
| 571 |
+
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
| 572 |
+
self.bbox_embed = None
|
| 573 |
+
self.class_embed = None
|
| 574 |
+
|
| 575 |
+
# Initialize weights and apply final processing
|
| 576 |
+
self.post_init()
|
| 577 |
+
|
| 578 |
+
def forward(
|
| 579 |
+
self,
|
| 580 |
+
inputs_embeds=None,
|
| 581 |
+
encoder_hidden_states=None,
|
| 582 |
+
encoder_attention_mask=None,
|
| 583 |
+
position_embeddings=None,
|
| 584 |
+
reference_points=None,
|
| 585 |
+
spatial_shapes=None,
|
| 586 |
+
spatial_shapes_list=None,
|
| 587 |
+
level_start_index=None,
|
| 588 |
+
valid_ratios=None,
|
| 589 |
+
output_attentions=None,
|
| 590 |
+
output_hidden_states=None,
|
| 591 |
+
return_dict=None,
|
| 592 |
+
):
|
| 593 |
+
r"""
|
| 594 |
+
Args:
|
| 595 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
| 596 |
+
The query embeddings that are passed into the decoder.
|
| 597 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 598 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 599 |
+
of the decoder.
|
| 600 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 601 |
+
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
|
| 602 |
+
in `[0, 1]`:
|
| 603 |
+
- 1 for pixels that are real (i.e. **not masked**),
|
| 604 |
+
- 0 for pixels that are padding (i.e. **masked**).
|
| 605 |
+
position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 606 |
+
Position embeddings that are added to the queries and keys in each self-attention layer.
|
| 607 |
+
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
|
| 608 |
+
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
|
| 609 |
+
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
|
| 610 |
+
Spatial shapes of the feature maps.
|
| 611 |
+
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
|
| 612 |
+
Indexes for the start of each feature level. In range `[0, sequence_length]`.
|
| 613 |
+
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
|
| 614 |
+
Ratio of valid area in each feature level.
|
| 615 |
+
|
| 616 |
+
output_attentions (`bool`, *optional*):
|
| 617 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 618 |
+
returned tensors for more detail.
|
| 619 |
+
output_hidden_states (`bool`, *optional*):
|
| 620 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 621 |
+
for more detail.
|
| 622 |
+
return_dict (`bool`, *optional*):
|
| 623 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 624 |
+
"""
|
| 625 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 626 |
+
output_hidden_states = (
|
| 627 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 628 |
+
)
|
| 629 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 630 |
+
|
| 631 |
+
if inputs_embeds is not None:
|
| 632 |
+
hidden_states = inputs_embeds
|
| 633 |
+
|
| 634 |
+
# decoder layers
|
| 635 |
+
all_hidden_states = () if output_hidden_states else None
|
| 636 |
+
all_self_attns = () if output_attentions else None
|
| 637 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 638 |
+
intermediate = ()
|
| 639 |
+
intermediate_reference_points = ()
|
| 640 |
+
intermediate_logits = ()
|
| 641 |
+
|
| 642 |
+
reference_points = F.sigmoid(reference_points)
|
| 643 |
+
|
| 644 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L252
|
| 645 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 646 |
+
reference_points_input = reference_points.unsqueeze(2)
|
| 647 |
+
position_embeddings = self.query_pos_head(reference_points)
|
| 648 |
+
|
| 649 |
+
if output_hidden_states:
|
| 650 |
+
all_hidden_states += (hidden_states,)
|
| 651 |
+
|
| 652 |
+
layer_outputs = decoder_layer(
|
| 653 |
+
hidden_states,
|
| 654 |
+
position_embeddings=position_embeddings,
|
| 655 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 656 |
+
reference_points=reference_points_input,
|
| 657 |
+
spatial_shapes=spatial_shapes,
|
| 658 |
+
spatial_shapes_list=spatial_shapes_list,
|
| 659 |
+
level_start_index=level_start_index,
|
| 660 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 661 |
+
output_attentions=output_attentions,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
hidden_states = layer_outputs[0]
|
| 665 |
+
|
| 666 |
+
# hack implementation for iterative bounding box refinement
|
| 667 |
+
if self.bbox_embed is not None:
|
| 668 |
+
predicted_corners = self.bbox_embed[idx](hidden_states)
|
| 669 |
+
new_reference_points = F.sigmoid(predicted_corners + inverse_sigmoid(reference_points))
|
| 670 |
+
reference_points = new_reference_points.detach()
|
| 671 |
+
|
| 672 |
+
intermediate += (hidden_states,)
|
| 673 |
+
intermediate_reference_points += (
|
| 674 |
+
(new_reference_points,) if self.bbox_embed is not None else (reference_points,)
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
if self.class_embed is not None:
|
| 678 |
+
logits = self.class_embed[idx](hidden_states)
|
| 679 |
+
intermediate_logits += (logits,)
|
| 680 |
+
|
| 681 |
+
if output_attentions:
|
| 682 |
+
all_self_attns += (layer_outputs[1],)
|
| 683 |
+
|
| 684 |
+
if encoder_hidden_states is not None:
|
| 685 |
+
all_cross_attentions += (layer_outputs[2],)
|
| 686 |
+
|
| 687 |
+
# Keep batch_size as first dimension
|
| 688 |
+
intermediate = torch.stack(intermediate, dim=1)
|
| 689 |
+
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
| 690 |
+
if self.class_embed is not None:
|
| 691 |
+
intermediate_logits = torch.stack(intermediate_logits, dim=1)
|
| 692 |
+
|
| 693 |
+
# add hidden states from the last decoder layer
|
| 694 |
+
if output_hidden_states:
|
| 695 |
+
all_hidden_states += (hidden_states,)
|
| 696 |
+
|
| 697 |
+
if not return_dict:
|
| 698 |
+
return tuple(
|
| 699 |
+
v
|
| 700 |
+
for v in [
|
| 701 |
+
hidden_states,
|
| 702 |
+
intermediate,
|
| 703 |
+
intermediate_logits,
|
| 704 |
+
intermediate_reference_points,
|
| 705 |
+
all_hidden_states,
|
| 706 |
+
all_self_attns,
|
| 707 |
+
all_cross_attentions,
|
| 708 |
+
]
|
| 709 |
+
if v is not None
|
| 710 |
+
)
|
| 711 |
+
return RTDetrV2DecoderOutput(
|
| 712 |
+
last_hidden_state=hidden_states,
|
| 713 |
+
intermediate_hidden_states=intermediate,
|
| 714 |
+
intermediate_logits=intermediate_logits,
|
| 715 |
+
intermediate_reference_points=intermediate_reference_points,
|
| 716 |
+
hidden_states=all_hidden_states,
|
| 717 |
+
attentions=all_self_attns,
|
| 718 |
+
cross_attentions=all_cross_attentions,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
@dataclass
|
| 723 |
+
@auto_docstring(
|
| 724 |
+
custom_intro="""
|
| 725 |
+
Base class for outputs of the RT-DETR encoder-decoder model.
|
| 726 |
+
"""
|
| 727 |
+
)
|
| 728 |
+
class RTDetrV2ModelOutput(ModelOutput):
|
| 729 |
+
r"""
|
| 730 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
| 731 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 732 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
| 733 |
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
| 734 |
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
|
| 735 |
+
Stacked intermediate logits (logits of each layer of the decoder).
|
| 736 |
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 737 |
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
| 738 |
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 739 |
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
| 740 |
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 741 |
+
Initial reference points used for the first decoder layer.
|
| 742 |
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 743 |
+
Initial reference points sent through the Transformer decoder.
|
| 744 |
+
enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
|
| 745 |
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
| 746 |
+
picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e.
|
| 747 |
+
foreground and background).
|
| 748 |
+
enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`):
|
| 749 |
+
Logits of predicted bounding boxes coordinates in the encoder stage.
|
| 750 |
+
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 751 |
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
| 752 |
+
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
|
| 753 |
+
foreground and background).
|
| 754 |
+
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 755 |
+
Logits of predicted bounding boxes coordinates in the first stage.
|
| 756 |
+
denoising_meta_values (`dict`):
|
| 757 |
+
Extra dictionary for the denoising related values.
|
| 758 |
+
"""
|
| 759 |
+
|
| 760 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 761 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 762 |
+
intermediate_logits: Optional[torch.FloatTensor] = None
|
| 763 |
+
intermediate_reference_points: Optional[torch.FloatTensor] = None
|
| 764 |
+
intermediate_predicted_corners: Optional[torch.FloatTensor] = None
|
| 765 |
+
initial_reference_points: Optional[torch.FloatTensor] = None
|
| 766 |
+
decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 767 |
+
decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 768 |
+
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 769 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 770 |
+
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 771 |
+
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 772 |
+
init_reference_points: Optional[torch.FloatTensor] = None
|
| 773 |
+
enc_topk_logits: Optional[torch.FloatTensor] = None
|
| 774 |
+
enc_topk_bboxes: Optional[torch.FloatTensor] = None
|
| 775 |
+
enc_outputs_class: Optional[torch.FloatTensor] = None
|
| 776 |
+
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
| 777 |
+
denoising_meta_values: Optional[dict] = None
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
class RTDetrV2FrozenBatchNorm2d(nn.Module):
|
| 781 |
+
"""
|
| 782 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 783 |
+
|
| 784 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
| 785 |
+
torchvision.models.resnet[18,34,50,101] produce nans.
|
| 786 |
+
"""
|
| 787 |
+
|
| 788 |
+
def __init__(self, n):
|
| 789 |
+
super().__init__()
|
| 790 |
+
self.register_buffer("weight", torch.ones(n))
|
| 791 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 792 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 793 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 794 |
+
|
| 795 |
+
def _load_from_state_dict(
|
| 796 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 797 |
+
):
|
| 798 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 799 |
+
if num_batches_tracked_key in state_dict:
|
| 800 |
+
del state_dict[num_batches_tracked_key]
|
| 801 |
+
|
| 802 |
+
super()._load_from_state_dict(
|
| 803 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
def forward(self, x):
|
| 807 |
+
# move reshapes to the beginning
|
| 808 |
+
# to make it user-friendly
|
| 809 |
+
weight = self.weight.reshape(1, -1, 1, 1)
|
| 810 |
+
bias = self.bias.reshape(1, -1, 1, 1)
|
| 811 |
+
running_var = self.running_var.reshape(1, -1, 1, 1)
|
| 812 |
+
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
| 813 |
+
epsilon = 1e-5
|
| 814 |
+
scale = weight * (running_var + epsilon).rsqrt()
|
| 815 |
+
bias = bias - running_mean * scale
|
| 816 |
+
return x * scale + bias
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def replace_batch_norm(model):
|
| 820 |
+
r"""
|
| 821 |
+
Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrV2FrozenBatchNorm2d`.
|
| 822 |
+
|
| 823 |
+
Args:
|
| 824 |
+
model (torch.nn.Module):
|
| 825 |
+
input model
|
| 826 |
+
"""
|
| 827 |
+
for name, module in model.named_children():
|
| 828 |
+
if isinstance(module, nn.BatchNorm2d):
|
| 829 |
+
new_module = RTDetrV2FrozenBatchNorm2d(module.num_features)
|
| 830 |
+
|
| 831 |
+
if module.weight.device != torch.device("meta"):
|
| 832 |
+
new_module.weight.data.copy_(module.weight)
|
| 833 |
+
new_module.bias.data.copy_(module.bias)
|
| 834 |
+
new_module.running_mean.data.copy_(module.running_mean)
|
| 835 |
+
new_module.running_var.data.copy_(module.running_var)
|
| 836 |
+
|
| 837 |
+
model._modules[name] = new_module
|
| 838 |
+
|
| 839 |
+
if len(list(module.children())) > 0:
|
| 840 |
+
replace_batch_norm(module)
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
class RTDetrV2ConvEncoder(nn.Module):
|
| 844 |
+
"""
|
| 845 |
+
Convolutional backbone using the modeling_rt_detr_v2_resnet.py.
|
| 846 |
+
|
| 847 |
+
nn.BatchNorm2d layers are replaced by RTDetrV2FrozenBatchNorm2d as defined above.
|
| 848 |
+
https://github.com/lyuwenyu/RT-DETR/blob/main/RTDetrV2_pytorch/src/nn/backbone/presnet.py#L142
|
| 849 |
+
"""
|
| 850 |
+
|
| 851 |
+
def __init__(self, config):
|
| 852 |
+
super().__init__()
|
| 853 |
+
|
| 854 |
+
backbone = load_backbone(config)
|
| 855 |
+
|
| 856 |
+
if config.freeze_backbone_batch_norms:
|
| 857 |
+
# replace batch norm by frozen batch norm
|
| 858 |
+
with torch.no_grad():
|
| 859 |
+
replace_batch_norm(backbone)
|
| 860 |
+
self.model = backbone
|
| 861 |
+
self.intermediate_channel_sizes = self.model.channels
|
| 862 |
+
|
| 863 |
+
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
| 864 |
+
# send pixel_values through the model to get list of feature maps
|
| 865 |
+
features = self.model(pixel_values).feature_maps
|
| 866 |
+
|
| 867 |
+
out = []
|
| 868 |
+
for feature_map in features:
|
| 869 |
+
# downsample pixel_mask to match shape of corresponding feature_map
|
| 870 |
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
| 871 |
+
out.append((feature_map, mask))
|
| 872 |
+
return out
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
class RTDetrV2ConvNormLayer(nn.Module):
|
| 876 |
+
def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
|
| 877 |
+
super().__init__()
|
| 878 |
+
self.conv = nn.Conv2d(
|
| 879 |
+
in_channels,
|
| 880 |
+
out_channels,
|
| 881 |
+
kernel_size,
|
| 882 |
+
stride,
|
| 883 |
+
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
| 884 |
+
bias=False,
|
| 885 |
+
)
|
| 886 |
+
self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
|
| 887 |
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
| 888 |
+
|
| 889 |
+
def forward(self, hidden_state):
|
| 890 |
+
hidden_state = self.conv(hidden_state)
|
| 891 |
+
hidden_state = self.norm(hidden_state)
|
| 892 |
+
hidden_state = self.activation(hidden_state)
|
| 893 |
+
return hidden_state
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
class RTDetrV2EncoderLayer(nn.Module):
|
| 897 |
+
def __init__(self, config: RTDetrV2Config):
|
| 898 |
+
super().__init__()
|
| 899 |
+
self.normalize_before = config.normalize_before
|
| 900 |
+
|
| 901 |
+
# self-attention
|
| 902 |
+
self.self_attn = RTDetrV2MultiheadAttention(
|
| 903 |
+
embed_dim=config.encoder_hidden_dim,
|
| 904 |
+
num_heads=config.num_attention_heads,
|
| 905 |
+
dropout=config.dropout,
|
| 906 |
+
)
|
| 907 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
| 908 |
+
self.dropout = config.dropout
|
| 909 |
+
self.activation_fn = ACT2FN[config.encoder_activation_function]
|
| 910 |
+
self.activation_dropout = config.activation_dropout
|
| 911 |
+
self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
|
| 912 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
|
| 913 |
+
self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
| 914 |
+
|
| 915 |
+
def forward(
|
| 916 |
+
self,
|
| 917 |
+
hidden_states: torch.Tensor,
|
| 918 |
+
attention_mask: torch.Tensor,
|
| 919 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 920 |
+
output_attentions: bool = False,
|
| 921 |
+
**kwargs,
|
| 922 |
+
):
|
| 923 |
+
"""
|
| 924 |
+
Args:
|
| 925 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 926 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 927 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 928 |
+
values.
|
| 929 |
+
position_embeddings (`torch.FloatTensor`, *optional*):
|
| 930 |
+
Object queries (also called content embeddings), to be added to the hidden states.
|
| 931 |
+
output_attentions (`bool`, *optional*):
|
| 932 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 933 |
+
returned tensors for more detail.
|
| 934 |
+
"""
|
| 935 |
+
residual = hidden_states
|
| 936 |
+
if self.normalize_before:
|
| 937 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 938 |
+
|
| 939 |
+
hidden_states, attn_weights = self.self_attn(
|
| 940 |
+
hidden_states=hidden_states,
|
| 941 |
+
attention_mask=attention_mask,
|
| 942 |
+
position_embeddings=position_embeddings,
|
| 943 |
+
output_attentions=output_attentions,
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 947 |
+
hidden_states = residual + hidden_states
|
| 948 |
+
if not self.normalize_before:
|
| 949 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 950 |
+
|
| 951 |
+
if self.normalize_before:
|
| 952 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 953 |
+
residual = hidden_states
|
| 954 |
+
|
| 955 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 956 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 957 |
+
|
| 958 |
+
hidden_states = self.fc2(hidden_states)
|
| 959 |
+
|
| 960 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 961 |
+
|
| 962 |
+
hidden_states = residual + hidden_states
|
| 963 |
+
if not self.normalize_before:
|
| 964 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 965 |
+
|
| 966 |
+
if self.training:
|
| 967 |
+
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
| 968 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 969 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 970 |
+
|
| 971 |
+
outputs = (hidden_states,)
|
| 972 |
+
|
| 973 |
+
if output_attentions:
|
| 974 |
+
outputs += (attn_weights,)
|
| 975 |
+
|
| 976 |
+
return outputs
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
class RTDetrV2RepVggBlock(nn.Module):
|
| 980 |
+
"""
|
| 981 |
+
RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
|
| 982 |
+
"""
|
| 983 |
+
|
| 984 |
+
def __init__(self, config: RTDetrV2Config):
|
| 985 |
+
super().__init__()
|
| 986 |
+
|
| 987 |
+
activation = config.activation_function
|
| 988 |
+
hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
|
| 989 |
+
self.conv1 = RTDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
|
| 990 |
+
self.conv2 = RTDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
|
| 991 |
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
| 992 |
+
|
| 993 |
+
def forward(self, x):
|
| 994 |
+
y = self.conv1(x) + self.conv2(x)
|
| 995 |
+
return self.activation(y)
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
class RTDetrV2CSPRepLayer(nn.Module):
|
| 999 |
+
"""
|
| 1000 |
+
Cross Stage Partial (CSP) network layer with RepVGG blocks.
|
| 1001 |
+
"""
|
| 1002 |
+
|
| 1003 |
+
def __init__(self, config: RTDetrV2Config):
|
| 1004 |
+
super().__init__()
|
| 1005 |
+
|
| 1006 |
+
in_channels = config.encoder_hidden_dim * 2
|
| 1007 |
+
out_channels = config.encoder_hidden_dim
|
| 1008 |
+
num_blocks = 3
|
| 1009 |
+
activation = config.activation_function
|
| 1010 |
+
|
| 1011 |
+
hidden_channels = int(out_channels * config.hidden_expansion)
|
| 1012 |
+
self.conv1 = RTDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
| 1013 |
+
self.conv2 = RTDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
| 1014 |
+
self.bottlenecks = nn.Sequential(*[RTDetrV2RepVggBlock(config) for _ in range(num_blocks)])
|
| 1015 |
+
if hidden_channels != out_channels:
|
| 1016 |
+
self.conv3 = RTDetrV2ConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
|
| 1017 |
+
else:
|
| 1018 |
+
self.conv3 = nn.Identity()
|
| 1019 |
+
|
| 1020 |
+
def forward(self, hidden_state):
|
| 1021 |
+
hidden_state_1 = self.conv1(hidden_state)
|
| 1022 |
+
hidden_state_1 = self.bottlenecks(hidden_state_1)
|
| 1023 |
+
hidden_state_2 = self.conv2(hidden_state)
|
| 1024 |
+
return self.conv3(hidden_state_1 + hidden_state_2)
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
class RTDetrV2Encoder(nn.Module):
|
| 1028 |
+
def __init__(self, config: RTDetrV2Config):
|
| 1029 |
+
super().__init__()
|
| 1030 |
+
|
| 1031 |
+
self.layers = nn.ModuleList([RTDetrV2EncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 1032 |
+
|
| 1033 |
+
def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
|
| 1034 |
+
hidden_states = src
|
| 1035 |
+
for layer in self.layers:
|
| 1036 |
+
hidden_states = layer(
|
| 1037 |
+
hidden_states,
|
| 1038 |
+
attention_mask=src_mask,
|
| 1039 |
+
position_embeddings=pos_embed,
|
| 1040 |
+
output_attentions=output_attentions,
|
| 1041 |
+
)
|
| 1042 |
+
return hidden_states
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
class RTDetrV2HybridEncoder(nn.Module):
|
| 1046 |
+
"""
|
| 1047 |
+
Decoder consisting of a projection layer, a set of `RTDetrV2Encoder`, a top-down Feature Pyramid Network
|
| 1048 |
+
(FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
|
| 1049 |
+
|
| 1050 |
+
Args:
|
| 1051 |
+
config: RTDetrV2Config
|
| 1052 |
+
"""
|
| 1053 |
+
|
| 1054 |
+
def __init__(self, config: RTDetrV2Config):
|
| 1055 |
+
super().__init__()
|
| 1056 |
+
self.config = config
|
| 1057 |
+
self.in_channels = config.encoder_in_channels
|
| 1058 |
+
self.feat_strides = config.feat_strides
|
| 1059 |
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
| 1060 |
+
self.encode_proj_layers = config.encode_proj_layers
|
| 1061 |
+
self.positional_encoding_temperature = config.positional_encoding_temperature
|
| 1062 |
+
self.eval_size = config.eval_size
|
| 1063 |
+
self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
|
| 1064 |
+
self.out_strides = self.feat_strides
|
| 1065 |
+
self.num_fpn_stages = len(self.in_channels) - 1
|
| 1066 |
+
self.num_pan_stages = len(self.in_channels) - 1
|
| 1067 |
+
activation = config.activation_function
|
| 1068 |
+
|
| 1069 |
+
# encoder transformer
|
| 1070 |
+
self.encoder = nn.ModuleList([RTDetrV2Encoder(config) for _ in range(len(self.encode_proj_layers))])
|
| 1071 |
+
|
| 1072 |
+
# top-down FPN
|
| 1073 |
+
self.lateral_convs = nn.ModuleList()
|
| 1074 |
+
self.fpn_blocks = nn.ModuleList()
|
| 1075 |
+
for _ in range(self.num_fpn_stages):
|
| 1076 |
+
lateral_conv = RTDetrV2ConvNormLayer(
|
| 1077 |
+
config,
|
| 1078 |
+
in_channels=self.encoder_hidden_dim,
|
| 1079 |
+
out_channels=self.encoder_hidden_dim,
|
| 1080 |
+
kernel_size=1,
|
| 1081 |
+
stride=1,
|
| 1082 |
+
activation=activation,
|
| 1083 |
+
)
|
| 1084 |
+
fpn_block = RTDetrV2CSPRepLayer(config)
|
| 1085 |
+
self.lateral_convs.append(lateral_conv)
|
| 1086 |
+
self.fpn_blocks.append(fpn_block)
|
| 1087 |
+
|
| 1088 |
+
# bottom-up PAN
|
| 1089 |
+
self.downsample_convs = nn.ModuleList()
|
| 1090 |
+
self.pan_blocks = nn.ModuleList()
|
| 1091 |
+
for _ in range(self.num_pan_stages):
|
| 1092 |
+
downsample_conv = RTDetrV2ConvNormLayer(
|
| 1093 |
+
config,
|
| 1094 |
+
in_channels=self.encoder_hidden_dim,
|
| 1095 |
+
out_channels=self.encoder_hidden_dim,
|
| 1096 |
+
kernel_size=3,
|
| 1097 |
+
stride=2,
|
| 1098 |
+
activation=activation,
|
| 1099 |
+
)
|
| 1100 |
+
pan_block = RTDetrV2CSPRepLayer(config)
|
| 1101 |
+
self.downsample_convs.append(downsample_conv)
|
| 1102 |
+
self.pan_blocks.append(pan_block)
|
| 1103 |
+
|
| 1104 |
+
@staticmethod
|
| 1105 |
+
def build_2d_sincos_position_embedding(
|
| 1106 |
+
width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
|
| 1107 |
+
):
|
| 1108 |
+
grid_w = torch.arange(torch_int(width), device=device).to(dtype)
|
| 1109 |
+
grid_h = torch.arange(torch_int(height), device=device).to(dtype)
|
| 1110 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
|
| 1111 |
+
if embed_dim % 4 != 0:
|
| 1112 |
+
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
| 1113 |
+
pos_dim = embed_dim // 4
|
| 1114 |
+
omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
|
| 1115 |
+
omega = 1.0 / (temperature**omega)
|
| 1116 |
+
|
| 1117 |
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
| 1118 |
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
| 1119 |
+
|
| 1120 |
+
return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
|
| 1121 |
+
|
| 1122 |
+
def forward(
|
| 1123 |
+
self,
|
| 1124 |
+
inputs_embeds=None,
|
| 1125 |
+
attention_mask=None,
|
| 1126 |
+
position_embeddings=None,
|
| 1127 |
+
spatial_shapes=None,
|
| 1128 |
+
level_start_index=None,
|
| 1129 |
+
valid_ratios=None,
|
| 1130 |
+
output_attentions=None,
|
| 1131 |
+
output_hidden_states=None,
|
| 1132 |
+
return_dict=None,
|
| 1133 |
+
):
|
| 1134 |
+
r"""
|
| 1135 |
+
Args:
|
| 1136 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 1137 |
+
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
| 1138 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1139 |
+
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
| 1140 |
+
- 1 for pixel features that are real (i.e. **not masked**),
|
| 1141 |
+
- 0 for pixel features that are padding (i.e. **masked**).
|
| 1142 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1143 |
+
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 1144 |
+
Position embeddings that are added to the queries and keys in each self-attention layer.
|
| 1145 |
+
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
| 1146 |
+
Spatial shapes of each feature map.
|
| 1147 |
+
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
| 1148 |
+
Starting index of each feature map.
|
| 1149 |
+
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
| 1150 |
+
Ratio of valid area in each feature level.
|
| 1151 |
+
output_attentions (`bool`, *optional*):
|
| 1152 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1153 |
+
returned tensors for more detail.
|
| 1154 |
+
output_hidden_states (`bool`, *optional*):
|
| 1155 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1156 |
+
for more detail.
|
| 1157 |
+
return_dict (`bool`, *optional*):
|
| 1158 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 1159 |
+
"""
|
| 1160 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1161 |
+
output_hidden_states = (
|
| 1162 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1163 |
+
)
|
| 1164 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1165 |
+
|
| 1166 |
+
hidden_states = inputs_embeds
|
| 1167 |
+
|
| 1168 |
+
encoder_states = () if output_hidden_states else None
|
| 1169 |
+
all_attentions = () if output_attentions else None
|
| 1170 |
+
|
| 1171 |
+
# encoder
|
| 1172 |
+
if self.config.encoder_layers > 0:
|
| 1173 |
+
for i, enc_ind in enumerate(self.encode_proj_layers):
|
| 1174 |
+
if output_hidden_states:
|
| 1175 |
+
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
| 1176 |
+
height, width = hidden_states[enc_ind].shape[2:]
|
| 1177 |
+
# flatten [batch, channel, height, width] to [batch, height*width, channel]
|
| 1178 |
+
src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
|
| 1179 |
+
if self.training or self.eval_size is None:
|
| 1180 |
+
pos_embed = self.build_2d_sincos_position_embedding(
|
| 1181 |
+
width,
|
| 1182 |
+
height,
|
| 1183 |
+
self.encoder_hidden_dim,
|
| 1184 |
+
self.positional_encoding_temperature,
|
| 1185 |
+
device=src_flatten.device,
|
| 1186 |
+
dtype=src_flatten.dtype,
|
| 1187 |
+
)
|
| 1188 |
+
else:
|
| 1189 |
+
pos_embed = None
|
| 1190 |
+
|
| 1191 |
+
layer_outputs = self.encoder[i](
|
| 1192 |
+
src_flatten,
|
| 1193 |
+
pos_embed=pos_embed,
|
| 1194 |
+
output_attentions=output_attentions,
|
| 1195 |
+
)
|
| 1196 |
+
hidden_states[enc_ind] = (
|
| 1197 |
+
layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
|
| 1198 |
+
)
|
| 1199 |
+
|
| 1200 |
+
if output_attentions:
|
| 1201 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 1202 |
+
|
| 1203 |
+
if output_hidden_states:
|
| 1204 |
+
encoder_states = encoder_states + (hidden_states[enc_ind],)
|
| 1205 |
+
|
| 1206 |
+
# top-down FPN
|
| 1207 |
+
fpn_feature_maps = [hidden_states[-1]]
|
| 1208 |
+
for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
|
| 1209 |
+
backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
|
| 1210 |
+
top_fpn_feature_map = fpn_feature_maps[-1]
|
| 1211 |
+
# apply lateral block
|
| 1212 |
+
top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
|
| 1213 |
+
fpn_feature_maps[-1] = top_fpn_feature_map
|
| 1214 |
+
# apply fpn block
|
| 1215 |
+
top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
|
| 1216 |
+
fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
|
| 1217 |
+
new_fpn_feature_map = fpn_block(fused_feature_map)
|
| 1218 |
+
fpn_feature_maps.append(new_fpn_feature_map)
|
| 1219 |
+
|
| 1220 |
+
fpn_feature_maps = fpn_feature_maps[::-1]
|
| 1221 |
+
|
| 1222 |
+
# bottom-up PAN
|
| 1223 |
+
pan_feature_maps = [fpn_feature_maps[0]]
|
| 1224 |
+
for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
|
| 1225 |
+
top_pan_feature_map = pan_feature_maps[-1]
|
| 1226 |
+
fpn_feature_map = fpn_feature_maps[idx + 1]
|
| 1227 |
+
downsampled_feature_map = downsample_conv(top_pan_feature_map)
|
| 1228 |
+
fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
|
| 1229 |
+
new_pan_feature_map = pan_block(fused_feature_map)
|
| 1230 |
+
pan_feature_maps.append(new_pan_feature_map)
|
| 1231 |
+
|
| 1232 |
+
if not return_dict:
|
| 1233 |
+
return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
|
| 1234 |
+
return BaseModelOutput(
|
| 1235 |
+
last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
|
| 1236 |
+
)
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
def get_contrastive_denoising_training_group(
|
| 1240 |
+
targets,
|
| 1241 |
+
num_classes,
|
| 1242 |
+
num_queries,
|
| 1243 |
+
class_embed,
|
| 1244 |
+
num_denoising_queries=100,
|
| 1245 |
+
label_noise_ratio=0.5,
|
| 1246 |
+
box_noise_scale=1.0,
|
| 1247 |
+
):
|
| 1248 |
+
"""
|
| 1249 |
+
Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
|
| 1250 |
+
|
| 1251 |
+
Args:
|
| 1252 |
+
targets (`list[dict]`):
|
| 1253 |
+
The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
|
| 1254 |
+
num_classes (`int`):
|
| 1255 |
+
Total number of classes in the dataset.
|
| 1256 |
+
num_queries (`int`):
|
| 1257 |
+
Number of query slots in the transformer.
|
| 1258 |
+
class_embed (`callable`):
|
| 1259 |
+
A function or a model layer to embed class labels.
|
| 1260 |
+
num_denoising_queries (`int`, *optional*, defaults to 100):
|
| 1261 |
+
Number of denoising queries.
|
| 1262 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 1263 |
+
Ratio of noise applied to labels.
|
| 1264 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 1265 |
+
Scale of noise applied to bounding boxes.
|
| 1266 |
+
Returns:
|
| 1267 |
+
`tuple` comprising various elements:
|
| 1268 |
+
- **input_query_class** (`torch.FloatTensor`) --
|
| 1269 |
+
Class queries with applied label noise.
|
| 1270 |
+
- **input_query_bbox** (`torch.FloatTensor`) --
|
| 1271 |
+
Bounding box queries with applied box noise.
|
| 1272 |
+
- **attn_mask** (`torch.FloatTensor`) --
|
| 1273 |
+
Attention mask for separating denoising and reconstruction queries.
|
| 1274 |
+
- **denoising_meta_values** (`dict`) --
|
| 1275 |
+
Metadata including denoising positive indices, number of groups, and split sizes.
|
| 1276 |
+
"""
|
| 1277 |
+
|
| 1278 |
+
if num_denoising_queries <= 0:
|
| 1279 |
+
return None, None, None, None
|
| 1280 |
+
|
| 1281 |
+
num_ground_truths = [len(t["class_labels"]) for t in targets]
|
| 1282 |
+
device = targets[0]["class_labels"].device
|
| 1283 |
+
|
| 1284 |
+
max_gt_num = max(num_ground_truths)
|
| 1285 |
+
if max_gt_num == 0:
|
| 1286 |
+
return None, None, None, None
|
| 1287 |
+
|
| 1288 |
+
num_groups_denoising_queries = num_denoising_queries // max_gt_num
|
| 1289 |
+
num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
|
| 1290 |
+
# pad gt to max_num of a batch
|
| 1291 |
+
batch_size = len(num_ground_truths)
|
| 1292 |
+
|
| 1293 |
+
input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
|
| 1294 |
+
input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
|
| 1295 |
+
pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
|
| 1296 |
+
|
| 1297 |
+
for i in range(batch_size):
|
| 1298 |
+
num_gt = num_ground_truths[i]
|
| 1299 |
+
if num_gt > 0:
|
| 1300 |
+
input_query_class[i, :num_gt] = targets[i]["class_labels"]
|
| 1301 |
+
input_query_bbox[i, :num_gt] = targets[i]["boxes"]
|
| 1302 |
+
pad_gt_mask[i, :num_gt] = 1
|
| 1303 |
+
# each group has positive and negative queries.
|
| 1304 |
+
input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
|
| 1305 |
+
input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
|
| 1306 |
+
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
|
| 1307 |
+
# positive and negative mask
|
| 1308 |
+
negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
|
| 1309 |
+
negative_gt_mask[:, max_gt_num:] = 1
|
| 1310 |
+
negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
|
| 1311 |
+
positive_gt_mask = 1 - negative_gt_mask
|
| 1312 |
+
# contrastive denoising training positive index
|
| 1313 |
+
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
|
| 1314 |
+
denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
|
| 1315 |
+
denoise_positive_idx = torch.split(
|
| 1316 |
+
denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
|
| 1317 |
+
)
|
| 1318 |
+
# total denoising queries
|
| 1319 |
+
num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
|
| 1320 |
+
|
| 1321 |
+
if label_noise_ratio > 0:
|
| 1322 |
+
mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
|
| 1323 |
+
# randomly put a new one here
|
| 1324 |
+
new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
|
| 1325 |
+
input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
|
| 1326 |
+
|
| 1327 |
+
if box_noise_scale > 0:
|
| 1328 |
+
known_bbox = center_to_corners_format(input_query_bbox)
|
| 1329 |
+
diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
|
| 1330 |
+
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
|
| 1331 |
+
rand_part = torch.rand_like(input_query_bbox)
|
| 1332 |
+
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
|
| 1333 |
+
rand_part *= rand_sign
|
| 1334 |
+
known_bbox += rand_part * diff
|
| 1335 |
+
known_bbox.clip_(min=0.0, max=1.0)
|
| 1336 |
+
input_query_bbox = corners_to_center_format(known_bbox)
|
| 1337 |
+
input_query_bbox = inverse_sigmoid(input_query_bbox)
|
| 1338 |
+
|
| 1339 |
+
input_query_class = class_embed(input_query_class)
|
| 1340 |
+
|
| 1341 |
+
target_size = num_denoising_queries + num_queries
|
| 1342 |
+
attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
|
| 1343 |
+
# match query cannot see the reconstruction
|
| 1344 |
+
attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
|
| 1345 |
+
|
| 1346 |
+
# reconstructions cannot see each other
|
| 1347 |
+
for i in range(num_groups_denoising_queries):
|
| 1348 |
+
idx_block_start = max_gt_num * 2 * i
|
| 1349 |
+
idx_block_end = max_gt_num * 2 * (i + 1)
|
| 1350 |
+
attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
|
| 1351 |
+
attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
|
| 1352 |
+
|
| 1353 |
+
denoising_meta_values = {
|
| 1354 |
+
"dn_positive_idx": denoise_positive_idx,
|
| 1355 |
+
"dn_num_group": num_groups_denoising_queries,
|
| 1356 |
+
"dn_num_split": [num_denoising_queries, num_queries],
|
| 1357 |
+
}
|
| 1358 |
+
|
| 1359 |
+
return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
|
| 1360 |
+
|
| 1361 |
+
|
| 1362 |
+
@auto_docstring(
|
| 1363 |
+
custom_intro="""
|
| 1364 |
+
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
|
| 1365 |
+
"""
|
| 1366 |
+
)
|
| 1367 |
+
class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
| 1368 |
+
def __init__(self, config: RTDetrV2Config):
|
| 1369 |
+
super().__init__(config)
|
| 1370 |
+
|
| 1371 |
+
# Create backbone
|
| 1372 |
+
self.backbone = RTDetrV2ConvEncoder(config)
|
| 1373 |
+
intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
|
| 1374 |
+
|
| 1375 |
+
# Create encoder input projection layers
|
| 1376 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/hybrid_encoder.py#L212
|
| 1377 |
+
num_backbone_outs = len(intermediate_channel_sizes)
|
| 1378 |
+
encoder_input_proj_list = []
|
| 1379 |
+
for _ in range(num_backbone_outs):
|
| 1380 |
+
in_channels = intermediate_channel_sizes[_]
|
| 1381 |
+
encoder_input_proj_list.append(
|
| 1382 |
+
nn.Sequential(
|
| 1383 |
+
nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
|
| 1384 |
+
nn.BatchNorm2d(config.encoder_hidden_dim),
|
| 1385 |
+
)
|
| 1386 |
+
)
|
| 1387 |
+
self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list)
|
| 1388 |
+
|
| 1389 |
+
# Create encoder
|
| 1390 |
+
self.encoder = RTDetrV2HybridEncoder(config)
|
| 1391 |
+
|
| 1392 |
+
# denoising part
|
| 1393 |
+
if config.num_denoising > 0:
|
| 1394 |
+
self.denoising_class_embed = nn.Embedding(
|
| 1395 |
+
config.num_labels + 1, config.d_model, padding_idx=config.num_labels
|
| 1396 |
+
)
|
| 1397 |
+
|
| 1398 |
+
# decoder embedding
|
| 1399 |
+
if config.learn_initial_query:
|
| 1400 |
+
self.weight_embedding = nn.Embedding(config.num_queries, config.d_model)
|
| 1401 |
+
|
| 1402 |
+
# encoder head
|
| 1403 |
+
self.enc_output = nn.Sequential(
|
| 1404 |
+
nn.Linear(config.d_model, config.d_model),
|
| 1405 |
+
nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
|
| 1406 |
+
)
|
| 1407 |
+
self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
|
| 1408 |
+
self.enc_bbox_head = RTDetrV2MLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
|
| 1409 |
+
|
| 1410 |
+
# init encoder output anchors and valid_mask
|
| 1411 |
+
if config.anchor_image_size:
|
| 1412 |
+
self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
|
| 1413 |
+
|
| 1414 |
+
# Create decoder input projection layers
|
| 1415 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412
|
| 1416 |
+
num_backbone_outs = len(config.decoder_in_channels)
|
| 1417 |
+
decoder_input_proj_list = []
|
| 1418 |
+
for _ in range(num_backbone_outs):
|
| 1419 |
+
in_channels = config.decoder_in_channels[_]
|
| 1420 |
+
decoder_input_proj_list.append(
|
| 1421 |
+
nn.Sequential(
|
| 1422 |
+
nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
|
| 1423 |
+
nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
|
| 1424 |
+
)
|
| 1425 |
+
)
|
| 1426 |
+
for _ in range(config.num_feature_levels - num_backbone_outs):
|
| 1427 |
+
decoder_input_proj_list.append(
|
| 1428 |
+
nn.Sequential(
|
| 1429 |
+
nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False),
|
| 1430 |
+
nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
|
| 1431 |
+
)
|
| 1432 |
+
)
|
| 1433 |
+
in_channels = config.d_model
|
| 1434 |
+
self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list)
|
| 1435 |
+
# decoder
|
| 1436 |
+
self.decoder = RTDetrV2Decoder(config)
|
| 1437 |
+
|
| 1438 |
+
self.post_init()
|
| 1439 |
+
|
| 1440 |
+
def get_encoder(self):
|
| 1441 |
+
return self.encoder
|
| 1442 |
+
|
| 1443 |
+
def freeze_backbone(self):
|
| 1444 |
+
for param in self.backbone.parameters():
|
| 1445 |
+
param.requires_grad_(False)
|
| 1446 |
+
|
| 1447 |
+
def unfreeze_backbone(self):
|
| 1448 |
+
for param in self.backbone.parameters():
|
| 1449 |
+
param.requires_grad_(True)
|
| 1450 |
+
|
| 1451 |
+
@compile_compatible_method_lru_cache(maxsize=32)
|
| 1452 |
+
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
|
| 1453 |
+
if spatial_shapes is None:
|
| 1454 |
+
spatial_shapes = [
|
| 1455 |
+
[int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
|
| 1456 |
+
for s in self.config.feat_strides
|
| 1457 |
+
]
|
| 1458 |
+
anchors = []
|
| 1459 |
+
for level, (height, width) in enumerate(spatial_shapes):
|
| 1460 |
+
grid_y, grid_x = torch.meshgrid(
|
| 1461 |
+
torch.arange(end=height, device=device).to(dtype),
|
| 1462 |
+
torch.arange(end=width, device=device).to(dtype),
|
| 1463 |
+
indexing="ij",
|
| 1464 |
+
)
|
| 1465 |
+
grid_xy = torch.stack([grid_x, grid_y], -1)
|
| 1466 |
+
grid_xy = grid_xy.unsqueeze(0) + 0.5
|
| 1467 |
+
grid_xy[..., 0] /= width
|
| 1468 |
+
grid_xy[..., 1] /= height
|
| 1469 |
+
wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
|
| 1470 |
+
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
|
| 1471 |
+
# define the valid range for anchor coordinates
|
| 1472 |
+
eps = 1e-2
|
| 1473 |
+
anchors = torch.concat(anchors, 1)
|
| 1474 |
+
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
|
| 1475 |
+
anchors = torch.log(anchors / (1 - anchors))
|
| 1476 |
+
anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
|
| 1477 |
+
|
| 1478 |
+
return anchors, valid_mask
|
| 1479 |
+
|
| 1480 |
+
@auto_docstring
|
| 1481 |
+
def forward(
|
| 1482 |
+
self,
|
| 1483 |
+
pixel_values: torch.FloatTensor,
|
| 1484 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1485 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1486 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1487 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1488 |
+
labels: Optional[list[dict]] = None,
|
| 1489 |
+
output_attentions: Optional[bool] = None,
|
| 1490 |
+
output_hidden_states: Optional[bool] = None,
|
| 1491 |
+
return_dict: Optional[bool] = None,
|
| 1492 |
+
) -> Union[tuple[torch.FloatTensor], RTDetrV2ModelOutput]:
|
| 1493 |
+
r"""
|
| 1494 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1495 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 1496 |
+
can choose to directly pass a flattened representation of an image.
|
| 1497 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 1498 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 1499 |
+
embedded representation.
|
| 1500 |
+
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
| 1501 |
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
| 1502 |
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
| 1503 |
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
| 1504 |
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
| 1505 |
+
|
| 1506 |
+
Examples:
|
| 1507 |
+
|
| 1508 |
+
```python
|
| 1509 |
+
>>> from transformers import AutoImageProcessor, RTDetrV2Model
|
| 1510 |
+
>>> from PIL import Image
|
| 1511 |
+
>>> import requests
|
| 1512 |
+
|
| 1513 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1514 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1515 |
+
|
| 1516 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("PekingU/RTDetrV2_r50vd")
|
| 1517 |
+
>>> model = RTDetrV2Model.from_pretrained("PekingU/RTDetrV2_r50vd")
|
| 1518 |
+
|
| 1519 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1520 |
+
|
| 1521 |
+
>>> outputs = model(**inputs)
|
| 1522 |
+
|
| 1523 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 1524 |
+
>>> list(last_hidden_states.shape)
|
| 1525 |
+
[1, 300, 256]
|
| 1526 |
+
```"""
|
| 1527 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1528 |
+
output_hidden_states = (
|
| 1529 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1530 |
+
)
|
| 1531 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1532 |
+
|
| 1533 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 1534 |
+
device = pixel_values.device
|
| 1535 |
+
|
| 1536 |
+
if pixel_mask is None:
|
| 1537 |
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
| 1538 |
+
|
| 1539 |
+
features = self.backbone(pixel_values, pixel_mask)
|
| 1540 |
+
|
| 1541 |
+
proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
|
| 1542 |
+
|
| 1543 |
+
if encoder_outputs is None:
|
| 1544 |
+
encoder_outputs = self.encoder(
|
| 1545 |
+
proj_feats,
|
| 1546 |
+
output_attentions=output_attentions,
|
| 1547 |
+
output_hidden_states=output_hidden_states,
|
| 1548 |
+
return_dict=return_dict,
|
| 1549 |
+
)
|
| 1550 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
| 1551 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 1552 |
+
encoder_outputs = BaseModelOutput(
|
| 1553 |
+
last_hidden_state=encoder_outputs[0],
|
| 1554 |
+
hidden_states=encoder_outputs[1] if output_hidden_states else None,
|
| 1555 |
+
attentions=encoder_outputs[2]
|
| 1556 |
+
if len(encoder_outputs) > 2
|
| 1557 |
+
else encoder_outputs[1]
|
| 1558 |
+
if output_attentions
|
| 1559 |
+
else None,
|
| 1560 |
+
)
|
| 1561 |
+
|
| 1562 |
+
# Equivalent to def _get_encoder_input
|
| 1563 |
+
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412
|
| 1564 |
+
sources = []
|
| 1565 |
+
for level, source in enumerate(encoder_outputs[0]):
|
| 1566 |
+
sources.append(self.decoder_input_proj[level](source))
|
| 1567 |
+
|
| 1568 |
+
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
| 1569 |
+
if self.config.num_feature_levels > len(sources):
|
| 1570 |
+
_len_sources = len(sources)
|
| 1571 |
+
sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
|
| 1572 |
+
for i in range(_len_sources + 1, self.config.num_feature_levels):
|
| 1573 |
+
sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
|
| 1574 |
+
|
| 1575 |
+
# Prepare encoder inputs (by flattening)
|
| 1576 |
+
source_flatten = []
|
| 1577 |
+
spatial_shapes_list = []
|
| 1578 |
+
spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long)
|
| 1579 |
+
for level, source in enumerate(sources):
|
| 1580 |
+
height, width = source.shape[-2:]
|
| 1581 |
+
spatial_shapes[level, 0] = height
|
| 1582 |
+
spatial_shapes[level, 1] = width
|
| 1583 |
+
spatial_shapes_list.append((height, width))
|
| 1584 |
+
source = source.flatten(2).transpose(1, 2)
|
| 1585 |
+
source_flatten.append(source)
|
| 1586 |
+
source_flatten = torch.cat(source_flatten, 1)
|
| 1587 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
| 1588 |
+
|
| 1589 |
+
# prepare denoising training
|
| 1590 |
+
if self.training and self.config.num_denoising > 0 and labels is not None:
|
| 1591 |
+
(
|
| 1592 |
+
denoising_class,
|
| 1593 |
+
denoising_bbox_unact,
|
| 1594 |
+
attention_mask,
|
| 1595 |
+
denoising_meta_values,
|
| 1596 |
+
) = get_contrastive_denoising_training_group(
|
| 1597 |
+
targets=labels,
|
| 1598 |
+
num_classes=self.config.num_labels,
|
| 1599 |
+
num_queries=self.config.num_queries,
|
| 1600 |
+
class_embed=self.denoising_class_embed,
|
| 1601 |
+
num_denoising_queries=self.config.num_denoising,
|
| 1602 |
+
label_noise_ratio=self.config.label_noise_ratio,
|
| 1603 |
+
box_noise_scale=self.config.box_noise_scale,
|
| 1604 |
+
)
|
| 1605 |
+
else:
|
| 1606 |
+
denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None
|
| 1607 |
+
|
| 1608 |
+
batch_size = len(source_flatten)
|
| 1609 |
+
device = source_flatten.device
|
| 1610 |
+
dtype = source_flatten.dtype
|
| 1611 |
+
|
| 1612 |
+
# prepare input for decoder
|
| 1613 |
+
if self.training or self.config.anchor_image_size is None:
|
| 1614 |
+
# Pass spatial_shapes as tuple to make it hashable and make sure
|
| 1615 |
+
# lru_cache is working for generate_anchors()
|
| 1616 |
+
spatial_shapes_tuple = tuple(spatial_shapes_list)
|
| 1617 |
+
anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
|
| 1618 |
+
else:
|
| 1619 |
+
anchors, valid_mask = self.anchors, self.valid_mask
|
| 1620 |
+
anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
|
| 1621 |
+
|
| 1622 |
+
# use the valid_mask to selectively retain values in the feature map where the mask is `True`
|
| 1623 |
+
memory = valid_mask.to(source_flatten.dtype) * source_flatten
|
| 1624 |
+
|
| 1625 |
+
output_memory = self.enc_output(memory)
|
| 1626 |
+
|
| 1627 |
+
enc_outputs_class = self.enc_score_head(output_memory)
|
| 1628 |
+
enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors
|
| 1629 |
+
|
| 1630 |
+
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1)
|
| 1631 |
+
|
| 1632 |
+
reference_points_unact = enc_outputs_coord_logits.gather(
|
| 1633 |
+
dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1])
|
| 1634 |
+
)
|
| 1635 |
+
|
| 1636 |
+
enc_topk_bboxes = F.sigmoid(reference_points_unact)
|
| 1637 |
+
if denoising_bbox_unact is not None:
|
| 1638 |
+
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
|
| 1639 |
+
|
| 1640 |
+
enc_topk_logits = enc_outputs_class.gather(
|
| 1641 |
+
dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
|
| 1642 |
+
)
|
| 1643 |
+
|
| 1644 |
+
# extract region features
|
| 1645 |
+
if self.config.learn_initial_query:
|
| 1646 |
+
target = self.weight_embedding.tile([batch_size, 1, 1])
|
| 1647 |
+
else:
|
| 1648 |
+
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
|
| 1649 |
+
target = target.detach()
|
| 1650 |
+
|
| 1651 |
+
if denoising_class is not None:
|
| 1652 |
+
target = torch.concat([denoising_class, target], 1)
|
| 1653 |
+
|
| 1654 |
+
init_reference_points = reference_points_unact.detach()
|
| 1655 |
+
|
| 1656 |
+
# decoder
|
| 1657 |
+
decoder_outputs = self.decoder(
|
| 1658 |
+
inputs_embeds=target,
|
| 1659 |
+
encoder_hidden_states=source_flatten,
|
| 1660 |
+
encoder_attention_mask=attention_mask,
|
| 1661 |
+
reference_points=init_reference_points,
|
| 1662 |
+
spatial_shapes=spatial_shapes,
|
| 1663 |
+
spatial_shapes_list=spatial_shapes_list,
|
| 1664 |
+
level_start_index=level_start_index,
|
| 1665 |
+
output_attentions=output_attentions,
|
| 1666 |
+
output_hidden_states=output_hidden_states,
|
| 1667 |
+
return_dict=return_dict,
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
if not return_dict:
|
| 1671 |
+
enc_outputs = tuple(
|
| 1672 |
+
value
|
| 1673 |
+
for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
|
| 1674 |
+
if value is not None
|
| 1675 |
+
)
|
| 1676 |
+
dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
|
| 1677 |
+
tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
|
| 1678 |
+
|
| 1679 |
+
return tuple_outputs
|
| 1680 |
+
|
| 1681 |
+
return RTDetrV2ModelOutput(
|
| 1682 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1683 |
+
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
| 1684 |
+
intermediate_logits=decoder_outputs.intermediate_logits,
|
| 1685 |
+
intermediate_reference_points=decoder_outputs.intermediate_reference_points,
|
| 1686 |
+
intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners,
|
| 1687 |
+
initial_reference_points=decoder_outputs.initial_reference_points,
|
| 1688 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1689 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1690 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 1691 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1692 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1693 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1694 |
+
init_reference_points=init_reference_points,
|
| 1695 |
+
enc_topk_logits=enc_topk_logits,
|
| 1696 |
+
enc_topk_bboxes=enc_topk_bboxes,
|
| 1697 |
+
enc_outputs_class=enc_outputs_class,
|
| 1698 |
+
enc_outputs_coord_logits=enc_outputs_coord_logits,
|
| 1699 |
+
denoising_meta_values=denoising_meta_values,
|
| 1700 |
+
)
|
| 1701 |
+
|
| 1702 |
+
|
| 1703 |
+
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 1704 |
+
class RTDetrV2MLPPredictionHead(nn.Module):
|
| 1705 |
+
"""
|
| 1706 |
+
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
| 1707 |
+
height and width of a bounding box w.r.t. an image.
|
| 1708 |
+
|
| 1709 |
+
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 1710 |
+
Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_paddle/ppdet/modeling/transformers/utils.py#L453
|
| 1711 |
+
|
| 1712 |
+
"""
|
| 1713 |
+
|
| 1714 |
+
def __init__(self, config, input_dim, d_model, output_dim, num_layers):
|
| 1715 |
+
super().__init__()
|
| 1716 |
+
self.num_layers = num_layers
|
| 1717 |
+
h = [d_model] * (num_layers - 1)
|
| 1718 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 1719 |
+
|
| 1720 |
+
def forward(self, x):
|
| 1721 |
+
for i, layer in enumerate(self.layers):
|
| 1722 |
+
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 1723 |
+
return x
|
| 1724 |
+
|
| 1725 |
+
|
| 1726 |
+
@dataclass
|
| 1727 |
+
@auto_docstring(
|
| 1728 |
+
custom_intro="""
|
| 1729 |
+
Output type of [`RTDetrV2ForObjectDetection`].
|
| 1730 |
+
"""
|
| 1731 |
+
)
|
| 1732 |
+
class RTDetrV2ObjectDetectionOutput(ModelOutput):
|
| 1733 |
+
r"""
|
| 1734 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
| 1735 |
+
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
| 1736 |
+
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
| 1737 |
+
scale-invariant IoU loss.
|
| 1738 |
+
loss_dict (`Dict`, *optional*):
|
| 1739 |
+
A dictionary containing the individual losses. Useful for logging.
|
| 1740 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
| 1741 |
+
Classification logits (including no-object) for all queries.
|
| 1742 |
+
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 1743 |
+
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
| 1744 |
+
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
| 1745 |
+
possible padding). You can use [`~RTDetrV2ImageProcessor.post_process_object_detection`] to retrieve the
|
| 1746 |
+
unnormalized (absolute) bounding boxes.
|
| 1747 |
+
auxiliary_outputs (`list[Dict]`, *optional*):
|
| 1748 |
+
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
| 1749 |
+
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
| 1750 |
+
`pred_boxes`) for each decoder layer.
|
| 1751 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
| 1752 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 1753 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
| 1754 |
+
Stacked intermediate hidden states (output of each layer of the decoder).
|
| 1755 |
+
intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`):
|
| 1756 |
+
Stacked intermediate logits (logits of each layer of the decoder).
|
| 1757 |
+
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 1758 |
+
Stacked intermediate reference points (reference points of each layer of the decoder).
|
| 1759 |
+
intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 1760 |
+
Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
|
| 1761 |
+
initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
| 1762 |
+
Stacked initial reference points (initial reference points of each layer of the decoder).
|
| 1763 |
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 1764 |
+
Initial reference points sent through the Transformer decoder.
|
| 1765 |
+
enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 1766 |
+
Logits of predicted bounding boxes coordinates in the encoder.
|
| 1767 |
+
enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 1768 |
+
Logits of predicted bounding boxes coordinates in the encoder.
|
| 1769 |
+
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 1770 |
+
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
| 1771 |
+
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
|
| 1772 |
+
foreground and background).
|
| 1773 |
+
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
| 1774 |
+
Logits of predicted bounding boxes coordinates in the first stage.
|
| 1775 |
+
denoising_meta_values (`dict`):
|
| 1776 |
+
Extra dictionary for the denoising related values
|
| 1777 |
+
"""
|
| 1778 |
+
|
| 1779 |
+
loss: Optional[torch.FloatTensor] = None
|
| 1780 |
+
loss_dict: Optional[dict] = None
|
| 1781 |
+
logits: Optional[torch.FloatTensor] = None
|
| 1782 |
+
pred_boxes: Optional[torch.FloatTensor] = None
|
| 1783 |
+
auxiliary_outputs: Optional[list[dict]] = None
|
| 1784 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 1785 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 1786 |
+
intermediate_logits: Optional[torch.FloatTensor] = None
|
| 1787 |
+
intermediate_reference_points: Optional[torch.FloatTensor] = None
|
| 1788 |
+
intermediate_predicted_corners: Optional[torch.FloatTensor] = None
|
| 1789 |
+
initial_reference_points: Optional[torch.FloatTensor] = None
|
| 1790 |
+
decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 1791 |
+
decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 1792 |
+
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 1793 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 1794 |
+
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 1795 |
+
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 1796 |
+
init_reference_points: Optional[tuple[torch.FloatTensor]] = None
|
| 1797 |
+
enc_topk_logits: Optional[torch.FloatTensor] = None
|
| 1798 |
+
enc_topk_bboxes: Optional[torch.FloatTensor] = None
|
| 1799 |
+
enc_outputs_class: Optional[torch.FloatTensor] = None
|
| 1800 |
+
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
| 1801 |
+
denoising_meta_values: Optional[dict] = None
|
| 1802 |
+
|
| 1803 |
+
|
| 1804 |
+
@auto_docstring(
|
| 1805 |
+
custom_intro="""
|
| 1806 |
+
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further
|
| 1807 |
+
decoded into scores and classes.
|
| 1808 |
+
"""
|
| 1809 |
+
)
|
| 1810 |
+
class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
|
| 1811 |
+
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
| 1812 |
+
_tied_weights_keys = ["bbox_embed", "class_embed"]
|
| 1813 |
+
# We can't initialize the model on meta device as some weights are modified during the initialization
|
| 1814 |
+
_no_split_modules = None
|
| 1815 |
+
|
| 1816 |
+
def __init__(self, config: RTDetrV2Config):
|
| 1817 |
+
super().__init__(config)
|
| 1818 |
+
# RTDETR encoder-decoder model
|
| 1819 |
+
self.model = RTDetrV2Model(config)
|
| 1820 |
+
|
| 1821 |
+
# Detection heads on top
|
| 1822 |
+
class_embed = partial(nn.Linear, config.d_model, config.num_labels)
|
| 1823 |
+
bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3)
|
| 1824 |
+
|
| 1825 |
+
self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)])
|
| 1826 |
+
self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)])
|
| 1827 |
+
|
| 1828 |
+
self.model.decoder.class_embed = self.class_embed
|
| 1829 |
+
self.model.decoder.bbox_embed = self.bbox_embed
|
| 1830 |
+
|
| 1831 |
+
# Initialize weights and apply final processing
|
| 1832 |
+
self.post_init()
|
| 1833 |
+
|
| 1834 |
+
@torch.jit.unused
|
| 1835 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
| 1836 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 1837 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 1838 |
+
# as a dict having both a Tensor and a list.
|
| 1839 |
+
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
|
| 1840 |
+
|
| 1841 |
+
@auto_docstring
|
| 1842 |
+
def forward(
|
| 1843 |
+
self,
|
| 1844 |
+
pixel_values: torch.FloatTensor,
|
| 1845 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1846 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1847 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1848 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1849 |
+
labels: Optional[list[dict]] = None,
|
| 1850 |
+
output_attentions: Optional[bool] = None,
|
| 1851 |
+
output_hidden_states: Optional[bool] = None,
|
| 1852 |
+
return_dict: Optional[bool] = None,
|
| 1853 |
+
**kwargs,
|
| 1854 |
+
) -> Union[tuple[torch.FloatTensor], RTDetrV2ObjectDetectionOutput]:
|
| 1855 |
+
r"""
|
| 1856 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1857 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 1858 |
+
can choose to directly pass a flattened representation of an image.
|
| 1859 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 1860 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 1861 |
+
embedded representation.
|
| 1862 |
+
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
| 1863 |
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
| 1864 |
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
| 1865 |
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
| 1866 |
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
| 1867 |
+
|
| 1868 |
+
Examples:
|
| 1869 |
+
|
| 1870 |
+
```python
|
| 1871 |
+
>>> from transformers import RTDetrV2ImageProcessor, RTDetrV2ForObjectDetection
|
| 1872 |
+
>>> from PIL import Image
|
| 1873 |
+
>>> import requests
|
| 1874 |
+
>>> import torch
|
| 1875 |
+
|
| 1876 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1877 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1878 |
+
|
| 1879 |
+
>>> image_processor = RTDetrV2ImageProcessor.from_pretrained("PekingU/RTDetrV2_r50vd")
|
| 1880 |
+
>>> model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/RTDetrV2_r50vd")
|
| 1881 |
+
|
| 1882 |
+
>>> # prepare image for the model
|
| 1883 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1884 |
+
|
| 1885 |
+
>>> # forward pass
|
| 1886 |
+
>>> outputs = model(**inputs)
|
| 1887 |
+
|
| 1888 |
+
>>> logits = outputs.logits
|
| 1889 |
+
>>> list(logits.shape)
|
| 1890 |
+
[1, 300, 80]
|
| 1891 |
+
|
| 1892 |
+
>>> boxes = outputs.pred_boxes
|
| 1893 |
+
>>> list(boxes.shape)
|
| 1894 |
+
[1, 300, 4]
|
| 1895 |
+
|
| 1896 |
+
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
| 1897 |
+
>>> target_sizes = torch.tensor([image.size[::-1]])
|
| 1898 |
+
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
|
| 1899 |
+
... 0
|
| 1900 |
+
... ]
|
| 1901 |
+
|
| 1902 |
+
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 1903 |
+
... box = [round(i, 2) for i in box.tolist()]
|
| 1904 |
+
... print(
|
| 1905 |
+
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
| 1906 |
+
... f"{round(score.item(), 3)} at location {box}"
|
| 1907 |
+
... )
|
| 1908 |
+
Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21]
|
| 1909 |
+
Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5]
|
| 1910 |
+
Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22]
|
| 1911 |
+
Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
|
| 1912 |
+
Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
|
| 1913 |
+
```"""
|
| 1914 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1915 |
+
output_hidden_states = (
|
| 1916 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1917 |
+
)
|
| 1918 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1919 |
+
|
| 1920 |
+
outputs = self.model(
|
| 1921 |
+
pixel_values,
|
| 1922 |
+
pixel_mask=pixel_mask,
|
| 1923 |
+
encoder_outputs=encoder_outputs,
|
| 1924 |
+
inputs_embeds=inputs_embeds,
|
| 1925 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1926 |
+
labels=labels,
|
| 1927 |
+
output_attentions=output_attentions,
|
| 1928 |
+
output_hidden_states=output_hidden_states,
|
| 1929 |
+
return_dict=return_dict,
|
| 1930 |
+
)
|
| 1931 |
+
|
| 1932 |
+
denoising_meta_values = (
|
| 1933 |
+
outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
|
| 1934 |
+
)
|
| 1935 |
+
|
| 1936 |
+
outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
|
| 1937 |
+
outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
|
| 1938 |
+
predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
|
| 1939 |
+
initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
|
| 1940 |
+
|
| 1941 |
+
logits = outputs_class[:, -1]
|
| 1942 |
+
pred_boxes = outputs_coord[:, -1]
|
| 1943 |
+
|
| 1944 |
+
loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
|
| 1945 |
+
if labels is not None:
|
| 1946 |
+
enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
|
| 1947 |
+
enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
|
| 1948 |
+
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
| 1949 |
+
logits,
|
| 1950 |
+
labels,
|
| 1951 |
+
self.device,
|
| 1952 |
+
pred_boxes,
|
| 1953 |
+
self.config,
|
| 1954 |
+
outputs_class,
|
| 1955 |
+
outputs_coord,
|
| 1956 |
+
enc_topk_logits=enc_topk_logits,
|
| 1957 |
+
enc_topk_bboxes=enc_topk_bboxes,
|
| 1958 |
+
denoising_meta_values=denoising_meta_values,
|
| 1959 |
+
predicted_corners=predicted_corners,
|
| 1960 |
+
initial_reference_points=initial_reference_points,
|
| 1961 |
+
**kwargs,
|
| 1962 |
+
)
|
| 1963 |
+
|
| 1964 |
+
if not return_dict:
|
| 1965 |
+
if auxiliary_outputs is not None:
|
| 1966 |
+
output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
|
| 1967 |
+
else:
|
| 1968 |
+
output = (logits, pred_boxes) + outputs
|
| 1969 |
+
return ((loss, loss_dict) + output) if loss is not None else output
|
| 1970 |
+
|
| 1971 |
+
return RTDetrV2ObjectDetectionOutput(
|
| 1972 |
+
loss=loss,
|
| 1973 |
+
loss_dict=loss_dict,
|
| 1974 |
+
logits=logits,
|
| 1975 |
+
pred_boxes=pred_boxes,
|
| 1976 |
+
auxiliary_outputs=auxiliary_outputs,
|
| 1977 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1978 |
+
intermediate_hidden_states=outputs.intermediate_hidden_states,
|
| 1979 |
+
intermediate_logits=outputs.intermediate_logits,
|
| 1980 |
+
intermediate_reference_points=outputs.intermediate_reference_points,
|
| 1981 |
+
intermediate_predicted_corners=outputs.intermediate_predicted_corners,
|
| 1982 |
+
initial_reference_points=outputs.initial_reference_points,
|
| 1983 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1984 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 1985 |
+
cross_attentions=outputs.cross_attentions,
|
| 1986 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1987 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1988 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 1989 |
+
init_reference_points=outputs.init_reference_points,
|
| 1990 |
+
enc_topk_logits=outputs.enc_topk_logits,
|
| 1991 |
+
enc_topk_bboxes=outputs.enc_topk_bboxes,
|
| 1992 |
+
enc_outputs_class=outputs.enc_outputs_class,
|
| 1993 |
+
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
| 1994 |
+
denoising_meta_values=outputs.denoising_meta_values,
|
| 1995 |
+
)
|
| 1996 |
+
|
| 1997 |
+
|
| 1998 |
+
__all__ = ["RTDetrV2Model", "RTDetrV2PreTrainedModel", "RTDetrV2ForObjectDetection"]
|
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/modular_rt_detr_v2.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import warnings
|
| 16 |
+
from functools import partial
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from torch import Tensor, nn
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import PretrainedConfig
|
| 24 |
+
from ...utils import is_torchdynamo_compiling, logging
|
| 25 |
+
from ...utils.backbone_utils import (
|
| 26 |
+
verify_backbone_config_arguments,
|
| 27 |
+
)
|
| 28 |
+
from ..auto import CONFIG_MAPPING
|
| 29 |
+
from ..rt_detr.modeling_rt_detr import (
|
| 30 |
+
RTDetrDecoder,
|
| 31 |
+
RTDetrDecoderLayer,
|
| 32 |
+
RTDetrForObjectDetection,
|
| 33 |
+
RTDetrMLPPredictionHead,
|
| 34 |
+
RTDetrModel,
|
| 35 |
+
RTDetrPreTrainedModel,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RTDetrV2Config(PretrainedConfig):
|
| 43 |
+
r"""
|
| 44 |
+
This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a
|
| 45 |
+
RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 46 |
+
with the defaults will yield a similar configuration to that of the RT-DETR architecture.
|
| 47 |
+
|
| 48 |
+
e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd)
|
| 49 |
+
|
| 50 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 51 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
initializer_range (`float`, *optional*, defaults to 0.01):
|
| 55 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 56 |
+
initializer_bias_prior_prob (`float`, *optional*):
|
| 57 |
+
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
| 58 |
+
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
| 59 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 60 |
+
The epsilon used by the layer normalization layers.
|
| 61 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 62 |
+
The epsilon used by the batch normalization layers.
|
| 63 |
+
backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
|
| 64 |
+
The configuration of the backbone model.
|
| 65 |
+
backbone (`str`, *optional*):
|
| 66 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 67 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 68 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 69 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 70 |
+
Whether to use pretrained weights for the backbone.
|
| 71 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 73 |
+
library.
|
| 74 |
+
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
|
| 75 |
+
Whether to freeze the batch normalization layers in the backbone.
|
| 76 |
+
backbone_kwargs (`dict`, *optional*):
|
| 77 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 78 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 79 |
+
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 80 |
+
Dimension of the layers in hybrid encoder.
|
| 81 |
+
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
|
| 82 |
+
Multi level features input for encoder.
|
| 83 |
+
feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
|
| 84 |
+
Strides used in each feature map.
|
| 85 |
+
encoder_layers (`int`, *optional*, defaults to 1):
|
| 86 |
+
Total of layers to be used by the encoder.
|
| 87 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 88 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 89 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 90 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 91 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 92 |
+
The ratio for all dropout layers.
|
| 93 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 94 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 95 |
+
encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
|
| 96 |
+
Indexes of the projected layers to be used in the encoder.
|
| 97 |
+
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
| 98 |
+
The temperature parameter used to create the positional encodings.
|
| 99 |
+
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
|
| 100 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 101 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 102 |
+
activation_function (`str`, *optional*, defaults to `"silu"`):
|
| 103 |
+
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
|
| 104 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 105 |
+
eval_size (`tuple[int, int]`, *optional*):
|
| 106 |
+
Height and width used to compute the effective height and width of the position embeddings after taking
|
| 107 |
+
into account the stride.
|
| 108 |
+
normalize_before (`bool`, *optional*, defaults to `False`):
|
| 109 |
+
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
|
| 110 |
+
feed-forward modules.
|
| 111 |
+
hidden_expansion (`float`, *optional*, defaults to 1.0):
|
| 112 |
+
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
|
| 113 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 114 |
+
Dimension of the layers exclude hybrid encoder.
|
| 115 |
+
num_queries (`int`, *optional*, defaults to 300):
|
| 116 |
+
Number of object queries.
|
| 117 |
+
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
|
| 118 |
+
Multi level features dimension for decoder
|
| 119 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 120 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 121 |
+
num_feature_levels (`int`, *optional*, defaults to 3):
|
| 122 |
+
The number of input feature levels.
|
| 123 |
+
decoder_n_points (`int`, *optional*, defaults to 4):
|
| 124 |
+
The number of sampled keys in each feature level for each attention head in the decoder.
|
| 125 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 126 |
+
Number of decoder layers.
|
| 127 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 128 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 129 |
+
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
|
| 130 |
+
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
|
| 131 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 132 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 133 |
+
The dropout ratio for the attention probabilities.
|
| 134 |
+
num_denoising (`int`, *optional*, defaults to 100):
|
| 135 |
+
The total number of denoising tasks or queries to be used for contrastive denoising.
|
| 136 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 137 |
+
The fraction of denoising labels to which random noise should be added.
|
| 138 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 139 |
+
Scale or magnitude of noise to be added to the bounding boxes.
|
| 140 |
+
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
| 141 |
+
Indicates whether the initial query embeddings for the decoder should be learned during training
|
| 142 |
+
anchor_image_size (`tuple[int, int]`, *optional*):
|
| 143 |
+
Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
|
| 144 |
+
with_box_refine (`bool`, *optional*, defaults to `True`):
|
| 145 |
+
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
| 146 |
+
based on the predictions from the previous layer.
|
| 147 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 148 |
+
Whether the architecture has an encoder decoder structure.
|
| 149 |
+
matcher_alpha (`float`, *optional*, defaults to 0.25):
|
| 150 |
+
Parameter alpha used by the Hungarian Matcher.
|
| 151 |
+
matcher_gamma (`float`, *optional*, defaults to 2.0):
|
| 152 |
+
Parameter gamma used by the Hungarian Matcher.
|
| 153 |
+
matcher_class_cost (`float`, *optional*, defaults to 2.0):
|
| 154 |
+
The relative weight of the class loss used by the Hungarian Matcher.
|
| 155 |
+
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
|
| 156 |
+
The relative weight of the bounding box loss used by the Hungarian Matcher.
|
| 157 |
+
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
|
| 158 |
+
The relative weight of the giou loss of used by the Hungarian Matcher.
|
| 159 |
+
use_focal_loss (`bool`, *optional*, defaults to `True`):
|
| 160 |
+
Parameter informing if focal loss should be used.
|
| 161 |
+
auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
| 162 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 163 |
+
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
|
| 164 |
+
Parameter alpha used to compute the focal loss.
|
| 165 |
+
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
|
| 166 |
+
Parameter gamma used to compute the focal loss.
|
| 167 |
+
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
|
| 168 |
+
Relative weight of the varifocal loss in the object detection loss.
|
| 169 |
+
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
|
| 170 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 171 |
+
weight_loss_giou (`float`, *optional*, defaults to 2.0):
|
| 172 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 173 |
+
eos_coefficient (`float`, *optional*, defaults to 0.0001):
|
| 174 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 175 |
+
decoder_n_levels (`int`, *optional*, defaults to 3):
|
| 176 |
+
The number of feature levels used by the decoder.
|
| 177 |
+
decoder_offset_scale (`float`, *optional*, defaults to 0.5):
|
| 178 |
+
Scaling factor applied to the attention offsets in the decoder.
|
| 179 |
+
decoder_method (`str`, *optional*, defaults to `"default"`):
|
| 180 |
+
The method to use for the decoder: `"default"` or `"discrete"`.
|
| 181 |
+
|
| 182 |
+
Examples:
|
| 183 |
+
|
| 184 |
+
```python
|
| 185 |
+
>>> from transformers import RTDetrV2Config, RTDetrV2Model
|
| 186 |
+
|
| 187 |
+
>>> # Initializing a RT-DETR configuration
|
| 188 |
+
>>> configuration = RTDetrV2Config()
|
| 189 |
+
|
| 190 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 191 |
+
>>> model = RTDetrV2Model(configuration)
|
| 192 |
+
|
| 193 |
+
>>> # Accessing the model configuration
|
| 194 |
+
>>> configuration = model.config
|
| 195 |
+
```
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
model_type = "rt_detr_v2"
|
| 199 |
+
layer_types = ["basic", "bottleneck"]
|
| 200 |
+
attribute_map = {
|
| 201 |
+
"hidden_size": "d_model",
|
| 202 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
initializer_range=0.01,
|
| 208 |
+
initializer_bias_prior_prob=None,
|
| 209 |
+
layer_norm_eps=1e-5,
|
| 210 |
+
batch_norm_eps=1e-5,
|
| 211 |
+
# backbone
|
| 212 |
+
backbone_config=None,
|
| 213 |
+
backbone=None,
|
| 214 |
+
use_pretrained_backbone=False,
|
| 215 |
+
use_timm_backbone=False,
|
| 216 |
+
freeze_backbone_batch_norms=True,
|
| 217 |
+
backbone_kwargs=None,
|
| 218 |
+
# encoder HybridEncoder
|
| 219 |
+
encoder_hidden_dim=256,
|
| 220 |
+
encoder_in_channels=[512, 1024, 2048],
|
| 221 |
+
feat_strides=[8, 16, 32],
|
| 222 |
+
encoder_layers=1,
|
| 223 |
+
encoder_ffn_dim=1024,
|
| 224 |
+
encoder_attention_heads=8,
|
| 225 |
+
dropout=0.0,
|
| 226 |
+
activation_dropout=0.0,
|
| 227 |
+
encode_proj_layers=[2],
|
| 228 |
+
positional_encoding_temperature=10000,
|
| 229 |
+
encoder_activation_function="gelu",
|
| 230 |
+
activation_function="silu",
|
| 231 |
+
eval_size=None,
|
| 232 |
+
normalize_before=False,
|
| 233 |
+
hidden_expansion=1.0,
|
| 234 |
+
# decoder RTDetrV2Transformer
|
| 235 |
+
d_model=256,
|
| 236 |
+
num_queries=300,
|
| 237 |
+
decoder_in_channels=[256, 256, 256],
|
| 238 |
+
decoder_ffn_dim=1024,
|
| 239 |
+
num_feature_levels=3,
|
| 240 |
+
decoder_n_points=4,
|
| 241 |
+
decoder_layers=6,
|
| 242 |
+
decoder_attention_heads=8,
|
| 243 |
+
decoder_activation_function="relu",
|
| 244 |
+
attention_dropout=0.0,
|
| 245 |
+
num_denoising=100,
|
| 246 |
+
label_noise_ratio=0.5,
|
| 247 |
+
box_noise_scale=1.0,
|
| 248 |
+
learn_initial_query=False,
|
| 249 |
+
anchor_image_size=None,
|
| 250 |
+
with_box_refine=True,
|
| 251 |
+
is_encoder_decoder=True,
|
| 252 |
+
# Loss
|
| 253 |
+
matcher_alpha=0.25,
|
| 254 |
+
matcher_gamma=2.0,
|
| 255 |
+
matcher_class_cost=2.0,
|
| 256 |
+
matcher_bbox_cost=5.0,
|
| 257 |
+
matcher_giou_cost=2.0,
|
| 258 |
+
use_focal_loss=True,
|
| 259 |
+
auxiliary_loss=True,
|
| 260 |
+
focal_loss_alpha=0.75,
|
| 261 |
+
focal_loss_gamma=2.0,
|
| 262 |
+
weight_loss_vfl=1.0,
|
| 263 |
+
weight_loss_bbox=5.0,
|
| 264 |
+
weight_loss_giou=2.0,
|
| 265 |
+
eos_coefficient=1e-4,
|
| 266 |
+
decoder_n_levels=3, # default value
|
| 267 |
+
decoder_offset_scale=0.5, # default value
|
| 268 |
+
decoder_method="default",
|
| 269 |
+
**kwargs,
|
| 270 |
+
):
|
| 271 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 272 |
+
self.initializer_range = initializer_range
|
| 273 |
+
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
| 274 |
+
self.layer_norm_eps = layer_norm_eps
|
| 275 |
+
self.batch_norm_eps = batch_norm_eps
|
| 276 |
+
# backbone
|
| 277 |
+
if backbone_config is None and backbone is None:
|
| 278 |
+
logger.info(
|
| 279 |
+
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone."
|
| 280 |
+
)
|
| 281 |
+
backbone_model_type = "rt_detr_resnet"
|
| 282 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 283 |
+
# this will map it to RTDetrResNetConfig
|
| 284 |
+
# note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1
|
| 285 |
+
# and we would need to create RTDetrV2ResNetModel
|
| 286 |
+
backbone_config = config_class(
|
| 287 |
+
num_channels=3,
|
| 288 |
+
embedding_size=64,
|
| 289 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 290 |
+
depths=[3, 4, 6, 3],
|
| 291 |
+
layer_type="bottleneck",
|
| 292 |
+
hidden_act="relu",
|
| 293 |
+
downsample_in_first_stage=False,
|
| 294 |
+
downsample_in_bottleneck=False,
|
| 295 |
+
out_features=None,
|
| 296 |
+
out_indices=[2, 3, 4],
|
| 297 |
+
)
|
| 298 |
+
elif isinstance(backbone_config, dict):
|
| 299 |
+
backbone_model_type = backbone_config.pop("model_type")
|
| 300 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 301 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 302 |
+
|
| 303 |
+
verify_backbone_config_arguments(
|
| 304 |
+
use_timm_backbone=use_timm_backbone,
|
| 305 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 306 |
+
backbone=backbone,
|
| 307 |
+
backbone_config=backbone_config,
|
| 308 |
+
backbone_kwargs=backbone_kwargs,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
self.backbone_config = backbone_config
|
| 312 |
+
self.backbone = backbone
|
| 313 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 314 |
+
self.use_timm_backbone = use_timm_backbone
|
| 315 |
+
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
|
| 316 |
+
self.backbone_kwargs = backbone_kwargs
|
| 317 |
+
# encoder
|
| 318 |
+
self.encoder_hidden_dim = encoder_hidden_dim
|
| 319 |
+
self.encoder_in_channels = encoder_in_channels
|
| 320 |
+
self.feat_strides = feat_strides
|
| 321 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 322 |
+
self.dropout = dropout
|
| 323 |
+
self.activation_dropout = activation_dropout
|
| 324 |
+
self.encode_proj_layers = encode_proj_layers
|
| 325 |
+
self.encoder_layers = encoder_layers
|
| 326 |
+
self.positional_encoding_temperature = positional_encoding_temperature
|
| 327 |
+
self.eval_size = eval_size
|
| 328 |
+
self.normalize_before = normalize_before
|
| 329 |
+
self.encoder_activation_function = encoder_activation_function
|
| 330 |
+
self.activation_function = activation_function
|
| 331 |
+
self.hidden_expansion = hidden_expansion
|
| 332 |
+
self.num_queries = num_queries
|
| 333 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 334 |
+
self.decoder_in_channels = decoder_in_channels
|
| 335 |
+
self.num_feature_levels = num_feature_levels
|
| 336 |
+
self.decoder_n_points = decoder_n_points
|
| 337 |
+
self.decoder_layers = decoder_layers
|
| 338 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 339 |
+
self.decoder_activation_function = decoder_activation_function
|
| 340 |
+
self.attention_dropout = attention_dropout
|
| 341 |
+
self.num_denoising = num_denoising
|
| 342 |
+
self.label_noise_ratio = label_noise_ratio
|
| 343 |
+
self.box_noise_scale = box_noise_scale
|
| 344 |
+
self.learn_initial_query = learn_initial_query
|
| 345 |
+
self.anchor_image_size = anchor_image_size
|
| 346 |
+
self.auxiliary_loss = auxiliary_loss
|
| 347 |
+
self.with_box_refine = with_box_refine
|
| 348 |
+
# Loss
|
| 349 |
+
self.matcher_alpha = matcher_alpha
|
| 350 |
+
self.matcher_gamma = matcher_gamma
|
| 351 |
+
self.matcher_class_cost = matcher_class_cost
|
| 352 |
+
self.matcher_bbox_cost = matcher_bbox_cost
|
| 353 |
+
self.matcher_giou_cost = matcher_giou_cost
|
| 354 |
+
self.use_focal_loss = use_focal_loss
|
| 355 |
+
self.focal_loss_alpha = focal_loss_alpha
|
| 356 |
+
self.focal_loss_gamma = focal_loss_gamma
|
| 357 |
+
self.weight_loss_vfl = weight_loss_vfl
|
| 358 |
+
self.weight_loss_bbox = weight_loss_bbox
|
| 359 |
+
self.weight_loss_giou = weight_loss_giou
|
| 360 |
+
self.eos_coefficient = eos_coefficient
|
| 361 |
+
|
| 362 |
+
if not hasattr(self, "d_model"):
|
| 363 |
+
self.d_model = d_model
|
| 364 |
+
|
| 365 |
+
if not hasattr(self, "encoder_attention_heads"):
|
| 366 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 367 |
+
# add the new attributes with the given values or defaults
|
| 368 |
+
self.decoder_n_levels = decoder_n_levels
|
| 369 |
+
self.decoder_offset_scale = decoder_offset_scale
|
| 370 |
+
self.decoder_method = decoder_method
|
| 371 |
+
|
| 372 |
+
@property
|
| 373 |
+
def sub_configs(self):
|
| 374 |
+
return (
|
| 375 |
+
{"backbone_config": type(self.backbone_config)}
|
| 376 |
+
if getattr(self, "backbone_config", None) is not None
|
| 377 |
+
else {}
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
@classmethod
|
| 381 |
+
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 382 |
+
"""Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
| 383 |
+
configuration.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
backbone_config ([`PretrainedConfig`]):
|
| 387 |
+
The backbone configuration.
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
[`RTDetrV2Config`]: An instance of a configuration object
|
| 391 |
+
"""
|
| 392 |
+
return cls(
|
| 393 |
+
backbone_config=backbone_config,
|
| 394 |
+
**kwargs,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def multi_scale_deformable_attention_v2(
|
| 399 |
+
value: Tensor,
|
| 400 |
+
value_spatial_shapes: Tensor,
|
| 401 |
+
sampling_locations: Tensor,
|
| 402 |
+
attention_weights: Tensor,
|
| 403 |
+
num_points_list: list[int],
|
| 404 |
+
method="default",
|
| 405 |
+
) -> Tensor:
|
| 406 |
+
batch_size, _, num_heads, hidden_dim = value.shape
|
| 407 |
+
_, num_queries, num_heads, num_levels, num_points = sampling_locations.shape
|
| 408 |
+
value_list = (
|
| 409 |
+
value.permute(0, 2, 3, 1)
|
| 410 |
+
.flatten(0, 1)
|
| 411 |
+
.split([height * width for height, width in value_spatial_shapes], dim=-1)
|
| 412 |
+
)
|
| 413 |
+
# sampling_offsets [8, 480, 8, 12, 2]
|
| 414 |
+
if method == "default":
|
| 415 |
+
sampling_grids = 2 * sampling_locations - 1
|
| 416 |
+
elif method == "discrete":
|
| 417 |
+
sampling_grids = sampling_locations
|
| 418 |
+
sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 419 |
+
sampling_grids = sampling_grids.split(num_points_list, dim=-2)
|
| 420 |
+
sampling_value_list = []
|
| 421 |
+
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
| 422 |
+
# batch_size, height*width, num_heads, hidden_dim
|
| 423 |
+
# -> batch_size, height*width, num_heads*hidden_dim
|
| 424 |
+
# -> batch_size, num_heads*hidden_dim, height*width
|
| 425 |
+
# -> batch_size*num_heads, hidden_dim, height, width
|
| 426 |
+
value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width)
|
| 427 |
+
# batch_size, num_queries, num_heads, num_points, 2
|
| 428 |
+
# -> batch_size, num_heads, num_queries, num_points, 2
|
| 429 |
+
# -> batch_size*num_heads, num_queries, num_points, 2
|
| 430 |
+
sampling_grid_l_ = sampling_grids[level_id]
|
| 431 |
+
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
| 432 |
+
if method == "default":
|
| 433 |
+
sampling_value_l_ = nn.functional.grid_sample(
|
| 434 |
+
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
| 435 |
+
)
|
| 436 |
+
elif method == "discrete":
|
| 437 |
+
sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to(
|
| 438 |
+
torch.int64
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Separate clamping for x and y coordinates
|
| 442 |
+
sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1)
|
| 443 |
+
sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1)
|
| 444 |
+
|
| 445 |
+
# Combine the clamped coordinates
|
| 446 |
+
sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1)
|
| 447 |
+
sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2)
|
| 448 |
+
sampling_idx = (
|
| 449 |
+
torch.arange(sampling_coord.shape[0], device=value.device)
|
| 450 |
+
.unsqueeze(-1)
|
| 451 |
+
.repeat(1, sampling_coord.shape[1])
|
| 452 |
+
)
|
| 453 |
+
sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]]
|
| 454 |
+
sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape(
|
| 455 |
+
batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id]
|
| 456 |
+
)
|
| 457 |
+
sampling_value_list.append(sampling_value_l_)
|
| 458 |
+
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
| 459 |
+
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
| 460 |
+
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
| 461 |
+
attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(
|
| 462 |
+
batch_size * num_heads, 1, num_queries, sum(num_points_list)
|
| 463 |
+
)
|
| 464 |
+
output = (
|
| 465 |
+
(torch.concat(sampling_value_list, dim=-1) * attention_weights)
|
| 466 |
+
.sum(-1)
|
| 467 |
+
.view(batch_size, num_heads * hidden_dim, num_queries)
|
| 468 |
+
)
|
| 469 |
+
return output.transpose(1, 2).contiguous()
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
# the main change
|
| 473 |
+
class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
| 474 |
+
"""
|
| 475 |
+
RTDetrV2 version of multiscale deformable attention, extending the base implementation
|
| 476 |
+
with improved offset handling and initialization.
|
| 477 |
+
"""
|
| 478 |
+
|
| 479 |
+
def __init__(self, config: RTDetrV2Config):
|
| 480 |
+
super().__init__()
|
| 481 |
+
num_heads = config.decoder_attention_heads
|
| 482 |
+
n_points = config.decoder_n_points
|
| 483 |
+
|
| 484 |
+
if config.d_model % num_heads != 0:
|
| 485 |
+
raise ValueError(
|
| 486 |
+
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
|
| 487 |
+
)
|
| 488 |
+
dim_per_head = config.d_model // num_heads
|
| 489 |
+
# check if dim_per_head is power of 2
|
| 490 |
+
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
|
| 491 |
+
warnings.warn(
|
| 492 |
+
"You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the"
|
| 493 |
+
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
|
| 494 |
+
" implementation."
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
self.im2col_step = 64
|
| 498 |
+
|
| 499 |
+
self.d_model = config.d_model
|
| 500 |
+
|
| 501 |
+
# V2-specific attributes
|
| 502 |
+
self.n_levels = config.decoder_n_levels
|
| 503 |
+
self.n_heads = num_heads
|
| 504 |
+
self.n_points = n_points
|
| 505 |
+
|
| 506 |
+
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
|
| 507 |
+
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
|
| 508 |
+
self.value_proj = nn.Linear(config.d_model, config.d_model)
|
| 509 |
+
self.output_proj = nn.Linear(config.d_model, config.d_model)
|
| 510 |
+
|
| 511 |
+
self.offset_scale = config.decoder_offset_scale
|
| 512 |
+
self.method = config.decoder_method
|
| 513 |
+
|
| 514 |
+
# Initialize n_points list and scale
|
| 515 |
+
n_points_list = [self.n_points for _ in range(self.n_levels)]
|
| 516 |
+
self.n_points_list = n_points_list
|
| 517 |
+
n_points_scale = [1 / n for n in n_points_list for _ in range(n)]
|
| 518 |
+
self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32))
|
| 519 |
+
|
| 520 |
+
def forward(
|
| 521 |
+
self,
|
| 522 |
+
hidden_states: torch.Tensor,
|
| 523 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 524 |
+
encoder_hidden_states=None,
|
| 525 |
+
encoder_attention_mask=None,
|
| 526 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 527 |
+
reference_points=None,
|
| 528 |
+
spatial_shapes=None,
|
| 529 |
+
spatial_shapes_list=None,
|
| 530 |
+
level_start_index=None,
|
| 531 |
+
output_attentions: bool = False,
|
| 532 |
+
):
|
| 533 |
+
# Process inputs up to sampling locations calculation using parent class logic
|
| 534 |
+
if position_embeddings is not None:
|
| 535 |
+
hidden_states = hidden_states + position_embeddings
|
| 536 |
+
|
| 537 |
+
batch_size, num_queries, _ = hidden_states.shape
|
| 538 |
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
| 539 |
+
if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
| 540 |
+
raise ValueError(
|
| 541 |
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
value = self.value_proj(encoder_hidden_states)
|
| 545 |
+
if attention_mask is not None:
|
| 546 |
+
value = value.masked_fill(~attention_mask[..., None], float(0))
|
| 547 |
+
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
| 548 |
+
|
| 549 |
+
# V2-specific sampling offsets shape
|
| 550 |
+
sampling_offsets = self.sampling_offsets(hidden_states).view(
|
| 551 |
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
attention_weights = self.attention_weights(hidden_states).view(
|
| 555 |
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
|
| 556 |
+
)
|
| 557 |
+
attention_weights = F.softmax(attention_weights, -1)
|
| 558 |
+
|
| 559 |
+
# V2-specific sampling locations calculation
|
| 560 |
+
if reference_points.shape[-1] == 2:
|
| 561 |
+
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
| 562 |
+
sampling_locations = (
|
| 563 |
+
reference_points[:, :, None, :, None, :]
|
| 564 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
| 565 |
+
)
|
| 566 |
+
elif reference_points.shape[-1] == 4:
|
| 567 |
+
n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
|
| 568 |
+
offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
|
| 569 |
+
sampling_locations = reference_points[:, :, None, :, :2] + offset
|
| 570 |
+
else:
|
| 571 |
+
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
| 572 |
+
|
| 573 |
+
# V2-specific attention implementation choice
|
| 574 |
+
output = multi_scale_deformable_attention_v2(
|
| 575 |
+
value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
output = self.output_proj(output)
|
| 579 |
+
return output, attention_weights
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
class RTDetrV2DecoderLayer(RTDetrDecoderLayer):
|
| 583 |
+
def __init__(self, config: RTDetrV2Config):
|
| 584 |
+
# initialize parent class
|
| 585 |
+
super().__init__(config)
|
| 586 |
+
# override only the encoder attention module with v2 version
|
| 587 |
+
self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel):
|
| 591 |
+
pass
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class RTDetrV2Decoder(RTDetrDecoder):
|
| 595 |
+
def __init__(self, config: RTDetrV2Config):
|
| 596 |
+
super().__init__(config)
|
| 597 |
+
self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)])
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class RTDetrV2Model(RTDetrModel):
|
| 601 |
+
def __init__(self, config: RTDetrV2Config):
|
| 602 |
+
super().__init__(config)
|
| 603 |
+
# decoder
|
| 604 |
+
self.decoder = RTDetrV2Decoder(config)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class RTDetrV2MLPPredictionHead(RTDetrMLPPredictionHead):
|
| 608 |
+
pass
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
class RTDetrV2ForObjectDetection(RTDetrForObjectDetection, RTDetrV2PreTrainedModel):
|
| 612 |
+
def __init__(self, config: RTDetrV2Config):
|
| 613 |
+
RTDetrV2PreTrainedModel.__init__(self, config)
|
| 614 |
+
# RTDETR encoder-decoder model
|
| 615 |
+
self.model = RTDetrV2Model(config)
|
| 616 |
+
|
| 617 |
+
# Detection heads on top
|
| 618 |
+
class_embed = partial(nn.Linear, config.d_model, config.num_labels)
|
| 619 |
+
bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3)
|
| 620 |
+
|
| 621 |
+
self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)])
|
| 622 |
+
self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)])
|
| 623 |
+
|
| 624 |
+
self.model.decoder.class_embed = self.class_embed
|
| 625 |
+
self.model.decoder.bbox_embed = self.bbox_embed
|
| 626 |
+
|
| 627 |
+
# Initialize weights and apply final processing
|
| 628 |
+
self.post_init()
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
__all__ = [
|
| 632 |
+
"RTDetrV2Config",
|
| 633 |
+
"RTDetrV2Model",
|
| 634 |
+
"RTDetrV2PreTrainedModel",
|
| 635 |
+
"RTDetrV2ForObjectDetection",
|
| 636 |
+
]
|
phivenv/Lib/site-packages/transformers/models/rwkv/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_rwkv import *
|
| 22 |
+
from .modeling_rwkv import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (522 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/configuration_rwkv.cpython-39.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/modeling_rwkv.cpython-39.pyc
ADDED
|
Binary file (23.5 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/rwkv/configuration_rwkv.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""RWKV configuration"""
|
| 17 |
+
|
| 18 |
+
from ...configuration_utils import PretrainedConfig
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RwkvConfig(PretrainedConfig):
|
| 26 |
+
"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV
|
| 28 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 29 |
+
defaults will yield a similar configuration to that of the RWVK-4
|
| 30 |
+
[RWKV/rwkv-4-169m-pile](https://huggingface.co/RWKV/rwkv-4-169m-pile) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
vocab_size (`int`, *optional*, defaults to 50277):
|
| 38 |
+
Vocabulary size of the RWKV model. Defines the number of different tokens that can be represented by the
|
| 39 |
+
`inputs_ids` passed when calling [`RwkvModel`].
|
| 40 |
+
context_length (`int`, *optional*, defaults to 1024):
|
| 41 |
+
The maximum sequence length that this model can be used with in a single forward (using it in RNN mode
|
| 42 |
+
lets use any sequence length).
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 44 |
+
Dimensionality of the embeddings and hidden states.
|
| 45 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 46 |
+
Number of hidden layers in the model.
|
| 47 |
+
attention_hidden_size (`int`, *optional*):
|
| 48 |
+
Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
|
| 49 |
+
intermediate_size (`int`, *optional*):
|
| 50 |
+
Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
|
| 51 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
| 52 |
+
The epsilon to use in the layer normalization layers.
|
| 53 |
+
bos_token_id (`int`, *optional*, defaults to 0):
|
| 54 |
+
The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer
|
| 55 |
+
as GPTNeoX.
|
| 56 |
+
eos_token_id (`int`, *optional*, defaults to 0):
|
| 57 |
+
The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as
|
| 58 |
+
GPTNeoX.
|
| 59 |
+
rescale_every (`int`, *optional*, defaults to 6):
|
| 60 |
+
At inference, the hidden states (and weights of the corresponding output layers) are divided by 2 every
|
| 61 |
+
`rescale_every` layer. If set to 0 or a negative number, no rescale is done.
|
| 62 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 63 |
+
Whether or not to tie the word embeddings with the input token embeddings.
|
| 64 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether or not the model should return the last state.
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
Example:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
>>> from transformers import RwkvConfig, RwkvModel
|
| 72 |
+
|
| 73 |
+
>>> # Initializing a Rwkv configuration
|
| 74 |
+
>>> configuration = RwkvConfig()
|
| 75 |
+
|
| 76 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 77 |
+
>>> model = RwkvModel(configuration)
|
| 78 |
+
|
| 79 |
+
>>> # Accessing the model configuration
|
| 80 |
+
>>> configuration = model.config
|
| 81 |
+
```"""
|
| 82 |
+
|
| 83 |
+
model_type = "rwkv"
|
| 84 |
+
attribute_map = {"max_position_embeddings": "context_length"}
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
vocab_size=50277,
|
| 89 |
+
context_length=1024,
|
| 90 |
+
hidden_size=4096,
|
| 91 |
+
num_hidden_layers=32,
|
| 92 |
+
attention_hidden_size=None,
|
| 93 |
+
intermediate_size=None,
|
| 94 |
+
layer_norm_epsilon=1e-5,
|
| 95 |
+
bos_token_id=0,
|
| 96 |
+
eos_token_id=0,
|
| 97 |
+
rescale_every=6,
|
| 98 |
+
tie_word_embeddings=False,
|
| 99 |
+
use_cache=True,
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
self.vocab_size = vocab_size
|
| 103 |
+
self.context_length = context_length
|
| 104 |
+
self.hidden_size = hidden_size
|
| 105 |
+
self.num_hidden_layers = num_hidden_layers
|
| 106 |
+
self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
|
| 107 |
+
self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size
|
| 108 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 109 |
+
self.rescale_every = rescale_every
|
| 110 |
+
self.use_cache = use_cache
|
| 111 |
+
|
| 112 |
+
self.bos_token_id = bos_token_id
|
| 113 |
+
self.eos_token_id = eos_token_id
|
| 114 |
+
|
| 115 |
+
super().__init__(
|
| 116 |
+
tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
__all__ = ["RwkvConfig"]
|
phivenv/Lib/site-packages/transformers/models/rwkv/modeling_rwkv.py
ADDED
|
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Bo Peng and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch RWKV model."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Optional, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from ...generation import GenerationMixin
|
| 28 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 29 |
+
from ...modeling_utils import PreTrainedModel
|
| 30 |
+
from ...utils import (
|
| 31 |
+
ModelOutput,
|
| 32 |
+
auto_docstring,
|
| 33 |
+
is_bitsandbytes_available,
|
| 34 |
+
is_ninja_available,
|
| 35 |
+
is_torch_cuda_available,
|
| 36 |
+
logging,
|
| 37 |
+
)
|
| 38 |
+
from .configuration_rwkv import RwkvConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
rwkv_cuda_kernel = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_wkv_cuda_kernel(context_length):
|
| 48 |
+
from torch.utils.cpp_extension import load as load_kernel
|
| 49 |
+
|
| 50 |
+
global rwkv_cuda_kernel
|
| 51 |
+
|
| 52 |
+
kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv"
|
| 53 |
+
cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]]
|
| 54 |
+
|
| 55 |
+
# Only load the kernel if it's not been loaded yet or if we changed the context length
|
| 56 |
+
if rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == context_length:
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
logger.info(f"Loading CUDA kernel for RWKV at context length of {context_length}.")
|
| 60 |
+
|
| 61 |
+
flags = [
|
| 62 |
+
"-res-usage",
|
| 63 |
+
"--maxrregcount 60",
|
| 64 |
+
"--use_fast_math",
|
| 65 |
+
"-O3",
|
| 66 |
+
"-Xptxas -O3",
|
| 67 |
+
"--extra-device-vectorization",
|
| 68 |
+
f"-DTmax={context_length}",
|
| 69 |
+
]
|
| 70 |
+
rwkv_cuda_kernel = load_kernel(
|
| 71 |
+
name=f"wkv_{context_length}",
|
| 72 |
+
sources=cuda_kernel_files,
|
| 73 |
+
verbose=(logging.get_verbosity() == logging.DEBUG),
|
| 74 |
+
extra_cuda_cflags=flags,
|
| 75 |
+
)
|
| 76 |
+
rwkv_cuda_kernel.max_seq_length = context_length
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class RwkvLinearAttention(torch.autograd.Function):
|
| 80 |
+
@staticmethod
|
| 81 |
+
def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
|
| 82 |
+
batch_size, seq_len, hidden_size = key.size()
|
| 83 |
+
if seq_len > rwkv_cuda_kernel.max_seq_length:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
|
| 86 |
+
f"{rwkv_cuda_kernel.max_seq_length} with this model."
|
| 87 |
+
)
|
| 88 |
+
if batch_size * hidden_size % min(hidden_size, 32) != 0:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
|
| 91 |
+
f"multiple of {min(hidden_size, 32)}."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
ctx.input_dtype = key.dtype
|
| 95 |
+
|
| 96 |
+
if (
|
| 97 |
+
time_decay.device.type != "cuda"
|
| 98 |
+
or time_first.device.type != "cuda"
|
| 99 |
+
or key.device.type != "cuda"
|
| 100 |
+
or value.device.type != "cuda"
|
| 101 |
+
):
|
| 102 |
+
raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
|
| 103 |
+
|
| 104 |
+
time_decay = -torch.exp(time_decay.float().contiguous())
|
| 105 |
+
if key.dtype == torch.float16:
|
| 106 |
+
time_first = time_first.float()
|
| 107 |
+
key = key.float()
|
| 108 |
+
value = value.float()
|
| 109 |
+
time_first = time_first.contiguous()
|
| 110 |
+
key = key.contiguous()
|
| 111 |
+
value = value.contiguous()
|
| 112 |
+
# The CUDA kernel will fill this tensor.
|
| 113 |
+
output = torch.empty_like(key, memory_format=torch.contiguous_format)
|
| 114 |
+
if return_state or state is not None:
|
| 115 |
+
if state is None:
|
| 116 |
+
state = torch.zeros(
|
| 117 |
+
batch_size,
|
| 118 |
+
hidden_size,
|
| 119 |
+
3,
|
| 120 |
+
dtype=torch.float32,
|
| 121 |
+
device=key.device,
|
| 122 |
+
memory_format=torch.contiguous_format,
|
| 123 |
+
)
|
| 124 |
+
state[:, :, 2] -= 1e38
|
| 125 |
+
else:
|
| 126 |
+
state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()
|
| 127 |
+
if key.dtype == torch.bfloat16:
|
| 128 |
+
forward_func = rwkv_cuda_kernel.forward_with_state_bf16
|
| 129 |
+
else:
|
| 130 |
+
forward_func = rwkv_cuda_kernel.forward_with_state
|
| 131 |
+
forward_func(time_decay, time_first, key, value, output, state)
|
| 132 |
+
else:
|
| 133 |
+
forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
|
| 134 |
+
forward_func(time_decay, time_first, key, value, output)
|
| 135 |
+
|
| 136 |
+
ctx.save_for_backward(time_decay, time_first, key, value, output)
|
| 137 |
+
|
| 138 |
+
if state is not None:
|
| 139 |
+
state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]
|
| 140 |
+
|
| 141 |
+
return output.to(ctx.input_dtype), state
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
# g stands for grad
|
| 145 |
+
def backward(ctx, g_output, g_state=None):
|
| 146 |
+
input_dtype = ctx.input_dtype
|
| 147 |
+
|
| 148 |
+
time_decay, time_first, key, value, output = ctx.saved_tensors
|
| 149 |
+
# The CUDA kernel will fill those tensors.
|
| 150 |
+
g_time_decay = torch.empty_like(
|
| 151 |
+
time_decay,
|
| 152 |
+
memory_format=torch.contiguous_format,
|
| 153 |
+
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
|
| 154 |
+
)
|
| 155 |
+
g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)
|
| 156 |
+
g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
|
| 157 |
+
g_value = torch.empty_like(value, memory_format=torch.contiguous_format)
|
| 158 |
+
|
| 159 |
+
if input_dtype == torch.float16:
|
| 160 |
+
g_output = g_output.float()
|
| 161 |
+
backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
|
| 162 |
+
backward_func(
|
| 163 |
+
time_decay,
|
| 164 |
+
time_first,
|
| 165 |
+
key,
|
| 166 |
+
value,
|
| 167 |
+
output,
|
| 168 |
+
g_output.contiguous(),
|
| 169 |
+
g_time_decay,
|
| 170 |
+
g_time_first,
|
| 171 |
+
g_key,
|
| 172 |
+
g_value,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return (
|
| 176 |
+
g_time_decay.to(input_dtype),
|
| 177 |
+
g_time_first.to(input_dtype),
|
| 178 |
+
g_key.to(input_dtype),
|
| 179 |
+
g_value.to(input_dtype),
|
| 180 |
+
None,
|
| 181 |
+
None,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False):
|
| 186 |
+
# For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
|
| 187 |
+
# within a torch.no_grad.
|
| 188 |
+
_, seq_length, _ = key.size()
|
| 189 |
+
output = torch.zeros_like(key)
|
| 190 |
+
|
| 191 |
+
if state is None:
|
| 192 |
+
num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
| 193 |
+
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
| 194 |
+
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
|
| 195 |
+
else:
|
| 196 |
+
num_state, den_state, max_state = state
|
| 197 |
+
# For numerical stability
|
| 198 |
+
# real_numerator_state = num_state * torch.exp(max_state)
|
| 199 |
+
# real_denominator_state = den_state * torch.exp(max_state)
|
| 200 |
+
|
| 201 |
+
time_decay = -torch.exp(time_decay)
|
| 202 |
+
|
| 203 |
+
for current_index in range(seq_length):
|
| 204 |
+
current_key = key[:, current_index].float()
|
| 205 |
+
current_value = value[:, current_index]
|
| 206 |
+
|
| 207 |
+
# wkv computation at time t
|
| 208 |
+
max_for_output = torch.maximum(max_state, current_key + time_first)
|
| 209 |
+
e1 = torch.exp(max_state - max_for_output)
|
| 210 |
+
e2 = torch.exp(current_key + time_first - max_for_output)
|
| 211 |
+
numerator = e1 * num_state + e2 * current_value
|
| 212 |
+
denominator = e1 * den_state + e2
|
| 213 |
+
output[:, current_index] = (numerator / denominator).to(output.dtype)
|
| 214 |
+
|
| 215 |
+
# Update state for next iteration
|
| 216 |
+
max_for_state = torch.maximum(max_state + time_decay, current_key)
|
| 217 |
+
e1 = torch.exp(max_state + time_decay - max_for_state)
|
| 218 |
+
e2 = torch.exp(current_key - max_for_state)
|
| 219 |
+
num_state = e1 * num_state + e2 * current_value
|
| 220 |
+
den_state = e1 * den_state + e2
|
| 221 |
+
max_state = max_for_state
|
| 222 |
+
|
| 223 |
+
if return_state or state is not None:
|
| 224 |
+
state = [num_state, den_state, max_state]
|
| 225 |
+
|
| 226 |
+
return output, state
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False):
|
| 230 |
+
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
|
| 231 |
+
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
|
| 232 |
+
# in this case).
|
| 233 |
+
one_token = key.size(1) == 1
|
| 234 |
+
if rwkv_cuda_kernel is None or no_cuda or one_token:
|
| 235 |
+
return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
|
| 236 |
+
else:
|
| 237 |
+
return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class RwkvSelfAttention(nn.Module):
|
| 241 |
+
def __init__(self, config, layer_id=0):
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.config = config
|
| 244 |
+
kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length
|
| 245 |
+
if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
|
| 246 |
+
try:
|
| 247 |
+
load_wkv_cuda_kernel(config.context_length)
|
| 248 |
+
except Exception:
|
| 249 |
+
logger.info("Could not load the custom CUDA kernel for RWKV attention.")
|
| 250 |
+
self.layer_id = layer_id
|
| 251 |
+
hidden_size = config.hidden_size
|
| 252 |
+
attention_hidden_size = (
|
| 253 |
+
config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
|
| 254 |
+
)
|
| 255 |
+
self.attention_hidden_size = attention_hidden_size
|
| 256 |
+
|
| 257 |
+
self.time_decay = nn.Parameter(torch.empty(attention_hidden_size))
|
| 258 |
+
self.time_first = nn.Parameter(torch.empty(attention_hidden_size))
|
| 259 |
+
|
| 260 |
+
self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 261 |
+
self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 262 |
+
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 263 |
+
|
| 264 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
| 265 |
+
self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
| 266 |
+
self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
| 267 |
+
self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
| 268 |
+
self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
|
| 269 |
+
|
| 270 |
+
# TODO: maybe jit, otherwise move inside forward
|
| 271 |
+
def extract_key_value(self, hidden, state=None):
|
| 272 |
+
# Mix hidden with the previous timestep to produce key, value, receptance
|
| 273 |
+
if hidden.size(1) == 1 and state is not None:
|
| 274 |
+
shifted = state[1][:, :, self.layer_id]
|
| 275 |
+
else:
|
| 276 |
+
shifted = self.time_shift(hidden)
|
| 277 |
+
if state is not None:
|
| 278 |
+
shifted[:, 0] = state[1][:, :, self.layer_id]
|
| 279 |
+
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
| 280 |
+
value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
|
| 281 |
+
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
| 282 |
+
|
| 283 |
+
key = self.key(key)
|
| 284 |
+
value = self.value(value)
|
| 285 |
+
receptance = torch.sigmoid(self.receptance(receptance))
|
| 286 |
+
if state is not None:
|
| 287 |
+
state[1][:, :, self.layer_id] = hidden[:, -1]
|
| 288 |
+
return receptance, key, value, state
|
| 289 |
+
|
| 290 |
+
def forward(self, hidden, state=None, use_cache=False):
|
| 291 |
+
receptance, key, value, state = self.extract_key_value(hidden, state=state)
|
| 292 |
+
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
|
| 293 |
+
rwkv, layer_state = rwkv_linear_attention(
|
| 294 |
+
self.time_decay,
|
| 295 |
+
self.time_first,
|
| 296 |
+
key,
|
| 297 |
+
value,
|
| 298 |
+
state=layer_state,
|
| 299 |
+
return_state=use_cache,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
if layer_state is not None:
|
| 303 |
+
state[2][:, :, self.layer_id] = layer_state[0]
|
| 304 |
+
state[3][:, :, self.layer_id] = layer_state[1]
|
| 305 |
+
state[4][:, :, self.layer_id] = layer_state[2]
|
| 306 |
+
|
| 307 |
+
return self.output(receptance * rwkv), state
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class RwkvFeedForward(nn.Module):
|
| 311 |
+
def __init__(self, config, layer_id=0):
|
| 312 |
+
super().__init__()
|
| 313 |
+
self.config = config
|
| 314 |
+
self.layer_id = layer_id
|
| 315 |
+
hidden_size = config.hidden_size
|
| 316 |
+
intermediate_size = (
|
| 317 |
+
config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
| 321 |
+
self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 322 |
+
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
|
| 323 |
+
|
| 324 |
+
self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 325 |
+
self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 326 |
+
self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 327 |
+
|
| 328 |
+
def forward(self, hidden, state=None):
|
| 329 |
+
if hidden.size(1) == 1 and state is not None:
|
| 330 |
+
shifted = state[0][:, :, self.layer_id]
|
| 331 |
+
else:
|
| 332 |
+
shifted = self.time_shift(hidden)
|
| 333 |
+
if state is not None:
|
| 334 |
+
shifted[:, 0] = state[0][:, :, self.layer_id]
|
| 335 |
+
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
| 336 |
+
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
| 337 |
+
|
| 338 |
+
key = torch.square(torch.relu(self.key(key)))
|
| 339 |
+
value = self.value(key)
|
| 340 |
+
receptance = torch.sigmoid(self.receptance(receptance))
|
| 341 |
+
|
| 342 |
+
if state is not None:
|
| 343 |
+
state[0][:, :, self.layer_id] = hidden[:, -1]
|
| 344 |
+
|
| 345 |
+
return receptance * value, state
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class RwkvBlock(GradientCheckpointingLayer):
|
| 349 |
+
def __init__(self, config, layer_id):
|
| 350 |
+
super().__init__()
|
| 351 |
+
self.config = config
|
| 352 |
+
self.layer_id = layer_id
|
| 353 |
+
|
| 354 |
+
if layer_id == 0:
|
| 355 |
+
self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 356 |
+
|
| 357 |
+
self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 358 |
+
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 359 |
+
|
| 360 |
+
self.attention = RwkvSelfAttention(config, layer_id)
|
| 361 |
+
self.feed_forward = RwkvFeedForward(config, layer_id)
|
| 362 |
+
|
| 363 |
+
def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
|
| 364 |
+
if self.layer_id == 0:
|
| 365 |
+
hidden = self.pre_ln(hidden)
|
| 366 |
+
|
| 367 |
+
attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
|
| 368 |
+
hidden = hidden + attention
|
| 369 |
+
|
| 370 |
+
feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
|
| 371 |
+
hidden = hidden + feed_forward
|
| 372 |
+
|
| 373 |
+
outputs = (hidden, state)
|
| 374 |
+
if output_attentions:
|
| 375 |
+
outputs += (attention,)
|
| 376 |
+
else:
|
| 377 |
+
outputs += (None,)
|
| 378 |
+
|
| 379 |
+
return outputs
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@auto_docstring
|
| 383 |
+
class RwkvPreTrainedModel(PreTrainedModel):
|
| 384 |
+
config: RwkvConfig
|
| 385 |
+
base_model_prefix = "rwkv"
|
| 386 |
+
_no_split_modules = ["RwkvBlock"]
|
| 387 |
+
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
| 388 |
+
supports_gradient_checkpointing = True
|
| 389 |
+
_is_stateful = True
|
| 390 |
+
|
| 391 |
+
def _init_weights(self, module: nn.Module):
|
| 392 |
+
"""Initialize the weights."""
|
| 393 |
+
if isinstance(module, RwkvSelfAttention):
|
| 394 |
+
layer_id = module.layer_id
|
| 395 |
+
num_hidden_layers = module.config.num_hidden_layers
|
| 396 |
+
hidden_size = module.config.hidden_size
|
| 397 |
+
attention_hidden_size = module.attention_hidden_size
|
| 398 |
+
|
| 399 |
+
ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
|
| 400 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
|
| 401 |
+
|
| 402 |
+
time_weight = torch.tensor(
|
| 403 |
+
[i / hidden_size for i in range(hidden_size)],
|
| 404 |
+
dtype=module.time_mix_key.dtype,
|
| 405 |
+
device=module.time_mix_key.device,
|
| 406 |
+
)
|
| 407 |
+
time_weight = time_weight[None, None, :]
|
| 408 |
+
|
| 409 |
+
decay_speed = [
|
| 410 |
+
-5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
| 411 |
+
for h in range(attention_hidden_size)
|
| 412 |
+
]
|
| 413 |
+
decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
|
| 414 |
+
zigzag = (
|
| 415 |
+
torch.tensor(
|
| 416 |
+
[(i + 1) % 3 - 1 for i in range(attention_hidden_size)],
|
| 417 |
+
dtype=module.time_first.dtype,
|
| 418 |
+
device=module.time_first.device,
|
| 419 |
+
)
|
| 420 |
+
* 0.5
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
module.time_decay.data = decay_speed
|
| 424 |
+
module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag)
|
| 425 |
+
|
| 426 |
+
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
| 427 |
+
module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
| 428 |
+
module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
|
| 429 |
+
elif isinstance(module, RwkvFeedForward):
|
| 430 |
+
layer_id = module.layer_id
|
| 431 |
+
num_hidden_layers = module.config.num_hidden_layers
|
| 432 |
+
hidden_size = module.config.hidden_size
|
| 433 |
+
|
| 434 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
|
| 435 |
+
|
| 436 |
+
time_weight = torch.tensor(
|
| 437 |
+
[i / hidden_size for i in range(hidden_size)],
|
| 438 |
+
dtype=module.time_mix_key.dtype,
|
| 439 |
+
device=module.time_mix_key.device,
|
| 440 |
+
)
|
| 441 |
+
time_weight = time_weight[None, None, :]
|
| 442 |
+
|
| 443 |
+
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
| 444 |
+
module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
|
| 445 |
+
elif isinstance(module, nn.Linear):
|
| 446 |
+
shape = module.weight.data.shape
|
| 447 |
+
gain = 1.0
|
| 448 |
+
scale = 1.0 # extra scale for gain
|
| 449 |
+
if module.bias is not None:
|
| 450 |
+
module.bias.data.zero_()
|
| 451 |
+
if shape[0] > shape[1]:
|
| 452 |
+
gain = math.sqrt(shape[0] / shape[1])
|
| 453 |
+
if shape[0] == self.config.vocab_size and shape[1] == self.config.hidden_size: # final projection?
|
| 454 |
+
scale = 0.5
|
| 455 |
+
|
| 456 |
+
gain *= scale
|
| 457 |
+
nn.init.orthogonal_(module.weight, gain=gain)
|
| 458 |
+
elif isinstance(module, nn.Embedding):
|
| 459 |
+
shape = module.weight.data.shape
|
| 460 |
+
gain = 1e-4 * math.sqrt(max(shape[0], shape[1]))
|
| 461 |
+
nn.init.orthogonal_(module.weight, gain=gain)
|
| 462 |
+
elif isinstance(module, nn.LayerNorm):
|
| 463 |
+
module.weight.data.fill_(1.0)
|
| 464 |
+
module.bias.data.zero_()
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
@dataclass
|
| 468 |
+
@auto_docstring(
|
| 469 |
+
custom_intro="""
|
| 470 |
+
Class for the RWKV model outputs.
|
| 471 |
+
"""
|
| 472 |
+
)
|
| 473 |
+
class RwkvOutput(ModelOutput):
|
| 474 |
+
r"""
|
| 475 |
+
state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
|
| 476 |
+
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
| 477 |
+
avoid providing the old `input_ids`.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 481 |
+
state: Optional[list[torch.FloatTensor]] = None
|
| 482 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 483 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
@dataclass
|
| 487 |
+
@auto_docstring(
|
| 488 |
+
custom_intro="""
|
| 489 |
+
Base class for causal language model (or autoregressive) outputs.
|
| 490 |
+
"""
|
| 491 |
+
)
|
| 492 |
+
class RwkvCausalLMOutput(ModelOutput):
|
| 493 |
+
r"""
|
| 494 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 495 |
+
Language modeling loss (for next-token prediction).
|
| 496 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 497 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 498 |
+
state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
|
| 499 |
+
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
| 500 |
+
avoid providing the old `input_ids`.
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
loss: Optional[torch.FloatTensor] = None
|
| 504 |
+
logits: Optional[torch.FloatTensor] = None
|
| 505 |
+
state: Optional[list[torch.FloatTensor]] = None
|
| 506 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 507 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@auto_docstring
|
| 511 |
+
class RwkvModel(RwkvPreTrainedModel):
|
| 512 |
+
def __init__(self, config):
|
| 513 |
+
super().__init__(config)
|
| 514 |
+
|
| 515 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 516 |
+
self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
|
| 517 |
+
self.ln_out = nn.LayerNorm(config.hidden_size)
|
| 518 |
+
|
| 519 |
+
self.layers_are_rescaled = False
|
| 520 |
+
|
| 521 |
+
self.gradient_checkpointing = False
|
| 522 |
+
|
| 523 |
+
# Initialize weights and apply final processing
|
| 524 |
+
self.post_init()
|
| 525 |
+
|
| 526 |
+
def get_input_embeddings(self):
|
| 527 |
+
return self.embeddings
|
| 528 |
+
|
| 529 |
+
def set_input_embeddings(self, new_embeddings):
|
| 530 |
+
self.embeddings = new_embeddings
|
| 531 |
+
|
| 532 |
+
@auto_docstring
|
| 533 |
+
def forward(
|
| 534 |
+
self,
|
| 535 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 536 |
+
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
| 537 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 538 |
+
state: Optional[list[torch.FloatTensor]] = None,
|
| 539 |
+
use_cache: Optional[bool] = None,
|
| 540 |
+
output_attentions: Optional[bool] = None,
|
| 541 |
+
output_hidden_states: Optional[bool] = None,
|
| 542 |
+
return_dict: Optional[bool] = None,
|
| 543 |
+
) -> Union[tuple, RwkvOutput]:
|
| 544 |
+
r"""
|
| 545 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
| 546 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
| 547 |
+
`past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
|
| 548 |
+
sequence tokens in the vocabulary.
|
| 549 |
+
|
| 550 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
| 551 |
+
`input_ids`.
|
| 552 |
+
|
| 553 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 554 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 555 |
+
|
| 556 |
+
[What are input IDs?](../glossary#input-ids)
|
| 557 |
+
state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
|
| 558 |
+
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
| 559 |
+
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
| 560 |
+
use_cache (`bool`, *optional*):
|
| 561 |
+
If set to `True`, the last state is returned and can be used to quickly generate the next logits.
|
| 562 |
+
"""
|
| 563 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 564 |
+
output_hidden_states = (
|
| 565 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 566 |
+
)
|
| 567 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
| 568 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 569 |
+
|
| 570 |
+
if attention_mask is not None:
|
| 571 |
+
logger.warning_once("`attention_mask` was passed, but it is unused in this model.")
|
| 572 |
+
|
| 573 |
+
if self.training == self.layers_are_rescaled:
|
| 574 |
+
self._rescale_layers()
|
| 575 |
+
|
| 576 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 577 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 578 |
+
elif input_ids is None and inputs_embeds is None:
|
| 579 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 580 |
+
|
| 581 |
+
if inputs_embeds is None:
|
| 582 |
+
inputs_embeds = self.embeddings(input_ids)
|
| 583 |
+
|
| 584 |
+
if use_cache and state is None:
|
| 585 |
+
shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
|
| 586 |
+
state = [
|
| 587 |
+
torch.zeros(
|
| 588 |
+
*shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device
|
| 589 |
+
)
|
| 590 |
+
for i in range(5)
|
| 591 |
+
]
|
| 592 |
+
state[4] -= 1e30
|
| 593 |
+
|
| 594 |
+
if self.gradient_checkpointing and self.training:
|
| 595 |
+
if use_cache:
|
| 596 |
+
logger.warning_once(
|
| 597 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 598 |
+
)
|
| 599 |
+
use_cache = False
|
| 600 |
+
|
| 601 |
+
hidden_states = inputs_embeds
|
| 602 |
+
|
| 603 |
+
all_self_attentions = () if output_attentions else None
|
| 604 |
+
all_hidden_states = () if output_hidden_states else None
|
| 605 |
+
for idx, block in enumerate(self.blocks):
|
| 606 |
+
hidden_states, state, attentions = block(
|
| 607 |
+
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if (
|
| 611 |
+
self.layers_are_rescaled
|
| 612 |
+
and self.config.rescale_every > 0
|
| 613 |
+
and (idx + 1) % self.config.rescale_every == 0
|
| 614 |
+
):
|
| 615 |
+
hidden_states = hidden_states / 2
|
| 616 |
+
|
| 617 |
+
if output_hidden_states:
|
| 618 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 619 |
+
|
| 620 |
+
if output_attentions:
|
| 621 |
+
all_self_attentions = all_self_attentions + (attentions,)
|
| 622 |
+
|
| 623 |
+
hidden_states = self.ln_out(hidden_states)
|
| 624 |
+
|
| 625 |
+
if output_hidden_states:
|
| 626 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 627 |
+
|
| 628 |
+
if not return_dict:
|
| 629 |
+
return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None)
|
| 630 |
+
|
| 631 |
+
return RwkvOutput(
|
| 632 |
+
last_hidden_state=hidden_states,
|
| 633 |
+
state=state,
|
| 634 |
+
hidden_states=all_hidden_states,
|
| 635 |
+
attentions=all_self_attentions,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
def _rescale_layers(self):
|
| 639 |
+
# Layers should be rescaled for inference only.
|
| 640 |
+
if self.layers_are_rescaled == (not self.training):
|
| 641 |
+
return
|
| 642 |
+
if self.config.rescale_every > 0:
|
| 643 |
+
with torch.no_grad():
|
| 644 |
+
for block_id, block in enumerate(self.blocks):
|
| 645 |
+
if self.training:
|
| 646 |
+
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
| 647 |
+
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
| 648 |
+
else:
|
| 649 |
+
# Deal with quantization statistics
|
| 650 |
+
if hasattr(block.attention.output.weight, "SCB"):
|
| 651 |
+
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
| 652 |
+
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
| 653 |
+
elif hasattr(block.attention.output.weight, "quant_state"):
|
| 654 |
+
self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
|
| 655 |
+
self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
|
| 656 |
+
else:
|
| 657 |
+
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
| 658 |
+
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
| 659 |
+
|
| 660 |
+
self.layers_are_rescaled = not self.training
|
| 661 |
+
|
| 662 |
+
def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
|
| 663 |
+
r"""
|
| 664 |
+
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
|
| 665 |
+
be quantized again.
|
| 666 |
+
"""
|
| 667 |
+
if not is_bitsandbytes_available():
|
| 668 |
+
raise ImportError("Please install bitsandbytes to use this method.")
|
| 669 |
+
import bitsandbytes as bnb
|
| 670 |
+
|
| 671 |
+
dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
|
| 672 |
+
|
| 673 |
+
dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
|
| 674 |
+
|
| 675 |
+
# re-quantize the model:
|
| 676 |
+
# we need to put it first on CPU then back to the device
|
| 677 |
+
# this will create an overhead :/
|
| 678 |
+
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
|
| 679 |
+
# bugs with bnb
|
| 680 |
+
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
|
| 681 |
+
setattr(target_layer, "weight", quant_weight)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
@auto_docstring(
|
| 685 |
+
custom_intro="""
|
| 686 |
+
The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 687 |
+
embeddings).
|
| 688 |
+
"""
|
| 689 |
+
)
|
| 690 |
+
class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin):
|
| 691 |
+
_tied_weights_keys = ["head.weight"]
|
| 692 |
+
|
| 693 |
+
def __init__(self, config):
|
| 694 |
+
super().__init__(config)
|
| 695 |
+
self.rwkv = RwkvModel(config)
|
| 696 |
+
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 697 |
+
|
| 698 |
+
# Initialize weights and apply final processing
|
| 699 |
+
self.post_init()
|
| 700 |
+
|
| 701 |
+
def get_output_embeddings(self):
|
| 702 |
+
return self.head
|
| 703 |
+
|
| 704 |
+
def set_output_embeddings(self, new_embeddings):
|
| 705 |
+
self.head = new_embeddings
|
| 706 |
+
|
| 707 |
+
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs):
|
| 708 |
+
# Overwritten -- this model uses `state`, but doesn't have a cache (`past_key_values`)
|
| 709 |
+
|
| 710 |
+
# only last token for inputs_ids if the state is passed along.
|
| 711 |
+
if state is not None:
|
| 712 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 713 |
+
|
| 714 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 715 |
+
if inputs_embeds is not None and state is None:
|
| 716 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 717 |
+
else:
|
| 718 |
+
model_inputs = {"input_ids": input_ids}
|
| 719 |
+
|
| 720 |
+
model_inputs["state"] = state
|
| 721 |
+
model_inputs["use_cache"] = use_cache
|
| 722 |
+
return model_inputs
|
| 723 |
+
|
| 724 |
+
@auto_docstring
|
| 725 |
+
def forward(
|
| 726 |
+
self,
|
| 727 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 728 |
+
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
| 729 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 730 |
+
state: Optional[list[torch.FloatTensor]] = None,
|
| 731 |
+
labels: Optional[torch.LongTensor] = None,
|
| 732 |
+
use_cache: Optional[bool] = None,
|
| 733 |
+
output_attentions: Optional[bool] = None,
|
| 734 |
+
output_hidden_states: Optional[bool] = None,
|
| 735 |
+
return_dict: Optional[bool] = None,
|
| 736 |
+
**kwargs,
|
| 737 |
+
) -> Union[tuple, RwkvCausalLMOutput]:
|
| 738 |
+
r"""
|
| 739 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
| 740 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
| 741 |
+
`past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
|
| 742 |
+
sequence tokens in the vocabulary.
|
| 743 |
+
|
| 744 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
| 745 |
+
`input_ids`.
|
| 746 |
+
|
| 747 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 748 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 749 |
+
|
| 750 |
+
[What are input IDs?](../glossary#input-ids)
|
| 751 |
+
state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
|
| 752 |
+
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
| 753 |
+
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
| 754 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 755 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 756 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 757 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 758 |
+
use_cache (`bool`, *optional*):
|
| 759 |
+
If set to `True`, the last state is returned and can be used to quickly generate the next logits.
|
| 760 |
+
"""
|
| 761 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 762 |
+
|
| 763 |
+
rwkv_outputs = self.rwkv(
|
| 764 |
+
input_ids,
|
| 765 |
+
inputs_embeds=inputs_embeds,
|
| 766 |
+
state=state,
|
| 767 |
+
use_cache=use_cache,
|
| 768 |
+
output_attentions=output_attentions,
|
| 769 |
+
output_hidden_states=output_hidden_states,
|
| 770 |
+
return_dict=return_dict,
|
| 771 |
+
)
|
| 772 |
+
hidden_states = rwkv_outputs[0]
|
| 773 |
+
|
| 774 |
+
logits = self.head(hidden_states)
|
| 775 |
+
|
| 776 |
+
loss = None
|
| 777 |
+
if labels is not None:
|
| 778 |
+
loss = self.loss_function(
|
| 779 |
+
logits,
|
| 780 |
+
labels,
|
| 781 |
+
vocab_size=self.config.vocab_size,
|
| 782 |
+
**kwargs,
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
if not return_dict:
|
| 786 |
+
output = (logits,) + rwkv_outputs[1:]
|
| 787 |
+
return ((loss,) + output) if loss is not None else output
|
| 788 |
+
|
| 789 |
+
return RwkvCausalLMOutput(
|
| 790 |
+
loss=loss,
|
| 791 |
+
logits=logits,
|
| 792 |
+
state=rwkv_outputs.state,
|
| 793 |
+
hidden_states=rwkv_outputs.hidden_states,
|
| 794 |
+
attentions=rwkv_outputs.attentions,
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
__all__ = ["RwkvForCausalLM", "RwkvModel", "RwkvPreTrainedModel"]
|
phivenv/Lib/site-packages/transformers/models/sam/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_sam import *
|
| 22 |
+
from .image_processing_sam import *
|
| 23 |
+
from .image_processing_sam_fast import *
|
| 24 |
+
from .modeling_sam import *
|
| 25 |
+
from .modeling_tf_sam import *
|
| 26 |
+
from .processing_sam import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (641 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/configuration_sam.cpython-39.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/image_processing_sam.cpython-39.pyc
ADDED
|
Binary file (49.9 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/image_processing_sam_fast.cpython-39.pyc
ADDED
|
Binary file (26.4 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/modeling_sam.cpython-39.pyc
ADDED
|
Binary file (46.7 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/modeling_tf_sam.cpython-39.pyc
ADDED
|
Binary file (55.4 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/processing_sam.cpython-39.pyc
ADDED
|
Binary file (8.46 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/sam/configuration_sam.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""SAM model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SamPromptEncoderConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`SamPromptEncoder`]. The [`SamPromptEncoder`]
|
| 27 |
+
module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield
|
| 28 |
+
a similar configuration to that of the SAM-vit-h
|
| 29 |
+
[facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
hidden_size (`int`, *optional*, defaults to 256):
|
| 36 |
+
Dimensionality of the hidden states.
|
| 37 |
+
image_size (`int`, *optional*, defaults to 1024):
|
| 38 |
+
The expected output resolution of the image.
|
| 39 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 40 |
+
The size (resolution) of each patch.
|
| 41 |
+
mask_input_channels (`int`, *optional*, defaults to 16):
|
| 42 |
+
The number of channels to be fed to the `MaskDecoder` module.
|
| 43 |
+
num_point_embeddings (`int`, *optional*, defaults to 4):
|
| 44 |
+
The number of point embeddings to be used.
|
| 45 |
+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
| 46 |
+
The non-linear activation function in the encoder and pooler.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
base_config_key = "prompt_encoder_config"
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
hidden_size=256,
|
| 54 |
+
image_size=1024,
|
| 55 |
+
patch_size=16,
|
| 56 |
+
mask_input_channels=16,
|
| 57 |
+
num_point_embeddings=4,
|
| 58 |
+
hidden_act="gelu",
|
| 59 |
+
layer_norm_eps=1e-6,
|
| 60 |
+
**kwargs,
|
| 61 |
+
):
|
| 62 |
+
super().__init__(**kwargs)
|
| 63 |
+
self.hidden_size = hidden_size
|
| 64 |
+
self.image_size = image_size
|
| 65 |
+
self.patch_size = patch_size
|
| 66 |
+
self.image_embedding_size = image_size // patch_size
|
| 67 |
+
self.mask_input_channels = mask_input_channels
|
| 68 |
+
self.num_point_embeddings = num_point_embeddings
|
| 69 |
+
self.hidden_act = hidden_act
|
| 70 |
+
self.layer_norm_eps = layer_norm_eps
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SamMaskDecoderConfig(PretrainedConfig):
|
| 74 |
+
r"""
|
| 75 |
+
This is the configuration class to store the configuration of a [`SamMaskDecoder`]. It is used to instantiate a SAM
|
| 76 |
+
mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults
|
| 77 |
+
will yield a similar configuration to that of the SAM-vit-h
|
| 78 |
+
[facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
|
| 79 |
+
|
| 80 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 81 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
hidden_size (`int`, *optional*, defaults to 256):
|
| 85 |
+
Dimensionality of the hidden states.
|
| 86 |
+
hidden_act (`str`, *optional*, defaults to `"relu"`):
|
| 87 |
+
The non-linear activation function used inside the `SamMaskDecoder` module.
|
| 88 |
+
mlp_dim (`int`, *optional*, defaults to 2048):
|
| 89 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 90 |
+
num_hidden_layers (`int`, *optional*, defaults to 2):
|
| 91 |
+
Number of hidden layers in the Transformer encoder.
|
| 92 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
| 93 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 94 |
+
attention_downsample_rate (`int`, *optional*, defaults to 2):
|
| 95 |
+
The downsampling rate of the attention layer.
|
| 96 |
+
num_multimask_outputs (`int`, *optional*, defaults to 3):
|
| 97 |
+
The number of outputs from the `SamMaskDecoder` module. In the Segment Anything paper, this is set to 3.
|
| 98 |
+
iou_head_depth (`int`, *optional*, defaults to 3):
|
| 99 |
+
The number of layers in the IoU head module.
|
| 100 |
+
iou_head_hidden_dim (`int`, *optional*, defaults to 256):
|
| 101 |
+
The dimensionality of the hidden states in the IoU head module.
|
| 102 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 103 |
+
The epsilon used by the layer normalization layers.
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
base_config_key = "mask_decoder_config"
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
hidden_size=256,
|
| 112 |
+
hidden_act="relu",
|
| 113 |
+
mlp_dim=2048,
|
| 114 |
+
num_hidden_layers=2,
|
| 115 |
+
num_attention_heads=8,
|
| 116 |
+
attention_downsample_rate=2,
|
| 117 |
+
num_multimask_outputs=3,
|
| 118 |
+
iou_head_depth=3,
|
| 119 |
+
iou_head_hidden_dim=256,
|
| 120 |
+
layer_norm_eps=1e-6,
|
| 121 |
+
**kwargs,
|
| 122 |
+
):
|
| 123 |
+
super().__init__(**kwargs)
|
| 124 |
+
self.hidden_size = hidden_size
|
| 125 |
+
self.hidden_act = hidden_act
|
| 126 |
+
self.mlp_dim = mlp_dim
|
| 127 |
+
self.num_hidden_layers = num_hidden_layers
|
| 128 |
+
self.num_attention_heads = num_attention_heads
|
| 129 |
+
self.attention_downsample_rate = attention_downsample_rate
|
| 130 |
+
self.num_multimask_outputs = num_multimask_outputs
|
| 131 |
+
self.iou_head_depth = iou_head_depth
|
| 132 |
+
self.iou_head_hidden_dim = iou_head_hidden_dim
|
| 133 |
+
self.layer_norm_eps = layer_norm_eps
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class SamVisionConfig(PretrainedConfig):
|
| 137 |
+
r"""
|
| 138 |
+
This is the configuration class to store the configuration of a [`SamVisionModel`]. It is used to instantiate a SAM
|
| 139 |
+
vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 140 |
+
defaults will yield a similar configuration to that of the SAM ViT-h
|
| 141 |
+
[facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
|
| 142 |
+
|
| 143 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 144 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 148 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 149 |
+
output_channels (`int`, *optional*, defaults to 256):
|
| 150 |
+
Dimensionality of the output channels in the Patch Encoder.
|
| 151 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 152 |
+
Number of hidden layers in the Transformer encoder.
|
| 153 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 154 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 155 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 156 |
+
Number of channels in the input image.
|
| 157 |
+
image_size (`int`, *optional*, defaults to 1024):
|
| 158 |
+
Expected resolution. Target size of the resized input image.
|
| 159 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 160 |
+
Size of the patches to be extracted from the input image.
|
| 161 |
+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
| 162 |
+
The non-linear activation function (function or string)
|
| 163 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 164 |
+
The epsilon used by the layer normalization layers.
|
| 165 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 166 |
+
The dropout ratio for the attention probabilities.
|
| 167 |
+
initializer_range (`float`, *optional*, defaults to 1e-10):
|
| 168 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 169 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 170 |
+
Whether to add a bias to query, key, value projections.
|
| 171 |
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
| 172 |
+
Ratio of mlp hidden dim to embedding dim.
|
| 173 |
+
use_abs_pos (`bool`, *optional*, defaults to `True`):
|
| 174 |
+
Whether to use absolute position embedding.
|
| 175 |
+
use_rel_pos (`bool`, *optional*, defaults to `True`):
|
| 176 |
+
Whether to use relative position embedding.
|
| 177 |
+
window_size (`int`, *optional*, defaults to 14):
|
| 178 |
+
Window size for relative position.
|
| 179 |
+
global_attn_indexes (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
|
| 180 |
+
The indexes of the global attention layers.
|
| 181 |
+
num_pos_feats (`int`, *optional*, defaults to 128):
|
| 182 |
+
The dimensionality of the position embedding.
|
| 183 |
+
mlp_dim (`int`, *optional*):
|
| 184 |
+
The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
|
| 185 |
+
hidden_size`.
|
| 186 |
+
|
| 187 |
+
Example:
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
>>> from transformers import (
|
| 191 |
+
... SamVisionConfig,
|
| 192 |
+
... SamVisionModel,
|
| 193 |
+
... )
|
| 194 |
+
|
| 195 |
+
>>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
|
| 196 |
+
>>> configuration = SamVisionConfig()
|
| 197 |
+
|
| 198 |
+
>>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
|
| 199 |
+
>>> model = SamVisionModel(configuration)
|
| 200 |
+
|
| 201 |
+
>>> # Accessing the model configuration
|
| 202 |
+
>>> configuration = model.config
|
| 203 |
+
```"""
|
| 204 |
+
|
| 205 |
+
base_config_key = "vision_config"
|
| 206 |
+
model_type = "sam_vision_model"
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
hidden_size=768,
|
| 211 |
+
output_channels=256,
|
| 212 |
+
num_hidden_layers=12,
|
| 213 |
+
num_attention_heads=12,
|
| 214 |
+
num_channels=3,
|
| 215 |
+
image_size=1024,
|
| 216 |
+
patch_size=16,
|
| 217 |
+
hidden_act="gelu",
|
| 218 |
+
layer_norm_eps=1e-06,
|
| 219 |
+
attention_dropout=0.0,
|
| 220 |
+
initializer_range=1e-10,
|
| 221 |
+
qkv_bias=True,
|
| 222 |
+
mlp_ratio=4.0,
|
| 223 |
+
use_abs_pos=True,
|
| 224 |
+
use_rel_pos=True,
|
| 225 |
+
window_size=14,
|
| 226 |
+
global_attn_indexes=[2, 5, 8, 11],
|
| 227 |
+
num_pos_feats=128,
|
| 228 |
+
mlp_dim=None,
|
| 229 |
+
**kwargs,
|
| 230 |
+
):
|
| 231 |
+
super().__init__(**kwargs)
|
| 232 |
+
|
| 233 |
+
self.hidden_size = hidden_size
|
| 234 |
+
self.output_channels = output_channels
|
| 235 |
+
self.num_hidden_layers = num_hidden_layers
|
| 236 |
+
self.num_attention_heads = num_attention_heads
|
| 237 |
+
self.num_channels = num_channels
|
| 238 |
+
self.image_size = image_size
|
| 239 |
+
self.patch_size = patch_size
|
| 240 |
+
self.hidden_act = hidden_act
|
| 241 |
+
self.layer_norm_eps = layer_norm_eps
|
| 242 |
+
self.attention_dropout = attention_dropout
|
| 243 |
+
self.initializer_range = initializer_range
|
| 244 |
+
self.qkv_bias = qkv_bias
|
| 245 |
+
self.mlp_ratio = mlp_ratio
|
| 246 |
+
self.use_abs_pos = use_abs_pos
|
| 247 |
+
self.use_rel_pos = use_rel_pos
|
| 248 |
+
self.window_size = window_size
|
| 249 |
+
self.global_attn_indexes = global_attn_indexes
|
| 250 |
+
self.num_pos_feats = num_pos_feats
|
| 251 |
+
self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class SamConfig(PretrainedConfig):
|
| 255 |
+
r"""
|
| 256 |
+
[`SamConfig`] is the configuration class to store the configuration of a [`SamModel`]. It is used to instantiate a
|
| 257 |
+
SAM model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder
|
| 258 |
+
configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
| 259 |
+
SAM-ViT-H [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
|
| 260 |
+
|
| 261 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 262 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
vision_config (Union[`dict`, `SamVisionConfig`], *optional*):
|
| 266 |
+
Dictionary of configuration options used to initialize [`SamVisionConfig`].
|
| 267 |
+
prompt_encoder_config (Union[`dict`, `SamPromptEncoderConfig`], *optional*):
|
| 268 |
+
Dictionary of configuration options used to initialize [`SamPromptEncoderConfig`].
|
| 269 |
+
mask_decoder_config (Union[`dict`, `SamMaskDecoderConfig`], *optional*):
|
| 270 |
+
Dictionary of configuration options used to initialize [`SamMaskDecoderConfig`].
|
| 271 |
+
|
| 272 |
+
kwargs (*optional*):
|
| 273 |
+
Dictionary of keyword arguments.
|
| 274 |
+
|
| 275 |
+
Example:
|
| 276 |
+
|
| 277 |
+
```python
|
| 278 |
+
>>> from transformers import (
|
| 279 |
+
... SamVisionConfig,
|
| 280 |
+
... SamPromptEncoderConfig,
|
| 281 |
+
... SamMaskDecoderConfig,
|
| 282 |
+
... SamModel,
|
| 283 |
+
... )
|
| 284 |
+
|
| 285 |
+
>>> # Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
|
| 286 |
+
>>> configuration = SamConfig()
|
| 287 |
+
|
| 288 |
+
>>> # Initializing a SamModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
|
| 289 |
+
>>> model = SamModel(configuration)
|
| 290 |
+
|
| 291 |
+
>>> # Accessing the model configuration
|
| 292 |
+
>>> configuration = model.config
|
| 293 |
+
|
| 294 |
+
>>> # We can also initialize a SamConfig from a SamVisionConfig, SamPromptEncoderConfig, and SamMaskDecoderConfig
|
| 295 |
+
|
| 296 |
+
>>> # Initializing SAM vision, SAM Q-Former and language model configurations
|
| 297 |
+
>>> vision_config = SamVisionConfig()
|
| 298 |
+
>>> prompt_encoder_config = SamPromptEncoderConfig()
|
| 299 |
+
>>> mask_decoder_config = SamMaskDecoderConfig()
|
| 300 |
+
|
| 301 |
+
>>> config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
|
| 302 |
+
```"""
|
| 303 |
+
|
| 304 |
+
model_type = "sam"
|
| 305 |
+
sub_configs = {
|
| 306 |
+
"prompt_encoder_config": SamPromptEncoderConfig,
|
| 307 |
+
"mask_decoder_config": SamMaskDecoderConfig,
|
| 308 |
+
"vision_config": SamVisionConfig,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
vision_config=None,
|
| 314 |
+
prompt_encoder_config=None,
|
| 315 |
+
mask_decoder_config=None,
|
| 316 |
+
initializer_range=0.02,
|
| 317 |
+
**kwargs,
|
| 318 |
+
):
|
| 319 |
+
super().__init__(**kwargs)
|
| 320 |
+
vision_config = vision_config if vision_config is not None else {}
|
| 321 |
+
prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
|
| 322 |
+
mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
|
| 323 |
+
|
| 324 |
+
if isinstance(vision_config, SamVisionConfig):
|
| 325 |
+
vision_config = vision_config.to_dict()
|
| 326 |
+
if isinstance(prompt_encoder_config, SamPromptEncoderConfig):
|
| 327 |
+
prompt_encoder_config = prompt_encoder_config.to_dict()
|
| 328 |
+
if isinstance(mask_decoder_config, SamMaskDecoderConfig):
|
| 329 |
+
mask_decoder_config = mask_decoder_config.to_dict()
|
| 330 |
+
|
| 331 |
+
self.vision_config = SamVisionConfig(**vision_config)
|
| 332 |
+
self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
|
| 333 |
+
self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config)
|
| 334 |
+
self.initializer_range = initializer_range
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
__all__ = ["SamConfig", "SamMaskDecoderConfig", "SamPromptEncoderConfig", "SamVisionConfig"]
|
phivenv/Lib/site-packages/transformers/models/sam/image_processing_sam.py
ADDED
|
@@ -0,0 +1,1499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for SAM."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from itertools import product
|
| 20 |
+
from typing import Any, Optional, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 25 |
+
from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
|
| 26 |
+
from ...image_utils import (
|
| 27 |
+
IMAGENET_DEFAULT_MEAN,
|
| 28 |
+
IMAGENET_DEFAULT_STD,
|
| 29 |
+
ChannelDimension,
|
| 30 |
+
ImageInput,
|
| 31 |
+
PILImageResampling,
|
| 32 |
+
get_image_size,
|
| 33 |
+
infer_channel_dimension_format,
|
| 34 |
+
is_scaled_image,
|
| 35 |
+
make_list_of_images,
|
| 36 |
+
to_numpy_array,
|
| 37 |
+
valid_images,
|
| 38 |
+
validate_preprocess_arguments,
|
| 39 |
+
)
|
| 40 |
+
from ...utils import (
|
| 41 |
+
TensorType,
|
| 42 |
+
filter_out_non_signature_kwargs,
|
| 43 |
+
is_tf_available,
|
| 44 |
+
is_torch_available,
|
| 45 |
+
is_torchvision_available,
|
| 46 |
+
logging,
|
| 47 |
+
requires_backends,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if is_torch_available():
|
| 52 |
+
import torch
|
| 53 |
+
import torch.nn.functional as F
|
| 54 |
+
|
| 55 |
+
if is_torchvision_available():
|
| 56 |
+
from torchvision.ops.boxes import batched_nms
|
| 57 |
+
|
| 58 |
+
if is_tf_available():
|
| 59 |
+
import tensorflow as tf
|
| 60 |
+
from tensorflow.experimental import numpy as tnp
|
| 61 |
+
|
| 62 |
+
from ...tf_utils import flatten, shape_list
|
| 63 |
+
|
| 64 |
+
logger = logging.get_logger(__name__)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class SamImageProcessor(BaseImageProcessor):
|
| 68 |
+
r"""
|
| 69 |
+
Constructs a SAM image processor.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 73 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
| 74 |
+
`do_resize` parameter in the `preprocess` method.
|
| 75 |
+
size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
|
| 76 |
+
Size of the output image after resizing. Resizes the longest edge of the image to match
|
| 77 |
+
`size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `size` parameter in the
|
| 78 |
+
`preprocess` method.
|
| 79 |
+
mask_size (`dict`, *optional*, defaults to `{"longest_edge": 256}`):
|
| 80 |
+
Size of the output segmentation map after resizing. Resizes the longest edge of the image to match
|
| 81 |
+
`size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `mask_size` parameter
|
| 82 |
+
in the `preprocess` method.
|
| 83 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 84 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 85 |
+
`preprocess` method.
|
| 86 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 87 |
+
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
| 88 |
+
`do_rescale` parameter in the `preprocess` method.
|
| 89 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 90 |
+
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
| 91 |
+
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
| 92 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 93 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 94 |
+
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
| 95 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
|
| 96 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 97 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
| 98 |
+
overridden by the `image_mean` parameter in the `preprocess` method.
|
| 99 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
| 100 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 101 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 102 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 103 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 104 |
+
Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
|
| 105 |
+
`preprocess` method.
|
| 106 |
+
pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
|
| 107 |
+
Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
|
| 108 |
+
method.
|
| 109 |
+
mask_pad_size (`dict`, *optional*, defaults to `{"height": 256, "width": 256}`):
|
| 110 |
+
Size of the output segmentation map after padding. Can be overridden by the `mask_pad_size` parameter in
|
| 111 |
+
the `preprocess` method.
|
| 112 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 113 |
+
Whether to convert the image to RGB.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
model_input_names = ["pixel_values"]
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
do_resize: bool = True,
|
| 121 |
+
size: Optional[dict[str, int]] = None,
|
| 122 |
+
mask_size: Optional[dict[str, int]] = None,
|
| 123 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 124 |
+
do_rescale: bool = True,
|
| 125 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 126 |
+
do_normalize: bool = True,
|
| 127 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 128 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 129 |
+
do_pad: bool = True,
|
| 130 |
+
pad_size: Optional[int] = None,
|
| 131 |
+
mask_pad_size: Optional[int] = None,
|
| 132 |
+
do_convert_rgb: bool = True,
|
| 133 |
+
**kwargs,
|
| 134 |
+
) -> None:
|
| 135 |
+
super().__init__(**kwargs)
|
| 136 |
+
size = size if size is not None else {"longest_edge": 1024}
|
| 137 |
+
size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
|
| 138 |
+
|
| 139 |
+
pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024}
|
| 140 |
+
pad_size = get_size_dict(pad_size, default_to_square=True)
|
| 141 |
+
|
| 142 |
+
mask_size = mask_size if mask_size is not None else {"longest_edge": 256}
|
| 143 |
+
mask_size = (
|
| 144 |
+
get_size_dict(max_size=mask_size, default_to_square=False)
|
| 145 |
+
if not isinstance(mask_size, dict)
|
| 146 |
+
else mask_size
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256}
|
| 150 |
+
mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
|
| 151 |
+
|
| 152 |
+
self.do_resize = do_resize
|
| 153 |
+
self.size = size
|
| 154 |
+
self.mask_size = mask_size
|
| 155 |
+
self.resample = resample
|
| 156 |
+
self.do_rescale = do_rescale
|
| 157 |
+
self.rescale_factor = rescale_factor
|
| 158 |
+
self.do_normalize = do_normalize
|
| 159 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 160 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 161 |
+
self.do_pad = do_pad
|
| 162 |
+
self.pad_size = pad_size
|
| 163 |
+
self.mask_pad_size = mask_pad_size
|
| 164 |
+
self.do_convert_rgb = do_convert_rgb
|
| 165 |
+
|
| 166 |
+
def pad_image(
|
| 167 |
+
self,
|
| 168 |
+
image: np.ndarray,
|
| 169 |
+
pad_size: dict[str, int],
|
| 170 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 171 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 172 |
+
**kwargs,
|
| 173 |
+
) -> np.ndarray:
|
| 174 |
+
"""
|
| 175 |
+
Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
image (`np.ndarray`):
|
| 179 |
+
Image to pad.
|
| 180 |
+
pad_size (`dict[str, int]`):
|
| 181 |
+
Size of the output image after padding.
|
| 182 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 183 |
+
The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
|
| 184 |
+
`data_format` of the `image` will be used.
|
| 185 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 186 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 187 |
+
"""
|
| 188 |
+
output_height, output_width = pad_size["height"], pad_size["width"]
|
| 189 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 190 |
+
|
| 191 |
+
pad_width = output_width - input_width
|
| 192 |
+
pad_height = output_height - input_height
|
| 193 |
+
|
| 194 |
+
padded_image = pad(
|
| 195 |
+
image,
|
| 196 |
+
((0, pad_height), (0, pad_width)),
|
| 197 |
+
data_format=data_format,
|
| 198 |
+
input_data_format=input_data_format,
|
| 199 |
+
**kwargs,
|
| 200 |
+
)
|
| 201 |
+
return padded_image
|
| 202 |
+
|
| 203 |
+
def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int):
|
| 204 |
+
"""
|
| 205 |
+
Compute the output size given input size and target long side length.
|
| 206 |
+
"""
|
| 207 |
+
oldh, oldw = old_shape
|
| 208 |
+
scale = longest_edge * 1.0 / max(oldh, oldw)
|
| 209 |
+
newh, neww = oldh * scale, oldw * scale
|
| 210 |
+
newh = int(newh + 0.5)
|
| 211 |
+
neww = int(neww + 0.5)
|
| 212 |
+
return (newh, neww)
|
| 213 |
+
|
| 214 |
+
def resize(
|
| 215 |
+
self,
|
| 216 |
+
image: np.ndarray,
|
| 217 |
+
size: dict[str, int],
|
| 218 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 219 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 220 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 221 |
+
**kwargs,
|
| 222 |
+
) -> np.ndarray:
|
| 223 |
+
"""
|
| 224 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
image (`np.ndarray`):
|
| 228 |
+
Image to resize.
|
| 229 |
+
size (`dict[str, int]`):
|
| 230 |
+
Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
|
| 231 |
+
edge of the image will be resized to the specified size, while the other edge will be resized to
|
| 232 |
+
maintain the aspect ratio.
|
| 233 |
+
resample:
|
| 234 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 235 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 236 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 237 |
+
image is used. Can be one of:
|
| 238 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 239 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 240 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 241 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 242 |
+
from the input image. Can be one of:
|
| 243 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 244 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
`np.ndarray`: The resized image.
|
| 248 |
+
"""
|
| 249 |
+
size = get_size_dict(size)
|
| 250 |
+
if "longest_edge" not in size:
|
| 251 |
+
raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
|
| 252 |
+
input_size = get_image_size(image, channel_dim=input_data_format)
|
| 253 |
+
output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
|
| 254 |
+
return resize(
|
| 255 |
+
image,
|
| 256 |
+
size=(output_height, output_width),
|
| 257 |
+
resample=resample,
|
| 258 |
+
data_format=data_format,
|
| 259 |
+
input_data_format=input_data_format,
|
| 260 |
+
**kwargs,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def _preprocess(
|
| 264 |
+
self,
|
| 265 |
+
image: ImageInput,
|
| 266 |
+
do_resize: bool,
|
| 267 |
+
do_rescale: bool,
|
| 268 |
+
do_normalize: bool,
|
| 269 |
+
size: Optional[dict[str, int]] = None,
|
| 270 |
+
resample: PILImageResampling = None,
|
| 271 |
+
rescale_factor: Optional[float] = None,
|
| 272 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 273 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 274 |
+
do_pad: Optional[bool] = None,
|
| 275 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 276 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 277 |
+
):
|
| 278 |
+
if do_resize:
|
| 279 |
+
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 280 |
+
reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
|
| 281 |
+
|
| 282 |
+
if do_rescale:
|
| 283 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 284 |
+
|
| 285 |
+
if do_normalize:
|
| 286 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 287 |
+
|
| 288 |
+
if do_pad:
|
| 289 |
+
image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
|
| 290 |
+
|
| 291 |
+
return image, reshaped_input_size
|
| 292 |
+
|
| 293 |
+
def _preprocess_image(
|
| 294 |
+
self,
|
| 295 |
+
image: ImageInput,
|
| 296 |
+
do_resize: Optional[bool] = None,
|
| 297 |
+
size: Optional[dict[str, int]] = None,
|
| 298 |
+
resample: PILImageResampling = None,
|
| 299 |
+
do_rescale: Optional[bool] = None,
|
| 300 |
+
rescale_factor: Optional[float] = None,
|
| 301 |
+
do_normalize: Optional[bool] = None,
|
| 302 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 303 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 304 |
+
do_pad: Optional[bool] = None,
|
| 305 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 306 |
+
do_convert_rgb: Optional[bool] = None,
|
| 307 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 308 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 309 |
+
) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
|
| 310 |
+
# PIL RGBA images are converted to RGB
|
| 311 |
+
if do_convert_rgb:
|
| 312 |
+
image = convert_to_rgb(image)
|
| 313 |
+
|
| 314 |
+
# All transformations expect numpy arrays.
|
| 315 |
+
image = to_numpy_array(image)
|
| 316 |
+
|
| 317 |
+
if do_rescale and is_scaled_image(image):
|
| 318 |
+
logger.warning_once(
|
| 319 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 320 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if input_data_format is None:
|
| 324 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 325 |
+
|
| 326 |
+
original_size = get_image_size(image, channel_dim=input_data_format)
|
| 327 |
+
|
| 328 |
+
image, reshaped_input_size = self._preprocess(
|
| 329 |
+
image=image,
|
| 330 |
+
do_resize=do_resize,
|
| 331 |
+
size=size,
|
| 332 |
+
resample=resample,
|
| 333 |
+
do_rescale=do_rescale,
|
| 334 |
+
rescale_factor=rescale_factor,
|
| 335 |
+
do_normalize=do_normalize,
|
| 336 |
+
image_mean=image_mean,
|
| 337 |
+
image_std=image_std,
|
| 338 |
+
do_pad=do_pad,
|
| 339 |
+
pad_size=pad_size,
|
| 340 |
+
input_data_format=input_data_format,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if data_format is not None:
|
| 344 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 345 |
+
|
| 346 |
+
return image, original_size, reshaped_input_size
|
| 347 |
+
|
| 348 |
+
def _preprocess_mask(
|
| 349 |
+
self,
|
| 350 |
+
segmentation_map: ImageInput,
|
| 351 |
+
do_resize: Optional[bool] = None,
|
| 352 |
+
mask_size: Optional[dict[str, int]] = None,
|
| 353 |
+
do_pad: Optional[bool] = None,
|
| 354 |
+
mask_pad_size: Optional[dict[str, int]] = None,
|
| 355 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 356 |
+
) -> np.ndarray:
|
| 357 |
+
segmentation_map = to_numpy_array(segmentation_map)
|
| 358 |
+
|
| 359 |
+
# Add channel dimension if missing - needed for certain transformations
|
| 360 |
+
if segmentation_map.ndim == 2:
|
| 361 |
+
added_channel_dim = True
|
| 362 |
+
segmentation_map = segmentation_map[None, ...]
|
| 363 |
+
input_data_format = ChannelDimension.FIRST
|
| 364 |
+
else:
|
| 365 |
+
added_channel_dim = False
|
| 366 |
+
if input_data_format is None:
|
| 367 |
+
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
| 368 |
+
|
| 369 |
+
original_size = get_image_size(segmentation_map, channel_dim=input_data_format)
|
| 370 |
+
|
| 371 |
+
segmentation_map, _ = self._preprocess(
|
| 372 |
+
image=segmentation_map,
|
| 373 |
+
do_resize=do_resize,
|
| 374 |
+
size=mask_size,
|
| 375 |
+
resample=PILImageResampling.NEAREST,
|
| 376 |
+
do_rescale=False,
|
| 377 |
+
do_normalize=False,
|
| 378 |
+
do_pad=do_pad,
|
| 379 |
+
pad_size=mask_pad_size,
|
| 380 |
+
input_data_format=input_data_format,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Remove extra channel dimension if added for processing
|
| 384 |
+
if added_channel_dim:
|
| 385 |
+
segmentation_map = segmentation_map.squeeze(0)
|
| 386 |
+
segmentation_map = segmentation_map.astype(np.int64)
|
| 387 |
+
|
| 388 |
+
return segmentation_map, original_size
|
| 389 |
+
|
| 390 |
+
def __call__(self, images, segmentation_maps=None, **kwargs):
|
| 391 |
+
# Overrides the `__call__` method of the `BaseImageProcessor` class such that the images and segmentation maps can both
|
| 392 |
+
# be passed in as positional arguments.
|
| 393 |
+
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
| 394 |
+
|
| 395 |
+
@filter_out_non_signature_kwargs()
|
| 396 |
+
def preprocess(
|
| 397 |
+
self,
|
| 398 |
+
images: ImageInput,
|
| 399 |
+
segmentation_maps: Optional[ImageInput] = None,
|
| 400 |
+
do_resize: Optional[bool] = None,
|
| 401 |
+
size: Optional[dict[str, int]] = None,
|
| 402 |
+
mask_size: Optional[dict[str, int]] = None,
|
| 403 |
+
resample: Optional["PILImageResampling"] = None,
|
| 404 |
+
do_rescale: Optional[bool] = None,
|
| 405 |
+
rescale_factor: Optional[Union[int, float]] = None,
|
| 406 |
+
do_normalize: Optional[bool] = None,
|
| 407 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 408 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 409 |
+
do_pad: Optional[bool] = None,
|
| 410 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 411 |
+
mask_pad_size: Optional[dict[str, int]] = None,
|
| 412 |
+
do_convert_rgb: Optional[bool] = None,
|
| 413 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 414 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 415 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 416 |
+
):
|
| 417 |
+
"""
|
| 418 |
+
Preprocess an image or batch of images.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
images (`ImageInput`):
|
| 422 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 423 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 424 |
+
segmentation_maps (`ImageInput`, *optional*):
|
| 425 |
+
Segmentation map to preprocess.
|
| 426 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 427 |
+
Whether to resize the image.
|
| 428 |
+
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
| 429 |
+
Controls the size of the image after `resize`. The longest edge of the image is resized to
|
| 430 |
+
`size["longest_edge"]` whilst preserving the aspect ratio.
|
| 431 |
+
mask_size (`dict[str, int]`, *optional*, defaults to `self.mask_size`):
|
| 432 |
+
Controls the size of the segmentation map after `resize`. The longest edge of the image is resized to
|
| 433 |
+
`size["longest_edge"]` whilst preserving the aspect ratio.
|
| 434 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 435 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 436 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 437 |
+
Whether to rescale the image pixel values by rescaling factor.
|
| 438 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
|
| 439 |
+
Rescale factor to apply to the image pixel values.
|
| 440 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 441 |
+
Whether to normalize the image.
|
| 442 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
| 443 |
+
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
| 444 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
| 445 |
+
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
| 446 |
+
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
| 447 |
+
Whether to pad the image.
|
| 448 |
+
pad_size (`dict[str, int]`, *optional*, defaults to `self.pad_size`):
|
| 449 |
+
Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
|
| 450 |
+
`pad_size["width"]` if `do_pad` is set to `True`.
|
| 451 |
+
mask_pad_size (`dict[str, int]`, *optional*, defaults to `self.mask_pad_size`):
|
| 452 |
+
Controls the size of the padding applied to the segmentation map. The image is padded to
|
| 453 |
+
`mask_pad_size["height"]` and `mask_pad_size["width"]` if `do_pad` is set to `True`.
|
| 454 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 455 |
+
Whether to convert the image to RGB.
|
| 456 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 457 |
+
The type of tensors to return. Can be one of:
|
| 458 |
+
- Unset: Return a list of `np.ndarray`.
|
| 459 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 460 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 461 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 462 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 463 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 464 |
+
The channel dimension format for the output image. Can be one of:
|
| 465 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 466 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 467 |
+
- Unset: Use the channel dimension format of the input image.
|
| 468 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 469 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 470 |
+
from the input image. Can be one of:
|
| 471 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 472 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 473 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 474 |
+
"""
|
| 475 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 476 |
+
size = size if size is not None else self.size
|
| 477 |
+
size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
|
| 478 |
+
mask_size = mask_size if mask_size is not None else self.mask_size
|
| 479 |
+
mask_size = (
|
| 480 |
+
get_size_dict(max_size=mask_size, default_to_square=False)
|
| 481 |
+
if not isinstance(mask_size, dict)
|
| 482 |
+
else mask_size
|
| 483 |
+
)
|
| 484 |
+
resample = resample if resample is not None else self.resample
|
| 485 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 486 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 487 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 488 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 489 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 490 |
+
do_pad = do_pad if do_pad is not None else self.do_pad
|
| 491 |
+
pad_size = pad_size if pad_size is not None else self.pad_size
|
| 492 |
+
pad_size = get_size_dict(pad_size, default_to_square=True)
|
| 493 |
+
mask_pad_size = mask_pad_size if mask_pad_size is not None else self.mask_pad_size
|
| 494 |
+
mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
|
| 495 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 496 |
+
|
| 497 |
+
images = make_list_of_images(images)
|
| 498 |
+
|
| 499 |
+
if not valid_images(images):
|
| 500 |
+
raise ValueError(
|
| 501 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 502 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if segmentation_maps is not None:
|
| 506 |
+
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
| 507 |
+
|
| 508 |
+
if not valid_images(segmentation_maps):
|
| 509 |
+
raise ValueError(
|
| 510 |
+
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 511 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 512 |
+
)
|
| 513 |
+
validate_preprocess_arguments(
|
| 514 |
+
do_rescale=do_rescale,
|
| 515 |
+
rescale_factor=rescale_factor,
|
| 516 |
+
do_normalize=do_normalize,
|
| 517 |
+
image_mean=image_mean,
|
| 518 |
+
image_std=image_std,
|
| 519 |
+
do_pad=do_pad,
|
| 520 |
+
size_divisibility=pad_size, # Here _preprocess needs do_pad and pad_size.
|
| 521 |
+
do_resize=do_resize,
|
| 522 |
+
size=size,
|
| 523 |
+
resample=resample,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
images, original_sizes, reshaped_input_sizes = zip(
|
| 527 |
+
*(
|
| 528 |
+
self._preprocess_image(
|
| 529 |
+
image=img,
|
| 530 |
+
do_resize=do_resize,
|
| 531 |
+
size=size,
|
| 532 |
+
resample=resample,
|
| 533 |
+
do_rescale=do_rescale,
|
| 534 |
+
rescale_factor=rescale_factor,
|
| 535 |
+
do_normalize=do_normalize,
|
| 536 |
+
image_mean=image_mean,
|
| 537 |
+
image_std=image_std,
|
| 538 |
+
do_pad=do_pad,
|
| 539 |
+
pad_size=pad_size,
|
| 540 |
+
do_convert_rgb=do_convert_rgb,
|
| 541 |
+
data_format=data_format,
|
| 542 |
+
input_data_format=input_data_format,
|
| 543 |
+
)
|
| 544 |
+
for img in images
|
| 545 |
+
)
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
data = {
|
| 549 |
+
"pixel_values": images,
|
| 550 |
+
"original_sizes": original_sizes,
|
| 551 |
+
"reshaped_input_sizes": reshaped_input_sizes,
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
if segmentation_maps is not None:
|
| 555 |
+
segmentation_maps, original_mask_sizes = zip(
|
| 556 |
+
*(
|
| 557 |
+
self._preprocess_mask(
|
| 558 |
+
segmentation_map=mask,
|
| 559 |
+
do_resize=do_resize,
|
| 560 |
+
mask_size=mask_size,
|
| 561 |
+
do_pad=do_pad,
|
| 562 |
+
mask_pad_size=mask_pad_size,
|
| 563 |
+
input_data_format=input_data_format,
|
| 564 |
+
)
|
| 565 |
+
for mask in segmentation_maps
|
| 566 |
+
)
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# masks should start out the same size as input images
|
| 570 |
+
assert all(
|
| 571 |
+
original_im_size == original_mask_size
|
| 572 |
+
for original_im_size, original_mask_size in zip(original_sizes, original_mask_sizes)
|
| 573 |
+
), "Segmentation maps should be the same size as input images."
|
| 574 |
+
|
| 575 |
+
data["labels"] = segmentation_maps
|
| 576 |
+
|
| 577 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 578 |
+
|
| 579 |
+
def post_process_masks(
|
| 580 |
+
self,
|
| 581 |
+
masks,
|
| 582 |
+
original_sizes,
|
| 583 |
+
reshaped_input_sizes,
|
| 584 |
+
mask_threshold=0.0,
|
| 585 |
+
binarize=True,
|
| 586 |
+
pad_size=None,
|
| 587 |
+
return_tensors="pt",
|
| 588 |
+
):
|
| 589 |
+
"""
|
| 590 |
+
Remove padding and upscale masks to the original image size.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
masks (`Union[list[torch.Tensor], list[np.ndarray], list[tf.Tensor]]`):
|
| 594 |
+
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
| 595 |
+
original_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`):
|
| 596 |
+
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
| 597 |
+
width) format.
|
| 598 |
+
reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`):
|
| 599 |
+
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
| 600 |
+
mask_threshold (`float`, *optional*, defaults to 0.0):
|
| 601 |
+
The threshold to use for binarizing the masks.
|
| 602 |
+
binarize (`bool`, *optional*, defaults to `True`):
|
| 603 |
+
Whether to binarize the masks.
|
| 604 |
+
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
| 605 |
+
The target size the images were padded to before being passed to the model. If None, the target size is
|
| 606 |
+
assumed to be the processor's `pad_size`.
|
| 607 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 608 |
+
If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
|
| 609 |
+
Returns:
|
| 610 |
+
(`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
|
| 611 |
+
(height, width) is given by original_size.
|
| 612 |
+
"""
|
| 613 |
+
if return_tensors == "pt":
|
| 614 |
+
return self._post_process_masks_pt(
|
| 615 |
+
masks=masks,
|
| 616 |
+
original_sizes=original_sizes,
|
| 617 |
+
reshaped_input_sizes=reshaped_input_sizes,
|
| 618 |
+
mask_threshold=mask_threshold,
|
| 619 |
+
binarize=binarize,
|
| 620 |
+
pad_size=pad_size,
|
| 621 |
+
)
|
| 622 |
+
elif return_tensors == "tf":
|
| 623 |
+
return self._post_process_masks_tf(
|
| 624 |
+
masks=masks,
|
| 625 |
+
original_sizes=original_sizes,
|
| 626 |
+
reshaped_input_sizes=reshaped_input_sizes,
|
| 627 |
+
mask_threshold=mask_threshold,
|
| 628 |
+
binarize=binarize,
|
| 629 |
+
pad_size=pad_size,
|
| 630 |
+
)
|
| 631 |
+
else:
|
| 632 |
+
raise ValueError("return_tensors must be either 'pt' or 'tf'")
|
| 633 |
+
|
| 634 |
+
def _post_process_masks_pt(
|
| 635 |
+
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
| 636 |
+
):
|
| 637 |
+
"""
|
| 638 |
+
Remove padding and upscale masks to the original image size.
|
| 639 |
+
|
| 640 |
+
Args:
|
| 641 |
+
masks (`Union[list[torch.Tensor], list[np.ndarray]]`):
|
| 642 |
+
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
| 643 |
+
original_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`):
|
| 644 |
+
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
| 645 |
+
width) format.
|
| 646 |
+
reshaped_input_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`):
|
| 647 |
+
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
| 648 |
+
mask_threshold (`float`, *optional*, defaults to 0.0):
|
| 649 |
+
The threshold to use for binarizing the masks.
|
| 650 |
+
binarize (`bool`, *optional*, defaults to `True`):
|
| 651 |
+
Whether to binarize the masks.
|
| 652 |
+
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
| 653 |
+
The target size the images were padded to before being passed to the model. If None, the target size is
|
| 654 |
+
assumed to be the processor's `pad_size`.
|
| 655 |
+
Returns:
|
| 656 |
+
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
|
| 657 |
+
is given by original_size.
|
| 658 |
+
"""
|
| 659 |
+
requires_backends(self, ["torch"])
|
| 660 |
+
pad_size = self.pad_size if pad_size is None else pad_size
|
| 661 |
+
target_image_size = (pad_size["height"], pad_size["width"])
|
| 662 |
+
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
|
| 663 |
+
original_sizes = original_sizes.tolist()
|
| 664 |
+
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
|
| 665 |
+
reshaped_input_sizes = reshaped_input_sizes.tolist()
|
| 666 |
+
output_masks = []
|
| 667 |
+
for i, original_size in enumerate(original_sizes):
|
| 668 |
+
if isinstance(masks[i], np.ndarray):
|
| 669 |
+
masks[i] = torch.from_numpy(masks[i])
|
| 670 |
+
elif not isinstance(masks[i], torch.Tensor):
|
| 671 |
+
raise TypeError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
|
| 672 |
+
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
| 673 |
+
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
| 674 |
+
interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
|
| 675 |
+
if binarize:
|
| 676 |
+
interpolated_mask = interpolated_mask > mask_threshold
|
| 677 |
+
output_masks.append(interpolated_mask)
|
| 678 |
+
|
| 679 |
+
return output_masks
|
| 680 |
+
|
| 681 |
+
def _post_process_masks_tf(
|
| 682 |
+
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
| 683 |
+
):
|
| 684 |
+
"""
|
| 685 |
+
Remove padding and upscale masks to the original image size.
|
| 686 |
+
|
| 687 |
+
Args:
|
| 688 |
+
masks (`tf.Tensor`):
|
| 689 |
+
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
| 690 |
+
original_sizes (`tf.Tensor`):
|
| 691 |
+
The original size of the images before resizing for input to the model, in (height, width) format.
|
| 692 |
+
reshaped_input_sizes (`tf.Tensor`):
|
| 693 |
+
The size of the image input to the model, in (height, width) format. Used to remove padding.
|
| 694 |
+
mask_threshold (`float`, *optional*, defaults to 0.0):
|
| 695 |
+
The threshold to use for binarizing the masks.
|
| 696 |
+
binarize (`bool`, *optional*, defaults to `True`):
|
| 697 |
+
Whether to binarize the masks.
|
| 698 |
+
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
| 699 |
+
The target size the images were padded to before being passed to the model. If None, the target size is
|
| 700 |
+
assumed to be the processor's `pad_size`.
|
| 701 |
+
Returns:
|
| 702 |
+
(`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
|
| 703 |
+
given by original_size.
|
| 704 |
+
"""
|
| 705 |
+
requires_backends(self, ["tf"])
|
| 706 |
+
pad_size = self.pad_size if pad_size is None else pad_size
|
| 707 |
+
target_image_size = (pad_size["height"], pad_size["width"])
|
| 708 |
+
|
| 709 |
+
output_masks = []
|
| 710 |
+
for i, original_size in enumerate(original_sizes):
|
| 711 |
+
# tf.image expects NHWC, we transpose the NCHW inputs for it
|
| 712 |
+
mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
|
| 713 |
+
interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
|
| 714 |
+
interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
|
| 715 |
+
interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
|
| 716 |
+
if binarize:
|
| 717 |
+
interpolated_mask = interpolated_mask > mask_threshold
|
| 718 |
+
# And then we transpose them back at the end
|
| 719 |
+
output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
|
| 720 |
+
|
| 721 |
+
return output_masks
|
| 722 |
+
|
| 723 |
+
def post_process_for_mask_generation(
|
| 724 |
+
self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
|
| 725 |
+
):
|
| 726 |
+
"""
|
| 727 |
+
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
| 728 |
+
|
| 729 |
+
Args:
|
| 730 |
+
all_masks (`Union[list[torch.Tensor], list[tf.Tensor]]`):
|
| 731 |
+
List of all predicted segmentation masks
|
| 732 |
+
all_scores (`Union[list[torch.Tensor], list[tf.Tensor]]`):
|
| 733 |
+
List of all predicted iou scores
|
| 734 |
+
all_boxes (`Union[list[torch.Tensor], list[tf.Tensor]]`):
|
| 735 |
+
List of all bounding boxes of the predicted masks
|
| 736 |
+
crops_nms_thresh (`float`):
|
| 737 |
+
Threshold for NMS (Non Maximum Suppression) algorithm.
|
| 738 |
+
return_tensors (`str`, *optional*, defaults to `pt`):
|
| 739 |
+
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
| 740 |
+
"""
|
| 741 |
+
if return_tensors == "pt":
|
| 742 |
+
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
| 743 |
+
elif return_tensors == "tf":
|
| 744 |
+
return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
| 745 |
+
|
| 746 |
+
def generate_crop_boxes(
|
| 747 |
+
self,
|
| 748 |
+
image,
|
| 749 |
+
target_size,
|
| 750 |
+
crop_n_layers: int = 0,
|
| 751 |
+
overlap_ratio: float = 512 / 1500,
|
| 752 |
+
points_per_crop: Optional[int] = 32,
|
| 753 |
+
crop_n_points_downscale_factor: Optional[list[int]] = 1,
|
| 754 |
+
device: Optional["torch.device"] = None,
|
| 755 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 756 |
+
return_tensors: str = "pt",
|
| 757 |
+
):
|
| 758 |
+
"""
|
| 759 |
+
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
image (`np.array`):
|
| 763 |
+
Input original image
|
| 764 |
+
target_size (`int`):
|
| 765 |
+
Target size of the resized image
|
| 766 |
+
crop_n_layers (`int`, *optional*, defaults to 0):
|
| 767 |
+
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
|
| 768 |
+
each layer has 2**i_layer number of image crops.
|
| 769 |
+
overlap_ratio (`float`, *optional*, defaults to 512/1500):
|
| 770 |
+
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
| 771 |
+
the image length. Later layers with more crops scale down this overlap.
|
| 772 |
+
points_per_crop (`int`, *optional*, defaults to 32):
|
| 773 |
+
Number of points to sample from each crop.
|
| 774 |
+
crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
|
| 775 |
+
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 776 |
+
device (`torch.device`, *optional*, defaults to None):
|
| 777 |
+
Device to use for the computation. If None, cpu will be used.
|
| 778 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 779 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 780 |
+
return_tensors (`str`, *optional*, defaults to `pt`):
|
| 781 |
+
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
| 782 |
+
"""
|
| 783 |
+
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
|
| 784 |
+
image,
|
| 785 |
+
target_size,
|
| 786 |
+
crop_n_layers,
|
| 787 |
+
overlap_ratio,
|
| 788 |
+
points_per_crop,
|
| 789 |
+
crop_n_points_downscale_factor,
|
| 790 |
+
input_data_format,
|
| 791 |
+
)
|
| 792 |
+
if return_tensors == "pt":
|
| 793 |
+
if device is None:
|
| 794 |
+
device = torch.device("cpu")
|
| 795 |
+
crop_boxes = torch.tensor(crop_boxes, device=device)
|
| 796 |
+
points_per_crop = torch.tensor(points_per_crop, device=device)
|
| 797 |
+
# cropped_images stays as np
|
| 798 |
+
input_labels = torch.tensor(input_labels, device=device)
|
| 799 |
+
|
| 800 |
+
elif return_tensors == "tf":
|
| 801 |
+
if device is not None:
|
| 802 |
+
raise ValueError("device is not a supported argument when return_tensors is tf!")
|
| 803 |
+
crop_boxes = tf.convert_to_tensor(crop_boxes)
|
| 804 |
+
points_per_crop = tf.convert_to_tensor(points_per_crop)
|
| 805 |
+
# cropped_images stays as np
|
| 806 |
+
input_labels = tf.convert_to_tensor(input_labels)
|
| 807 |
+
else:
|
| 808 |
+
raise ValueError("return_tensors must be either 'pt' or 'tf'.")
|
| 809 |
+
return crop_boxes, points_per_crop, cropped_images, input_labels
|
| 810 |
+
|
| 811 |
+
def filter_masks(
|
| 812 |
+
self,
|
| 813 |
+
masks,
|
| 814 |
+
iou_scores,
|
| 815 |
+
original_size,
|
| 816 |
+
cropped_box_image,
|
| 817 |
+
pred_iou_thresh=0.88,
|
| 818 |
+
stability_score_thresh=0.95,
|
| 819 |
+
mask_threshold=0,
|
| 820 |
+
stability_score_offset=1,
|
| 821 |
+
return_tensors="pt",
|
| 822 |
+
):
|
| 823 |
+
"""
|
| 824 |
+
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
| 825 |
+
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
| 826 |
+
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
| 827 |
+
bounding boxes and pad the predicted masks if necessary.
|
| 828 |
+
|
| 829 |
+
Args:
|
| 830 |
+
masks (`Union[torch.Tensor, tf.Tensor]`):
|
| 831 |
+
Input masks.
|
| 832 |
+
iou_scores (`Union[torch.Tensor, tf.Tensor]`):
|
| 833 |
+
List of IoU scores.
|
| 834 |
+
original_size (`tuple[int,int]`):
|
| 835 |
+
Size of the original image.
|
| 836 |
+
cropped_box_image (`np.array`):
|
| 837 |
+
The cropped image.
|
| 838 |
+
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
| 839 |
+
The threshold for the iou scores.
|
| 840 |
+
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
| 841 |
+
The threshold for the stability score.
|
| 842 |
+
mask_threshold (`float`, *optional*, defaults to 0):
|
| 843 |
+
The threshold for the predicted masks.
|
| 844 |
+
stability_score_offset (`float`, *optional*, defaults to 1):
|
| 845 |
+
The offset for the stability score used in the `_compute_stability_score` method.
|
| 846 |
+
return_tensors (`str`, *optional*, defaults to `pt`):
|
| 847 |
+
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
| 848 |
+
"""
|
| 849 |
+
if return_tensors == "pt":
|
| 850 |
+
return self._filter_masks_pt(
|
| 851 |
+
masks=masks,
|
| 852 |
+
iou_scores=iou_scores,
|
| 853 |
+
original_size=original_size,
|
| 854 |
+
cropped_box_image=cropped_box_image,
|
| 855 |
+
pred_iou_thresh=pred_iou_thresh,
|
| 856 |
+
stability_score_thresh=stability_score_thresh,
|
| 857 |
+
mask_threshold=mask_threshold,
|
| 858 |
+
stability_score_offset=stability_score_offset,
|
| 859 |
+
)
|
| 860 |
+
elif return_tensors == "tf":
|
| 861 |
+
return self._filter_masks_tf(
|
| 862 |
+
masks=masks,
|
| 863 |
+
iou_scores=iou_scores,
|
| 864 |
+
original_size=original_size,
|
| 865 |
+
cropped_box_image=cropped_box_image,
|
| 866 |
+
pred_iou_thresh=pred_iou_thresh,
|
| 867 |
+
stability_score_thresh=stability_score_thresh,
|
| 868 |
+
mask_threshold=mask_threshold,
|
| 869 |
+
stability_score_offset=stability_score_offset,
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
def _filter_masks_pt(
|
| 873 |
+
self,
|
| 874 |
+
masks,
|
| 875 |
+
iou_scores,
|
| 876 |
+
original_size,
|
| 877 |
+
cropped_box_image,
|
| 878 |
+
pred_iou_thresh=0.88,
|
| 879 |
+
stability_score_thresh=0.95,
|
| 880 |
+
mask_threshold=0,
|
| 881 |
+
stability_score_offset=1,
|
| 882 |
+
):
|
| 883 |
+
"""
|
| 884 |
+
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
| 885 |
+
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
| 886 |
+
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
| 887 |
+
bounding boxes and pad the predicted masks if necessary.
|
| 888 |
+
|
| 889 |
+
Args:
|
| 890 |
+
masks (`torch.Tensor`):
|
| 891 |
+
Input masks.
|
| 892 |
+
iou_scores (`torch.Tensor`):
|
| 893 |
+
List of IoU scores.
|
| 894 |
+
original_size (`tuple[int,int]`):
|
| 895 |
+
Size of the original image.
|
| 896 |
+
cropped_box_image (`np.array`):
|
| 897 |
+
The cropped image.
|
| 898 |
+
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
| 899 |
+
The threshold for the iou scores.
|
| 900 |
+
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
| 901 |
+
The threshold for the stability score.
|
| 902 |
+
mask_threshold (`float`, *optional*, defaults to 0):
|
| 903 |
+
The threshold for the predicted masks.
|
| 904 |
+
stability_score_offset (`float`, *optional*, defaults to 1):
|
| 905 |
+
The offset for the stability score used in the `_compute_stability_score` method.
|
| 906 |
+
|
| 907 |
+
"""
|
| 908 |
+
requires_backends(self, ["torch"])
|
| 909 |
+
original_height, original_width = original_size
|
| 910 |
+
iou_scores = iou_scores.flatten(0, 1)
|
| 911 |
+
masks = masks.flatten(0, 1)
|
| 912 |
+
|
| 913 |
+
if masks.shape[0] != iou_scores.shape[0]:
|
| 914 |
+
raise ValueError("masks and iou_scores must have the same batch size.")
|
| 915 |
+
|
| 916 |
+
if masks.device != iou_scores.device:
|
| 917 |
+
iou_scores = iou_scores.to(masks.device)
|
| 918 |
+
|
| 919 |
+
batch_size = masks.shape[0]
|
| 920 |
+
|
| 921 |
+
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
|
| 922 |
+
|
| 923 |
+
if pred_iou_thresh > 0.0:
|
| 924 |
+
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
| 925 |
+
|
| 926 |
+
# compute stability score
|
| 927 |
+
if stability_score_thresh > 0.0:
|
| 928 |
+
stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
|
| 929 |
+
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
| 930 |
+
|
| 931 |
+
scores = iou_scores[keep_mask]
|
| 932 |
+
masks = masks[keep_mask]
|
| 933 |
+
|
| 934 |
+
# binarize masks
|
| 935 |
+
masks = masks > mask_threshold
|
| 936 |
+
converted_boxes = _batched_mask_to_box(masks)
|
| 937 |
+
|
| 938 |
+
keep_mask = ~_is_box_near_crop_edge(
|
| 939 |
+
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
scores = scores[keep_mask]
|
| 943 |
+
masks = masks[keep_mask]
|
| 944 |
+
converted_boxes = converted_boxes[keep_mask]
|
| 945 |
+
|
| 946 |
+
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
|
| 947 |
+
# conversion to rle is necessary to run non-maximum suppression
|
| 948 |
+
masks = _mask_to_rle_pytorch(masks)
|
| 949 |
+
|
| 950 |
+
return masks, scores, converted_boxes
|
| 951 |
+
|
| 952 |
+
def _filter_masks_tf(
|
| 953 |
+
self,
|
| 954 |
+
masks,
|
| 955 |
+
iou_scores,
|
| 956 |
+
original_size,
|
| 957 |
+
cropped_box_image,
|
| 958 |
+
pred_iou_thresh=0.88,
|
| 959 |
+
stability_score_thresh=0.95,
|
| 960 |
+
mask_threshold=0,
|
| 961 |
+
stability_score_offset=1,
|
| 962 |
+
):
|
| 963 |
+
"""
|
| 964 |
+
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
| 965 |
+
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
| 966 |
+
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
| 967 |
+
bounding boxes and pad the predicted masks if necessary.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
masks (`tf.Tensor`):
|
| 971 |
+
Input masks.
|
| 972 |
+
iou_scores (`tf.Tensor`):
|
| 973 |
+
List of IoU scores.
|
| 974 |
+
original_size (`tuple[int,int]`):
|
| 975 |
+
Size of the original image.
|
| 976 |
+
cropped_box_image (`np.array`):
|
| 977 |
+
The cropped image.
|
| 978 |
+
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
| 979 |
+
The threshold for the iou scores.
|
| 980 |
+
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
| 981 |
+
The threshold for the stability score.
|
| 982 |
+
mask_threshold (`float`, *optional*, defaults to 0):
|
| 983 |
+
The threshold for the predicted masks.
|
| 984 |
+
stability_score_offset (`float`, *optional*, defaults to 1):
|
| 985 |
+
The offset for the stability score used in the `_compute_stability_score` method.
|
| 986 |
+
|
| 987 |
+
"""
|
| 988 |
+
requires_backends(self, ["tf"])
|
| 989 |
+
original_height, original_width = original_size
|
| 990 |
+
iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
|
| 991 |
+
masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
|
| 992 |
+
|
| 993 |
+
if masks.shape[0] != iou_scores.shape[0]:
|
| 994 |
+
raise ValueError("masks and iou_scores must have the same batch size.")
|
| 995 |
+
|
| 996 |
+
batch_size = masks.shape[0]
|
| 997 |
+
|
| 998 |
+
keep_mask = tf.ones(batch_size, dtype=tf.bool)
|
| 999 |
+
|
| 1000 |
+
if pred_iou_thresh > 0.0:
|
| 1001 |
+
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
| 1002 |
+
|
| 1003 |
+
# compute stability score
|
| 1004 |
+
if stability_score_thresh > 0.0:
|
| 1005 |
+
stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
|
| 1006 |
+
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
| 1007 |
+
|
| 1008 |
+
scores = iou_scores[keep_mask]
|
| 1009 |
+
masks = masks[keep_mask]
|
| 1010 |
+
|
| 1011 |
+
# binarize masks
|
| 1012 |
+
masks = masks > mask_threshold
|
| 1013 |
+
converted_boxes = _batched_mask_to_box_tf(masks)
|
| 1014 |
+
|
| 1015 |
+
keep_mask = ~_is_box_near_crop_edge_tf(
|
| 1016 |
+
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
scores = scores[keep_mask]
|
| 1020 |
+
masks = masks[keep_mask]
|
| 1021 |
+
converted_boxes = converted_boxes[keep_mask]
|
| 1022 |
+
|
| 1023 |
+
masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
|
| 1024 |
+
# conversion to rle is necessary to run non-maximum suppression
|
| 1025 |
+
masks = _mask_to_rle_tf(masks)
|
| 1026 |
+
|
| 1027 |
+
return masks, scores, converted_boxes
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
| 1031 |
+
# One mask is always contained inside the other.
|
| 1032 |
+
# Save memory by preventing unnecessary cast to torch.int64
|
| 1033 |
+
intersections = (
|
| 1034 |
+
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
| 1035 |
+
)
|
| 1036 |
+
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
| 1037 |
+
stability_scores = intersections / unions
|
| 1038 |
+
return stability_scores
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
|
| 1042 |
+
# Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
|
| 1043 |
+
# we get the right division results.
|
| 1044 |
+
intersections = tf.count_nonzero(
|
| 1045 |
+
masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
|
| 1046 |
+
)
|
| 1047 |
+
unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
|
| 1048 |
+
stability_scores = intersections / unions
|
| 1049 |
+
return stability_scores
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def _build_point_grid(n_per_side: int) -> np.ndarray:
|
| 1053 |
+
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
| 1054 |
+
offset = 1 / (2 * n_per_side)
|
| 1055 |
+
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
| 1056 |
+
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
| 1057 |
+
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
| 1058 |
+
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
| 1059 |
+
return points
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
def _normalize_coordinates(
|
| 1063 |
+
target_size: int, coords: np.ndarray, original_size: tuple[int, int], is_bounding_box=False
|
| 1064 |
+
) -> np.ndarray:
|
| 1065 |
+
"""
|
| 1066 |
+
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
|
| 1067 |
+
format.
|
| 1068 |
+
"""
|
| 1069 |
+
old_height, old_width = original_size
|
| 1070 |
+
|
| 1071 |
+
scale = target_size * 1.0 / max(old_height, old_width)
|
| 1072 |
+
new_height, new_width = old_height * scale, old_width * scale
|
| 1073 |
+
new_width = int(new_width + 0.5)
|
| 1074 |
+
new_height = int(new_height + 0.5)
|
| 1075 |
+
|
| 1076 |
+
coords = deepcopy(coords).astype(float)
|
| 1077 |
+
|
| 1078 |
+
if is_bounding_box:
|
| 1079 |
+
coords = coords.reshape(-1, 2, 2)
|
| 1080 |
+
|
| 1081 |
+
coords[..., 0] = coords[..., 0] * (new_width / old_width)
|
| 1082 |
+
coords[..., 1] = coords[..., 1] * (new_height / old_height)
|
| 1083 |
+
|
| 1084 |
+
if is_bounding_box:
|
| 1085 |
+
coords = coords.reshape(-1, 4)
|
| 1086 |
+
|
| 1087 |
+
return coords
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
def _generate_crop_boxes(
|
| 1091 |
+
image,
|
| 1092 |
+
target_size: int, # Is it tuple here?
|
| 1093 |
+
crop_n_layers: int = 0,
|
| 1094 |
+
overlap_ratio: float = 512 / 1500,
|
| 1095 |
+
points_per_crop: Optional[int] = 32,
|
| 1096 |
+
crop_n_points_downscale_factor: Optional[list[int]] = 1,
|
| 1097 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 1098 |
+
) -> tuple[list[list[int]], list[int]]:
|
| 1099 |
+
"""
|
| 1100 |
+
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
| 1101 |
+
|
| 1102 |
+
Args:
|
| 1103 |
+
image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
|
| 1104 |
+
Image to generate crops for.
|
| 1105 |
+
target_size (`int`):
|
| 1106 |
+
Size of the smallest crop.
|
| 1107 |
+
crop_n_layers (`int`, *optional*):
|
| 1108 |
+
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
|
| 1109 |
+
to run, where each layer has 2**i_layer number of image crops.
|
| 1110 |
+
overlap_ratio (`int`, *optional*):
|
| 1111 |
+
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
|
| 1112 |
+
image length. Later layers with more crops scale down this overlap.
|
| 1113 |
+
points_per_crop (`int`, *optional*):
|
| 1114 |
+
Number of points to sample per crop.
|
| 1115 |
+
crop_n_points_downscale_factor (`int`, *optional*):
|
| 1116 |
+
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 1117 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 1118 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 1119 |
+
"""
|
| 1120 |
+
|
| 1121 |
+
if isinstance(image, list):
|
| 1122 |
+
raise TypeError("Only one image is allowed for crop generation.")
|
| 1123 |
+
image = to_numpy_array(image)
|
| 1124 |
+
original_size = get_image_size(image, input_data_format)
|
| 1125 |
+
|
| 1126 |
+
points_grid = []
|
| 1127 |
+
for i in range(crop_n_layers + 1):
|
| 1128 |
+
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
|
| 1129 |
+
points_grid.append(_build_point_grid(n_points))
|
| 1130 |
+
|
| 1131 |
+
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
|
| 1132 |
+
|
| 1133 |
+
cropped_images, point_grid_per_crop = _generate_crop_images(
|
| 1134 |
+
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
|
| 1135 |
+
)
|
| 1136 |
+
crop_boxes = np.array(crop_boxes)
|
| 1137 |
+
crop_boxes = crop_boxes.astype(np.float32)
|
| 1138 |
+
points_per_crop = np.array([point_grid_per_crop])
|
| 1139 |
+
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
|
| 1140 |
+
|
| 1141 |
+
input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
|
| 1142 |
+
|
| 1143 |
+
return crop_boxes, points_per_crop, cropped_images, input_labels
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
|
| 1147 |
+
"""
|
| 1148 |
+
Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
|
| 1149 |
+
consists of the following required indices:
|
| 1150 |
+
- X: X coordinate of the top left of the bounding box
|
| 1151 |
+
- Y: Y coordinate of the top left of the bounding box
|
| 1152 |
+
- W: width of the bounding box
|
| 1153 |
+
- H: height of the bounding box
|
| 1154 |
+
"""
|
| 1155 |
+
crop_boxes, layer_idxs = [], []
|
| 1156 |
+
im_height, im_width = original_size
|
| 1157 |
+
short_side = min(im_height, im_width)
|
| 1158 |
+
|
| 1159 |
+
# Original image
|
| 1160 |
+
crop_boxes.append([0, 0, im_width, im_height])
|
| 1161 |
+
layer_idxs.append(0)
|
| 1162 |
+
for i_layer in range(crop_n_layers):
|
| 1163 |
+
n_crops_per_side = 2 ** (i_layer + 1)
|
| 1164 |
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
| 1165 |
+
|
| 1166 |
+
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
|
| 1167 |
+
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
|
| 1168 |
+
|
| 1169 |
+
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
|
| 1170 |
+
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
|
| 1171 |
+
|
| 1172 |
+
for left, top in product(crop_box_x0, crop_box_y0):
|
| 1173 |
+
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
|
| 1174 |
+
crop_boxes.append(box)
|
| 1175 |
+
layer_idxs.append(i_layer + 1)
|
| 1176 |
+
|
| 1177 |
+
return crop_boxes, layer_idxs
|
| 1178 |
+
|
| 1179 |
+
|
| 1180 |
+
def _generate_crop_images(
|
| 1181 |
+
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
|
| 1182 |
+
):
|
| 1183 |
+
"""
|
| 1184 |
+
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
|
| 1185 |
+
also passed.
|
| 1186 |
+
"""
|
| 1187 |
+
cropped_images = []
|
| 1188 |
+
total_points_per_crop = []
|
| 1189 |
+
for i, crop_box in enumerate(crop_boxes):
|
| 1190 |
+
left, top, right, bottom = crop_box
|
| 1191 |
+
|
| 1192 |
+
channel_dim = infer_channel_dimension_format(image, input_data_format)
|
| 1193 |
+
if channel_dim == ChannelDimension.LAST:
|
| 1194 |
+
cropped_im = image[top:bottom, left:right, :]
|
| 1195 |
+
else:
|
| 1196 |
+
cropped_im = image[:, top:bottom, left:right]
|
| 1197 |
+
|
| 1198 |
+
cropped_images.append(cropped_im)
|
| 1199 |
+
|
| 1200 |
+
cropped_im_size = get_image_size(cropped_im, channel_dim)
|
| 1201 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
| 1202 |
+
|
| 1203 |
+
points = points_grid[layer_idxs[i]] * points_scale
|
| 1204 |
+
normalized_points = _normalize_coordinates(target_size, points, original_size)
|
| 1205 |
+
total_points_per_crop.append(normalized_points)
|
| 1206 |
+
|
| 1207 |
+
return cropped_images, total_points_per_crop
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
|
| 1211 |
+
left, top, right, bottom = crop_box
|
| 1212 |
+
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
| 1213 |
+
return masks
|
| 1214 |
+
# Coordinate transform masks
|
| 1215 |
+
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
| 1216 |
+
pad = (left, pad_x - left, top, pad_y - top)
|
| 1217 |
+
return torch.nn.functional.pad(masks, pad, value=0)
|
| 1218 |
+
|
| 1219 |
+
|
| 1220 |
+
def _pad_masks_tf(masks, crop_box: list[int], orig_height: int, orig_width: int):
|
| 1221 |
+
left, top, right, bottom = crop_box
|
| 1222 |
+
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
| 1223 |
+
return masks
|
| 1224 |
+
# Coordinate transform masks
|
| 1225 |
+
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
| 1226 |
+
pad = (left, pad_x - left, top, pad_y - top)
|
| 1227 |
+
return tf.pad(masks, pad, constant_values=0)
|
| 1228 |
+
|
| 1229 |
+
|
| 1230 |
+
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
| 1231 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
| 1232 |
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
| 1233 |
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
| 1234 |
+
|
| 1235 |
+
left, top, _, _ = crop_box
|
| 1236 |
+
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
|
| 1237 |
+
# Check if boxes has a channel dimension
|
| 1238 |
+
if len(boxes.shape) == 3:
|
| 1239 |
+
offset = offset.unsqueeze(1)
|
| 1240 |
+
boxes = (boxes + offset).float()
|
| 1241 |
+
|
| 1242 |
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
| 1243 |
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
| 1244 |
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
| 1245 |
+
return torch.any(near_crop_edge, dim=1)
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
|
| 1249 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
| 1250 |
+
crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
|
| 1251 |
+
orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
|
| 1252 |
+
|
| 1253 |
+
left, top, _, _ = crop_box
|
| 1254 |
+
offset = tf.convert_to_tensor([[left, top, left, top]])
|
| 1255 |
+
# Check if boxes has a channel dimension
|
| 1256 |
+
if len(boxes.shape) == 3:
|
| 1257 |
+
offset = tf.expand_dims(offset, 1)
|
| 1258 |
+
boxes = tf.cast(boxes + offset, tf.float32)
|
| 1259 |
+
|
| 1260 |
+
near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
|
| 1261 |
+
near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
|
| 1262 |
+
near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
|
| 1263 |
+
return tf.reduce_any(near_crop_edge, axis=1)
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
def _batched_mask_to_box(masks: "torch.Tensor"):
|
| 1267 |
+
"""
|
| 1268 |
+
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
| 1269 |
+
corresponds the following required indices:
|
| 1270 |
+
- LEFT: left hand side of the bounding box
|
| 1271 |
+
- TOP: top of the bounding box
|
| 1272 |
+
- RIGHT: right of the bounding box
|
| 1273 |
+
- BOTTOM: bottom of the bounding box
|
| 1274 |
+
|
| 1275 |
+
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
| 1276 |
+
is channel_1 x channel_2 x ... x 4.
|
| 1277 |
+
|
| 1278 |
+
Args:
|
| 1279 |
+
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
|
| 1280 |
+
"""
|
| 1281 |
+
# torch.max below raises an error on empty inputs, just skip in this case
|
| 1282 |
+
|
| 1283 |
+
if torch.numel(masks) == 0:
|
| 1284 |
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
| 1285 |
+
|
| 1286 |
+
# Normalize shape to Cxheightxwidth
|
| 1287 |
+
shape = masks.shape
|
| 1288 |
+
height, width = shape[-2:]
|
| 1289 |
+
|
| 1290 |
+
# Get top and bottom edges
|
| 1291 |
+
in_height, _ = torch.max(masks, dim=-1)
|
| 1292 |
+
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
|
| 1293 |
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
| 1294 |
+
in_height_coords = in_height_coords + height * (~in_height)
|
| 1295 |
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
| 1296 |
+
|
| 1297 |
+
# Get left and right edges
|
| 1298 |
+
in_width, _ = torch.max(masks, dim=-2)
|
| 1299 |
+
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
|
| 1300 |
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
| 1301 |
+
in_width_coords = in_width_coords + width * (~in_width)
|
| 1302 |
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
| 1303 |
+
|
| 1304 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
| 1305 |
+
# Replace these boxes with [0, 0, 0, 0]
|
| 1306 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
| 1307 |
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
| 1308 |
+
out = out * (~empty_filter).unsqueeze(-1)
|
| 1309 |
+
|
| 1310 |
+
# Return to original shape
|
| 1311 |
+
out = out.reshape(*shape[:-2], 4)
|
| 1312 |
+
return out
|
| 1313 |
+
|
| 1314 |
+
|
| 1315 |
+
def _batched_mask_to_box_tf(masks: "tf.Tensor"):
|
| 1316 |
+
"""
|
| 1317 |
+
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
| 1318 |
+
corresponds the following required indices:
|
| 1319 |
+
- LEFT: left hand side of the bounding box
|
| 1320 |
+
- TOP: top of the bounding box
|
| 1321 |
+
- RIGHT: right of the bounding box
|
| 1322 |
+
- BOTTOM: bottom of the bounding box
|
| 1323 |
+
|
| 1324 |
+
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
| 1325 |
+
is channel_1 x channel_2 x ... x 4.
|
| 1326 |
+
|
| 1327 |
+
Args:
|
| 1328 |
+
- masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
|
| 1329 |
+
"""
|
| 1330 |
+
|
| 1331 |
+
if tf.size(masks) == 0:
|
| 1332 |
+
return tf.zeros([*masks.shape[:-2], 4])
|
| 1333 |
+
|
| 1334 |
+
# Normalize shape to Cxheightxwidth
|
| 1335 |
+
shape = shape_list(masks)
|
| 1336 |
+
height, width = shape[-2:]
|
| 1337 |
+
|
| 1338 |
+
# Get top and bottom edges
|
| 1339 |
+
in_height = tf.reduce_max(masks, axis=-1)
|
| 1340 |
+
in_height_coords = in_height * tf.range(height)[None, :]
|
| 1341 |
+
bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
|
| 1342 |
+
in_height_coords = in_height_coords + height * (~in_height)
|
| 1343 |
+
top_edges = tf.reduce_min(in_height_coords, axis=-1)
|
| 1344 |
+
|
| 1345 |
+
# Get left and right edges
|
| 1346 |
+
in_width, _ = tf.reduce_max(masks, axis=-2)
|
| 1347 |
+
in_width_coords = in_width * tf.range(width)[None, :]
|
| 1348 |
+
right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
|
| 1349 |
+
in_width_coords = in_width_coords + width * (~in_width)
|
| 1350 |
+
left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
|
| 1351 |
+
|
| 1352 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
| 1353 |
+
# Replace these boxes with [0, 0, 0, 0]
|
| 1354 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
| 1355 |
+
out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
|
| 1356 |
+
out = out * tf.expand_dims(~empty_filter, -1)
|
| 1357 |
+
|
| 1358 |
+
# Return to original shape
|
| 1359 |
+
out = tf.reshape(out, *shape[:-2], 4)
|
| 1360 |
+
return out
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
+
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
| 1364 |
+
"""
|
| 1365 |
+
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
| 1366 |
+
"""
|
| 1367 |
+
# Put in fortran order and flatten height and width
|
| 1368 |
+
batch_size, height, width = input_mask.shape
|
| 1369 |
+
input_mask = input_mask.permute(0, 2, 1).flatten(1)
|
| 1370 |
+
|
| 1371 |
+
# Compute change indices
|
| 1372 |
+
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
| 1373 |
+
change_indices = diff.nonzero()
|
| 1374 |
+
|
| 1375 |
+
# Encode run length
|
| 1376 |
+
out = []
|
| 1377 |
+
for i in range(batch_size):
|
| 1378 |
+
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
| 1379 |
+
if len(cur_idxs) == 0:
|
| 1380 |
+
# No changes => either all 0 or all 1
|
| 1381 |
+
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
| 1382 |
+
if input_mask[i, 0] == 0:
|
| 1383 |
+
out.append({"size": [height, width], "counts": [height * width]})
|
| 1384 |
+
else:
|
| 1385 |
+
out.append({"size": [height, width], "counts": [0, height * width]})
|
| 1386 |
+
continue
|
| 1387 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
| 1388 |
+
counts = [] if input_mask[i, 0] == 0 else [0]
|
| 1389 |
+
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
|
| 1390 |
+
out.append({"size": [height, width], "counts": counts})
|
| 1391 |
+
return out
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
def _mask_to_rle_tf(input_mask: "tf.Tensor"):
|
| 1395 |
+
"""
|
| 1396 |
+
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
| 1397 |
+
"""
|
| 1398 |
+
# Put in fortran order and flatten height and width
|
| 1399 |
+
batch_size, height, width = input_mask.shape
|
| 1400 |
+
input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
|
| 1401 |
+
|
| 1402 |
+
# Compute change indices
|
| 1403 |
+
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
| 1404 |
+
change_indices = tf.where(diff)
|
| 1405 |
+
|
| 1406 |
+
# Encode run length
|
| 1407 |
+
out = []
|
| 1408 |
+
for i in range(batch_size):
|
| 1409 |
+
cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
|
| 1410 |
+
if len(cur_idxs) == 0:
|
| 1411 |
+
# No changes => either all 0 or all 1
|
| 1412 |
+
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
| 1413 |
+
if input_mask[i, 0] == 0:
|
| 1414 |
+
out.append({"size": [height, width], "counts": [height * width]})
|
| 1415 |
+
else:
|
| 1416 |
+
out.append({"size": [height, width], "counts": [0, height * width]})
|
| 1417 |
+
continue
|
| 1418 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
| 1419 |
+
counts = [] if input_mask[i, 0] == 0 else [0]
|
| 1420 |
+
counts += (
|
| 1421 |
+
[cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
|
| 1422 |
+
)
|
| 1423 |
+
out.append({"size": [height, width], "counts": counts})
|
| 1424 |
+
return out
|
| 1425 |
+
|
| 1426 |
+
|
| 1427 |
+
def _rle_to_mask(rle: dict[str, Any]) -> np.ndarray:
|
| 1428 |
+
"""Compute a binary mask from an uncompressed RLE."""
|
| 1429 |
+
height, width = rle["size"]
|
| 1430 |
+
mask = np.empty(height * width, dtype=bool)
|
| 1431 |
+
idx = 0
|
| 1432 |
+
parity = False
|
| 1433 |
+
for count in rle["counts"]:
|
| 1434 |
+
mask[idx : idx + count] = parity
|
| 1435 |
+
idx += count
|
| 1436 |
+
parity = not parity
|
| 1437 |
+
mask = mask.reshape(width, height)
|
| 1438 |
+
return mask.transpose() # Reshape to original shape
|
| 1439 |
+
|
| 1440 |
+
|
| 1441 |
+
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
| 1442 |
+
"""
|
| 1443 |
+
Perform NMS (Non Maximum Suppression) on the outputs.
|
| 1444 |
+
|
| 1445 |
+
Args:
|
| 1446 |
+
rle_masks (`torch.Tensor`):
|
| 1447 |
+
binary masks in the RLE format
|
| 1448 |
+
iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
|
| 1449 |
+
iou_scores predicted by the model
|
| 1450 |
+
mask_boxes (`torch.Tensor`):
|
| 1451 |
+
The bounding boxes corresponding to segmentation masks
|
| 1452 |
+
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
| 1453 |
+
NMS threshold.
|
| 1454 |
+
"""
|
| 1455 |
+
keep_by_nms = batched_nms(
|
| 1456 |
+
boxes=mask_boxes.float(),
|
| 1457 |
+
scores=iou_scores,
|
| 1458 |
+
idxs=torch.zeros(mask_boxes.shape[0]),
|
| 1459 |
+
iou_threshold=amg_crops_nms_thresh,
|
| 1460 |
+
)
|
| 1461 |
+
|
| 1462 |
+
iou_scores = iou_scores[keep_by_nms]
|
| 1463 |
+
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
| 1464 |
+
mask_boxes = mask_boxes[keep_by_nms]
|
| 1465 |
+
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
| 1466 |
+
|
| 1467 |
+
return masks, iou_scores, rle_masks, mask_boxes
|
| 1468 |
+
|
| 1469 |
+
|
| 1470 |
+
def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
| 1471 |
+
"""
|
| 1472 |
+
Perform NMS (Non Maximum Suppression) on the outputs.
|
| 1473 |
+
|
| 1474 |
+
Args:
|
| 1475 |
+
rle_masks (`tf.Tensor`):
|
| 1476 |
+
binary masks in the RLE format
|
| 1477 |
+
iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
|
| 1478 |
+
iou_scores predicted by the model
|
| 1479 |
+
mask_boxes (`tf.Tensor`):
|
| 1480 |
+
The bounding boxes corresponding to segmentation masks
|
| 1481 |
+
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
| 1482 |
+
NMS threshold.
|
| 1483 |
+
"""
|
| 1484 |
+
keep_by_nms = tf.image.combined_non_max_suppression(
|
| 1485 |
+
boxes=mask_boxes.float(),
|
| 1486 |
+
scores=iou_scores,
|
| 1487 |
+
idxs=torch.zeros(mask_boxes.shape[0]),
|
| 1488 |
+
iou_threshold=amg_crops_nms_thresh,
|
| 1489 |
+
)
|
| 1490 |
+
|
| 1491 |
+
iou_scores = iou_scores[keep_by_nms]
|
| 1492 |
+
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
| 1493 |
+
mask_boxes = mask_boxes[keep_by_nms]
|
| 1494 |
+
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
| 1495 |
+
|
| 1496 |
+
return masks, iou_scores, rle_masks, mask_boxes
|
| 1497 |
+
|
| 1498 |
+
|
| 1499 |
+
__all__ = ["SamImageProcessor"]
|
phivenv/Lib/site-packages/transformers/models/sam/image_processing_sam_fast.py
ADDED
|
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for SAM."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from itertools import product
|
| 20 |
+
from typing import Any, Optional, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from ...image_processing_utils import BatchFeature, get_size_dict
|
| 26 |
+
from ...image_processing_utils_fast import (
|
| 27 |
+
BaseImageProcessorFast,
|
| 28 |
+
DefaultFastImageProcessorKwargs,
|
| 29 |
+
group_images_by_shape,
|
| 30 |
+
reorder_images,
|
| 31 |
+
)
|
| 32 |
+
from ...image_utils import (
|
| 33 |
+
IMAGENET_DEFAULT_MEAN,
|
| 34 |
+
IMAGENET_DEFAULT_STD,
|
| 35 |
+
ChannelDimension,
|
| 36 |
+
ImageInput,
|
| 37 |
+
PILImageResampling,
|
| 38 |
+
SizeDict,
|
| 39 |
+
pil_torch_interpolation_mapping,
|
| 40 |
+
)
|
| 41 |
+
from ...processing_utils import Unpack
|
| 42 |
+
from ...utils import (
|
| 43 |
+
TensorType,
|
| 44 |
+
auto_docstring,
|
| 45 |
+
is_torch_available,
|
| 46 |
+
is_torchvision_available,
|
| 47 |
+
is_torchvision_v2_available,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if is_torch_available():
|
| 52 |
+
import torch
|
| 53 |
+
from torch.nn import functional as F
|
| 54 |
+
|
| 55 |
+
if is_torchvision_v2_available():
|
| 56 |
+
from torchvision.ops.boxes import batched_nms
|
| 57 |
+
from torchvision.transforms.v2 import functional as F_t
|
| 58 |
+
elif is_torchvision_available():
|
| 59 |
+
from torchvision.ops.boxes import batched_nms
|
| 60 |
+
from torchvision.transforms import functional as F_t
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
| 64 |
+
r"""
|
| 65 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 66 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 67 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 68 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 69 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 70 |
+
provided for preprocessing.
|
| 71 |
+
mask_size (`dict[str, int]`, *optional*):
|
| 72 |
+
The size `{"longest_edge": int}` to resize the segmentation maps to.
|
| 73 |
+
mask_pad_size (`dict[str, int]`, *optional*):
|
| 74 |
+
The size `{"height": int, "width": int}` to pad the segmentation maps to. Must be larger than any segmentation
|
| 75 |
+
map size provided for preprocessing.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
mask_size: Optional[dict[str, int]]
|
| 79 |
+
do_pad: Optional[bool]
|
| 80 |
+
pad_size: Optional[dict[str, int]]
|
| 81 |
+
mask_pad_size: Optional[dict[str, int]]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@auto_docstring
|
| 85 |
+
class SamImageProcessorFast(BaseImageProcessorFast):
|
| 86 |
+
resample = PILImageResampling.BILINEAR
|
| 87 |
+
image_mean = IMAGENET_DEFAULT_MEAN
|
| 88 |
+
image_std = IMAGENET_DEFAULT_STD
|
| 89 |
+
size = {"longest_edge": 1024}
|
| 90 |
+
mask_size = {"longest_edge": 256}
|
| 91 |
+
do_resize = True
|
| 92 |
+
do_rescale = True
|
| 93 |
+
do_normalize = True
|
| 94 |
+
do_convert_rgb = True
|
| 95 |
+
|
| 96 |
+
valid_kwargs = SamFastImageProcessorKwargs
|
| 97 |
+
|
| 98 |
+
do_pad = True
|
| 99 |
+
pad_size = {"height": 1024, "width": 1024}
|
| 100 |
+
mask_pad_size = {"height": 256, "width": 256}
|
| 101 |
+
|
| 102 |
+
def __init__(self, **kwargs: Unpack[SamFastImageProcessorKwargs]):
|
| 103 |
+
super().__init__(**kwargs)
|
| 104 |
+
|
| 105 |
+
def pad_image(self, images: "torch.Tensor", pad_size: SizeDict):
|
| 106 |
+
"""Pad images to the specified size."""
|
| 107 |
+
output_height, output_width = pad_size.height, pad_size.width
|
| 108 |
+
input_height, input_width = images.shape[-2:]
|
| 109 |
+
pad_width = output_width - input_width
|
| 110 |
+
pad_height = output_height - input_height
|
| 111 |
+
padding = (0, 0, pad_width, pad_height)
|
| 112 |
+
return F_t.pad(images, padding)
|
| 113 |
+
|
| 114 |
+
def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int):
|
| 115 |
+
"""
|
| 116 |
+
Compute the output size given input size and target long side length.
|
| 117 |
+
"""
|
| 118 |
+
oldh, oldw = old_shape
|
| 119 |
+
scale = longest_edge * 1.0 / max(oldh, oldw)
|
| 120 |
+
newh, neww = oldh * scale, oldw * scale
|
| 121 |
+
newh = int(newh + 0.5)
|
| 122 |
+
neww = int(neww + 0.5)
|
| 123 |
+
return (newh, neww)
|
| 124 |
+
|
| 125 |
+
def resize(
|
| 126 |
+
self, image: "torch.Tensor", size: SizeDict, interpolation: Optional["F_t.InterpolationMode"], **kwargs
|
| 127 |
+
) -> "torch.Tensor":
|
| 128 |
+
"""
|
| 129 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
image (`np.ndarray`):
|
| 133 |
+
Image to resize.
|
| 134 |
+
size (`dict[str, int]`):
|
| 135 |
+
Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
|
| 136 |
+
edge of the image will be resized to the specified size, while the other edge will be resized to
|
| 137 |
+
maintain the aspect ratio.
|
| 138 |
+
interpolation:
|
| 139 |
+
`F_t.InterpolationMode` filter to use when resizing the image e.g. `F_t.InterpolationMode.BICUBIC`.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
`torch.Tensor`: The resized image.
|
| 143 |
+
"""
|
| 144 |
+
if not size.longest_edge:
|
| 145 |
+
raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
|
| 146 |
+
input_size = image.shape[-2:]
|
| 147 |
+
output_height, output_width = self._get_preprocess_shape(input_size, size.longest_edge)
|
| 148 |
+
return super().resize(
|
| 149 |
+
image, size=SizeDict(height=output_height, width=output_width), interpolation=interpolation, **kwargs
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def _further_process_kwargs(
|
| 153 |
+
self,
|
| 154 |
+
size: Optional[SizeDict] = None,
|
| 155 |
+
pad_size: Optional[SizeDict] = None,
|
| 156 |
+
mask_size: Optional[SizeDict] = None,
|
| 157 |
+
mask_pad_size: Optional[SizeDict] = None,
|
| 158 |
+
default_to_square: Optional[bool] = None,
|
| 159 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 160 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 161 |
+
data_format: Optional[ChannelDimension] = None,
|
| 162 |
+
**kwargs,
|
| 163 |
+
) -> dict:
|
| 164 |
+
"""
|
| 165 |
+
Update kwargs that need further processing before being validated
|
| 166 |
+
Can be overridden by subclasses to customize the processing of kwargs.
|
| 167 |
+
"""
|
| 168 |
+
if kwargs is None:
|
| 169 |
+
kwargs = {}
|
| 170 |
+
if size is not None:
|
| 171 |
+
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
|
| 172 |
+
if pad_size is not None:
|
| 173 |
+
pad_size = SizeDict(**get_size_dict(pad_size, param_name="pad_size"))
|
| 174 |
+
if mask_size is not None:
|
| 175 |
+
mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
|
| 176 |
+
if mask_pad_size is not None:
|
| 177 |
+
mask_pad_size = SizeDict(**get_size_dict(mask_pad_size, param_name="mask_pad_size"))
|
| 178 |
+
if isinstance(image_mean, list):
|
| 179 |
+
image_mean = tuple(image_mean)
|
| 180 |
+
if isinstance(image_std, list):
|
| 181 |
+
image_std = tuple(image_std)
|
| 182 |
+
if data_format is None:
|
| 183 |
+
data_format = ChannelDimension.FIRST
|
| 184 |
+
|
| 185 |
+
kwargs["size"] = size
|
| 186 |
+
kwargs["pad_size"] = pad_size
|
| 187 |
+
kwargs["mask_size"] = mask_size
|
| 188 |
+
kwargs["mask_pad_size"] = mask_pad_size
|
| 189 |
+
kwargs["image_mean"] = image_mean
|
| 190 |
+
kwargs["image_std"] = image_std
|
| 191 |
+
kwargs["data_format"] = data_format
|
| 192 |
+
|
| 193 |
+
# torch resize uses interpolation instead of resample
|
| 194 |
+
# Check if resample is an int before checking if it's an instance of PILImageResampling
|
| 195 |
+
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
|
| 196 |
+
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
|
| 197 |
+
resample = kwargs.pop("resample")
|
| 198 |
+
kwargs["interpolation"] = (
|
| 199 |
+
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return kwargs
|
| 203 |
+
|
| 204 |
+
@auto_docstring
|
| 205 |
+
def preprocess(
|
| 206 |
+
self,
|
| 207 |
+
images: ImageInput,
|
| 208 |
+
segmentation_maps: Optional[ImageInput] = None,
|
| 209 |
+
**kwargs: Unpack[SamFastImageProcessorKwargs],
|
| 210 |
+
) -> BatchFeature:
|
| 211 |
+
r"""
|
| 212 |
+
segmentation_maps (`ImageInput`, *optional*):
|
| 213 |
+
The segmentation maps to preprocess.
|
| 214 |
+
"""
|
| 215 |
+
return super().preprocess(images, segmentation_maps, **kwargs)
|
| 216 |
+
|
| 217 |
+
def _preprocess_image_like_inputs(
|
| 218 |
+
self,
|
| 219 |
+
images: ImageInput,
|
| 220 |
+
segmentation_maps: Optional[ImageInput],
|
| 221 |
+
do_convert_rgb: bool,
|
| 222 |
+
input_data_format: ChannelDimension,
|
| 223 |
+
device: Optional[Union[str, "torch.device"]] = None,
|
| 224 |
+
**kwargs: Unpack[SamFastImageProcessorKwargs],
|
| 225 |
+
) -> BatchFeature:
|
| 226 |
+
"""
|
| 227 |
+
Preprocess image-like inputs.
|
| 228 |
+
"""
|
| 229 |
+
images = self._prepare_image_like_inputs(
|
| 230 |
+
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
| 231 |
+
)
|
| 232 |
+
original_sizes = [image.shape[-2:] for image in images]
|
| 233 |
+
images_kwargs = kwargs.copy()
|
| 234 |
+
pixel_values = self._preprocess(images, **images_kwargs)
|
| 235 |
+
reshaped_input_sizes = [image.shape[-2:] for image in images]
|
| 236 |
+
data = {
|
| 237 |
+
"pixel_values": pixel_values,
|
| 238 |
+
"original_sizes": original_sizes,
|
| 239 |
+
"reshaped_input_sizes": reshaped_input_sizes,
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
if segmentation_maps is not None:
|
| 243 |
+
processed_segmentation_maps = self._prepare_image_like_inputs(
|
| 244 |
+
images=segmentation_maps,
|
| 245 |
+
expected_ndims=2,
|
| 246 |
+
do_convert_rgb=False,
|
| 247 |
+
input_data_format=ChannelDimension.FIRST,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
segmentation_maps_kwargs = kwargs.copy()
|
| 251 |
+
segmentation_maps_kwargs.update(
|
| 252 |
+
{
|
| 253 |
+
"do_normalize": False,
|
| 254 |
+
"do_rescale": False,
|
| 255 |
+
"interpolation": F_t.InterpolationMode.NEAREST_EXACT
|
| 256 |
+
if is_torchvision_v2_available()
|
| 257 |
+
else F_t.InterpolationMode.NEAREST,
|
| 258 |
+
"size": segmentation_maps_kwargs.pop("mask_size"),
|
| 259 |
+
"pad_size": segmentation_maps_kwargs.pop("mask_pad_size"),
|
| 260 |
+
}
|
| 261 |
+
)
|
| 262 |
+
processed_segmentation_maps = self._preprocess(
|
| 263 |
+
images=processed_segmentation_maps, **segmentation_maps_kwargs
|
| 264 |
+
)
|
| 265 |
+
data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
|
| 266 |
+
|
| 267 |
+
return BatchFeature(data=data, tensor_type=kwargs["return_tensors"])
|
| 268 |
+
|
| 269 |
+
def _preprocess(
|
| 270 |
+
self,
|
| 271 |
+
images: list["torch.Tensor"],
|
| 272 |
+
do_resize: bool,
|
| 273 |
+
size: SizeDict,
|
| 274 |
+
interpolation: Optional["F_t.InterpolationMode"],
|
| 275 |
+
do_rescale: bool,
|
| 276 |
+
rescale_factor: float,
|
| 277 |
+
do_normalize: bool,
|
| 278 |
+
image_mean: Optional[Union[float, list[float]]],
|
| 279 |
+
image_std: Optional[Union[float, list[float]]],
|
| 280 |
+
do_pad: bool,
|
| 281 |
+
pad_size: SizeDict,
|
| 282 |
+
disable_grouping: Optional[bool],
|
| 283 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 284 |
+
**kwargs,
|
| 285 |
+
) -> Union["torch.Tensor", list["torch.Tensor"]]:
|
| 286 |
+
# Group images by size for batched resizing
|
| 287 |
+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
| 288 |
+
resized_images_grouped = {}
|
| 289 |
+
for shape, stacked_images in grouped_images.items():
|
| 290 |
+
if do_resize:
|
| 291 |
+
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
| 292 |
+
resized_images_grouped[shape] = stacked_images
|
| 293 |
+
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
| 294 |
+
|
| 295 |
+
# Group images by size for further processing
|
| 296 |
+
# Needed in case do_resize is False, or resize returns images with different sizes
|
| 297 |
+
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
| 298 |
+
processed_images_grouped = {}
|
| 299 |
+
for shape, stacked_images in grouped_images.items():
|
| 300 |
+
# Fused rescale and normalize
|
| 301 |
+
stacked_images = self.rescale_and_normalize(
|
| 302 |
+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
| 303 |
+
)
|
| 304 |
+
if do_pad:
|
| 305 |
+
stacked_images = self.pad_image(stacked_images, pad_size)
|
| 306 |
+
processed_images_grouped[shape] = stacked_images
|
| 307 |
+
|
| 308 |
+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
| 309 |
+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
| 310 |
+
|
| 311 |
+
return processed_images
|
| 312 |
+
|
| 313 |
+
def generate_crop_boxes(
|
| 314 |
+
self,
|
| 315 |
+
image: "torch.Tensor",
|
| 316 |
+
target_size,
|
| 317 |
+
crop_n_layers: int = 0,
|
| 318 |
+
overlap_ratio: float = 512 / 1500,
|
| 319 |
+
points_per_crop: Optional[int] = 32,
|
| 320 |
+
crop_n_points_downscale_factor: Optional[list[int]] = 1,
|
| 321 |
+
device: Optional["torch.device"] = None,
|
| 322 |
+
):
|
| 323 |
+
"""
|
| 324 |
+
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
image (`torch.Tensor`):
|
| 328 |
+
Input original image
|
| 329 |
+
target_size (`int`):
|
| 330 |
+
Target size of the resized image
|
| 331 |
+
crop_n_layers (`int`, *optional*, defaults to 0):
|
| 332 |
+
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
|
| 333 |
+
each layer has 2**i_layer number of image crops.
|
| 334 |
+
overlap_ratio (`float`, *optional*, defaults to 512/1500):
|
| 335 |
+
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
| 336 |
+
the image length. Later layers with more crops scale down this overlap.
|
| 337 |
+
points_per_crop (`int`, *optional*, defaults to 32):
|
| 338 |
+
Number of points to sample from each crop.
|
| 339 |
+
crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
|
| 340 |
+
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 341 |
+
device (`torch.device`, *optional*, defaults to None):
|
| 342 |
+
Device to use for the computation. If None, cpu will be used.
|
| 343 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 344 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 345 |
+
return_tensors (`str`, *optional*, defaults to `pt`):
|
| 346 |
+
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
| 347 |
+
"""
|
| 348 |
+
image = self._process_image(image)
|
| 349 |
+
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
|
| 350 |
+
image,
|
| 351 |
+
target_size,
|
| 352 |
+
crop_n_layers,
|
| 353 |
+
overlap_ratio,
|
| 354 |
+
points_per_crop,
|
| 355 |
+
crop_n_points_downscale_factor,
|
| 356 |
+
)
|
| 357 |
+
if device is None:
|
| 358 |
+
device = torch.device("cpu")
|
| 359 |
+
crop_boxes = crop_boxes.to(device)
|
| 360 |
+
points_per_crop = points_per_crop.to(device)
|
| 361 |
+
# cropped_images stays as torch.Tensor
|
| 362 |
+
input_labels = input_labels.to(device)
|
| 363 |
+
|
| 364 |
+
return crop_boxes, points_per_crop, cropped_images, input_labels
|
| 365 |
+
|
| 366 |
+
def filter_masks(
|
| 367 |
+
self,
|
| 368 |
+
masks,
|
| 369 |
+
iou_scores,
|
| 370 |
+
original_size,
|
| 371 |
+
cropped_box_image,
|
| 372 |
+
pred_iou_thresh=0.88,
|
| 373 |
+
stability_score_thresh=0.95,
|
| 374 |
+
mask_threshold=0,
|
| 375 |
+
stability_score_offset=1,
|
| 376 |
+
):
|
| 377 |
+
"""
|
| 378 |
+
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
| 379 |
+
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
| 380 |
+
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
| 381 |
+
bounding boxes and pad the predicted masks if necessary.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
masks (`torch.Tensor`):
|
| 385 |
+
Input masks.
|
| 386 |
+
iou_scores (`torch.Tensor`):
|
| 387 |
+
List of IoU scores.
|
| 388 |
+
original_size (`tuple[int,int]`):
|
| 389 |
+
Size of the original image.
|
| 390 |
+
cropped_box_image (`torch.Tensor`):
|
| 391 |
+
The cropped image.
|
| 392 |
+
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
| 393 |
+
The threshold for the iou scores.
|
| 394 |
+
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
| 395 |
+
The threshold for the stability score.
|
| 396 |
+
mask_threshold (`float`, *optional*, defaults to 0):
|
| 397 |
+
The threshold for the predicted masks.
|
| 398 |
+
stability_score_offset (`float`, *optional*, defaults to 1):
|
| 399 |
+
The offset for the stability score used in the `_compute_stability_score` method.
|
| 400 |
+
|
| 401 |
+
"""
|
| 402 |
+
original_height, original_width = original_size
|
| 403 |
+
iou_scores = iou_scores.flatten(0, 1)
|
| 404 |
+
masks = masks.flatten(0, 1)
|
| 405 |
+
|
| 406 |
+
if masks.shape[0] != iou_scores.shape[0]:
|
| 407 |
+
raise ValueError("masks and iou_scores must have the same batch size.")
|
| 408 |
+
|
| 409 |
+
if masks.device != iou_scores.device:
|
| 410 |
+
iou_scores = iou_scores.to(masks.device)
|
| 411 |
+
|
| 412 |
+
batch_size = masks.shape[0]
|
| 413 |
+
|
| 414 |
+
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
|
| 415 |
+
|
| 416 |
+
if pred_iou_thresh > 0.0:
|
| 417 |
+
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
| 418 |
+
|
| 419 |
+
# compute stability score
|
| 420 |
+
if stability_score_thresh > 0.0:
|
| 421 |
+
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
|
| 422 |
+
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
| 423 |
+
|
| 424 |
+
scores = iou_scores[keep_mask]
|
| 425 |
+
masks = masks[keep_mask]
|
| 426 |
+
|
| 427 |
+
# binarize masks
|
| 428 |
+
masks = masks > mask_threshold
|
| 429 |
+
converted_boxes = _batched_mask_to_box(masks)
|
| 430 |
+
|
| 431 |
+
keep_mask = ~_is_box_near_crop_edge(
|
| 432 |
+
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
scores = scores[keep_mask]
|
| 436 |
+
masks = masks[keep_mask]
|
| 437 |
+
converted_boxes = converted_boxes[keep_mask]
|
| 438 |
+
|
| 439 |
+
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
|
| 440 |
+
# conversion to rle is necessary to run non-maximum suppression
|
| 441 |
+
masks = _mask_to_rle(masks)
|
| 442 |
+
|
| 443 |
+
return masks, scores, converted_boxes
|
| 444 |
+
|
| 445 |
+
def post_process_masks(
|
| 446 |
+
self,
|
| 447 |
+
masks,
|
| 448 |
+
original_sizes,
|
| 449 |
+
reshaped_input_sizes,
|
| 450 |
+
mask_threshold=0.0,
|
| 451 |
+
binarize=True,
|
| 452 |
+
pad_size=None,
|
| 453 |
+
):
|
| 454 |
+
"""
|
| 455 |
+
Remove padding and upscale masks to the original image size.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
|
| 459 |
+
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
| 460 |
+
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
| 461 |
+
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
| 462 |
+
width) format.
|
| 463 |
+
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
| 464 |
+
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
| 465 |
+
mask_threshold (`float`, *optional*, defaults to 0.0):
|
| 466 |
+
The threshold to use for binarizing the masks.
|
| 467 |
+
binarize (`bool`, *optional*, defaults to `True`):
|
| 468 |
+
Whether to binarize the masks.
|
| 469 |
+
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
| 470 |
+
The target size the images were padded to before being passed to the model. If None, the target size is
|
| 471 |
+
assumed to be the processor's `pad_size`.
|
| 472 |
+
Returns:
|
| 473 |
+
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
|
| 474 |
+
is given by original_size.
|
| 475 |
+
"""
|
| 476 |
+
pad_size = self.size if pad_size is None else pad_size
|
| 477 |
+
target_image_size = (pad_size["height"], pad_size["width"])
|
| 478 |
+
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
|
| 479 |
+
original_sizes = original_sizes.tolist()
|
| 480 |
+
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
|
| 481 |
+
reshaped_input_sizes = reshaped_input_sizes.tolist()
|
| 482 |
+
|
| 483 |
+
output_masks = []
|
| 484 |
+
for i, original_size in enumerate(original_sizes):
|
| 485 |
+
if isinstance(masks[i], np.ndarray):
|
| 486 |
+
masks[i] = torch.from_numpy(masks[i])
|
| 487 |
+
elif not isinstance(masks[i], torch.Tensor):
|
| 488 |
+
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
|
| 489 |
+
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
| 490 |
+
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
| 491 |
+
interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
|
| 492 |
+
if binarize:
|
| 493 |
+
interpolated_mask = interpolated_mask > mask_threshold
|
| 494 |
+
output_masks.append(interpolated_mask)
|
| 495 |
+
|
| 496 |
+
return output_masks
|
| 497 |
+
|
| 498 |
+
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
|
| 499 |
+
"""
|
| 500 |
+
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
all_masks (`torch.Tensor`):
|
| 504 |
+
List of all predicted segmentation masks
|
| 505 |
+
all_scores (`torch.Tensor`):
|
| 506 |
+
List of all predicted iou scores
|
| 507 |
+
all_boxes (`torch.Tensor`):
|
| 508 |
+
List of all bounding boxes of the predicted masks
|
| 509 |
+
crops_nms_thresh (`float`):
|
| 510 |
+
Threshold for NMS (Non Maximum Suppression) algorithm.
|
| 511 |
+
"""
|
| 512 |
+
return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
| 516 |
+
# One mask is always contained inside the other.
|
| 517 |
+
# Save memory by preventing unnecessary cast to torch.int64
|
| 518 |
+
intersections = (
|
| 519 |
+
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
| 520 |
+
)
|
| 521 |
+
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
| 522 |
+
stability_scores = intersections / unions
|
| 523 |
+
return stability_scores
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def _mask_to_rle(input_mask: "torch.Tensor"):
|
| 527 |
+
"""
|
| 528 |
+
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
| 529 |
+
"""
|
| 530 |
+
# Put in fortran order and flatten height and width
|
| 531 |
+
batch_size, height, width = input_mask.shape
|
| 532 |
+
input_mask = input_mask.permute(0, 2, 1).flatten(1)
|
| 533 |
+
|
| 534 |
+
# Compute change indices
|
| 535 |
+
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
| 536 |
+
change_indices = diff.nonzero()
|
| 537 |
+
|
| 538 |
+
# Encode run length
|
| 539 |
+
out = []
|
| 540 |
+
for i in range(batch_size):
|
| 541 |
+
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
| 542 |
+
if len(cur_idxs) == 0:
|
| 543 |
+
# No changes => either all 0 or all 1
|
| 544 |
+
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
| 545 |
+
if input_mask[i, 0] == 0:
|
| 546 |
+
out.append({"size": [height, width], "counts": [height * width]})
|
| 547 |
+
else:
|
| 548 |
+
out.append({"size": [height, width], "counts": [0, height * width]})
|
| 549 |
+
continue
|
| 550 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
| 551 |
+
counts = [] if input_mask[i, 0] == 0 else [0]
|
| 552 |
+
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
|
| 553 |
+
out.append({"size": [height, width], "counts": counts})
|
| 554 |
+
return out
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def _batched_mask_to_box(masks: "torch.Tensor"):
|
| 558 |
+
"""
|
| 559 |
+
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
| 560 |
+
corresponds the following required indices:
|
| 561 |
+
- LEFT: left hand side of the bounding box
|
| 562 |
+
- TOP: top of the bounding box
|
| 563 |
+
- RIGHT: right of the bounding box
|
| 564 |
+
- BOTTOM: bottom of the bounding box
|
| 565 |
+
|
| 566 |
+
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
| 567 |
+
is channel_1 x channel_2 x ... x 4.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
|
| 571 |
+
"""
|
| 572 |
+
# torch.max below raises an error on empty inputs, just skip in this case
|
| 573 |
+
|
| 574 |
+
if torch.numel(masks) == 0:
|
| 575 |
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
| 576 |
+
|
| 577 |
+
# Normalize shape to Cxheightxwidth
|
| 578 |
+
shape = masks.shape
|
| 579 |
+
height, width = shape[-2:]
|
| 580 |
+
|
| 581 |
+
# Get top and bottom edges
|
| 582 |
+
in_height, _ = torch.max(masks, dim=-1)
|
| 583 |
+
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
|
| 584 |
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
| 585 |
+
in_height_coords = in_height_coords + height * (~in_height)
|
| 586 |
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
| 587 |
+
|
| 588 |
+
# Get left and right edges
|
| 589 |
+
in_width, _ = torch.max(masks, dim=-2)
|
| 590 |
+
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
|
| 591 |
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
| 592 |
+
in_width_coords = in_width_coords + width * (~in_width)
|
| 593 |
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
| 594 |
+
|
| 595 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
| 596 |
+
# Replace these boxes with [0, 0, 0, 0]
|
| 597 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
| 598 |
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
| 599 |
+
out = out * (~empty_filter).unsqueeze(-1)
|
| 600 |
+
|
| 601 |
+
# Return to original shape
|
| 602 |
+
out = out.reshape(*shape[:-2], 4)
|
| 603 |
+
return out
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
| 607 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
| 608 |
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
| 609 |
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
| 610 |
+
|
| 611 |
+
left, top, _, _ = crop_box
|
| 612 |
+
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
|
| 613 |
+
# Check if boxes has a channel dimension
|
| 614 |
+
if len(boxes.shape) == 3:
|
| 615 |
+
offset = offset.unsqueeze(1)
|
| 616 |
+
boxes = (boxes + offset).float()
|
| 617 |
+
|
| 618 |
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
| 619 |
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
| 620 |
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
| 621 |
+
return torch.any(near_crop_edge, dim=1)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
|
| 625 |
+
left, top, right, bottom = crop_box
|
| 626 |
+
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
| 627 |
+
return masks
|
| 628 |
+
# Coordinate transform masks
|
| 629 |
+
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
| 630 |
+
pad = (left, pad_x - left, top, pad_y - top)
|
| 631 |
+
return torch.nn.functional.pad(masks, pad, value=0)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def _generate_crop_boxes(
|
| 635 |
+
image,
|
| 636 |
+
target_size: int, # Is it tuple here?
|
| 637 |
+
crop_n_layers: int = 0,
|
| 638 |
+
overlap_ratio: float = 512 / 1500,
|
| 639 |
+
points_per_crop: Optional[int] = 32,
|
| 640 |
+
crop_n_points_downscale_factor: Optional[list[int]] = 1,
|
| 641 |
+
) -> tuple[list[list[int]], list[int]]:
|
| 642 |
+
"""
|
| 643 |
+
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
| 644 |
+
|
| 645 |
+
Args:
|
| 646 |
+
image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
|
| 647 |
+
Image to generate crops for.
|
| 648 |
+
target_size (`int`):
|
| 649 |
+
Size of the smallest crop.
|
| 650 |
+
crop_n_layers (`int`, *optional*):
|
| 651 |
+
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
|
| 652 |
+
to run, where each layer has 2**i_layer number of image crops.
|
| 653 |
+
overlap_ratio (`int`, *optional*):
|
| 654 |
+
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
|
| 655 |
+
image length. Later layers with more crops scale down this overlap.
|
| 656 |
+
points_per_crop (`int`, *optional*):
|
| 657 |
+
Number of points to sample per crop.
|
| 658 |
+
crop_n_points_downscale_factor (`int`, *optional*):
|
| 659 |
+
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 660 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 661 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 662 |
+
"""
|
| 663 |
+
|
| 664 |
+
if isinstance(image, list):
|
| 665 |
+
raise ValueError("Only one image is allowed for crop generation.")
|
| 666 |
+
original_size = image.shape[-2:]
|
| 667 |
+
|
| 668 |
+
points_grid = []
|
| 669 |
+
for i in range(crop_n_layers + 1):
|
| 670 |
+
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
|
| 671 |
+
points_grid.append(_build_point_grid(n_points))
|
| 672 |
+
|
| 673 |
+
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
|
| 674 |
+
|
| 675 |
+
cropped_images, point_grid_per_crop = _generate_crop_images(
|
| 676 |
+
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
| 677 |
+
)
|
| 678 |
+
crop_boxes = torch.tensor(crop_boxes)
|
| 679 |
+
crop_boxes = crop_boxes.float()
|
| 680 |
+
points_per_crop = torch.stack(point_grid_per_crop)
|
| 681 |
+
points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3)
|
| 682 |
+
cropped_images = torch.stack(cropped_images)
|
| 683 |
+
|
| 684 |
+
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64)
|
| 685 |
+
|
| 686 |
+
return crop_boxes, points_per_crop, cropped_images, input_labels
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
|
| 690 |
+
"""
|
| 691 |
+
Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
|
| 692 |
+
consists of the following required indices:
|
| 693 |
+
- X: X coordinate of the top left of the bounding box
|
| 694 |
+
- Y: Y coordinate of the top left of the bounding box
|
| 695 |
+
- W: width of the bounding box
|
| 696 |
+
- H: height of the bounding box
|
| 697 |
+
"""
|
| 698 |
+
crop_boxes, layer_idxs = [], []
|
| 699 |
+
im_height, im_width = original_size
|
| 700 |
+
short_side = min(im_height, im_width)
|
| 701 |
+
|
| 702 |
+
# Original image
|
| 703 |
+
crop_boxes.append([0, 0, im_width, im_height])
|
| 704 |
+
layer_idxs.append(0)
|
| 705 |
+
for i_layer in range(crop_n_layers):
|
| 706 |
+
n_crops_per_side = 2 ** (i_layer + 1)
|
| 707 |
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
| 708 |
+
|
| 709 |
+
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
|
| 710 |
+
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
|
| 711 |
+
|
| 712 |
+
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
|
| 713 |
+
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
|
| 714 |
+
|
| 715 |
+
for left, top in product(crop_box_x0, crop_box_y0):
|
| 716 |
+
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
|
| 717 |
+
crop_boxes.append(box)
|
| 718 |
+
layer_idxs.append(i_layer + 1)
|
| 719 |
+
|
| 720 |
+
return crop_boxes, layer_idxs
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def _build_point_grid(n_per_side: int) -> torch.Tensor:
|
| 724 |
+
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
| 725 |
+
offset = 1 / (2 * n_per_side)
|
| 726 |
+
points_one_side = torch.linspace(offset, 1 - offset, n_per_side)
|
| 727 |
+
points_x = torch.tile(points_one_side[None, :], (n_per_side, 1))
|
| 728 |
+
points_y = torch.tile(points_one_side[:, None], (1, n_per_side))
|
| 729 |
+
points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2)
|
| 730 |
+
return points
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
def _generate_crop_images(
|
| 734 |
+
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
|
| 735 |
+
):
|
| 736 |
+
"""
|
| 737 |
+
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
|
| 738 |
+
also passed.
|
| 739 |
+
"""
|
| 740 |
+
cropped_images = []
|
| 741 |
+
total_points_per_crop = []
|
| 742 |
+
for i, crop_box in enumerate(crop_boxes):
|
| 743 |
+
left, top, right, bottom = crop_box
|
| 744 |
+
cropped_im = image[:, top:bottom, left:right]
|
| 745 |
+
|
| 746 |
+
cropped_images.append(cropped_im)
|
| 747 |
+
|
| 748 |
+
cropped_im_size = cropped_im.shape[-2:]
|
| 749 |
+
points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0)
|
| 750 |
+
|
| 751 |
+
points = points_grid[layer_idxs[i]] * points_scale
|
| 752 |
+
normalized_points = _normalize_coordinates(target_size, points, original_size)
|
| 753 |
+
total_points_per_crop.append(normalized_points)
|
| 754 |
+
|
| 755 |
+
return cropped_images, total_points_per_crop
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def _normalize_coordinates(
|
| 759 |
+
target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False
|
| 760 |
+
) -> torch.Tensor:
|
| 761 |
+
"""
|
| 762 |
+
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
|
| 763 |
+
format.
|
| 764 |
+
"""
|
| 765 |
+
old_height, old_width = original_size
|
| 766 |
+
|
| 767 |
+
scale = target_size * 1.0 / max(old_height, old_width)
|
| 768 |
+
new_height, new_width = old_height * scale, old_width * scale
|
| 769 |
+
new_width = int(new_width + 0.5)
|
| 770 |
+
new_height = int(new_height + 0.5)
|
| 771 |
+
|
| 772 |
+
coords = deepcopy(coords).float()
|
| 773 |
+
|
| 774 |
+
if is_bounding_box:
|
| 775 |
+
coords = coords.reshape(-1, 2, 2)
|
| 776 |
+
|
| 777 |
+
coords[..., 0] = coords[..., 0] * (new_width / old_width)
|
| 778 |
+
coords[..., 1] = coords[..., 1] * (new_height / old_height)
|
| 779 |
+
|
| 780 |
+
if is_bounding_box:
|
| 781 |
+
coords = coords.reshape(-1, 4)
|
| 782 |
+
|
| 783 |
+
return coords
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor:
|
| 787 |
+
"""Compute a binary mask from an uncompressed RLE."""
|
| 788 |
+
height, width = rle["size"]
|
| 789 |
+
mask = torch.empty(height * width, dtype=bool)
|
| 790 |
+
idx = 0
|
| 791 |
+
parity = False
|
| 792 |
+
for count in rle["counts"]:
|
| 793 |
+
mask[idx : idx + count] = parity
|
| 794 |
+
idx += count
|
| 795 |
+
parity = not parity
|
| 796 |
+
mask = mask.reshape(width, height)
|
| 797 |
+
return mask.transpose(0, 1) # Reshape to original shape
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
| 801 |
+
"""
|
| 802 |
+
Perform NMS (Non Maximum Suppression) on the outputs.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
rle_masks (`torch.Tensor`):
|
| 806 |
+
binary masks in the RLE format
|
| 807 |
+
iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
|
| 808 |
+
iou_scores predicted by the model
|
| 809 |
+
mask_boxes (`torch.Tensor`):
|
| 810 |
+
The bounding boxes corresponding to segmentation masks
|
| 811 |
+
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
| 812 |
+
NMS threshold.
|
| 813 |
+
"""
|
| 814 |
+
keep_by_nms = batched_nms(
|
| 815 |
+
boxes=mask_boxes.float(),
|
| 816 |
+
scores=iou_scores,
|
| 817 |
+
idxs=torch.zeros(mask_boxes.shape[0]),
|
| 818 |
+
iou_threshold=amg_crops_nms_thresh,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
iou_scores = iou_scores[keep_by_nms]
|
| 822 |
+
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
| 823 |
+
mask_boxes = mask_boxes[keep_by_nms]
|
| 824 |
+
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
| 825 |
+
|
| 826 |
+
return masks, iou_scores, rle_masks, mask_boxes
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
__all__ = ["SamImageProcessorFast"]
|
phivenv/Lib/site-packages/transformers/models/sam/modeling_sam.py
ADDED
|
@@ -0,0 +1,1368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch SAM model."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Callable, Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from torch import Tensor, nn
|
| 25 |
+
|
| 26 |
+
from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
|
| 27 |
+
|
| 28 |
+
from ...activations import ACT2FN
|
| 29 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 30 |
+
from ...modeling_outputs import BaseModelOutput
|
| 31 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 32 |
+
from ...processing_utils import Unpack
|
| 33 |
+
from ...utils import (
|
| 34 |
+
ModelOutput,
|
| 35 |
+
auto_docstring,
|
| 36 |
+
logging,
|
| 37 |
+
)
|
| 38 |
+
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
@auto_docstring(
|
| 46 |
+
custom_intro="""
|
| 47 |
+
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
|
| 48 |
+
layer to the pooler_output.
|
| 49 |
+
"""
|
| 50 |
+
)
|
| 51 |
+
class SamVisionEncoderOutput(ModelOutput):
|
| 52 |
+
r"""
|
| 53 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 54 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
| 58 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 59 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 60 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
@auto_docstring(
|
| 65 |
+
custom_intro="""
|
| 66 |
+
Base class for Segment-Anything model's output
|
| 67 |
+
"""
|
| 68 |
+
)
|
| 69 |
+
class SamImageSegmentationOutput(ModelOutput):
|
| 70 |
+
r"""
|
| 71 |
+
iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
|
| 72 |
+
The iou scores of the predicted masks.
|
| 73 |
+
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
|
| 74 |
+
The predicted low resolutions masks. Needs to be post-processed by the processor
|
| 75 |
+
vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 76 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 77 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 78 |
+
|
| 79 |
+
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
|
| 80 |
+
vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 81 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 82 |
+
sequence_length)`.
|
| 83 |
+
|
| 84 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 85 |
+
heads.
|
| 86 |
+
mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 87 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 88 |
+
sequence_length)`.
|
| 89 |
+
|
| 90 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 91 |
+
heads.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
iou_scores: Optional[torch.FloatTensor] = None
|
| 95 |
+
pred_masks: Optional[torch.FloatTensor] = None
|
| 96 |
+
vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 97 |
+
vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 98 |
+
mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class SamPatchEmbeddings(nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 104 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 105 |
+
Transformer.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, config):
|
| 109 |
+
super().__init__()
|
| 110 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 111 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 112 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 113 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 114 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 115 |
+
self.image_size = image_size
|
| 116 |
+
self.patch_size = patch_size
|
| 117 |
+
self.num_channels = num_channels
|
| 118 |
+
self.num_patches = num_patches
|
| 119 |
+
|
| 120 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 121 |
+
|
| 122 |
+
def forward(self, pixel_values):
|
| 123 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 124 |
+
if num_channels != self.num_channels:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 127 |
+
)
|
| 128 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
| 131 |
+
)
|
| 132 |
+
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
|
| 133 |
+
return embeddings
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class SamMLPBlock(nn.Module):
|
| 137 |
+
def __init__(self, config):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
|
| 140 |
+
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
|
| 141 |
+
self.act = ACT2FN[config.hidden_act]
|
| 142 |
+
|
| 143 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 144 |
+
hidden_states = self.lin1(hidden_states)
|
| 145 |
+
hidden_states = self.act(hidden_states)
|
| 146 |
+
hidden_states = self.lin2(hidden_states)
|
| 147 |
+
return hidden_states
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
|
| 151 |
+
class SamLayerNorm(nn.LayerNorm):
|
| 152 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 153 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
| 154 |
+
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
|
| 158 |
+
super().__init__(normalized_shape, eps=eps, **kwargs)
|
| 159 |
+
if data_format not in ["channels_last", "channels_first"]:
|
| 160 |
+
raise NotImplementedError(f"Unsupported data format: {data_format}")
|
| 161 |
+
self.data_format = data_format
|
| 162 |
+
|
| 163 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
"""
|
| 165 |
+
Args:
|
| 166 |
+
features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
|
| 167 |
+
"""
|
| 168 |
+
if self.data_format == "channels_first":
|
| 169 |
+
features = features.permute(0, 2, 3, 1)
|
| 170 |
+
features = super().forward(features)
|
| 171 |
+
features = features.permute(0, 3, 1, 2)
|
| 172 |
+
else:
|
| 173 |
+
features = super().forward(features)
|
| 174 |
+
return features
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def eager_attention_forward(
|
| 178 |
+
module: nn.Module,
|
| 179 |
+
query: torch.Tensor,
|
| 180 |
+
key: torch.Tensor,
|
| 181 |
+
value: torch.Tensor,
|
| 182 |
+
attention_mask: Optional[torch.Tensor],
|
| 183 |
+
scaling: float,
|
| 184 |
+
dropout: float = 0.0,
|
| 185 |
+
**kwargs,
|
| 186 |
+
):
|
| 187 |
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
| 188 |
+
if attention_mask is not None:
|
| 189 |
+
attn_weights = attn_weights + attention_mask
|
| 190 |
+
|
| 191 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 192 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 193 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 194 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 195 |
+
|
| 196 |
+
return attn_output, attn_weights
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class SamAttention(nn.Module):
|
| 200 |
+
"""
|
| 201 |
+
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
| 202 |
+
values.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self, config, downsample_rate=None):
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.config = config
|
| 208 |
+
self.hidden_size = config.hidden_size
|
| 209 |
+
|
| 210 |
+
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
|
| 211 |
+
|
| 212 |
+
self.internal_dim = config.hidden_size // downsample_rate
|
| 213 |
+
self.num_attention_heads = config.num_attention_heads
|
| 214 |
+
if self.internal_dim % config.num_attention_heads != 0:
|
| 215 |
+
raise ValueError("num_attention_heads must divide hidden_size.")
|
| 216 |
+
self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
|
| 217 |
+
|
| 218 |
+
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
| 219 |
+
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
| 220 |
+
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
| 221 |
+
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
| 222 |
+
|
| 223 |
+
self.is_causal = False
|
| 224 |
+
|
| 225 |
+
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
| 226 |
+
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
| 227 |
+
c_per_head = channel // num_attention_heads
|
| 228 |
+
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
|
| 229 |
+
return hidden_states.transpose(1, 2)
|
| 230 |
+
|
| 231 |
+
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
| 232 |
+
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
|
| 233 |
+
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
| 234 |
+
|
| 235 |
+
def forward(
|
| 236 |
+
self,
|
| 237 |
+
query: Tensor,
|
| 238 |
+
key: Tensor,
|
| 239 |
+
value: Tensor,
|
| 240 |
+
attention_similarity: Optional[Tensor] = None,
|
| 241 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 242 |
+
) -> Tensor:
|
| 243 |
+
# Input projections
|
| 244 |
+
query = self.q_proj(query)
|
| 245 |
+
key = self.k_proj(key)
|
| 246 |
+
value = self.v_proj(value)
|
| 247 |
+
|
| 248 |
+
point_batch_size = query.shape[1]
|
| 249 |
+
# Separate into heads
|
| 250 |
+
query = self._separate_heads(query, self.num_attention_heads)
|
| 251 |
+
key = self._separate_heads(key, self.num_attention_heads)
|
| 252 |
+
value = self._separate_heads(value, self.num_attention_heads)
|
| 253 |
+
|
| 254 |
+
# SamAttention
|
| 255 |
+
attention_interface: Callable = eager_attention_forward
|
| 256 |
+
if self.config._attn_implementation != "eager":
|
| 257 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 258 |
+
|
| 259 |
+
attn_output, attn_weights = attention_interface(
|
| 260 |
+
self,
|
| 261 |
+
query,
|
| 262 |
+
key,
|
| 263 |
+
value,
|
| 264 |
+
attention_mask=attention_similarity,
|
| 265 |
+
dropout=0.0,
|
| 266 |
+
scaling=self.scaling,
|
| 267 |
+
is_causal=self.is_causal,
|
| 268 |
+
**kwargs,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
attn_output = self._recombine_heads(attn_output, point_batch_size)
|
| 272 |
+
attn_output = self.out_proj(attn_output)
|
| 273 |
+
|
| 274 |
+
return attn_output, attn_weights
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class SamTwoWayAttentionBlock(nn.Module):
|
| 278 |
+
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
|
| 279 |
+
"""
|
| 280 |
+
A transformer block with four layers:
|
| 281 |
+
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
|
| 282 |
+
sparse inputs (4) cross attention of dense inputs -> sparse inputs
|
| 283 |
+
|
| 284 |
+
Arguments:
|
| 285 |
+
config (`SamMaskDecoderConfig`):
|
| 286 |
+
The configuration file used to instantiate the block
|
| 287 |
+
attention_downsample_rate (*optionalk*, int, defaults to 2):
|
| 288 |
+
The downsample ratio of the block used to reduce the inner dim of the attention.
|
| 289 |
+
skip_first_layer_pe (*optional*, bool, defaults to `False`):
|
| 290 |
+
Whether or not to skip the addition of the query_point_embedding on the first layer.
|
| 291 |
+
"""
|
| 292 |
+
super().__init__()
|
| 293 |
+
|
| 294 |
+
self.hidden_size = config.hidden_size
|
| 295 |
+
self.layer_norm_eps = config.layer_norm_eps
|
| 296 |
+
|
| 297 |
+
self.self_attn = SamAttention(config, downsample_rate=1)
|
| 298 |
+
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
| 299 |
+
|
| 300 |
+
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
|
| 301 |
+
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
| 302 |
+
|
| 303 |
+
self.mlp = SamMLPBlock(config)
|
| 304 |
+
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
| 305 |
+
|
| 306 |
+
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
| 307 |
+
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
|
| 308 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 309 |
+
|
| 310 |
+
def forward(
|
| 311 |
+
self,
|
| 312 |
+
queries: Tensor,
|
| 313 |
+
keys: Tensor,
|
| 314 |
+
query_point_embedding: Tensor,
|
| 315 |
+
key_point_embedding: Tensor,
|
| 316 |
+
attention_similarity: Tensor,
|
| 317 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 318 |
+
):
|
| 319 |
+
# Self attention block
|
| 320 |
+
if self.skip_first_layer_pe:
|
| 321 |
+
queries, _ = self.self_attn(query=queries, key=queries, value=queries)
|
| 322 |
+
else:
|
| 323 |
+
query = queries + query_point_embedding
|
| 324 |
+
attn_out, _ = self.self_attn(query=query, key=query, value=queries)
|
| 325 |
+
queries = queries + attn_out
|
| 326 |
+
queries = self.layer_norm1(queries)
|
| 327 |
+
|
| 328 |
+
# Cross attention block, tokens attending to image embedding
|
| 329 |
+
query = queries + query_point_embedding
|
| 330 |
+
key = keys + key_point_embedding
|
| 331 |
+
|
| 332 |
+
attn_out, _ = self.cross_attn_token_to_image(
|
| 333 |
+
query=query, key=key, value=keys, attention_similarity=attention_similarity
|
| 334 |
+
)
|
| 335 |
+
queries = queries + attn_out
|
| 336 |
+
|
| 337 |
+
queries = self.layer_norm2(queries)
|
| 338 |
+
|
| 339 |
+
# MLP block
|
| 340 |
+
mlp_out = self.mlp(queries)
|
| 341 |
+
queries = queries + mlp_out
|
| 342 |
+
queries = self.layer_norm3(queries)
|
| 343 |
+
|
| 344 |
+
# Cross attention block, image embedding attending to tokens
|
| 345 |
+
query = queries + query_point_embedding
|
| 346 |
+
key = keys + key_point_embedding
|
| 347 |
+
|
| 348 |
+
attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
|
| 349 |
+
keys = keys + attn_out
|
| 350 |
+
|
| 351 |
+
keys = self.layer_norm4(keys)
|
| 352 |
+
return queries, keys, attn_out
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class SamTwoWayTransformer(nn.Module):
|
| 356 |
+
def __init__(self, config: SamMaskDecoderConfig):
|
| 357 |
+
super().__init__()
|
| 358 |
+
self.config = config
|
| 359 |
+
|
| 360 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 361 |
+
self.layers = nn.ModuleList()
|
| 362 |
+
|
| 363 |
+
for i in range(self.num_hidden_layers):
|
| 364 |
+
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
|
| 365 |
+
|
| 366 |
+
self.final_attn_token_to_image = SamAttention(config)
|
| 367 |
+
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
|
| 368 |
+
|
| 369 |
+
def forward(
|
| 370 |
+
self,
|
| 371 |
+
point_embeddings: Tensor,
|
| 372 |
+
image_embeddings: Tensor,
|
| 373 |
+
image_positional_embeddings: Tensor,
|
| 374 |
+
attention_similarity: Tensor,
|
| 375 |
+
target_embedding=None,
|
| 376 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 377 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 378 |
+
if image_embeddings is None:
|
| 379 |
+
raise ValueError("You have to specify an image_embedding")
|
| 380 |
+
|
| 381 |
+
image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
|
| 382 |
+
image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
|
| 383 |
+
|
| 384 |
+
# Prepare queries
|
| 385 |
+
queries = point_embeddings
|
| 386 |
+
keys = image_embeddings
|
| 387 |
+
|
| 388 |
+
# Apply transformer blocks and final layernorm
|
| 389 |
+
for layer in self.layers:
|
| 390 |
+
if target_embedding is not None:
|
| 391 |
+
queries += target_embedding
|
| 392 |
+
|
| 393 |
+
queries, keys, _ = layer(
|
| 394 |
+
queries=queries,
|
| 395 |
+
keys=keys,
|
| 396 |
+
query_point_embedding=point_embeddings,
|
| 397 |
+
key_point_embedding=image_positional_embeddings,
|
| 398 |
+
attention_similarity=attention_similarity,
|
| 399 |
+
**kwargs,
|
| 400 |
+
)
|
| 401 |
+
# Apply the final attention layer from the points to the image
|
| 402 |
+
query = queries + point_embeddings
|
| 403 |
+
key = keys + image_positional_embeddings
|
| 404 |
+
|
| 405 |
+
attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
|
| 406 |
+
|
| 407 |
+
queries = queries + attn_out
|
| 408 |
+
queries = self.layer_norm_final_attn(queries)
|
| 409 |
+
return queries, keys
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class SamFeedForward(nn.Module):
|
| 413 |
+
def __init__(
|
| 414 |
+
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
|
| 415 |
+
):
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.num_layers = num_layers
|
| 418 |
+
self.activation = nn.ReLU()
|
| 419 |
+
self.proj_in = nn.Linear(input_dim, hidden_dim)
|
| 420 |
+
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
| 421 |
+
self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
|
| 422 |
+
self.sigmoid_output = sigmoid_output
|
| 423 |
+
|
| 424 |
+
def forward(self, hidden_states):
|
| 425 |
+
hidden_states = self.proj_in(hidden_states)
|
| 426 |
+
hidden_states = self.activation(hidden_states)
|
| 427 |
+
for layer in self.layers:
|
| 428 |
+
hidden_states = self.activation(layer(hidden_states))
|
| 429 |
+
|
| 430 |
+
hidden_states = self.proj_out(hidden_states)
|
| 431 |
+
if self.sigmoid_output:
|
| 432 |
+
hidden_states = F.sigmoid(hidden_states)
|
| 433 |
+
return hidden_states
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class SamMaskDecoder(nn.Module):
|
| 437 |
+
def __init__(self, config: SamMaskDecoderConfig):
|
| 438 |
+
super().__init__()
|
| 439 |
+
self.config = config
|
| 440 |
+
self.hidden_size = config.hidden_size
|
| 441 |
+
|
| 442 |
+
self.num_multimask_outputs = config.num_multimask_outputs
|
| 443 |
+
self.num_mask_tokens = config.num_multimask_outputs + 1
|
| 444 |
+
|
| 445 |
+
self.iou_token = nn.Embedding(1, self.hidden_size)
|
| 446 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
|
| 447 |
+
|
| 448 |
+
self.transformer = SamTwoWayTransformer(config)
|
| 449 |
+
|
| 450 |
+
# should we create a new class for this?
|
| 451 |
+
self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
|
| 452 |
+
self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
|
| 453 |
+
self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
|
| 454 |
+
self.activation = nn.GELU()
|
| 455 |
+
|
| 456 |
+
mlps_list = []
|
| 457 |
+
for _ in range(self.num_mask_tokens):
|
| 458 |
+
mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
|
| 459 |
+
self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
|
| 460 |
+
|
| 461 |
+
self.iou_prediction_head = SamFeedForward(
|
| 462 |
+
self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def forward(
|
| 466 |
+
self,
|
| 467 |
+
image_embeddings: torch.Tensor,
|
| 468 |
+
image_positional_embeddings: torch.Tensor,
|
| 469 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 470 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 471 |
+
multimask_output: bool,
|
| 472 |
+
attention_similarity: Optional[torch.Tensor] = None,
|
| 473 |
+
target_embedding: Optional[torch.Tensor] = None,
|
| 474 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 475 |
+
"""
|
| 476 |
+
Predict masks given image and prompt embeddings.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
image_embeddings (`torch.Tensor`):
|
| 480 |
+
the embeddings from the image encoder
|
| 481 |
+
image_positional_embedding (`torch.Tensor`):
|
| 482 |
+
positional encoding with the shape of image_embeddings
|
| 483 |
+
sparse_prompt_embeddings (`torch.Tensor`):
|
| 484 |
+
The embeddings of the points and boxes
|
| 485 |
+
dense_prompt_embeddings (`torch.Tensor`):
|
| 486 |
+
the embeddings of the mask inputs
|
| 487 |
+
multimask_output (bool):
|
| 488 |
+
Whether to return multiple masks or a single mask.
|
| 489 |
+
"""
|
| 490 |
+
batch_size, num_channels, height, width = image_embeddings.shape
|
| 491 |
+
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
|
| 492 |
+
# Concatenate output tokens
|
| 493 |
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
| 494 |
+
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
| 495 |
+
|
| 496 |
+
if sparse_prompt_embeddings is not None:
|
| 497 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
| 498 |
+
else:
|
| 499 |
+
tokens = output_tokens
|
| 500 |
+
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
| 501 |
+
|
| 502 |
+
# Expand per-image data in batch direction to be per-point
|
| 503 |
+
image_embeddings = image_embeddings + dense_prompt_embeddings
|
| 504 |
+
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
| 505 |
+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
|
| 506 |
+
|
| 507 |
+
# Run the transformer, image_positional_embedding are consumed
|
| 508 |
+
point_embedding, image_embeddings = self.transformer(
|
| 509 |
+
point_embeddings=point_embeddings,
|
| 510 |
+
image_embeddings=image_embeddings,
|
| 511 |
+
image_positional_embeddings=image_positional_embeddings,
|
| 512 |
+
attention_similarity=attention_similarity,
|
| 513 |
+
target_embedding=target_embedding,
|
| 514 |
+
)
|
| 515 |
+
iou_token_out = point_embedding[:, :, 0, :]
|
| 516 |
+
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
|
| 517 |
+
|
| 518 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
| 519 |
+
image_embeddings = image_embeddings.transpose(2, 3).reshape(
|
| 520 |
+
batch_size * point_batch_size, num_channels, height, width
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
upscaled_embedding = self.upscale_conv1(image_embeddings)
|
| 524 |
+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
| 525 |
+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
|
| 526 |
+
|
| 527 |
+
hyper_in_list = []
|
| 528 |
+
for i in range(self.num_mask_tokens):
|
| 529 |
+
current_mlp = self.output_hypernetworks_mlps[i]
|
| 530 |
+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
| 531 |
+
hyper_in = torch.stack(hyper_in_list, dim=2)
|
| 532 |
+
|
| 533 |
+
_, num_channels, height, width = upscaled_embedding.shape
|
| 534 |
+
upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
|
| 535 |
+
masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
|
| 536 |
+
|
| 537 |
+
# Generate mask quality predictions
|
| 538 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 539 |
+
|
| 540 |
+
# Select the correct mask or masks for output
|
| 541 |
+
if multimask_output:
|
| 542 |
+
mask_slice = slice(1, None)
|
| 543 |
+
else:
|
| 544 |
+
mask_slice = slice(0, 1)
|
| 545 |
+
masks = masks[:, :, mask_slice, :, :]
|
| 546 |
+
iou_pred = iou_pred[:, :, mask_slice]
|
| 547 |
+
return masks, iou_pred
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
class SamPositionalEmbedding(nn.Module):
|
| 551 |
+
def __init__(self, config):
|
| 552 |
+
super().__init__()
|
| 553 |
+
self.scale = config.hidden_size // 2
|
| 554 |
+
self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
|
| 555 |
+
|
| 556 |
+
def forward(self, input_coords, input_shape=None):
|
| 557 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 558 |
+
coordinates = input_coords.clone()
|
| 559 |
+
|
| 560 |
+
if input_shape is not None:
|
| 561 |
+
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
|
| 562 |
+
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
|
| 563 |
+
|
| 564 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 565 |
+
coordinates = 2 * coordinates - 1
|
| 566 |
+
coordinates = coordinates.to(self.positional_embedding.dtype)
|
| 567 |
+
coordinates = coordinates @ self.positional_embedding
|
| 568 |
+
coordinates = 2 * np.pi * coordinates
|
| 569 |
+
# outputs d_1 x ... x d_n x channel shape
|
| 570 |
+
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class SamMaskEmbedding(nn.Module):
|
| 574 |
+
def __init__(self, config: SamPromptEncoderConfig):
|
| 575 |
+
super().__init__()
|
| 576 |
+
self.mask_input_channels = config.mask_input_channels // 4
|
| 577 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 578 |
+
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
|
| 579 |
+
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
|
| 580 |
+
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
|
| 581 |
+
self.layer_norm1 = SamLayerNorm(
|
| 582 |
+
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
|
| 583 |
+
)
|
| 584 |
+
self.layer_norm2 = SamLayerNorm(
|
| 585 |
+
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
def forward(self, masks):
|
| 589 |
+
hidden_states = self.conv1(masks)
|
| 590 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 591 |
+
hidden_states = self.activation(hidden_states)
|
| 592 |
+
|
| 593 |
+
hidden_states = self.conv2(hidden_states)
|
| 594 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 595 |
+
hidden_states = self.activation(hidden_states)
|
| 596 |
+
dense_embeddings = self.conv3(hidden_states)
|
| 597 |
+
return dense_embeddings
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class SamPromptEncoder(nn.Module):
|
| 601 |
+
def __init__(self, config: SamConfig):
|
| 602 |
+
super().__init__()
|
| 603 |
+
self.shared_embedding = SamPositionalEmbedding(config.vision_config)
|
| 604 |
+
config = config.prompt_encoder_config
|
| 605 |
+
self.mask_embed = SamMaskEmbedding(config)
|
| 606 |
+
self.no_mask_embed = nn.Embedding(1, config.hidden_size)
|
| 607 |
+
|
| 608 |
+
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
|
| 609 |
+
self.input_image_size = config.image_size
|
| 610 |
+
|
| 611 |
+
self.point_embed = nn.ModuleList(
|
| 612 |
+
[nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
|
| 613 |
+
)
|
| 614 |
+
self.hidden_size = config.hidden_size
|
| 615 |
+
self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
|
| 616 |
+
|
| 617 |
+
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
|
| 618 |
+
"""Embeds point prompts."""
|
| 619 |
+
points = points + 0.5 # Shift to center of pixel
|
| 620 |
+
if pad:
|
| 621 |
+
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
|
| 622 |
+
target_labels_shape = (points.shape[0], points.shape[1], 1)
|
| 623 |
+
padding_point = torch.zeros(target_point_shape, device=points.device)
|
| 624 |
+
padding_label = -torch.ones(target_labels_shape, device=labels.device)
|
| 625 |
+
points = torch.cat([points, padding_point], dim=2)
|
| 626 |
+
labels = torch.cat([labels, padding_label], dim=2)
|
| 627 |
+
input_shape = (self.input_image_size, self.input_image_size)
|
| 628 |
+
point_embedding = self.shared_embedding(points, input_shape)
|
| 629 |
+
|
| 630 |
+
# torch.where and expanding the labels tensor is required by the ONNX export
|
| 631 |
+
point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
|
| 632 |
+
|
| 633 |
+
# This is required for the ONNX export. The dtype, device need to be explicitly
|
| 634 |
+
# specified as otherwise torch.onnx.export interprets as double
|
| 635 |
+
point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
|
| 636 |
+
|
| 637 |
+
point_embedding = torch.where(
|
| 638 |
+
(labels == 0)[:, :, :, None],
|
| 639 |
+
point_embedding + self.point_embed[0].weight[None, None, :, :],
|
| 640 |
+
point_embedding,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
point_embedding = torch.where(
|
| 644 |
+
(labels == 1)[:, :, :, None],
|
| 645 |
+
point_embedding + self.point_embed[1].weight[None, None, :, :],
|
| 646 |
+
point_embedding,
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
return point_embedding
|
| 650 |
+
|
| 651 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
| 652 |
+
"""Embeds box prompts."""
|
| 653 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
| 654 |
+
batch_size, nb_boxes = boxes.shape[:2]
|
| 655 |
+
coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
|
| 656 |
+
input_shape = (self.input_image_size, self.input_image_size)
|
| 657 |
+
corner_embedding = self.shared_embedding(coords, input_shape)
|
| 658 |
+
corner_embedding[:, :, 0, :] += self.point_embed[2].weight
|
| 659 |
+
corner_embedding[:, :, 1, :] += self.point_embed[3].weight
|
| 660 |
+
return corner_embedding
|
| 661 |
+
|
| 662 |
+
def forward(
|
| 663 |
+
self,
|
| 664 |
+
input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
|
| 665 |
+
input_labels: Optional[torch.Tensor],
|
| 666 |
+
input_boxes: Optional[torch.Tensor],
|
| 667 |
+
input_masks: Optional[torch.Tensor],
|
| 668 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 669 |
+
"""
|
| 670 |
+
Embeds different types of prompts, returning both sparse and dense embeddings.
|
| 671 |
+
|
| 672 |
+
Args:
|
| 673 |
+
points (`torch.Tensor`, *optional*):
|
| 674 |
+
point coordinates and labels to embed.
|
| 675 |
+
boxes (`torch.Tensor`, *optional*):
|
| 676 |
+
boxes to embed
|
| 677 |
+
masks (`torch.Tensor`, *optional*):
|
| 678 |
+
masks to embed
|
| 679 |
+
"""
|
| 680 |
+
sparse_embeddings = None
|
| 681 |
+
batch_size = 1
|
| 682 |
+
if input_points is not None:
|
| 683 |
+
batch_size = input_points.shape[0]
|
| 684 |
+
if input_labels is None:
|
| 685 |
+
raise ValueError("If points are provided, labels must also be provided.")
|
| 686 |
+
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
| 687 |
+
sparse_embeddings = point_embeddings
|
| 688 |
+
if input_boxes is not None:
|
| 689 |
+
batch_size = input_boxes.shape[0]
|
| 690 |
+
box_embeddings = self._embed_boxes(input_boxes)
|
| 691 |
+
if sparse_embeddings is None:
|
| 692 |
+
sparse_embeddings = box_embeddings
|
| 693 |
+
else:
|
| 694 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
|
| 695 |
+
if input_masks is not None:
|
| 696 |
+
dense_embeddings = self.mask_embed(input_masks)
|
| 697 |
+
else:
|
| 698 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
| 699 |
+
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
return sparse_embeddings, dense_embeddings
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
class SamVisionAttention(nn.Module):
|
| 706 |
+
"""Multi-head Attention block with relative position embeddings."""
|
| 707 |
+
|
| 708 |
+
def __init__(self, config, window_size):
|
| 709 |
+
super().__init__()
|
| 710 |
+
input_size = (
|
| 711 |
+
(config.image_size // config.patch_size, config.image_size // config.patch_size)
|
| 712 |
+
if window_size == 0
|
| 713 |
+
else (window_size, window_size)
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
self.num_attention_heads = config.num_attention_heads
|
| 717 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
| 718 |
+
self.scale = head_dim**-0.5
|
| 719 |
+
self.dropout = config.attention_dropout
|
| 720 |
+
|
| 721 |
+
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
|
| 722 |
+
self.proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 723 |
+
|
| 724 |
+
self.use_rel_pos = config.use_rel_pos
|
| 725 |
+
if self.use_rel_pos:
|
| 726 |
+
if input_size is None:
|
| 727 |
+
raise ValueError("Input size must be provided if using relative positional encoding.")
|
| 728 |
+
|
| 729 |
+
# initialize relative positional embeddings
|
| 730 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
| 731 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
| 732 |
+
|
| 733 |
+
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
| 734 |
+
"""
|
| 735 |
+
Get relative positional embeddings according to the relative positions of
|
| 736 |
+
query and key sizes.
|
| 737 |
+
|
| 738 |
+
Args:
|
| 739 |
+
q_size (int):
|
| 740 |
+
size of the query.
|
| 741 |
+
k_size (int):
|
| 742 |
+
size of key k.
|
| 743 |
+
rel_pos (`torch.Tensor`):
|
| 744 |
+
relative position embeddings (L, channel).
|
| 745 |
+
|
| 746 |
+
Returns:
|
| 747 |
+
Extracted positional embeddings according to relative positions.
|
| 748 |
+
"""
|
| 749 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 750 |
+
# Interpolate rel pos.
|
| 751 |
+
rel_pos_resized = F.interpolate(
|
| 752 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| 753 |
+
size=max_rel_dist,
|
| 754 |
+
mode="linear",
|
| 755 |
+
)
|
| 756 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| 757 |
+
|
| 758 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 759 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
| 760 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
| 761 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 762 |
+
|
| 763 |
+
return rel_pos_resized[relative_coords.long()]
|
| 764 |
+
|
| 765 |
+
def get_decomposed_rel_pos(
|
| 766 |
+
self,
|
| 767 |
+
query: torch.Tensor,
|
| 768 |
+
rel_pos_h: torch.Tensor,
|
| 769 |
+
rel_pos_w: torch.Tensor,
|
| 770 |
+
q_size: tuple[int, int],
|
| 771 |
+
k_size: tuple[int, int],
|
| 772 |
+
) -> torch.Tensor:
|
| 773 |
+
"""
|
| 774 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
| 775 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
| 776 |
+
|
| 777 |
+
Args:
|
| 778 |
+
query (`torch.Tensor`):
|
| 779 |
+
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
| 780 |
+
rel_pos_h (`torch.Tensor`):
|
| 781 |
+
relative position embeddings (Lh, channel) for height axis.
|
| 782 |
+
rel_pos_w (`torch.Tensor`):
|
| 783 |
+
relative position embeddings (Lw, channel) for width axis.
|
| 784 |
+
q_size (tuple):
|
| 785 |
+
spatial sequence size of query q with (query_height, query_width).
|
| 786 |
+
k_size (tuple):
|
| 787 |
+
spatial sequence size of key k with (key_height, key_width).
|
| 788 |
+
|
| 789 |
+
Returns:
|
| 790 |
+
decomposed_rel_pos (`torch.Tensor`):
|
| 791 |
+
decomposed relative position embeddings.
|
| 792 |
+
"""
|
| 793 |
+
query_height, query_width = q_size
|
| 794 |
+
key_height, key_width = k_size
|
| 795 |
+
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
|
| 796 |
+
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
|
| 797 |
+
|
| 798 |
+
batch_size, _, dim = query.shape
|
| 799 |
+
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
|
| 800 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
| 801 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
| 802 |
+
|
| 803 |
+
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
| 804 |
+
|
| 805 |
+
return decomposed_rel_pos
|
| 806 |
+
|
| 807 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]:
|
| 808 |
+
batch_size, height, width, _ = hidden_states.shape
|
| 809 |
+
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
| 810 |
+
qkv = (
|
| 811 |
+
self.qkv(hidden_states)
|
| 812 |
+
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
| 813 |
+
.permute(2, 0, 3, 1, 4)
|
| 814 |
+
)
|
| 815 |
+
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
| 816 |
+
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
| 817 |
+
|
| 818 |
+
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
| 819 |
+
|
| 820 |
+
if self.use_rel_pos:
|
| 821 |
+
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
| 822 |
+
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
| 823 |
+
)
|
| 824 |
+
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
|
| 825 |
+
attn_weights = attn_weights + decomposed_rel_pos
|
| 826 |
+
|
| 827 |
+
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
| 828 |
+
|
| 829 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 830 |
+
|
| 831 |
+
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
| 832 |
+
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
| 833 |
+
|
| 834 |
+
attn_output = self.proj(attn_output)
|
| 835 |
+
return attn_output, attn_weights
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
class SamVisionSdpaAttention(SamVisionAttention):
|
| 839 |
+
"""
|
| 840 |
+
Multi-head Attention block with relative position embeddings.
|
| 841 |
+
Using SDPA instead of the default attention.
|
| 842 |
+
"""
|
| 843 |
+
|
| 844 |
+
def __init__(self, config, window_size):
|
| 845 |
+
super().__init__(config, window_size)
|
| 846 |
+
|
| 847 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
| 848 |
+
if output_attentions:
|
| 849 |
+
logger.warning_once(
|
| 850 |
+
"`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
| 851 |
+
"`output_attentions=True`. Falling back to the manual attention implementation, but "
|
| 852 |
+
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
| 853 |
+
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 854 |
+
)
|
| 855 |
+
return super().forward(
|
| 856 |
+
hidden_states=hidden_states,
|
| 857 |
+
output_attentions=output_attentions,
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
batch_size, height, width, _ = hidden_states.shape
|
| 861 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
| 862 |
+
qkv = (
|
| 863 |
+
self.qkv(hidden_states)
|
| 864 |
+
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
| 865 |
+
.permute(2, 0, 3, 1, 4)
|
| 866 |
+
)
|
| 867 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
| 868 |
+
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
| 869 |
+
|
| 870 |
+
attn_bias = None
|
| 871 |
+
if self.use_rel_pos:
|
| 872 |
+
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
| 873 |
+
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
| 874 |
+
)
|
| 875 |
+
decomposed_rel_pos = decomposed_rel_pos.reshape(
|
| 876 |
+
batch_size, self.num_attention_heads, height * width, height * width
|
| 877 |
+
)
|
| 878 |
+
attn_bias = decomposed_rel_pos
|
| 879 |
+
|
| 880 |
+
query = query.view(batch_size, self.num_attention_heads, height * width, -1)
|
| 881 |
+
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
|
| 882 |
+
value = value.view(batch_size, self.num_attention_heads, height * width, -1)
|
| 883 |
+
|
| 884 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
|
| 885 |
+
|
| 886 |
+
attn_output = (
|
| 887 |
+
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
|
| 888 |
+
.permute(0, 2, 3, 1, 4)
|
| 889 |
+
.reshape(batch_size, height, width, -1)
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
attn_output = self.proj(attn_output)
|
| 893 |
+
return attn_output, None
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
SAM_VISION_ATTENTION_CLASSES = {
|
| 897 |
+
"eager": SamVisionAttention,
|
| 898 |
+
"sdpa": SamVisionSdpaAttention,
|
| 899 |
+
}
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
class SamVisionLayer(GradientCheckpointingLayer):
|
| 903 |
+
def __init__(self, config, window_size):
|
| 904 |
+
super().__init__()
|
| 905 |
+
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 906 |
+
self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
|
| 907 |
+
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 908 |
+
self.mlp = SamMLPBlock(config)
|
| 909 |
+
self.window_size = window_size
|
| 910 |
+
|
| 911 |
+
def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
|
| 912 |
+
"""
|
| 913 |
+
Args:
|
| 914 |
+
Partition into non-overlapping windows with padding if needed.
|
| 915 |
+
hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
|
| 916 |
+
size.
|
| 917 |
+
|
| 918 |
+
Returns:
|
| 919 |
+
windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
|
| 920 |
+
(pad_height, pad_width): padded height and width before partition
|
| 921 |
+
"""
|
| 922 |
+
batch_size, height, width, channel = hidden_states.shape
|
| 923 |
+
|
| 924 |
+
pad_h = (window_size - height % window_size) % window_size
|
| 925 |
+
pad_w = (window_size - width % window_size) % window_size
|
| 926 |
+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
|
| 927 |
+
pad_height, pad_width = height + pad_h, width + pad_w
|
| 928 |
+
|
| 929 |
+
hidden_states = hidden_states.reshape(
|
| 930 |
+
batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
|
| 931 |
+
)
|
| 932 |
+
windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
|
| 933 |
+
return windows, (pad_height, pad_width)
|
| 934 |
+
|
| 935 |
+
def window_unpartition(
|
| 936 |
+
self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int]
|
| 937 |
+
) -> torch.Tensor:
|
| 938 |
+
"""
|
| 939 |
+
Args:
|
| 940 |
+
Window unpartition into original sequences and removing padding.
|
| 941 |
+
hidden_states (tensor):
|
| 942 |
+
input tokens with [batch_size * num_windows, window_size, window_size, channel].
|
| 943 |
+
window_size (int):
|
| 944 |
+
window size.
|
| 945 |
+
padding_shape (Tuple):
|
| 946 |
+
padded height and width (pad_height, pad_width).
|
| 947 |
+
original_shape (Tuple): original height and width (height, width) before padding.
|
| 948 |
+
|
| 949 |
+
Returns:
|
| 950 |
+
hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
|
| 951 |
+
"""
|
| 952 |
+
pad_height, pad_width = padding_shape
|
| 953 |
+
height, width = original_shape
|
| 954 |
+
batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
|
| 955 |
+
hidden_states = windows.reshape(
|
| 956 |
+
batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
|
| 957 |
+
)
|
| 958 |
+
hidden_states = (
|
| 959 |
+
hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
hidden_states = hidden_states[:, :height, :width, :].contiguous()
|
| 963 |
+
return hidden_states
|
| 964 |
+
|
| 965 |
+
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
|
| 966 |
+
residual = hidden_states
|
| 967 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 968 |
+
# Window partition
|
| 969 |
+
if self.window_size > 0:
|
| 970 |
+
height, width = hidden_states.shape[1], hidden_states.shape[2]
|
| 971 |
+
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
|
| 972 |
+
|
| 973 |
+
hidden_states, attn_weights = self.attn(
|
| 974 |
+
hidden_states=hidden_states,
|
| 975 |
+
)
|
| 976 |
+
# Reverse window partition
|
| 977 |
+
if self.window_size > 0:
|
| 978 |
+
hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
|
| 979 |
+
|
| 980 |
+
hidden_states = residual + hidden_states
|
| 981 |
+
layernorm_output = self.layer_norm2(hidden_states)
|
| 982 |
+
hidden_states = hidden_states + self.mlp(layernorm_output)
|
| 983 |
+
return hidden_states
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
class SamVisionNeck(nn.Module):
|
| 987 |
+
def __init__(self, config: SamVisionConfig):
|
| 988 |
+
super().__init__()
|
| 989 |
+
self.config = config
|
| 990 |
+
|
| 991 |
+
self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
|
| 992 |
+
self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
|
| 993 |
+
self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
|
| 994 |
+
self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
|
| 995 |
+
|
| 996 |
+
def forward(self, hidden_states):
|
| 997 |
+
hidden_states = hidden_states.permute(0, 3, 1, 2)
|
| 998 |
+
hidden_states = self.conv1(hidden_states)
|
| 999 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 1000 |
+
|
| 1001 |
+
hidden_states = self.conv2(hidden_states)
|
| 1002 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 1003 |
+
return hidden_states
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
@auto_docstring
|
| 1007 |
+
class SamPreTrainedModel(PreTrainedModel):
|
| 1008 |
+
config: SamConfig
|
| 1009 |
+
base_model_prefix = "sam"
|
| 1010 |
+
main_input_name = "pixel_values"
|
| 1011 |
+
_no_split_modules = ["SamVisionAttention"]
|
| 1012 |
+
supports_gradient_checkpointing = True
|
| 1013 |
+
_supports_sdpa = True
|
| 1014 |
+
|
| 1015 |
+
def _init_weights(self, module: nn.Module):
|
| 1016 |
+
super()._init_weights(module)
|
| 1017 |
+
if isinstance(module, SamVisionAttention):
|
| 1018 |
+
if module.use_rel_pos:
|
| 1019 |
+
module.rel_pos_h.data.zero_()
|
| 1020 |
+
module.rel_pos_w.data.zero_()
|
| 1021 |
+
elif isinstance(module, SamVisionEncoder):
|
| 1022 |
+
if self.config.use_abs_pos:
|
| 1023 |
+
module.pos_embed.data.zero_()
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
class SamVisionEncoder(SamPreTrainedModel):
|
| 1027 |
+
_can_record_outputs = {"hidden_states": SamVisionLayer, "attentions": SamVisionAttention}
|
| 1028 |
+
|
| 1029 |
+
def __init__(self, config: SamVisionConfig):
|
| 1030 |
+
super().__init__(config)
|
| 1031 |
+
self.config = config
|
| 1032 |
+
self.image_size = config.image_size
|
| 1033 |
+
self.patch_embed = SamPatchEmbeddings(config)
|
| 1034 |
+
|
| 1035 |
+
self.pos_embed = None
|
| 1036 |
+
if config.use_abs_pos:
|
| 1037 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 1038 |
+
self.pos_embed = nn.Parameter(
|
| 1039 |
+
torch.zeros(
|
| 1040 |
+
1,
|
| 1041 |
+
config.image_size // config.patch_size,
|
| 1042 |
+
config.image_size // config.patch_size,
|
| 1043 |
+
config.hidden_size,
|
| 1044 |
+
)
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
self.layers = nn.ModuleList()
|
| 1048 |
+
for i in range(config.num_hidden_layers):
|
| 1049 |
+
layer = SamVisionLayer(
|
| 1050 |
+
config,
|
| 1051 |
+
window_size=config.window_size if i not in config.global_attn_indexes else 0,
|
| 1052 |
+
)
|
| 1053 |
+
self.layers.append(layer)
|
| 1054 |
+
|
| 1055 |
+
self.neck = SamVisionNeck(config)
|
| 1056 |
+
|
| 1057 |
+
self.gradient_checkpointing = False
|
| 1058 |
+
|
| 1059 |
+
def get_input_embeddings(self):
|
| 1060 |
+
return self.patch_embed
|
| 1061 |
+
|
| 1062 |
+
@check_model_inputs
|
| 1063 |
+
def forward(
|
| 1064 |
+
self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
|
| 1065 |
+
) -> SamVisionEncoderOutput:
|
| 1066 |
+
if pixel_values is None:
|
| 1067 |
+
raise ValueError("You have to specify pixel_values")
|
| 1068 |
+
|
| 1069 |
+
hidden_states = self.patch_embed(pixel_values)
|
| 1070 |
+
if self.pos_embed is not None:
|
| 1071 |
+
hidden_states = hidden_states + self.pos_embed
|
| 1072 |
+
for layer_module in self.layers:
|
| 1073 |
+
hidden_states = layer_module(hidden_states)
|
| 1074 |
+
hidden_states = self.neck(hidden_states)
|
| 1075 |
+
return SamVisionEncoderOutput(
|
| 1076 |
+
last_hidden_state=hidden_states,
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
|
| 1080 |
+
@auto_docstring(
|
| 1081 |
+
custom_intro="""
|
| 1082 |
+
The vision model from Sam without any head or projection on top.
|
| 1083 |
+
"""
|
| 1084 |
+
)
|
| 1085 |
+
class SamVisionModel(SamPreTrainedModel):
|
| 1086 |
+
config: SamVisionConfig
|
| 1087 |
+
main_input_name = "pixel_values"
|
| 1088 |
+
|
| 1089 |
+
def __init__(self, config: SamVisionConfig):
|
| 1090 |
+
super().__init__(config)
|
| 1091 |
+
self.vision_encoder = SamVisionEncoder(config)
|
| 1092 |
+
self.post_init()
|
| 1093 |
+
|
| 1094 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 1095 |
+
return self.vision_encoder.patch_embed
|
| 1096 |
+
|
| 1097 |
+
@auto_docstring
|
| 1098 |
+
def forward(
|
| 1099 |
+
self,
|
| 1100 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1101 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1102 |
+
) -> Union[tuple, SamVisionEncoderOutput]:
|
| 1103 |
+
return self.vision_encoder(pixel_values, **kwargs)
|
| 1104 |
+
|
| 1105 |
+
|
| 1106 |
+
@auto_docstring(
|
| 1107 |
+
custom_intro="""
|
| 1108 |
+
Segment Anything Model (SAM) for generating segmentation masks, given an input image and
|
| 1109 |
+
input points and labels, boxes, or masks.
|
| 1110 |
+
"""
|
| 1111 |
+
)
|
| 1112 |
+
class SamModel(SamPreTrainedModel):
|
| 1113 |
+
_tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
|
| 1114 |
+
# need to be ignored, as it's a buffer and will not be correctly detected as tied weight
|
| 1115 |
+
_keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
|
| 1116 |
+
_can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)}
|
| 1117 |
+
|
| 1118 |
+
def __init__(self, config: SamConfig):
|
| 1119 |
+
super().__init__(config)
|
| 1120 |
+
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
|
| 1121 |
+
|
| 1122 |
+
self.vision_encoder = SamVisionEncoder(config.vision_config)
|
| 1123 |
+
self.prompt_encoder = SamPromptEncoder(config)
|
| 1124 |
+
# The module using it is not a PreTrainedModel subclass so we need this
|
| 1125 |
+
config.mask_decoder_config._attn_implementation = config._attn_implementation
|
| 1126 |
+
self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
|
| 1127 |
+
|
| 1128 |
+
self.post_init()
|
| 1129 |
+
|
| 1130 |
+
def _tie_weights(self):
|
| 1131 |
+
self.prompt_encoder.shared_embedding.positional_embedding.data = (
|
| 1132 |
+
self.shared_image_embedding.positional_embedding.data
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
def get_input_embeddings(self):
|
| 1136 |
+
return self.vision_encoder.get_input_embeddings()
|
| 1137 |
+
|
| 1138 |
+
def get_image_wide_positional_embeddings(self):
|
| 1139 |
+
size = self.config.prompt_encoder_config.image_embedding_size
|
| 1140 |
+
target_device = self.shared_image_embedding.positional_embedding.device
|
| 1141 |
+
target_dtype = self.shared_image_embedding.positional_embedding.dtype
|
| 1142 |
+
grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
|
| 1143 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 1144 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 1145 |
+
y_embed = y_embed / size
|
| 1146 |
+
x_embed = x_embed / size
|
| 1147 |
+
|
| 1148 |
+
positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
|
| 1149 |
+
return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
|
| 1150 |
+
|
| 1151 |
+
@torch.no_grad()
|
| 1152 |
+
def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
|
| 1153 |
+
r"""
|
| 1154 |
+
Returns the image embeddings by passing the pixel values through the vision encoder.
|
| 1155 |
+
|
| 1156 |
+
Args:
|
| 1157 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1158 |
+
Input pixel values
|
| 1159 |
+
"""
|
| 1160 |
+
vision_output = self.vision_encoder(
|
| 1161 |
+
pixel_values,
|
| 1162 |
+
**kwargs,
|
| 1163 |
+
)
|
| 1164 |
+
image_embeddings = vision_output[0]
|
| 1165 |
+
return image_embeddings
|
| 1166 |
+
|
| 1167 |
+
@torch.no_grad()
|
| 1168 |
+
def get_prompt_embeddings(
|
| 1169 |
+
self,
|
| 1170 |
+
input_points: Optional[torch.FloatTensor] = None,
|
| 1171 |
+
input_labels: Optional[torch.LongTensor] = None,
|
| 1172 |
+
input_boxes: Optional[torch.FloatTensor] = None,
|
| 1173 |
+
input_masks: Optional[torch.LongTensor] = None,
|
| 1174 |
+
):
|
| 1175 |
+
r"""
|
| 1176 |
+
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
|
| 1177 |
+
|
| 1178 |
+
Args:
|
| 1179 |
+
input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
|
| 1180 |
+
Optional input points for the prompt encoder. The padding of the point is automatically done by the
|
| 1181 |
+
processor. `point_batch_size` refers to the number of masks that we want the model to predict per
|
| 1182 |
+
point. The model will output `point_batch_size` times 3 masks in total.
|
| 1183 |
+
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
|
| 1184 |
+
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
|
| 1185 |
+
processor, or can be fed by the user.
|
| 1186 |
+
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
|
| 1187 |
+
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
|
| 1188 |
+
processor. users can also pass manually the input boxes.
|
| 1189 |
+
input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
|
| 1190 |
+
Optional input masks for the prompt encoder.
|
| 1191 |
+
"""
|
| 1192 |
+
prompt_output = self.prompt_encoder(
|
| 1193 |
+
input_points=input_points,
|
| 1194 |
+
input_labels=input_labels,
|
| 1195 |
+
input_boxes=input_boxes,
|
| 1196 |
+
input_masks=input_masks,
|
| 1197 |
+
)
|
| 1198 |
+
return prompt_output
|
| 1199 |
+
|
| 1200 |
+
@check_model_inputs
|
| 1201 |
+
@auto_docstring
|
| 1202 |
+
def forward(
|
| 1203 |
+
self,
|
| 1204 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1205 |
+
input_points: Optional[torch.FloatTensor] = None,
|
| 1206 |
+
input_labels: Optional[torch.LongTensor] = None,
|
| 1207 |
+
input_boxes: Optional[torch.FloatTensor] = None,
|
| 1208 |
+
input_masks: Optional[torch.LongTensor] = None,
|
| 1209 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
| 1210 |
+
multimask_output: bool = True,
|
| 1211 |
+
attention_similarity: Optional[torch.FloatTensor] = None,
|
| 1212 |
+
target_embedding: Optional[torch.FloatTensor] = None,
|
| 1213 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1214 |
+
) -> SamImageSegmentationOutput:
|
| 1215 |
+
r"""
|
| 1216 |
+
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
|
| 1217 |
+
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
|
| 1218 |
+
better results. The points can be obtained by passing a list of list of list to the processor that will
|
| 1219 |
+
create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
|
| 1220 |
+
second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
|
| 1221 |
+
per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
|
| 1222 |
+
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
|
| 1223 |
+
coordinates of the point. If a different number of points is passed either for each image, or for each
|
| 1224 |
+
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
|
| 1225 |
+
computation of the embedding will be skipped for these points using the labels.
|
| 1226 |
+
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
|
| 1227 |
+
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
|
| 1228 |
+
official implementation, there are 3 types of labels
|
| 1229 |
+
|
| 1230 |
+
- `1`: the point is a point that contains the object of interest
|
| 1231 |
+
- `0`: the point is a point that does not contain the object of interest
|
| 1232 |
+
- `-1`: the point corresponds to the background
|
| 1233 |
+
|
| 1234 |
+
We added the label:
|
| 1235 |
+
|
| 1236 |
+
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder
|
| 1237 |
+
|
| 1238 |
+
The padding labels should be automatically done by the processor.
|
| 1239 |
+
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
|
| 1240 |
+
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
|
| 1241 |
+
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
|
| 1242 |
+
that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
|
| 1243 |
+
size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
|
| 1244 |
+
In the order (`x1`, `y1`, `x2`, `y2`):
|
| 1245 |
+
|
| 1246 |
+
- `x1`: the x coordinate of the top left point of the input box
|
| 1247 |
+
- `y1`: the y coordinate of the top left point of the input box
|
| 1248 |
+
- `x2`: the x coordinate of the bottom right point of the input box
|
| 1249 |
+
- `y2`: the y coordinate of the bottom right point of the input box
|
| 1250 |
+
input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
|
| 1251 |
+
SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
|
| 1252 |
+
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
|
| 1253 |
+
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
|
| 1254 |
+
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
|
| 1255 |
+
Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
|
| 1256 |
+
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
|
| 1257 |
+
method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
|
| 1258 |
+
multimask_output (`bool`, *optional*):
|
| 1259 |
+
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
|
| 1260 |
+
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
|
| 1261 |
+
"best" mask, by specifying `multimask_output=False`.
|
| 1262 |
+
attention_similarity (`torch.FloatTensor`, *optional*):
|
| 1263 |
+
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
|
| 1264 |
+
model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
|
| 1265 |
+
target_embedding (`torch.FloatTensor`, *optional*):
|
| 1266 |
+
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
|
| 1267 |
+
the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
|
| 1268 |
+
|
| 1269 |
+
Example:
|
| 1270 |
+
|
| 1271 |
+
```python
|
| 1272 |
+
>>> from PIL import Image
|
| 1273 |
+
>>> import requests
|
| 1274 |
+
>>> from transformers import AutoModel, AutoProcessor
|
| 1275 |
+
|
| 1276 |
+
>>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
|
| 1277 |
+
>>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
|
| 1278 |
+
|
| 1279 |
+
>>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
|
| 1280 |
+
>>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
| 1281 |
+
>>> input_points = [[[400, 650]]] # 2D location of a window on the car
|
| 1282 |
+
>>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
|
| 1283 |
+
|
| 1284 |
+
>>> # Get segmentation mask
|
| 1285 |
+
>>> outputs = model(**inputs)
|
| 1286 |
+
|
| 1287 |
+
>>> # Postprocess masks
|
| 1288 |
+
>>> masks = processor.post_process_masks(
|
| 1289 |
+
... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
|
| 1290 |
+
... )
|
| 1291 |
+
```
|
| 1292 |
+
"""
|
| 1293 |
+
if pixel_values is None and image_embeddings is None:
|
| 1294 |
+
raise ValueError("Either pixel_values or image_embeddings must be provided.")
|
| 1295 |
+
|
| 1296 |
+
if pixel_values is not None and image_embeddings is not None:
|
| 1297 |
+
raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
|
| 1298 |
+
|
| 1299 |
+
if input_points is not None and len(input_points.shape) != 4:
|
| 1300 |
+
raise ValueError(
|
| 1301 |
+
"The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
|
| 1302 |
+
f" got {input_points.shape}.",
|
| 1303 |
+
)
|
| 1304 |
+
if input_boxes is not None and len(input_boxes.shape) != 3:
|
| 1305 |
+
raise ValueError(
|
| 1306 |
+
"The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
|
| 1307 |
+
f" got {input_boxes.shape}.",
|
| 1308 |
+
)
|
| 1309 |
+
if input_points is not None and input_boxes is not None:
|
| 1310 |
+
point_batch_size = input_points.shape[1]
|
| 1311 |
+
box_batch_size = input_boxes.shape[1]
|
| 1312 |
+
if point_batch_size != box_batch_size:
|
| 1313 |
+
raise ValueError(
|
| 1314 |
+
f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}."
|
| 1315 |
+
)
|
| 1316 |
+
|
| 1317 |
+
image_positional_embeddings = self.get_image_wide_positional_embeddings()
|
| 1318 |
+
# repeat with batch size
|
| 1319 |
+
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
|
| 1320 |
+
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
| 1321 |
+
|
| 1322 |
+
vision_attentions = None
|
| 1323 |
+
vision_hidden_states = None
|
| 1324 |
+
|
| 1325 |
+
if pixel_values is not None:
|
| 1326 |
+
vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
|
| 1327 |
+
image_embeddings = vision_outputs.last_hidden_state
|
| 1328 |
+
vision_hidden_states = vision_outputs.hidden_states
|
| 1329 |
+
vision_attentions = vision_outputs.attentions
|
| 1330 |
+
|
| 1331 |
+
if input_points is not None and input_labels is None:
|
| 1332 |
+
input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
|
| 1333 |
+
|
| 1334 |
+
if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
|
| 1335 |
+
raise ValueError(
|
| 1336 |
+
"The batch size of the image embeddings and the input points must be the same. ",
|
| 1337 |
+
f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.",
|
| 1338 |
+
" if you want to pass multiple points for the same image, make sure that you passed ",
|
| 1339 |
+
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
|
| 1340 |
+
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
| 1341 |
+
)
|
| 1342 |
+
|
| 1343 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 1344 |
+
input_points=input_points,
|
| 1345 |
+
input_labels=input_labels,
|
| 1346 |
+
input_boxes=input_boxes,
|
| 1347 |
+
input_masks=input_masks,
|
| 1348 |
+
)
|
| 1349 |
+
|
| 1350 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
| 1351 |
+
image_embeddings=image_embeddings,
|
| 1352 |
+
image_positional_embeddings=image_positional_embeddings,
|
| 1353 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 1354 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 1355 |
+
multimask_output=multimask_output,
|
| 1356 |
+
attention_similarity=attention_similarity,
|
| 1357 |
+
target_embedding=target_embedding,
|
| 1358 |
+
)
|
| 1359 |
+
|
| 1360 |
+
return SamImageSegmentationOutput(
|
| 1361 |
+
iou_scores=iou_predictions,
|
| 1362 |
+
pred_masks=low_res_masks,
|
| 1363 |
+
vision_hidden_states=vision_hidden_states,
|
| 1364 |
+
vision_attentions=vision_attentions,
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]
|