cranky-coder08 commited on
Commit
db52997
·
verified ·
1 Parent(s): 86f2519

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. phivenv/Lib/site-packages/transformers/models/__pycache__/__init__.cpython-39.pyc +0 -0
  2. phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/__init__.cpython-39.pyc +0 -0
  3. phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_flax_roformer.cpython-39.pyc +0 -0
  4. phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_roformer.cpython-39.pyc +0 -0
  5. phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_tf_roformer.cpython-39.pyc +0 -0
  6. phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_roformer.cpython-39.pyc +0 -0
  7. phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_roformer_fast.cpython-39.pyc +0 -0
  8. phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_utils.cpython-39.pyc +0 -0
  9. phivenv/Lib/site-packages/transformers/models/rt_detr/__init__.py +33 -0
  10. phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/__init__.cpython-39.pyc +0 -0
  11. phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr.cpython-39.pyc +0 -0
  12. phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr_resnet.cpython-39.pyc +0 -0
  13. phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr.cpython-39.pyc +0 -0
  14. phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr_fast.cpython-39.pyc +0 -0
  15. phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr.cpython-39.pyc +0 -0
  16. phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr_resnet.cpython-39.pyc +0 -0
  17. phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modular_rt_detr.cpython-39.pyc +0 -0
  18. phivenv/Lib/site-packages/transformers/models/rt_detr/configuration_rt_detr.py +372 -0
  19. phivenv/Lib/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py +114 -0
  20. phivenv/Lib/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py +1103 -0
  21. phivenv/Lib/site-packages/transformers/models/rt_detr/image_processing_rt_detr_fast.py +590 -0
  22. phivenv/Lib/site-packages/transformers/models/rt_detr/modeling_rt_detr.py +2013 -0
  23. phivenv/Lib/site-packages/transformers/models/rt_detr/modeling_rt_detr_resnet.py +399 -0
  24. phivenv/Lib/site-packages/transformers/models/rt_detr/modular_rt_detr.py +365 -0
  25. phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__init__.py +29 -0
  26. phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/__init__.cpython-39.pyc +0 -0
  27. phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/configuration_rt_detr_v2.cpython-39.pyc +0 -0
  28. phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/modeling_rt_detr_v2.cpython-39.pyc +0 -0
  29. phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/modular_rt_detr_v2.cpython-39.pyc +0 -0
  30. phivenv/Lib/site-packages/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +387 -0
  31. phivenv/Lib/site-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +1998 -0
  32. phivenv/Lib/site-packages/transformers/models/rt_detr_v2/modular_rt_detr_v2.py +636 -0
  33. phivenv/Lib/site-packages/transformers/models/rwkv/__init__.py +27 -0
  34. phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/__init__.cpython-39.pyc +0 -0
  35. phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/configuration_rwkv.cpython-39.pyc +0 -0
  36. phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/modeling_rwkv.cpython-39.pyc +0 -0
  37. phivenv/Lib/site-packages/transformers/models/rwkv/configuration_rwkv.py +120 -0
  38. phivenv/Lib/site-packages/transformers/models/rwkv/modeling_rwkv.py +798 -0
  39. phivenv/Lib/site-packages/transformers/models/sam/__init__.py +31 -0
  40. phivenv/Lib/site-packages/transformers/models/sam/__pycache__/__init__.cpython-39.pyc +0 -0
  41. phivenv/Lib/site-packages/transformers/models/sam/__pycache__/configuration_sam.cpython-39.pyc +0 -0
  42. phivenv/Lib/site-packages/transformers/models/sam/__pycache__/image_processing_sam.cpython-39.pyc +0 -0
  43. phivenv/Lib/site-packages/transformers/models/sam/__pycache__/image_processing_sam_fast.cpython-39.pyc +0 -0
  44. phivenv/Lib/site-packages/transformers/models/sam/__pycache__/modeling_sam.cpython-39.pyc +0 -0
  45. phivenv/Lib/site-packages/transformers/models/sam/__pycache__/modeling_tf_sam.cpython-39.pyc +0 -0
  46. phivenv/Lib/site-packages/transformers/models/sam/__pycache__/processing_sam.cpython-39.pyc +0 -0
  47. phivenv/Lib/site-packages/transformers/models/sam/configuration_sam.py +337 -0
  48. phivenv/Lib/site-packages/transformers/models/sam/image_processing_sam.py +1499 -0
  49. phivenv/Lib/site-packages/transformers/models/sam/image_processing_sam_fast.py +829 -0
  50. phivenv/Lib/site-packages/transformers/models/sam/modeling_sam.py +1368 -0
phivenv/Lib/site-packages/transformers/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (7.91 kB). View file
 
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (671 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_flax_roformer.cpython-39.pyc ADDED
Binary file (28.7 kB). View file
 
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_roformer.cpython-39.pyc ADDED
Binary file (41.7 kB). View file
 
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/modeling_tf_roformer.cpython-39.pyc ADDED
Binary file (47 kB). View file
 
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_roformer.cpython-39.pyc ADDED
Binary file (16.9 kB). View file
 
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_roformer_fast.cpython-39.pyc ADDED
Binary file (4.66 kB). View file
 
phivenv/Lib/site-packages/transformers/models/roformer/__pycache__/tokenization_utils.cpython-39.pyc ADDED
Binary file (1.62 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ from ...utils import _LazyModule
19
+ from ...utils.import_utils import define_import_structure
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from .configuration_rt_detr import *
24
+ from .configuration_rt_detr_resnet import *
25
+ from .image_processing_rt_detr import *
26
+ from .image_processing_rt_detr_fast import *
27
+ from .modeling_rt_detr import *
28
+ from .modeling_rt_detr_resnet import *
29
+ else:
30
+ import sys
31
+
32
+ _file = globals()["__file__"]
33
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (683 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr.cpython-39.pyc ADDED
Binary file (15.1 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr_resnet.cpython-39.pyc ADDED
Binary file (5.02 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr.cpython-39.pyc ADDED
Binary file (39.6 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr_fast.cpython-39.pyc ADDED
Binary file (18.4 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr.cpython-39.pyc ADDED
Binary file (64.1 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr_resnet.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/__pycache__/modular_rt_detr.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr/configuration_rt_detr.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RT-DETR model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ...utils.backbone_utils import verify_backbone_config_arguments
20
+ from ..auto import CONFIG_MAPPING
21
+ from .configuration_rt_detr_resnet import RTDetrResNetConfig
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class RTDetrConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a
30
+ RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
31
+ with the defaults will yield a similar configuration to that of the RT-DETR
32
+ [PekingU/rtdetr_r50vd](https://huggingface.co/PekingU/rtdetr_r50vd) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+ Args:
38
+ initializer_range (`float`, *optional*, defaults to 0.01):
39
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
40
+ initializer_bias_prior_prob (`float`, *optional*):
41
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
42
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
43
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
44
+ The epsilon used by the layer normalization layers.
45
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
46
+ The epsilon used by the batch normalization layers.
47
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
48
+ The configuration of the backbone model.
49
+ backbone (`str`, *optional*):
50
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
51
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
52
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
53
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
54
+ Whether to use pretrained weights for the backbone.
55
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
56
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
57
+ library.
58
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
59
+ Whether to freeze the batch normalization layers in the backbone.
60
+ backbone_kwargs (`dict`, *optional*):
61
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
62
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
63
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
64
+ Dimension of the layers in hybrid encoder.
65
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
66
+ Multi level features input for encoder.
67
+ feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
68
+ Strides used in each feature map.
69
+ encoder_layers (`int`, *optional*, defaults to 1):
70
+ Total of layers to be used by the encoder.
71
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
72
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
73
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
74
+ Number of attention heads for each attention layer in the Transformer encoder.
75
+ dropout (`float`, *optional*, defaults to 0.0):
76
+ The ratio for all dropout layers.
77
+ activation_dropout (`float`, *optional*, defaults to 0.0):
78
+ The dropout ratio for activations inside the fully connected layer.
79
+ encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
80
+ Indexes of the projected layers to be used in the encoder.
81
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
82
+ The temperature parameter used to create the positional encodings.
83
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
84
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
85
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
86
+ activation_function (`str`, *optional*, defaults to `"silu"`):
87
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
88
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
89
+ eval_size (`tuple[int, int]`, *optional*):
90
+ Height and width used to computes the effective height and width of the position embeddings after taking
91
+ into account the stride.
92
+ normalize_before (`bool`, *optional*, defaults to `False`):
93
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
94
+ feed-forward modules.
95
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
96
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
97
+ d_model (`int`, *optional*, defaults to 256):
98
+ Dimension of the layers exclude hybrid encoder.
99
+ num_queries (`int`, *optional*, defaults to 300):
100
+ Number of object queries.
101
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
102
+ Multi level features dimension for decoder
103
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
104
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
105
+ num_feature_levels (`int`, *optional*, defaults to 3):
106
+ The number of input feature levels.
107
+ decoder_n_points (`int`, *optional*, defaults to 4):
108
+ The number of sampled keys in each feature level for each attention head in the decoder.
109
+ decoder_layers (`int`, *optional*, defaults to 6):
110
+ Number of decoder layers.
111
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
112
+ Number of attention heads for each attention layer in the Transformer decoder.
113
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
114
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
115
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
116
+ attention_dropout (`float`, *optional*, defaults to 0.0):
117
+ The dropout ratio for the attention probabilities.
118
+ num_denoising (`int`, *optional*, defaults to 100):
119
+ The total number of denoising tasks or queries to be used for contrastive denoising.
120
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
121
+ The fraction of denoising labels to which random noise should be added.
122
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
123
+ Scale or magnitude of noise to be added to the bounding boxes.
124
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
125
+ Indicates whether the initial query embeddings for the decoder should be learned during training
126
+ anchor_image_size (`tuple[int, int]`, *optional*):
127
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
128
+ disable_custom_kernels (`bool`, *optional*, defaults to `True`):
129
+ Whether to disable custom kernels.
130
+ with_box_refine (`bool`, *optional*, defaults to `True`):
131
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
132
+ based on the predictions from the previous layer.
133
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
134
+ Whether the architecture has an encoder decoder structure.
135
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
136
+ Parameter alpha used by the Hungarian Matcher.
137
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
138
+ Parameter gamma used by the Hungarian Matcher.
139
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
140
+ The relative weight of the class loss used by the Hungarian Matcher.
141
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
142
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
143
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
144
+ The relative weight of the giou loss of used by the Hungarian Matcher.
145
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
146
+ Parameter informing if focal focal should be used.
147
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
148
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
149
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
150
+ Parameter alpha used to compute the focal loss.
151
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
152
+ Parameter gamma used to compute the focal loss.
153
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
154
+ Relative weight of the varifocal loss in the object detection loss.
155
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
156
+ Relative weight of the L1 bounding box loss in the object detection loss.
157
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
158
+ Relative weight of the generalized IoU loss in the object detection loss.
159
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
160
+ Relative classification weight of the 'no-object' class in the object detection loss.
161
+
162
+ Examples:
163
+
164
+ ```python
165
+ >>> from transformers import RTDetrConfig, RTDetrModel
166
+
167
+ >>> # Initializing a RT-DETR configuration
168
+ >>> configuration = RTDetrConfig()
169
+
170
+ >>> # Initializing a model (with random weights) from the configuration
171
+ >>> model = RTDetrModel(configuration)
172
+
173
+ >>> # Accessing the model configuration
174
+ >>> configuration = model.config
175
+ ```"""
176
+
177
+ model_type = "rt_detr"
178
+ layer_types = ["basic", "bottleneck"]
179
+ attribute_map = {
180
+ "hidden_size": "d_model",
181
+ "num_attention_heads": "encoder_attention_heads",
182
+ }
183
+
184
+ def __init__(
185
+ self,
186
+ initializer_range=0.01,
187
+ initializer_bias_prior_prob=None,
188
+ layer_norm_eps=1e-5,
189
+ batch_norm_eps=1e-5,
190
+ # backbone
191
+ backbone_config=None,
192
+ backbone=None,
193
+ use_pretrained_backbone=False,
194
+ use_timm_backbone=False,
195
+ freeze_backbone_batch_norms=True,
196
+ backbone_kwargs=None,
197
+ # encoder HybridEncoder
198
+ encoder_hidden_dim=256,
199
+ encoder_in_channels=[512, 1024, 2048],
200
+ feat_strides=[8, 16, 32],
201
+ encoder_layers=1,
202
+ encoder_ffn_dim=1024,
203
+ encoder_attention_heads=8,
204
+ dropout=0.0,
205
+ activation_dropout=0.0,
206
+ encode_proj_layers=[2],
207
+ positional_encoding_temperature=10000,
208
+ encoder_activation_function="gelu",
209
+ activation_function="silu",
210
+ eval_size=None,
211
+ normalize_before=False,
212
+ hidden_expansion=1.0,
213
+ # decoder RTDetrTransformer
214
+ d_model=256,
215
+ num_queries=300,
216
+ decoder_in_channels=[256, 256, 256],
217
+ decoder_ffn_dim=1024,
218
+ num_feature_levels=3,
219
+ decoder_n_points=4,
220
+ decoder_layers=6,
221
+ decoder_attention_heads=8,
222
+ decoder_activation_function="relu",
223
+ attention_dropout=0.0,
224
+ num_denoising=100,
225
+ label_noise_ratio=0.5,
226
+ box_noise_scale=1.0,
227
+ learn_initial_query=False,
228
+ anchor_image_size=None,
229
+ disable_custom_kernels=True,
230
+ with_box_refine=True,
231
+ is_encoder_decoder=True,
232
+ # Loss
233
+ matcher_alpha=0.25,
234
+ matcher_gamma=2.0,
235
+ matcher_class_cost=2.0,
236
+ matcher_bbox_cost=5.0,
237
+ matcher_giou_cost=2.0,
238
+ use_focal_loss=True,
239
+ auxiliary_loss=True,
240
+ focal_loss_alpha=0.75,
241
+ focal_loss_gamma=2.0,
242
+ weight_loss_vfl=1.0,
243
+ weight_loss_bbox=5.0,
244
+ weight_loss_giou=2.0,
245
+ eos_coefficient=1e-4,
246
+ **kwargs,
247
+ ):
248
+ self.initializer_range = initializer_range
249
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
250
+ self.layer_norm_eps = layer_norm_eps
251
+ self.batch_norm_eps = batch_norm_eps
252
+ # backbone
253
+ if backbone_config is None and backbone is None:
254
+ logger.info(
255
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone."
256
+ )
257
+ backbone_config = RTDetrResNetConfig(
258
+ num_channels=3,
259
+ embedding_size=64,
260
+ hidden_sizes=[256, 512, 1024, 2048],
261
+ depths=[3, 4, 6, 3],
262
+ layer_type="bottleneck",
263
+ hidden_act="relu",
264
+ downsample_in_first_stage=False,
265
+ downsample_in_bottleneck=False,
266
+ out_features=None,
267
+ out_indices=[2, 3, 4],
268
+ )
269
+ elif isinstance(backbone_config, dict):
270
+ backbone_model_type = backbone_config.pop("model_type")
271
+ config_class = CONFIG_MAPPING[backbone_model_type]
272
+ backbone_config = config_class.from_dict(backbone_config)
273
+
274
+ verify_backbone_config_arguments(
275
+ use_timm_backbone=use_timm_backbone,
276
+ use_pretrained_backbone=use_pretrained_backbone,
277
+ backbone=backbone,
278
+ backbone_config=backbone_config,
279
+ backbone_kwargs=backbone_kwargs,
280
+ )
281
+
282
+ self.backbone_config = backbone_config
283
+ self.backbone = backbone
284
+ self.use_pretrained_backbone = use_pretrained_backbone
285
+ self.use_timm_backbone = use_timm_backbone
286
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
287
+ self.backbone_kwargs = backbone_kwargs
288
+ # encoder
289
+ self.encoder_hidden_dim = encoder_hidden_dim
290
+ self.encoder_in_channels = encoder_in_channels
291
+ self.feat_strides = feat_strides
292
+ self.encoder_attention_heads = encoder_attention_heads
293
+ self.encoder_ffn_dim = encoder_ffn_dim
294
+ self.dropout = dropout
295
+ self.activation_dropout = activation_dropout
296
+ self.encode_proj_layers = encode_proj_layers
297
+ self.encoder_layers = encoder_layers
298
+ self.positional_encoding_temperature = positional_encoding_temperature
299
+ self.eval_size = eval_size
300
+ self.normalize_before = normalize_before
301
+ self.encoder_activation_function = encoder_activation_function
302
+ self.activation_function = activation_function
303
+ self.hidden_expansion = hidden_expansion
304
+ # decoder
305
+ self.d_model = d_model
306
+ self.num_queries = num_queries
307
+ self.decoder_ffn_dim = decoder_ffn_dim
308
+ self.decoder_in_channels = decoder_in_channels
309
+ self.num_feature_levels = num_feature_levels
310
+ self.decoder_n_points = decoder_n_points
311
+ self.decoder_layers = decoder_layers
312
+ self.decoder_attention_heads = decoder_attention_heads
313
+ self.decoder_activation_function = decoder_activation_function
314
+ self.attention_dropout = attention_dropout
315
+ self.num_denoising = num_denoising
316
+ self.label_noise_ratio = label_noise_ratio
317
+ self.box_noise_scale = box_noise_scale
318
+ self.learn_initial_query = learn_initial_query
319
+ self.anchor_image_size = anchor_image_size
320
+ self.auxiliary_loss = auxiliary_loss
321
+ self.disable_custom_kernels = disable_custom_kernels
322
+ self.with_box_refine = with_box_refine
323
+ # Loss
324
+ self.matcher_alpha = matcher_alpha
325
+ self.matcher_gamma = matcher_gamma
326
+ self.matcher_class_cost = matcher_class_cost
327
+ self.matcher_bbox_cost = matcher_bbox_cost
328
+ self.matcher_giou_cost = matcher_giou_cost
329
+ self.use_focal_loss = use_focal_loss
330
+ self.focal_loss_alpha = focal_loss_alpha
331
+ self.focal_loss_gamma = focal_loss_gamma
332
+ self.weight_loss_vfl = weight_loss_vfl
333
+ self.weight_loss_bbox = weight_loss_bbox
334
+ self.weight_loss_giou = weight_loss_giou
335
+ self.eos_coefficient = eos_coefficient
336
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
337
+
338
+ @property
339
+ def num_attention_heads(self) -> int:
340
+ return self.encoder_attention_heads
341
+
342
+ @property
343
+ def hidden_size(self) -> int:
344
+ return self.d_model
345
+
346
+ @property
347
+ def sub_configs(self):
348
+ return (
349
+ {"backbone_config": type(self.backbone_config)}
350
+ if getattr(self, "backbone_config", None) is not None
351
+ else {}
352
+ )
353
+
354
+ @classmethod
355
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
356
+ """Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
357
+ configuration.
358
+
359
+ Args:
360
+ backbone_config ([`PretrainedConfig`]):
361
+ The backbone configuration.
362
+
363
+ Returns:
364
+ [`RTDetrConfig`]: An instance of a configuration object
365
+ """
366
+ return cls(
367
+ backbone_config=backbone_config,
368
+ **kwargs,
369
+ )
370
+
371
+
372
+ __all__ = ["RTDetrConfig"]
phivenv/Lib/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RT-DETR ResNet model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an
28
+ ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the ResNet
30
+ [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ num_channels (`int`, *optional*, defaults to 3):
37
+ The number of input channels.
38
+ embedding_size (`int`, *optional*, defaults to 64):
39
+ Dimensionality (hidden size) for the embedding layer.
40
+ hidden_sizes (`list[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
41
+ Dimensionality (hidden size) at each stage.
42
+ depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
43
+ Depth (number of layers) for each stage.
44
+ layer_type (`str`, *optional*, defaults to `"bottleneck"`):
45
+ The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
46
+ `"bottleneck"` (used for larger models like resnet-50 and above).
47
+ hidden_act (`str`, *optional*, defaults to `"relu"`):
48
+ The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
49
+ are supported.
50
+ downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
51
+ If `True`, the first stage will downsample the inputs using a `stride` of 2.
52
+ downsample_in_bottleneck (`bool`, *optional*, defaults to `False`):
53
+ If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2.
54
+ out_features (`list[str]`, *optional*):
55
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
56
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
57
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
58
+ same order as defined in the `stage_names` attribute.
59
+ out_indices (`list[int]`, *optional*):
60
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
61
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
62
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
63
+ same order as defined in the `stage_names` attribute.
64
+
65
+ Example:
66
+ ```python
67
+ >>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone
68
+
69
+ >>> # Initializing a ResNet resnet-50 style configuration
70
+ >>> configuration = RTDetrResNetConfig()
71
+
72
+ >>> # Initializing a model (with random weights) from the resnet-50 style configuration
73
+ >>> model = RTDetrResnetBackbone(configuration)
74
+
75
+ >>> # Accessing the model configuration
76
+ >>> configuration = model.config
77
+ ```
78
+ """
79
+
80
+ model_type = "rt_detr_resnet"
81
+ layer_types = ["basic", "bottleneck"]
82
+
83
+ def __init__(
84
+ self,
85
+ num_channels=3,
86
+ embedding_size=64,
87
+ hidden_sizes=[256, 512, 1024, 2048],
88
+ depths=[3, 4, 6, 3],
89
+ layer_type="bottleneck",
90
+ hidden_act="relu",
91
+ downsample_in_first_stage=False,
92
+ downsample_in_bottleneck=False,
93
+ out_features=None,
94
+ out_indices=None,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(**kwargs)
98
+ if layer_type not in self.layer_types:
99
+ raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
100
+ self.num_channels = num_channels
101
+ self.embedding_size = embedding_size
102
+ self.hidden_sizes = hidden_sizes
103
+ self.depths = depths
104
+ self.layer_type = layer_type
105
+ self.hidden_act = hidden_act
106
+ self.downsample_in_first_stage = downsample_in_first_stage
107
+ self.downsample_in_bottleneck = downsample_in_bottleneck
108
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
109
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
110
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
111
+ )
112
+
113
+
114
+ __all__ = ["RTDetrResNetConfig"]
phivenv/Lib/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for RT-DETR."""
16
+
17
+ import pathlib
18
+ from collections.abc import Iterable
19
+ from typing import Any, Callable, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from ...feature_extraction_utils import BatchFeature
24
+ from ...image_processing_utils import BaseImageProcessor, get_size_dict
25
+ from ...image_transforms import (
26
+ PaddingMode,
27
+ center_to_corners_format,
28
+ corners_to_center_format,
29
+ pad,
30
+ rescale,
31
+ resize,
32
+ to_channel_dimension_format,
33
+ )
34
+ from ...image_utils import (
35
+ IMAGENET_DEFAULT_MEAN,
36
+ IMAGENET_DEFAULT_STD,
37
+ AnnotationFormat,
38
+ AnnotationType,
39
+ ChannelDimension,
40
+ ImageInput,
41
+ PILImageResampling,
42
+ get_image_size,
43
+ infer_channel_dimension_format,
44
+ is_scaled_image,
45
+ make_list_of_images,
46
+ to_numpy_array,
47
+ valid_images,
48
+ validate_annotations,
49
+ validate_preprocess_arguments,
50
+ )
51
+ from ...utils import (
52
+ filter_out_non_signature_kwargs,
53
+ is_flax_available,
54
+ is_jax_tensor,
55
+ is_tf_available,
56
+ is_tf_tensor,
57
+ is_torch_available,
58
+ is_torch_tensor,
59
+ logging,
60
+ requires_backends,
61
+ )
62
+ from ...utils.generic import TensorType
63
+
64
+
65
+ if is_torch_available():
66
+ import torch
67
+
68
+
69
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
70
+
71
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
72
+
73
+
74
+ # Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
75
+ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
76
+ """
77
+ Computes the output image size given the input image size and the desired output size.
78
+
79
+ Args:
80
+ image_size (`tuple[int, int]`):
81
+ The input image size.
82
+ size (`int`):
83
+ The desired output size.
84
+ max_size (`int`, *optional*):
85
+ The maximum allowed output size.
86
+ """
87
+ height, width = image_size
88
+ raw_size = None
89
+ if max_size is not None:
90
+ min_original_size = float(min((height, width)))
91
+ max_original_size = float(max((height, width)))
92
+ if max_original_size / min_original_size * size > max_size:
93
+ raw_size = max_size * min_original_size / max_original_size
94
+ size = int(round(raw_size))
95
+
96
+ if (height <= width and height == size) or (width <= height and width == size):
97
+ oh, ow = height, width
98
+ elif width < height:
99
+ ow = size
100
+ if max_size is not None and raw_size is not None:
101
+ oh = int(raw_size * height / width)
102
+ else:
103
+ oh = int(size * height / width)
104
+ else:
105
+ oh = size
106
+ if max_size is not None and raw_size is not None:
107
+ ow = int(raw_size * width / height)
108
+ else:
109
+ ow = int(size * width / height)
110
+
111
+ return (oh, ow)
112
+
113
+
114
+ # Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
115
+ def get_resize_output_image_size(
116
+ input_image: np.ndarray,
117
+ size: Union[int, tuple[int, int], list[int]],
118
+ max_size: Optional[int] = None,
119
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
120
+ ) -> tuple[int, int]:
121
+ """
122
+ Computes the output image size given the input image size and the desired output size. If the desired output size
123
+ is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
124
+ image size is computed by keeping the aspect ratio of the input image size.
125
+
126
+ Args:
127
+ input_image (`np.ndarray`):
128
+ The image to resize.
129
+ size (`int` or `tuple[int, int]` or `list[int]`):
130
+ The desired output size.
131
+ max_size (`int`, *optional*):
132
+ The maximum allowed output size.
133
+ input_data_format (`ChannelDimension` or `str`, *optional*):
134
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
135
+ """
136
+ image_size = get_image_size(input_image, input_data_format)
137
+ if isinstance(size, (list, tuple)):
138
+ return size
139
+
140
+ return get_size_with_aspect_ratio(image_size, size, max_size)
141
+
142
+
143
+ # Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width
144
+ def get_image_size_for_max_height_width(
145
+ input_image: np.ndarray,
146
+ max_height: int,
147
+ max_width: int,
148
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
149
+ ) -> tuple[int, int]:
150
+ """
151
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
152
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
153
+ to at least one of the edges be equal to max_height or max_width.
154
+ For example:
155
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
156
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
157
+ Args:
158
+ input_image (`np.ndarray`):
159
+ The image to resize.
160
+ max_height (`int`):
161
+ The maximum allowed height.
162
+ max_width (`int`):
163
+ The maximum allowed width.
164
+ input_data_format (`ChannelDimension` or `str`, *optional*):
165
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
166
+ """
167
+ image_size = get_image_size(input_image, input_data_format)
168
+ height, width = image_size
169
+ height_scale = max_height / height
170
+ width_scale = max_width / width
171
+ min_scale = min(height_scale, width_scale)
172
+ new_height = int(height * min_scale)
173
+ new_width = int(width * min_scale)
174
+ return new_height, new_width
175
+
176
+
177
+ # Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn
178
+ def get_numpy_to_framework_fn(arr) -> Callable:
179
+ """
180
+ Returns a function that converts a numpy array to the framework of the input array.
181
+
182
+ Args:
183
+ arr (`np.ndarray`): The array to convert.
184
+ """
185
+ if isinstance(arr, np.ndarray):
186
+ return np.array
187
+ if is_tf_available() and is_tf_tensor(arr):
188
+ import tensorflow as tf
189
+
190
+ return tf.convert_to_tensor
191
+ if is_torch_available() and is_torch_tensor(arr):
192
+ import torch
193
+
194
+ return torch.tensor
195
+ if is_flax_available() and is_jax_tensor(arr):
196
+ import jax.numpy as jnp
197
+
198
+ return jnp.array
199
+ raise ValueError(f"Cannot convert arrays of type {type(arr)}")
200
+
201
+
202
+ # Copied from transformers.models.detr.image_processing_detr.safe_squeeze
203
+ def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
204
+ """
205
+ Squeezes an array, but only if the axis specified has dim 1.
206
+ """
207
+ if axis is None:
208
+ return arr.squeeze()
209
+
210
+ try:
211
+ return arr.squeeze(axis=axis)
212
+ except ValueError:
213
+ return arr
214
+
215
+
216
+ # Copied from transformers.models.detr.image_processing_detr.normalize_annotation
217
+ def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict:
218
+ image_height, image_width = image_size
219
+ norm_annotation = {}
220
+ for key, value in annotation.items():
221
+ if key == "boxes":
222
+ boxes = value
223
+ boxes = corners_to_center_format(boxes)
224
+ boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
225
+ norm_annotation[key] = boxes
226
+ else:
227
+ norm_annotation[key] = value
228
+ return norm_annotation
229
+
230
+
231
+ # Copied from transformers.models.detr.image_processing_detr.max_across_indices
232
+ def max_across_indices(values: Iterable[Any]) -> list[Any]:
233
+ """
234
+ Return the maximum value across all indices of an iterable of values.
235
+ """
236
+ return [max(values_i) for values_i in zip(*values)]
237
+
238
+
239
+ # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
240
+ def get_max_height_width(
241
+ images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
242
+ ) -> list[int]:
243
+ """
244
+ Get the maximum height and width across all images in a batch.
245
+ """
246
+ if input_data_format is None:
247
+ input_data_format = infer_channel_dimension_format(images[0])
248
+
249
+ if input_data_format == ChannelDimension.FIRST:
250
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
251
+ elif input_data_format == ChannelDimension.LAST:
252
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
253
+ else:
254
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
255
+ return (max_height, max_width)
256
+
257
+
258
+ # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
259
+ def make_pixel_mask(
260
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
261
+ ) -> np.ndarray:
262
+ """
263
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
264
+
265
+ Args:
266
+ image (`np.ndarray`):
267
+ Image to make the pixel mask for.
268
+ output_size (`tuple[int, int]`):
269
+ Output size of the mask.
270
+ """
271
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
272
+ mask = np.zeros(output_size, dtype=np.int64)
273
+ mask[:input_height, :input_width] = 1
274
+ return mask
275
+
276
+
277
+ def prepare_coco_detection_annotation(
278
+ image,
279
+ target,
280
+ return_segmentation_masks: bool = False,
281
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
282
+ ):
283
+ """
284
+ Convert the target in COCO format into the format expected by RTDETR.
285
+ """
286
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
287
+
288
+ image_id = target["image_id"]
289
+ image_id = np.asarray([image_id], dtype=np.int64)
290
+
291
+ # Get all COCO annotations for the given image.
292
+ annotations = target["annotations"]
293
+ annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
294
+
295
+ classes = [obj["category_id"] for obj in annotations]
296
+ classes = np.asarray(classes, dtype=np.int64)
297
+
298
+ # for conversion to coco api
299
+ area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
300
+ iscrowd = np.asarray([obj.get("iscrowd", 0) for obj in annotations], dtype=np.int64)
301
+
302
+ boxes = [obj["bbox"] for obj in annotations]
303
+ # guard against no boxes via resizing
304
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
305
+ boxes[:, 2:] += boxes[:, :2]
306
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
307
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
308
+
309
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
310
+
311
+ new_target = {}
312
+ new_target["image_id"] = image_id
313
+ new_target["class_labels"] = classes[keep]
314
+ new_target["boxes"] = boxes[keep]
315
+ new_target["area"] = area[keep]
316
+ new_target["iscrowd"] = iscrowd[keep]
317
+ new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
318
+
319
+ if annotations and "keypoints" in annotations[0]:
320
+ keypoints = [obj["keypoints"] for obj in annotations]
321
+ # Converting the filtered keypoints list to a numpy array
322
+ keypoints = np.asarray(keypoints, dtype=np.float32)
323
+ # Apply the keep mask here to filter the relevant annotations
324
+ keypoints = keypoints[keep]
325
+ num_keypoints = keypoints.shape[0]
326
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
327
+ new_target["keypoints"] = keypoints
328
+
329
+ return new_target
330
+
331
+
332
+ # Copied from transformers.models.detr.image_processing_detr.resize_annotation
333
+ def resize_annotation(
334
+ annotation: dict[str, Any],
335
+ orig_size: tuple[int, int],
336
+ target_size: tuple[int, int],
337
+ threshold: float = 0.5,
338
+ resample: PILImageResampling = PILImageResampling.NEAREST,
339
+ ):
340
+ """
341
+ Resizes an annotation to a target size.
342
+
343
+ Args:
344
+ annotation (`dict[str, Any]`):
345
+ The annotation dictionary.
346
+ orig_size (`tuple[int, int]`):
347
+ The original size of the input image.
348
+ target_size (`tuple[int, int]`):
349
+ The target size of the image, as returned by the preprocessing `resize` step.
350
+ threshold (`float`, *optional*, defaults to 0.5):
351
+ The threshold used to binarize the segmentation masks.
352
+ resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
353
+ The resampling filter to use when resizing the masks.
354
+ """
355
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
356
+ ratio_height, ratio_width = ratios
357
+
358
+ new_annotation = {}
359
+ new_annotation["size"] = target_size
360
+
361
+ for key, value in annotation.items():
362
+ if key == "boxes":
363
+ boxes = value
364
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
365
+ new_annotation["boxes"] = scaled_boxes
366
+ elif key == "area":
367
+ area = value
368
+ scaled_area = area * (ratio_width * ratio_height)
369
+ new_annotation["area"] = scaled_area
370
+ elif key == "masks":
371
+ masks = value[:, None]
372
+ masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
373
+ masks = masks.astype(np.float32)
374
+ masks = masks[:, 0] > threshold
375
+ new_annotation["masks"] = masks
376
+ elif key == "size":
377
+ new_annotation["size"] = target_size
378
+ else:
379
+ new_annotation[key] = value
380
+
381
+ return new_annotation
382
+
383
+
384
+ class RTDetrImageProcessor(BaseImageProcessor):
385
+ r"""
386
+ Constructs a RT-DETR image processor.
387
+
388
+ Args:
389
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
390
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
391
+ do_resize (`bool`, *optional*, defaults to `True`):
392
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be
393
+ overridden by the `do_resize` parameter in the `preprocess` method.
394
+ size (`dict[str, int]` *optional*, defaults to `{"height": 640, "width": 640}`):
395
+ Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
396
+ in the `preprocess` method. Available options are:
397
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
398
+ Do NOT keep the aspect ratio.
399
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
400
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
401
+ less or equal to `longest_edge`.
402
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
403
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
404
+ `max_width`.
405
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
406
+ Resampling filter to use if resizing the image.
407
+ do_rescale (`bool`, *optional*, defaults to `True`):
408
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
409
+ `do_rescale` parameter in the `preprocess` method.
410
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
411
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
412
+ `preprocess` method.
413
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
414
+ `preprocess` method.
415
+ do_normalize (`bool`, *optional*, defaults to `False`):
416
+ Whether to normalize the image.
417
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
418
+ Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
419
+ channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
420
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
421
+ Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
422
+ for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
423
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
424
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
425
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
426
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
427
+ do_pad (`bool`, *optional*, defaults to `False`):
428
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
429
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
430
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
431
+ Otherwise, the image will be padded to the maximum height and width of the batch.
432
+ pad_size (`dict[str, int]`, *optional*):
433
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
434
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
435
+ height and width in the batch.
436
+ """
437
+
438
+ model_input_names = ["pixel_values", "pixel_mask"]
439
+
440
+ def __init__(
441
+ self,
442
+ format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
443
+ do_resize: bool = True,
444
+ size: Optional[dict[str, int]] = None,
445
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
446
+ do_rescale: bool = True,
447
+ rescale_factor: Union[int, float] = 1 / 255,
448
+ do_normalize: bool = False,
449
+ image_mean: Optional[Union[float, list[float]]] = None,
450
+ image_std: Optional[Union[float, list[float]]] = None,
451
+ do_convert_annotations: bool = True,
452
+ do_pad: bool = False,
453
+ pad_size: Optional[dict[str, int]] = None,
454
+ **kwargs,
455
+ ) -> None:
456
+ size = size if size is not None else {"height": 640, "width": 640}
457
+ size = get_size_dict(size, default_to_square=False)
458
+
459
+ if do_convert_annotations is None:
460
+ do_convert_annotations = do_normalize
461
+
462
+ super().__init__(**kwargs)
463
+ self.format = format
464
+ self.do_resize = do_resize
465
+ self.size = size
466
+ self.resample = resample
467
+ self.do_rescale = do_rescale
468
+ self.rescale_factor = rescale_factor
469
+ self.do_normalize = do_normalize
470
+ self.do_convert_annotations = do_convert_annotations
471
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
472
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
473
+ self.do_pad = do_pad
474
+ self.pad_size = pad_size
475
+
476
+ def prepare_annotation(
477
+ self,
478
+ image: np.ndarray,
479
+ target: dict,
480
+ format: Optional[AnnotationFormat] = None,
481
+ return_segmentation_masks: Optional[bool] = None,
482
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
483
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
484
+ ) -> dict:
485
+ """
486
+ Prepare an annotation for feeding into RTDETR model.
487
+ """
488
+ format = format if format is not None else self.format
489
+
490
+ if format == AnnotationFormat.COCO_DETECTION:
491
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
492
+ target = prepare_coco_detection_annotation(
493
+ image, target, return_segmentation_masks, input_data_format=input_data_format
494
+ )
495
+ else:
496
+ raise ValueError(f"Format {format} is not supported.")
497
+ return target
498
+
499
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize
500
+ def resize(
501
+ self,
502
+ image: np.ndarray,
503
+ size: dict[str, int],
504
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
505
+ data_format: Optional[ChannelDimension] = None,
506
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
507
+ **kwargs,
508
+ ) -> np.ndarray:
509
+ """
510
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
511
+ int, smaller edge of the image will be matched to this number.
512
+
513
+ Args:
514
+ image (`np.ndarray`):
515
+ Image to resize.
516
+ size (`dict[str, int]`):
517
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
518
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
519
+ Do NOT keep the aspect ratio.
520
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
521
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
522
+ less or equal to `longest_edge`.
523
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
524
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
525
+ `max_width`.
526
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
527
+ Resampling filter to use if resizing the image.
528
+ data_format (`str` or `ChannelDimension`, *optional*):
529
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
530
+ image is used.
531
+ input_data_format (`ChannelDimension` or `str`, *optional*):
532
+ The channel dimension format of the input image. If not provided, it will be inferred.
533
+ """
534
+ if "max_size" in kwargs:
535
+ logger.warning_once(
536
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
537
+ "Please specify in `size['longest_edge'] instead`.",
538
+ )
539
+ max_size = kwargs.pop("max_size")
540
+ else:
541
+ max_size = None
542
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
543
+ if "shortest_edge" in size and "longest_edge" in size:
544
+ new_size = get_resize_output_image_size(
545
+ image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
546
+ )
547
+ elif "max_height" in size and "max_width" in size:
548
+ new_size = get_image_size_for_max_height_width(
549
+ image, size["max_height"], size["max_width"], input_data_format=input_data_format
550
+ )
551
+ elif "height" in size and "width" in size:
552
+ new_size = (size["height"], size["width"])
553
+ else:
554
+ raise ValueError(
555
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
556
+ f" {size.keys()}."
557
+ )
558
+ image = resize(
559
+ image,
560
+ size=new_size,
561
+ resample=resample,
562
+ data_format=data_format,
563
+ input_data_format=input_data_format,
564
+ **kwargs,
565
+ )
566
+ return image
567
+
568
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
569
+ def resize_annotation(
570
+ self,
571
+ annotation,
572
+ orig_size,
573
+ size,
574
+ resample: PILImageResampling = PILImageResampling.NEAREST,
575
+ ) -> dict:
576
+ """
577
+ Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
578
+ to this number.
579
+ """
580
+ return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
581
+
582
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
583
+ def rescale(
584
+ self,
585
+ image: np.ndarray,
586
+ rescale_factor: float,
587
+ data_format: Optional[Union[str, ChannelDimension]] = None,
588
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
589
+ ) -> np.ndarray:
590
+ """
591
+ Rescale the image by the given factor. image = image * rescale_factor.
592
+
593
+ Args:
594
+ image (`np.ndarray`):
595
+ Image to rescale.
596
+ rescale_factor (`float`):
597
+ The value to use for rescaling.
598
+ data_format (`str` or `ChannelDimension`, *optional*):
599
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
600
+ image is used. Can be one of:
601
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
602
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
603
+ input_data_format (`str` or `ChannelDimension`, *optional*):
604
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
605
+ one of:
606
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
607
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
608
+ """
609
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
610
+
611
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
612
+ def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
613
+ """
614
+ Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
615
+ `[center_x, center_y, width, height]` format and from absolute to relative pixel values.
616
+ """
617
+ return normalize_annotation(annotation, image_size=image_size)
618
+
619
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image
620
+ def _update_annotation_for_padded_image(
621
+ self,
622
+ annotation: dict,
623
+ input_image_size: tuple[int, int],
624
+ output_image_size: tuple[int, int],
625
+ padding,
626
+ update_bboxes,
627
+ ) -> dict:
628
+ """
629
+ Update the annotation for a padded image.
630
+ """
631
+ new_annotation = {}
632
+ new_annotation["size"] = output_image_size
633
+
634
+ for key, value in annotation.items():
635
+ if key == "masks":
636
+ masks = value
637
+ masks = pad(
638
+ masks,
639
+ padding,
640
+ mode=PaddingMode.CONSTANT,
641
+ constant_values=0,
642
+ input_data_format=ChannelDimension.FIRST,
643
+ )
644
+ masks = safe_squeeze(masks, 1)
645
+ new_annotation["masks"] = masks
646
+ elif key == "boxes" and update_bboxes:
647
+ boxes = value
648
+ boxes *= np.asarray(
649
+ [
650
+ input_image_size[1] / output_image_size[1],
651
+ input_image_size[0] / output_image_size[0],
652
+ input_image_size[1] / output_image_size[1],
653
+ input_image_size[0] / output_image_size[0],
654
+ ]
655
+ )
656
+ new_annotation["boxes"] = boxes
657
+ elif key == "size":
658
+ new_annotation["size"] = output_image_size
659
+ else:
660
+ new_annotation[key] = value
661
+ return new_annotation
662
+
663
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
664
+ def _pad_image(
665
+ self,
666
+ image: np.ndarray,
667
+ output_size: tuple[int, int],
668
+ annotation: Optional[dict[str, Any]] = None,
669
+ constant_values: Union[float, Iterable[float]] = 0,
670
+ data_format: Optional[ChannelDimension] = None,
671
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
672
+ update_bboxes: bool = True,
673
+ ) -> np.ndarray:
674
+ """
675
+ Pad an image with zeros to the given size.
676
+ """
677
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
678
+ output_height, output_width = output_size
679
+
680
+ pad_bottom = output_height - input_height
681
+ pad_right = output_width - input_width
682
+ padding = ((0, pad_bottom), (0, pad_right))
683
+ padded_image = pad(
684
+ image,
685
+ padding,
686
+ mode=PaddingMode.CONSTANT,
687
+ constant_values=constant_values,
688
+ data_format=data_format,
689
+ input_data_format=input_data_format,
690
+ )
691
+ if annotation is not None:
692
+ annotation = self._update_annotation_for_padded_image(
693
+ annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
694
+ )
695
+ return padded_image, annotation
696
+
697
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
698
+ def pad(
699
+ self,
700
+ images: list[np.ndarray],
701
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
702
+ constant_values: Union[float, Iterable[float]] = 0,
703
+ return_pixel_mask: bool = True,
704
+ return_tensors: Optional[Union[str, TensorType]] = None,
705
+ data_format: Optional[ChannelDimension] = None,
706
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
707
+ update_bboxes: bool = True,
708
+ pad_size: Optional[dict[str, int]] = None,
709
+ ) -> BatchFeature:
710
+ """
711
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
712
+ in the batch and optionally returns their corresponding pixel mask.
713
+
714
+ Args:
715
+ images (list[`np.ndarray`]):
716
+ Images to pad.
717
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
718
+ Annotations to transform according to the padding that is applied to the images.
719
+ constant_values (`float` or `Iterable[float]`, *optional*):
720
+ The value to use for the padding if `mode` is `"constant"`.
721
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
722
+ Whether to return a pixel mask.
723
+ return_tensors (`str` or `TensorType`, *optional*):
724
+ The type of tensors to return. Can be one of:
725
+ - Unset: Return a list of `np.ndarray`.
726
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
727
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
728
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
729
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
730
+ data_format (`str` or `ChannelDimension`, *optional*):
731
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
732
+ input_data_format (`ChannelDimension` or `str`, *optional*):
733
+ The channel dimension format of the input image. If not provided, it will be inferred.
734
+ update_bboxes (`bool`, *optional*, defaults to `True`):
735
+ Whether to update the bounding boxes in the annotations to match the padded images. If the
736
+ bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
737
+ format, the bounding boxes will not be updated.
738
+ pad_size (`dict[str, int]`, *optional*):
739
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
740
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
741
+ height and width in the batch.
742
+ """
743
+ pad_size = pad_size if pad_size is not None else self.pad_size
744
+ if pad_size is not None:
745
+ padded_size = (pad_size["height"], pad_size["width"])
746
+ else:
747
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
748
+
749
+ annotation_list = annotations if annotations is not None else [None] * len(images)
750
+ padded_images = []
751
+ padded_annotations = []
752
+ for image, annotation in zip(images, annotation_list):
753
+ padded_image, padded_annotation = self._pad_image(
754
+ image,
755
+ padded_size,
756
+ annotation,
757
+ constant_values=constant_values,
758
+ data_format=data_format,
759
+ input_data_format=input_data_format,
760
+ update_bboxes=update_bboxes,
761
+ )
762
+ padded_images.append(padded_image)
763
+ padded_annotations.append(padded_annotation)
764
+
765
+ data = {"pixel_values": padded_images}
766
+
767
+ if return_pixel_mask:
768
+ masks = [
769
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
770
+ for image in images
771
+ ]
772
+ data["pixel_mask"] = masks
773
+
774
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
775
+
776
+ if annotations is not None:
777
+ encoded_inputs["labels"] = [
778
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
779
+ ]
780
+
781
+ return encoded_inputs
782
+
783
+ @filter_out_non_signature_kwargs()
784
+ def preprocess(
785
+ self,
786
+ images: ImageInput,
787
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
788
+ return_segmentation_masks: Optional[bool] = None,
789
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
790
+ do_resize: Optional[bool] = None,
791
+ size: Optional[dict[str, int]] = None,
792
+ resample=None, # PILImageResampling
793
+ do_rescale: Optional[bool] = None,
794
+ rescale_factor: Optional[Union[int, float]] = None,
795
+ do_normalize: Optional[bool] = None,
796
+ do_convert_annotations: Optional[bool] = None,
797
+ image_mean: Optional[Union[float, list[float]]] = None,
798
+ image_std: Optional[Union[float, list[float]]] = None,
799
+ do_pad: Optional[bool] = None,
800
+ format: Optional[Union[str, AnnotationFormat]] = None,
801
+ return_tensors: Optional[Union[TensorType, str]] = None,
802
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
803
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
804
+ pad_size: Optional[dict[str, int]] = None,
805
+ ) -> BatchFeature:
806
+ """
807
+ Preprocess an image or a batch of images so that it can be used by the model.
808
+
809
+ Args:
810
+ images (`ImageInput`):
811
+ Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
812
+ from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
813
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
814
+ List of annotations associated with the image or batch of images. If annotation is for object
815
+ detection, the annotations should be a dictionary with the following keys:
816
+ - "image_id" (`int`): The image id.
817
+ - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
818
+ dictionary. An image can have no annotations, in which case the list should be empty.
819
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
820
+ - "image_id" (`int`): The image id.
821
+ - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
822
+ An image can have no segments, in which case the list should be empty.
823
+ - "file_name" (`str`): The file name of the image.
824
+ return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
825
+ Whether to return segmentation masks.
826
+ masks_path (`str` or `pathlib.Path`, *optional*):
827
+ Path to the directory containing the segmentation masks.
828
+ do_resize (`bool`, *optional*, defaults to self.do_resize):
829
+ Whether to resize the image.
830
+ size (`dict[str, int]`, *optional*, defaults to self.size):
831
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
832
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
833
+ Do NOT keep the aspect ratio.
834
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
835
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
836
+ less or equal to `longest_edge`.
837
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
838
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
839
+ `max_width`.
840
+ resample (`PILImageResampling`, *optional*, defaults to self.resample):
841
+ Resampling filter to use when resizing the image.
842
+ do_rescale (`bool`, *optional*, defaults to self.do_rescale):
843
+ Whether to rescale the image.
844
+ rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
845
+ Rescale factor to use when rescaling the image.
846
+ do_normalize (`bool`, *optional*, defaults to self.do_normalize):
847
+ Whether to normalize the image.
848
+ do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
849
+ Whether to convert the annotations to the format expected by the model. Converts the bounding
850
+ boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
851
+ and in relative coordinates.
852
+ image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean):
853
+ Mean to use when normalizing the image.
854
+ image_std (`float` or `list[float]`, *optional*, defaults to self.image_std):
855
+ Standard deviation to use when normalizing the image.
856
+ do_pad (`bool`, *optional*, defaults to self.do_pad):
857
+ Whether to pad the image. If `True`, padding will be applied to the bottom and right of
858
+ the image with zeros. If `pad_size` is provided, the image will be padded to the specified
859
+ dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
860
+ format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
861
+ Format of the annotations.
862
+ return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
863
+ Type of tensors to return. If `None`, will return the list of images.
864
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
865
+ The channel dimension format for the output image. Can be one of:
866
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
867
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
868
+ - Unset: Use the channel dimension format of the input image.
869
+ input_data_format (`ChannelDimension` or `str`, *optional*):
870
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
871
+ from the input image. Can be one of:
872
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
873
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
874
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
875
+ pad_size (`dict[str, int]`, *optional*):
876
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
877
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
878
+ height and width in the batch.
879
+ """
880
+ do_resize = self.do_resize if do_resize is None else do_resize
881
+ size = self.size if size is None else size
882
+ size = get_size_dict(size=size, default_to_square=True)
883
+ resample = self.resample if resample is None else resample
884
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
885
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
886
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
887
+ image_mean = self.image_mean if image_mean is None else image_mean
888
+ image_std = self.image_std if image_std is None else image_std
889
+ do_convert_annotations = (
890
+ self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
891
+ )
892
+ do_pad = self.do_pad if do_pad is None else do_pad
893
+ pad_size = self.pad_size if pad_size is None else pad_size
894
+ format = self.format if format is None else format
895
+
896
+ images = make_list_of_images(images)
897
+
898
+ if not valid_images(images):
899
+ raise ValueError(
900
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
901
+ "torch.Tensor, tf.Tensor or jax.ndarray."
902
+ )
903
+
904
+ # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
905
+
906
+ validate_preprocess_arguments(
907
+ do_rescale=do_rescale,
908
+ rescale_factor=rescale_factor,
909
+ do_normalize=do_normalize,
910
+ image_mean=image_mean,
911
+ image_std=image_std,
912
+ do_resize=do_resize,
913
+ size=size,
914
+ resample=resample,
915
+ )
916
+
917
+ if annotations is not None and isinstance(annotations, dict):
918
+ annotations = [annotations]
919
+
920
+ if annotations is not None and len(images) != len(annotations):
921
+ raise ValueError(
922
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
923
+ )
924
+
925
+ format = AnnotationFormat(format)
926
+ if annotations is not None:
927
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
928
+
929
+ images = make_list_of_images(images)
930
+ if not valid_images(images):
931
+ raise ValueError(
932
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
933
+ "torch.Tensor, tf.Tensor or jax.ndarray."
934
+ )
935
+
936
+ # All transformations expect numpy arrays
937
+ images = [to_numpy_array(image) for image in images]
938
+
939
+ if do_rescale and is_scaled_image(images[0]):
940
+ logger.warning_once(
941
+ "It looks like you are trying to rescale already rescaled images. If the input"
942
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
943
+ )
944
+
945
+ if input_data_format is None:
946
+ # We assume that all images have the same channel dimension format.
947
+ input_data_format = infer_channel_dimension_format(images[0])
948
+
949
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
950
+ if annotations is not None:
951
+ prepared_images = []
952
+ prepared_annotations = []
953
+ for image, target in zip(images, annotations):
954
+ target = self.prepare_annotation(
955
+ image,
956
+ target,
957
+ format,
958
+ return_segmentation_masks=return_segmentation_masks,
959
+ masks_path=masks_path,
960
+ input_data_format=input_data_format,
961
+ )
962
+ prepared_images.append(image)
963
+ prepared_annotations.append(target)
964
+ images = prepared_images
965
+ annotations = prepared_annotations
966
+ del prepared_images, prepared_annotations
967
+
968
+ # transformations
969
+ if do_resize:
970
+ if annotations is not None:
971
+ resized_images, resized_annotations = [], []
972
+ for image, target in zip(images, annotations):
973
+ orig_size = get_image_size(image, input_data_format)
974
+ resized_image = self.resize(
975
+ image, size=size, resample=resample, input_data_format=input_data_format
976
+ )
977
+ resized_annotation = self.resize_annotation(
978
+ target, orig_size, get_image_size(resized_image, input_data_format)
979
+ )
980
+ resized_images.append(resized_image)
981
+ resized_annotations.append(resized_annotation)
982
+ images = resized_images
983
+ annotations = resized_annotations
984
+ del resized_images, resized_annotations
985
+ else:
986
+ images = [
987
+ self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
988
+ for image in images
989
+ ]
990
+
991
+ if do_rescale:
992
+ images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
993
+
994
+ if do_normalize:
995
+ images = [
996
+ self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
997
+ ]
998
+
999
+ if do_convert_annotations and annotations is not None:
1000
+ annotations = [
1001
+ self.normalize_annotation(annotation, get_image_size(image, input_data_format))
1002
+ for annotation, image in zip(annotations, images)
1003
+ ]
1004
+
1005
+ if do_pad:
1006
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
1007
+ encoded_inputs = self.pad(
1008
+ images,
1009
+ annotations=annotations,
1010
+ return_pixel_mask=True,
1011
+ data_format=data_format,
1012
+ input_data_format=input_data_format,
1013
+ update_bboxes=do_convert_annotations,
1014
+ return_tensors=return_tensors,
1015
+ pad_size=pad_size,
1016
+ )
1017
+ else:
1018
+ images = [
1019
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
1020
+ for image in images
1021
+ ]
1022
+ encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
1023
+ if annotations is not None:
1024
+ encoded_inputs["labels"] = [
1025
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
1026
+ ]
1027
+
1028
+ return encoded_inputs
1029
+
1030
+ def post_process_object_detection(
1031
+ self,
1032
+ outputs,
1033
+ threshold: float = 0.5,
1034
+ target_sizes: Union[TensorType, list[tuple]] = None,
1035
+ use_focal_loss: bool = True,
1036
+ ):
1037
+ """
1038
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
1039
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1040
+
1041
+ Args:
1042
+ outputs ([`DetrObjectDetectionOutput`]):
1043
+ Raw outputs of the model.
1044
+ threshold (`float`, *optional*, defaults to 0.5):
1045
+ Score threshold to keep object detection predictions.
1046
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
1047
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
1048
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
1049
+ use_focal_loss (`bool` defaults to `True`):
1050
+ Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
1051
+ to compute the scores of each detection, otherwise, a softmax function is used.
1052
+
1053
+ Returns:
1054
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1055
+ in the batch as predicted by the model.
1056
+ """
1057
+ requires_backends(self, ["torch"])
1058
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1059
+ # convert from relative cxcywh to absolute xyxy
1060
+ boxes = center_to_corners_format(out_bbox)
1061
+ if target_sizes is not None:
1062
+ if len(out_logits) != len(target_sizes):
1063
+ raise ValueError(
1064
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1065
+ )
1066
+ if isinstance(target_sizes, list):
1067
+ img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
1068
+ else:
1069
+ img_h, img_w = target_sizes.unbind(1)
1070
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
1071
+ boxes = boxes * scale_fct[:, None, :]
1072
+
1073
+ num_top_queries = out_logits.shape[1]
1074
+ num_classes = out_logits.shape[2]
1075
+
1076
+ if use_focal_loss:
1077
+ scores = torch.nn.functional.sigmoid(out_logits)
1078
+ scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
1079
+ labels = index % num_classes
1080
+ index = index // num_classes
1081
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
1082
+ else:
1083
+ scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
1084
+ scores, labels = scores.max(dim=-1)
1085
+ if scores.shape[1] > num_top_queries:
1086
+ scores, index = torch.topk(scores, num_top_queries, dim=-1)
1087
+ labels = torch.gather(labels, dim=1, index=index)
1088
+ boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
1089
+
1090
+ results = []
1091
+ for score, label, box in zip(scores, labels, boxes):
1092
+ results.append(
1093
+ {
1094
+ "scores": score[score > threshold],
1095
+ "labels": label[score > threshold],
1096
+ "boxes": box[score > threshold],
1097
+ }
1098
+ )
1099
+
1100
+ return results
1101
+
1102
+
1103
+ __all__ = ["RTDetrImageProcessor"]
phivenv/Lib/site-packages/transformers/models/rt_detr/image_processing_rt_detr_fast.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/rt_detr/modular_rt_detr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_rt_detr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ import pathlib
8
+ from typing import Any, Optional, Union
9
+
10
+ from ...image_processing_utils import BatchFeature
11
+ from ...image_processing_utils_fast import (
12
+ BaseImageProcessorFast,
13
+ DefaultFastImageProcessorKwargs,
14
+ SizeDict,
15
+ get_image_size_for_max_height_width,
16
+ get_max_height_width,
17
+ safe_squeeze,
18
+ )
19
+ from ...image_transforms import center_to_corners_format, corners_to_center_format
20
+ from ...image_utils import (
21
+ IMAGENET_DEFAULT_MEAN,
22
+ IMAGENET_DEFAULT_STD,
23
+ AnnotationFormat,
24
+ AnnotationType,
25
+ ChannelDimension,
26
+ ImageInput,
27
+ PILImageResampling,
28
+ get_image_size,
29
+ validate_annotations,
30
+ )
31
+ from ...processing_utils import Unpack
32
+ from ...utils import (
33
+ TensorType,
34
+ auto_docstring,
35
+ is_torch_available,
36
+ is_torchvision_available,
37
+ is_torchvision_v2_available,
38
+ requires_backends,
39
+ )
40
+ from ...utils.import_utils import requires
41
+ from .image_processing_rt_detr import get_size_with_aspect_ratio
42
+
43
+
44
+ if is_torch_available():
45
+ import torch
46
+
47
+
48
+ if is_torchvision_v2_available():
49
+ from torchvision.transforms.v2 import functional as F
50
+ elif is_torchvision_available():
51
+ from torchvision.transforms import functional as F
52
+
53
+
54
+ class RTDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
55
+ r"""
56
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
57
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
58
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
59
+ Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the
60
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
61
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
62
+ do_pad (`bool`, *optional*, defaults to `True`):
63
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
64
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
65
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
66
+ Otherwise, the image will be padded to the maximum height and width of the batch.
67
+ pad_size (`dict[str, int]`, *optional*):
68
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
69
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
70
+ height and width in the batch.
71
+ return_segmentation_masks (`bool`, *optional*, defaults to `False`):
72
+ Whether to return segmentation masks.
73
+ """
74
+
75
+ format: Optional[Union[str, AnnotationFormat]]
76
+ do_convert_annotations: Optional[bool]
77
+ do_pad: Optional[bool]
78
+ pad_size: Optional[dict[str, int]]
79
+ return_segmentation_masks: Optional[bool]
80
+
81
+
82
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
83
+
84
+
85
+ def prepare_coco_detection_annotation(
86
+ image,
87
+ target,
88
+ return_segmentation_masks: bool = False,
89
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
90
+ ):
91
+ """
92
+ Convert the target in COCO format into the format expected by RT-DETR.
93
+ """
94
+ image_height, image_width = image.size()[-2:]
95
+
96
+ image_id = target["image_id"]
97
+ image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
98
+
99
+ # Get all COCO annotations for the given image.
100
+ annotations = target["annotations"]
101
+ classes = []
102
+ area = []
103
+ boxes = []
104
+ keypoints = []
105
+ for obj in annotations:
106
+ if "iscrowd" not in obj or obj["iscrowd"] == 0:
107
+ classes.append(obj["category_id"])
108
+ area.append(obj["area"])
109
+ boxes.append(obj["bbox"])
110
+ if "keypoints" in obj:
111
+ keypoints.append(obj["keypoints"])
112
+
113
+ classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
114
+ area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
115
+ iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
116
+ # guard against no boxes via resizing
117
+ boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
118
+ boxes[:, 2:] += boxes[:, :2]
119
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
120
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
121
+
122
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
123
+
124
+ new_target = {
125
+ "image_id": image_id,
126
+ "class_labels": classes[keep],
127
+ "boxes": boxes[keep],
128
+ "area": area[keep],
129
+ "iscrowd": iscrowd[keep],
130
+ "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
131
+ }
132
+
133
+ if keypoints:
134
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
135
+ # Apply the keep mask here to filter the relevant annotations
136
+ keypoints = keypoints[keep]
137
+ num_keypoints = keypoints.shape[0]
138
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
139
+ new_target["keypoints"] = keypoints
140
+
141
+ return new_target
142
+
143
+
144
+ @auto_docstring
145
+ @requires(backends=("torchvision", "torch"))
146
+ class RTDetrImageProcessorFast(BaseImageProcessorFast):
147
+ resample = PILImageResampling.BILINEAR
148
+ image_mean = IMAGENET_DEFAULT_MEAN
149
+ image_std = IMAGENET_DEFAULT_STD
150
+ format = AnnotationFormat.COCO_DETECTION
151
+ do_resize = True
152
+ do_rescale = True
153
+ do_normalize = False
154
+ do_pad = False
155
+ size = {"height": 640, "width": 640}
156
+ default_to_square = False
157
+ model_input_names = ["pixel_values", "pixel_mask"]
158
+ valid_kwargs = RTDetrFastImageProcessorKwargs
159
+ do_convert_annotations = True
160
+
161
+ def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None:
162
+ # Backwards compatibility
163
+ do_convert_annotations = kwargs.get("do_convert_annotations")
164
+ do_normalize = kwargs.get("do_normalize")
165
+ if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
166
+ self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
167
+
168
+ super().__init__(**kwargs)
169
+
170
+ def prepare_annotation(
171
+ self,
172
+ image: torch.Tensor,
173
+ target: dict,
174
+ format: Optional[AnnotationFormat] = None,
175
+ return_segmentation_masks: Optional[bool] = None,
176
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
177
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
178
+ ) -> dict:
179
+ """
180
+ Prepare an annotation for feeding into RT_DETR model.
181
+ """
182
+ format = format if format is not None else self.format
183
+
184
+ if format == AnnotationFormat.COCO_DETECTION:
185
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
186
+ target = prepare_coco_detection_annotation(
187
+ image, target, return_segmentation_masks, input_data_format=input_data_format
188
+ )
189
+ else:
190
+ raise ValueError(f"Format {format} is not supported.")
191
+ return target
192
+
193
+ def resize(
194
+ self,
195
+ image: torch.Tensor,
196
+ size: SizeDict,
197
+ interpolation: "F.InterpolationMode" = None,
198
+ **kwargs,
199
+ ) -> torch.Tensor:
200
+ """
201
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
202
+ int, smaller edge of the image will be matched to this number.
203
+
204
+ Args:
205
+ image (`torch.Tensor`):
206
+ Image to resize.
207
+ size (`SizeDict`):
208
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
209
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
210
+ Do NOT keep the aspect ratio.
211
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
212
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
213
+ less or equal to `longest_edge`.
214
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
215
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
216
+ `max_width`.
217
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
218
+ Resampling filter to use if resizing the image.
219
+ """
220
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
221
+ if size.shortest_edge and size.longest_edge:
222
+ # Resize the image so that the shortest edge or the longest edge is of the given size
223
+ # while maintaining the aspect ratio of the original image.
224
+ new_size = get_size_with_aspect_ratio(
225
+ image.size()[-2:],
226
+ size["shortest_edge"],
227
+ size["longest_edge"],
228
+ )
229
+ elif size.max_height and size.max_width:
230
+ new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
231
+ elif size.height and size.width:
232
+ new_size = (size["height"], size["width"])
233
+ else:
234
+ raise ValueError(
235
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
236
+ f" {size.keys()}."
237
+ )
238
+
239
+ image = F.resize(
240
+ image,
241
+ size=new_size,
242
+ interpolation=interpolation,
243
+ **kwargs,
244
+ )
245
+ return image
246
+
247
+ def resize_annotation(
248
+ self,
249
+ annotation: dict[str, Any],
250
+ orig_size: tuple[int, int],
251
+ target_size: tuple[int, int],
252
+ threshold: float = 0.5,
253
+ interpolation: "F.InterpolationMode" = None,
254
+ ):
255
+ """
256
+ Resizes an annotation to a target size.
257
+
258
+ Args:
259
+ annotation (`dict[str, Any]`):
260
+ The annotation dictionary.
261
+ orig_size (`tuple[int, int]`):
262
+ The original size of the input image.
263
+ target_size (`tuple[int, int]`):
264
+ The target size of the image, as returned by the preprocessing `resize` step.
265
+ threshold (`float`, *optional*, defaults to 0.5):
266
+ The threshold used to binarize the segmentation masks.
267
+ resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`):
268
+ The resampling filter to use when resizing the masks.
269
+ """
270
+ interpolation = (
271
+ interpolation
272
+ if interpolation is not None
273
+ else F.InterpolationMode.NEAREST_EXACT
274
+ if is_torchvision_v2_available()
275
+ else F.InterpolationMode.NEAREST
276
+ )
277
+ ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
278
+
279
+ new_annotation = {}
280
+ new_annotation["size"] = target_size
281
+
282
+ for key, value in annotation.items():
283
+ if key == "boxes":
284
+ boxes = value
285
+ scaled_boxes = boxes * torch.as_tensor(
286
+ [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
287
+ )
288
+ new_annotation["boxes"] = scaled_boxes
289
+ elif key == "area":
290
+ area = value
291
+ scaled_area = area * (ratio_width * ratio_height)
292
+ new_annotation["area"] = scaled_area
293
+ elif key == "masks":
294
+ masks = value[:, None]
295
+ masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
296
+ masks = torch.stack(masks).to(torch.float32)
297
+ masks = masks[:, 0] > threshold
298
+ new_annotation["masks"] = masks
299
+ elif key == "size":
300
+ new_annotation["size"] = target_size
301
+ else:
302
+ new_annotation[key] = value
303
+
304
+ return new_annotation
305
+
306
+ def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
307
+ image_height, image_width = image_size
308
+ norm_annotation = {}
309
+ for key, value in annotation.items():
310
+ if key == "boxes":
311
+ boxes = value
312
+ boxes = corners_to_center_format(boxes)
313
+ boxes /= torch.as_tensor(
314
+ [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
315
+ )
316
+ norm_annotation[key] = boxes
317
+ else:
318
+ norm_annotation[key] = value
319
+ return norm_annotation
320
+
321
+ def _update_annotation_for_padded_image(
322
+ self,
323
+ annotation: dict,
324
+ input_image_size: tuple[int, int],
325
+ output_image_size: tuple[int, int],
326
+ padding,
327
+ update_bboxes,
328
+ ) -> dict:
329
+ """
330
+ Update the annotation for a padded image.
331
+ """
332
+ new_annotation = {}
333
+ new_annotation["size"] = output_image_size
334
+ ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
335
+
336
+ for key, value in annotation.items():
337
+ if key == "masks":
338
+ masks = value
339
+ masks = F.pad(
340
+ masks,
341
+ padding,
342
+ fill=0,
343
+ )
344
+ masks = safe_squeeze(masks, 1)
345
+ new_annotation["masks"] = masks
346
+ elif key == "boxes" and update_bboxes:
347
+ boxes = value
348
+ boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
349
+ new_annotation["boxes"] = boxes
350
+ elif key == "size":
351
+ new_annotation["size"] = output_image_size
352
+ else:
353
+ new_annotation[key] = value
354
+ return new_annotation
355
+
356
+ def pad(
357
+ self,
358
+ image: torch.Tensor,
359
+ padded_size: tuple[int, int],
360
+ annotation: Optional[dict[str, Any]] = None,
361
+ update_bboxes: bool = True,
362
+ fill: int = 0,
363
+ ):
364
+ original_size = image.size()[-2:]
365
+ padding_bottom = padded_size[0] - original_size[0]
366
+ padding_right = padded_size[1] - original_size[1]
367
+ if padding_bottom < 0 or padding_right < 0:
368
+ raise ValueError(
369
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
370
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
371
+ )
372
+ if original_size != padded_size:
373
+ padding = [0, 0, padding_right, padding_bottom]
374
+ image = F.pad(image, padding, fill=fill)
375
+ if annotation is not None:
376
+ annotation = self._update_annotation_for_padded_image(
377
+ annotation, original_size, padded_size, padding, update_bboxes
378
+ )
379
+
380
+ # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
381
+ pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
382
+ pixel_mask[: original_size[0], : original_size[1]] = 1
383
+
384
+ return image, pixel_mask, annotation
385
+
386
+ @auto_docstring
387
+ def preprocess(
388
+ self,
389
+ images: ImageInput,
390
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
391
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
392
+ **kwargs: Unpack[RTDetrFastImageProcessorKwargs],
393
+ ) -> BatchFeature:
394
+ r"""
395
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
396
+ List of annotations associated with the image or batch of images. If annotation is for object
397
+ detection, the annotations should be a dictionary with the following keys:
398
+ - "image_id" (`int`): The image id.
399
+ - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
400
+ dictionary. An image can have no annotations, in which case the list should be empty.
401
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
402
+ - "image_id" (`int`): The image id.
403
+ - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
404
+ An image can have no segments, in which case the list should be empty.
405
+ - "file_name" (`str`): The file name of the image.
406
+ masks_path (`str` or `pathlib.Path`, *optional*):
407
+ Path to the directory containing the segmentation masks.
408
+ """
409
+ return super().preprocess(images, annotations, masks_path, **kwargs)
410
+
411
+ def _preprocess(
412
+ self,
413
+ images: list["torch.Tensor"],
414
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
415
+ masks_path: Optional[Union[str, pathlib.Path]],
416
+ return_segmentation_masks: bool,
417
+ do_resize: bool,
418
+ size: SizeDict,
419
+ interpolation: Optional["F.InterpolationMode"],
420
+ do_rescale: bool,
421
+ rescale_factor: float,
422
+ do_normalize: bool,
423
+ do_convert_annotations: bool,
424
+ image_mean: Optional[Union[float, list[float]]],
425
+ image_std: Optional[Union[float, list[float]]],
426
+ do_pad: bool,
427
+ pad_size: Optional[dict[str, int]],
428
+ format: Optional[Union[str, AnnotationFormat]],
429
+ return_tensors: Optional[Union[str, TensorType]],
430
+ **kwargs,
431
+ ) -> BatchFeature:
432
+ """
433
+ Preprocess an image or a batch of images so that it can be used by the model.
434
+ """
435
+
436
+ if annotations is not None and isinstance(annotations, dict):
437
+ annotations = [annotations]
438
+
439
+ if annotations is not None and len(images) != len(annotations):
440
+ raise ValueError(
441
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
442
+ )
443
+
444
+ format = AnnotationFormat(format)
445
+ if annotations is not None:
446
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
447
+
448
+ data = {}
449
+ processed_images = []
450
+ processed_annotations = []
451
+ pixel_masks = [] # Initialize pixel_masks here
452
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
453
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
454
+ if annotations is not None:
455
+ annotation = self.prepare_annotation(
456
+ image,
457
+ annotation,
458
+ format,
459
+ return_segmentation_masks=return_segmentation_masks,
460
+ masks_path=masks_path,
461
+ input_data_format=ChannelDimension.FIRST,
462
+ )
463
+
464
+ if do_resize:
465
+ resized_image = self.resize(image, size=size, interpolation=interpolation)
466
+ if annotations is not None:
467
+ annotation = self.resize_annotation(
468
+ annotation,
469
+ orig_size=image.size()[-2:],
470
+ target_size=resized_image.size()[-2:],
471
+ )
472
+ image = resized_image
473
+ # Fused rescale and normalize
474
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
475
+ if do_convert_annotations and annotations is not None:
476
+ annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
477
+
478
+ processed_images.append(image)
479
+ processed_annotations.append(annotation)
480
+ images = processed_images
481
+ annotations = processed_annotations if annotations is not None else None
482
+
483
+ if do_pad:
484
+ # depends on all resized image shapes so we need another loop
485
+ if pad_size is not None:
486
+ padded_size = (pad_size["height"], pad_size["width"])
487
+ else:
488
+ padded_size = get_max_height_width(images)
489
+
490
+ padded_images = []
491
+ padded_annotations = []
492
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
493
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
494
+ if padded_size == image.size()[-2:]:
495
+ padded_images.append(image)
496
+ pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
497
+ padded_annotations.append(annotation)
498
+ continue
499
+ image, pixel_mask, annotation = self.pad(
500
+ image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
501
+ )
502
+ padded_images.append(image)
503
+ padded_annotations.append(annotation)
504
+ pixel_masks.append(pixel_mask)
505
+ images = padded_images
506
+ annotations = padded_annotations if annotations is not None else None
507
+ data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
508
+
509
+ data.update({"pixel_values": torch.stack(images, dim=0)})
510
+ encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
511
+ if annotations is not None:
512
+ encoded_inputs["labels"] = [
513
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
514
+ ]
515
+ return encoded_inputs
516
+
517
+ def post_process_object_detection(
518
+ self,
519
+ outputs,
520
+ threshold: float = 0.5,
521
+ target_sizes: Union[TensorType, list[tuple]] = None,
522
+ use_focal_loss: bool = True,
523
+ ):
524
+ """
525
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
526
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
527
+
528
+ Args:
529
+ outputs ([`DetrObjectDetectionOutput`]):
530
+ Raw outputs of the model.
531
+ threshold (`float`, *optional*, defaults to 0.5):
532
+ Score threshold to keep object detection predictions.
533
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
534
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
535
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
536
+ use_focal_loss (`bool` defaults to `True`):
537
+ Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
538
+ to compute the scores of each detection, otherwise, a softmax function is used.
539
+
540
+ Returns:
541
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
542
+ in the batch as predicted by the model.
543
+ """
544
+ requires_backends(self, ["torch"])
545
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
546
+ # convert from relative cxcywh to absolute xyxy
547
+ boxes = center_to_corners_format(out_bbox)
548
+ if target_sizes is not None:
549
+ if len(out_logits) != len(target_sizes):
550
+ raise ValueError(
551
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
552
+ )
553
+ if isinstance(target_sizes, list):
554
+ img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
555
+ else:
556
+ img_h, img_w = target_sizes.unbind(1)
557
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
558
+ boxes = boxes * scale_fct[:, None, :]
559
+
560
+ num_top_queries = out_logits.shape[1]
561
+ num_classes = out_logits.shape[2]
562
+
563
+ if use_focal_loss:
564
+ scores = torch.nn.functional.sigmoid(out_logits)
565
+ scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
566
+ labels = index % num_classes
567
+ index = index // num_classes
568
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
569
+ else:
570
+ scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
571
+ scores, labels = scores.max(dim=-1)
572
+ if scores.shape[1] > num_top_queries:
573
+ scores, index = torch.topk(scores, num_top_queries, dim=-1)
574
+ labels = torch.gather(labels, dim=1, index=index)
575
+ boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
576
+
577
+ results = []
578
+ for score, label, box in zip(scores, labels, boxes):
579
+ results.append(
580
+ {
581
+ "scores": score[score > threshold],
582
+ "labels": label[score > threshold],
583
+ "boxes": box[score > threshold],
584
+ }
585
+ )
586
+
587
+ return results
588
+
589
+
590
+ __all__ = ["RTDetrImageProcessorFast"]
phivenv/Lib/site-packages/transformers/models/rt_detr/modeling_rt_detr.py ADDED
@@ -0,0 +1,2013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Baidu Inc and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch RT-DETR model."""
16
+
17
+ import math
18
+ import warnings
19
+ from dataclasses import dataclass
20
+ from functools import partial
21
+ from typing import Optional, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import Tensor, nn
26
+
27
+ from ...activations import ACT2CLS, ACT2FN
28
+ from ...image_transforms import center_to_corners_format, corners_to_center_format
29
+ from ...integrations import use_kernel_forward_from_hub
30
+ from ...modeling_outputs import BaseModelOutput
31
+ from ...modeling_utils import PreTrainedModel
32
+ from ...pytorch_utils import compile_compatible_method_lru_cache
33
+ from ...utils import (
34
+ ModelOutput,
35
+ auto_docstring,
36
+ logging,
37
+ torch_int,
38
+ )
39
+ from ...utils.backbone_utils import load_backbone
40
+ from .configuration_rt_detr import RTDetrConfig
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ # TODO: Replace all occurrences of the checkpoint with the final one
47
+
48
+
49
+ @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
50
+ # Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
51
+ class MultiScaleDeformableAttention(nn.Module):
52
+ def forward(
53
+ self,
54
+ value: Tensor,
55
+ value_spatial_shapes: Tensor,
56
+ value_spatial_shapes_list: list[tuple],
57
+ level_start_index: Tensor,
58
+ sampling_locations: Tensor,
59
+ attention_weights: Tensor,
60
+ im2col_step: int,
61
+ ):
62
+ batch_size, _, num_heads, hidden_dim = value.shape
63
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
64
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
65
+ sampling_grids = 2 * sampling_locations - 1
66
+ sampling_value_list = []
67
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
68
+ # batch_size, height*width, num_heads, hidden_dim
69
+ # -> batch_size, height*width, num_heads*hidden_dim
70
+ # -> batch_size, num_heads*hidden_dim, height*width
71
+ # -> batch_size*num_heads, hidden_dim, height, width
72
+ value_l_ = (
73
+ value_list[level_id]
74
+ .flatten(2)
75
+ .transpose(1, 2)
76
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
77
+ )
78
+ # batch_size, num_queries, num_heads, num_points, 2
79
+ # -> batch_size, num_heads, num_queries, num_points, 2
80
+ # -> batch_size*num_heads, num_queries, num_points, 2
81
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
82
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
83
+ sampling_value_l_ = nn.functional.grid_sample(
84
+ value_l_,
85
+ sampling_grid_l_,
86
+ mode="bilinear",
87
+ padding_mode="zeros",
88
+ align_corners=False,
89
+ )
90
+ sampling_value_list.append(sampling_value_l_)
91
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
92
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
93
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
94
+ attention_weights = attention_weights.transpose(1, 2).reshape(
95
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
96
+ )
97
+ output = (
98
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
99
+ .sum(-1)
100
+ .view(batch_size, num_heads * hidden_dim, num_queries)
101
+ )
102
+ return output.transpose(1, 2).contiguous()
103
+
104
+
105
+ @dataclass
106
+ @auto_docstring(
107
+ custom_intro="""
108
+ Base class for outputs of the RTDetrDecoder. This class adds two attributes to
109
+ BaseModelOutputWithCrossAttentions, namely:
110
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
111
+ - a stacked tensor of intermediate reference points.
112
+ """
113
+ )
114
+ class RTDetrDecoderOutput(ModelOutput):
115
+ r"""
116
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
117
+ Stacked intermediate hidden states (output of each layer of the decoder).
118
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
119
+ Stacked intermediate logits (logits of each layer of the decoder).
120
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
121
+ Stacked intermediate reference points (reference points of each layer of the decoder).
122
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
123
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
124
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
125
+ Stacked initial reference points (initial reference points of each layer of the decoder).
126
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
127
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
128
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
129
+ used to compute the weighted average in the cross-attention heads.
130
+ """
131
+
132
+ last_hidden_state: Optional[torch.FloatTensor] = None
133
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
134
+ intermediate_logits: Optional[torch.FloatTensor] = None
135
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
136
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
137
+ initial_reference_points: Optional[torch.FloatTensor] = None
138
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
139
+ attentions: Optional[tuple[torch.FloatTensor]] = None
140
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
141
+
142
+
143
+ @dataclass
144
+ @auto_docstring(
145
+ custom_intro="""
146
+ Base class for outputs of the RT-DETR encoder-decoder model.
147
+ """
148
+ )
149
+ class RTDetrModelOutput(ModelOutput):
150
+ r"""
151
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
152
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
153
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
154
+ Stacked intermediate hidden states (output of each layer of the decoder).
155
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
156
+ Stacked intermediate logits (logits of each layer of the decoder).
157
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
158
+ Stacked intermediate reference points (reference points of each layer of the decoder).
159
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
160
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
161
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
162
+ Initial reference points used for the first decoder layer.
163
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
164
+ Initial reference points sent through the Transformer decoder.
165
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
166
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
167
+ picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e.
168
+ foreground and background).
169
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`):
170
+ Logits of predicted bounding boxes coordinates in the encoder stage.
171
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
172
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
173
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
174
+ foreground and background).
175
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
176
+ Logits of predicted bounding boxes coordinates in the first stage.
177
+ denoising_meta_values (`dict`):
178
+ Extra dictionary for the denoising related values.
179
+ """
180
+
181
+ last_hidden_state: Optional[torch.FloatTensor] = None
182
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
183
+ intermediate_logits: Optional[torch.FloatTensor] = None
184
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
185
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
186
+ initial_reference_points: Optional[torch.FloatTensor] = None
187
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
188
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
189
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
190
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
191
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
192
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
193
+ init_reference_points: Optional[torch.FloatTensor] = None
194
+ enc_topk_logits: Optional[torch.FloatTensor] = None
195
+ enc_topk_bboxes: Optional[torch.FloatTensor] = None
196
+ enc_outputs_class: Optional[torch.FloatTensor] = None
197
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
198
+ denoising_meta_values: Optional[dict] = None
199
+
200
+
201
+ @dataclass
202
+ @auto_docstring(
203
+ custom_intro="""
204
+ Output type of [`RTDetrForObjectDetection`].
205
+ """
206
+ )
207
+ class RTDetrObjectDetectionOutput(ModelOutput):
208
+ r"""
209
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
210
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
211
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
212
+ scale-invariant IoU loss.
213
+ loss_dict (`Dict`, *optional*):
214
+ A dictionary containing the individual losses. Useful for logging.
215
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
216
+ Classification logits (including no-object) for all queries.
217
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
218
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
219
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
220
+ possible padding). You can use [`~RTDetrImageProcessor.post_process_object_detection`] to retrieve the
221
+ unnormalized (absolute) bounding boxes.
222
+ auxiliary_outputs (`list[Dict]`, *optional*):
223
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
224
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
225
+ `pred_boxes`) for each decoder layer.
226
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
227
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
228
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
229
+ Stacked intermediate hidden states (output of each layer of the decoder).
230
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`):
231
+ Stacked intermediate logits (logits of each layer of the decoder).
232
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
233
+ Stacked intermediate reference points (reference points of each layer of the decoder).
234
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
235
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
236
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
237
+ Stacked initial reference points (initial reference points of each layer of the decoder).
238
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
239
+ Initial reference points sent through the Transformer decoder.
240
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
241
+ Logits of predicted bounding boxes coordinates in the encoder.
242
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
243
+ Logits of predicted bounding boxes coordinates in the encoder.
244
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
245
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
246
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
247
+ foreground and background).
248
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
249
+ Logits of predicted bounding boxes coordinates in the first stage.
250
+ denoising_meta_values (`dict`):
251
+ Extra dictionary for the denoising related values
252
+ """
253
+
254
+ loss: Optional[torch.FloatTensor] = None
255
+ loss_dict: Optional[dict] = None
256
+ logits: Optional[torch.FloatTensor] = None
257
+ pred_boxes: Optional[torch.FloatTensor] = None
258
+ auxiliary_outputs: Optional[list[dict]] = None
259
+ last_hidden_state: Optional[torch.FloatTensor] = None
260
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
261
+ intermediate_logits: Optional[torch.FloatTensor] = None
262
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
263
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
264
+ initial_reference_points: Optional[torch.FloatTensor] = None
265
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
266
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
267
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
268
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
269
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
270
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
271
+ init_reference_points: Optional[tuple[torch.FloatTensor]] = None
272
+ enc_topk_logits: Optional[torch.FloatTensor] = None
273
+ enc_topk_bboxes: Optional[torch.FloatTensor] = None
274
+ enc_outputs_class: Optional[torch.FloatTensor] = None
275
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
276
+ denoising_meta_values: Optional[dict] = None
277
+
278
+
279
+ def _get_clones(partial_module, N):
280
+ return nn.ModuleList([partial_module() for i in range(N)])
281
+
282
+
283
+ # Copied from transformers.models.conditional_detr.modeling_conditional_detr.inverse_sigmoid
284
+ def inverse_sigmoid(x, eps=1e-5):
285
+ x = x.clamp(min=0, max=1)
286
+ x1 = x.clamp(min=eps)
287
+ x2 = (1 - x).clamp(min=eps)
288
+ return torch.log(x1 / x2)
289
+
290
+
291
+ # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->RTDetr
292
+ class RTDetrFrozenBatchNorm2d(nn.Module):
293
+ """
294
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
295
+
296
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
297
+ torchvision.models.resnet[18,34,50,101] produce nans.
298
+ """
299
+
300
+ def __init__(self, n):
301
+ super().__init__()
302
+ self.register_buffer("weight", torch.ones(n))
303
+ self.register_buffer("bias", torch.zeros(n))
304
+ self.register_buffer("running_mean", torch.zeros(n))
305
+ self.register_buffer("running_var", torch.ones(n))
306
+
307
+ def _load_from_state_dict(
308
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
309
+ ):
310
+ num_batches_tracked_key = prefix + "num_batches_tracked"
311
+ if num_batches_tracked_key in state_dict:
312
+ del state_dict[num_batches_tracked_key]
313
+
314
+ super()._load_from_state_dict(
315
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
316
+ )
317
+
318
+ def forward(self, x):
319
+ # move reshapes to the beginning
320
+ # to make it user-friendly
321
+ weight = self.weight.reshape(1, -1, 1, 1)
322
+ bias = self.bias.reshape(1, -1, 1, 1)
323
+ running_var = self.running_var.reshape(1, -1, 1, 1)
324
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
325
+ epsilon = 1e-5
326
+ scale = weight * (running_var + epsilon).rsqrt()
327
+ bias = bias - running_mean * scale
328
+ return x * scale + bias
329
+
330
+
331
+ # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->RTDetr
332
+ def replace_batch_norm(model):
333
+ r"""
334
+ Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`.
335
+
336
+ Args:
337
+ model (torch.nn.Module):
338
+ input model
339
+ """
340
+ for name, module in model.named_children():
341
+ if isinstance(module, nn.BatchNorm2d):
342
+ new_module = RTDetrFrozenBatchNorm2d(module.num_features)
343
+
344
+ if module.weight.device != torch.device("meta"):
345
+ new_module.weight.data.copy_(module.weight)
346
+ new_module.bias.data.copy_(module.bias)
347
+ new_module.running_mean.data.copy_(module.running_mean)
348
+ new_module.running_var.data.copy_(module.running_var)
349
+
350
+ model._modules[name] = new_module
351
+
352
+ if len(list(module.children())) > 0:
353
+ replace_batch_norm(module)
354
+
355
+
356
+ def get_contrastive_denoising_training_group(
357
+ targets,
358
+ num_classes,
359
+ num_queries,
360
+ class_embed,
361
+ num_denoising_queries=100,
362
+ label_noise_ratio=0.5,
363
+ box_noise_scale=1.0,
364
+ ):
365
+ """
366
+ Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
367
+
368
+ Args:
369
+ targets (`list[dict]`):
370
+ The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
371
+ num_classes (`int`):
372
+ Total number of classes in the dataset.
373
+ num_queries (`int`):
374
+ Number of query slots in the transformer.
375
+ class_embed (`callable`):
376
+ A function or a model layer to embed class labels.
377
+ num_denoising_queries (`int`, *optional*, defaults to 100):
378
+ Number of denoising queries.
379
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
380
+ Ratio of noise applied to labels.
381
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
382
+ Scale of noise applied to bounding boxes.
383
+ Returns:
384
+ `tuple` comprising various elements:
385
+ - **input_query_class** (`torch.FloatTensor`) --
386
+ Class queries with applied label noise.
387
+ - **input_query_bbox** (`torch.FloatTensor`) --
388
+ Bounding box queries with applied box noise.
389
+ - **attn_mask** (`torch.FloatTensor`) --
390
+ Attention mask for separating denoising and reconstruction queries.
391
+ - **denoising_meta_values** (`dict`) --
392
+ Metadata including denoising positive indices, number of groups, and split sizes.
393
+ """
394
+
395
+ if num_denoising_queries <= 0:
396
+ return None, None, None, None
397
+
398
+ num_ground_truths = [len(t["class_labels"]) for t in targets]
399
+ device = targets[0]["class_labels"].device
400
+
401
+ max_gt_num = max(num_ground_truths)
402
+ if max_gt_num == 0:
403
+ return None, None, None, None
404
+
405
+ num_groups_denoising_queries = num_denoising_queries // max_gt_num
406
+ num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
407
+ # pad gt to max_num of a batch
408
+ batch_size = len(num_ground_truths)
409
+
410
+ input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
411
+ input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
412
+ pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
413
+
414
+ for i in range(batch_size):
415
+ num_gt = num_ground_truths[i]
416
+ if num_gt > 0:
417
+ input_query_class[i, :num_gt] = targets[i]["class_labels"]
418
+ input_query_bbox[i, :num_gt] = targets[i]["boxes"]
419
+ pad_gt_mask[i, :num_gt] = 1
420
+ # each group has positive and negative queries.
421
+ input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
422
+ input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
423
+ pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
424
+ # positive and negative mask
425
+ negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
426
+ negative_gt_mask[:, max_gt_num:] = 1
427
+ negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
428
+ positive_gt_mask = 1 - negative_gt_mask
429
+ # contrastive denoising training positive index
430
+ positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
431
+ denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
432
+ denoise_positive_idx = torch.split(
433
+ denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
434
+ )
435
+ # total denoising queries
436
+ num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
437
+
438
+ if label_noise_ratio > 0:
439
+ mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
440
+ # randomly put a new one here
441
+ new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
442
+ input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
443
+
444
+ if box_noise_scale > 0:
445
+ known_bbox = center_to_corners_format(input_query_bbox)
446
+ diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
447
+ rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
448
+ rand_part = torch.rand_like(input_query_bbox)
449
+ rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
450
+ rand_part *= rand_sign
451
+ known_bbox += rand_part * diff
452
+ known_bbox.clip_(min=0.0, max=1.0)
453
+ input_query_bbox = corners_to_center_format(known_bbox)
454
+ input_query_bbox = inverse_sigmoid(input_query_bbox)
455
+
456
+ input_query_class = class_embed(input_query_class)
457
+
458
+ target_size = num_denoising_queries + num_queries
459
+ attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
460
+ # match query cannot see the reconstruction
461
+ attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
462
+
463
+ # reconstructions cannot see each other
464
+ for i in range(num_groups_denoising_queries):
465
+ idx_block_start = max_gt_num * 2 * i
466
+ idx_block_end = max_gt_num * 2 * (i + 1)
467
+ attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
468
+ attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
469
+
470
+ denoising_meta_values = {
471
+ "dn_positive_idx": denoise_positive_idx,
472
+ "dn_num_group": num_groups_denoising_queries,
473
+ "dn_num_split": [num_denoising_queries, num_queries],
474
+ }
475
+
476
+ return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
477
+
478
+
479
+ class RTDetrConvEncoder(nn.Module):
480
+ """
481
+ Convolutional backbone using the modeling_rt_detr_resnet.py.
482
+
483
+ nn.BatchNorm2d layers are replaced by RTDetrFrozenBatchNorm2d as defined above.
484
+ https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/nn/backbone/presnet.py#L142
485
+ """
486
+
487
+ def __init__(self, config):
488
+ super().__init__()
489
+
490
+ backbone = load_backbone(config)
491
+
492
+ if config.freeze_backbone_batch_norms:
493
+ # replace batch norm by frozen batch norm
494
+ with torch.no_grad():
495
+ replace_batch_norm(backbone)
496
+ self.model = backbone
497
+ self.intermediate_channel_sizes = self.model.channels
498
+
499
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
500
+ # send pixel_values through the model to get list of feature maps
501
+ features = self.model(pixel_values).feature_maps
502
+
503
+ out = []
504
+ for feature_map in features:
505
+ # downsample pixel_mask to match shape of corresponding feature_map
506
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
507
+ out.append((feature_map, mask))
508
+ return out
509
+
510
+
511
+ class RTDetrConvNormLayer(nn.Module):
512
+ def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
513
+ super().__init__()
514
+ self.conv = nn.Conv2d(
515
+ in_channels,
516
+ out_channels,
517
+ kernel_size,
518
+ stride,
519
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
520
+ bias=False,
521
+ )
522
+ self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
523
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
524
+
525
+ def forward(self, hidden_state):
526
+ hidden_state = self.conv(hidden_state)
527
+ hidden_state = self.norm(hidden_state)
528
+ hidden_state = self.activation(hidden_state)
529
+ return hidden_state
530
+
531
+
532
+ class RTDetrEncoderLayer(nn.Module):
533
+ def __init__(self, config: RTDetrConfig):
534
+ super().__init__()
535
+ self.normalize_before = config.normalize_before
536
+
537
+ # self-attention
538
+ self.self_attn = RTDetrMultiheadAttention(
539
+ embed_dim=config.encoder_hidden_dim,
540
+ num_heads=config.num_attention_heads,
541
+ dropout=config.dropout,
542
+ )
543
+ self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
544
+ self.dropout = config.dropout
545
+ self.activation_fn = ACT2FN[config.encoder_activation_function]
546
+ self.activation_dropout = config.activation_dropout
547
+ self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
548
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
549
+ self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
550
+
551
+ def forward(
552
+ self,
553
+ hidden_states: torch.Tensor,
554
+ attention_mask: torch.Tensor,
555
+ position_embeddings: Optional[torch.Tensor] = None,
556
+ output_attentions: bool = False,
557
+ **kwargs,
558
+ ):
559
+ """
560
+ Args:
561
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
562
+ attention_mask (`torch.FloatTensor`): attention mask of size
563
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
564
+ values.
565
+ position_embeddings (`torch.FloatTensor`, *optional*):
566
+ Object queries (also called content embeddings), to be added to the hidden states.
567
+ output_attentions (`bool`, *optional*):
568
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
569
+ returned tensors for more detail.
570
+ """
571
+ residual = hidden_states
572
+ if self.normalize_before:
573
+ hidden_states = self.self_attn_layer_norm(hidden_states)
574
+
575
+ hidden_states, attn_weights = self.self_attn(
576
+ hidden_states=hidden_states,
577
+ attention_mask=attention_mask,
578
+ position_embeddings=position_embeddings,
579
+ output_attentions=output_attentions,
580
+ )
581
+
582
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
583
+ hidden_states = residual + hidden_states
584
+ if not self.normalize_before:
585
+ hidden_states = self.self_attn_layer_norm(hidden_states)
586
+
587
+ if self.normalize_before:
588
+ hidden_states = self.final_layer_norm(hidden_states)
589
+ residual = hidden_states
590
+
591
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
592
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
593
+
594
+ hidden_states = self.fc2(hidden_states)
595
+
596
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
597
+
598
+ hidden_states = residual + hidden_states
599
+ if not self.normalize_before:
600
+ hidden_states = self.final_layer_norm(hidden_states)
601
+
602
+ if self.training:
603
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
604
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
605
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
606
+
607
+ outputs = (hidden_states,)
608
+
609
+ if output_attentions:
610
+ outputs += (attn_weights,)
611
+
612
+ return outputs
613
+
614
+
615
+ class RTDetrRepVggBlock(nn.Module):
616
+ """
617
+ RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
618
+ """
619
+
620
+ def __init__(self, config: RTDetrConfig):
621
+ super().__init__()
622
+
623
+ activation = config.activation_function
624
+ hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
625
+ self.conv1 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
626
+ self.conv2 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
627
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
628
+
629
+ def forward(self, x):
630
+ y = self.conv1(x) + self.conv2(x)
631
+ return self.activation(y)
632
+
633
+
634
+ class RTDetrCSPRepLayer(nn.Module):
635
+ """
636
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
637
+ """
638
+
639
+ def __init__(self, config: RTDetrConfig):
640
+ super().__init__()
641
+
642
+ in_channels = config.encoder_hidden_dim * 2
643
+ out_channels = config.encoder_hidden_dim
644
+ num_blocks = 3
645
+ activation = config.activation_function
646
+
647
+ hidden_channels = int(out_channels * config.hidden_expansion)
648
+ self.conv1 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
649
+ self.conv2 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
650
+ self.bottlenecks = nn.Sequential(*[RTDetrRepVggBlock(config) for _ in range(num_blocks)])
651
+ if hidden_channels != out_channels:
652
+ self.conv3 = RTDetrConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
653
+ else:
654
+ self.conv3 = nn.Identity()
655
+
656
+ def forward(self, hidden_state):
657
+ hidden_state_1 = self.conv1(hidden_state)
658
+ hidden_state_1 = self.bottlenecks(hidden_state_1)
659
+ hidden_state_2 = self.conv2(hidden_state)
660
+ return self.conv3(hidden_state_1 + hidden_state_2)
661
+
662
+
663
+ # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
664
+ class RTDetrMultiscaleDeformableAttention(nn.Module):
665
+ """
666
+ Multiscale deformable attention as proposed in Deformable DETR.
667
+ """
668
+
669
+ def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int):
670
+ super().__init__()
671
+
672
+ self.attn = MultiScaleDeformableAttention()
673
+
674
+ if config.d_model % num_heads != 0:
675
+ raise ValueError(
676
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
677
+ )
678
+ dim_per_head = config.d_model // num_heads
679
+ # check if dim_per_head is power of 2
680
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
681
+ warnings.warn(
682
+ "You'd better set embed_dim (d_model) in RTDetrMultiscaleDeformableAttention to make the"
683
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
684
+ " implementation."
685
+ )
686
+
687
+ self.im2col_step = 64
688
+
689
+ self.d_model = config.d_model
690
+ self.n_levels = config.num_feature_levels
691
+ self.n_heads = num_heads
692
+ self.n_points = n_points
693
+
694
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
695
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
696
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
697
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
698
+
699
+ self.disable_custom_kernels = config.disable_custom_kernels
700
+
701
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
702
+ return tensor if position_embeddings is None else tensor + position_embeddings
703
+
704
+ def forward(
705
+ self,
706
+ hidden_states: torch.Tensor,
707
+ attention_mask: Optional[torch.Tensor] = None,
708
+ encoder_hidden_states=None,
709
+ encoder_attention_mask=None,
710
+ position_embeddings: Optional[torch.Tensor] = None,
711
+ reference_points=None,
712
+ spatial_shapes=None,
713
+ spatial_shapes_list=None,
714
+ level_start_index=None,
715
+ output_attentions: bool = False,
716
+ ):
717
+ # add position embeddings to the hidden states before projecting to queries and keys
718
+ if position_embeddings is not None:
719
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
720
+
721
+ batch_size, num_queries, _ = hidden_states.shape
722
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
723
+ total_elements = sum(height * width for height, width in spatial_shapes_list)
724
+ if total_elements != sequence_length:
725
+ raise ValueError(
726
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
727
+ )
728
+
729
+ value = self.value_proj(encoder_hidden_states)
730
+ if attention_mask is not None:
731
+ # we invert the attention_mask
732
+ value = value.masked_fill(~attention_mask[..., None], float(0))
733
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
734
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
735
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
736
+ )
737
+ attention_weights = self.attention_weights(hidden_states).view(
738
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
739
+ )
740
+ attention_weights = F.softmax(attention_weights, -1).view(
741
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
742
+ )
743
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
744
+ num_coordinates = reference_points.shape[-1]
745
+ if num_coordinates == 2:
746
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
747
+ sampling_locations = (
748
+ reference_points[:, :, None, :, None, :]
749
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
750
+ )
751
+ elif num_coordinates == 4:
752
+ sampling_locations = (
753
+ reference_points[:, :, None, :, None, :2]
754
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
755
+ )
756
+ else:
757
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
758
+
759
+ output = self.attn(
760
+ value,
761
+ spatial_shapes,
762
+ spatial_shapes_list,
763
+ level_start_index,
764
+ sampling_locations,
765
+ attention_weights,
766
+ self.im2col_step,
767
+ )
768
+
769
+ output = self.output_proj(output)
770
+
771
+ return output, attention_weights
772
+
773
+
774
+ class RTDetrMultiheadAttention(nn.Module):
775
+ """
776
+ Multi-headed attention from 'Attention Is All You Need' paper.
777
+
778
+ Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
779
+ """
780
+
781
+ def __init__(
782
+ self,
783
+ embed_dim: int,
784
+ num_heads: int,
785
+ dropout: float = 0.0,
786
+ bias: bool = True,
787
+ ):
788
+ super().__init__()
789
+ self.embed_dim = embed_dim
790
+ self.num_heads = num_heads
791
+ self.dropout = dropout
792
+ self.head_dim = embed_dim // num_heads
793
+ if self.head_dim * num_heads != self.embed_dim:
794
+ raise ValueError(
795
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
796
+ f" {num_heads})."
797
+ )
798
+ self.scaling = self.head_dim**-0.5
799
+
800
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
801
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
802
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
803
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
804
+
805
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
806
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
807
+
808
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
809
+ return tensor if position_embeddings is None else tensor + position_embeddings
810
+
811
+ def forward(
812
+ self,
813
+ hidden_states: torch.Tensor,
814
+ attention_mask: Optional[torch.Tensor] = None,
815
+ position_embeddings: Optional[torch.Tensor] = None,
816
+ output_attentions: bool = False,
817
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
818
+ """Input shape: Batch x Time x Channel"""
819
+
820
+ batch_size, target_len, embed_dim = hidden_states.size()
821
+ # add position embeddings to the hidden states before projecting to queries and keys
822
+ if position_embeddings is not None:
823
+ hidden_states_original = hidden_states
824
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
825
+
826
+ # get queries, keys and values
827
+ query_states = self.q_proj(hidden_states) * self.scaling
828
+ key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
829
+ value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
830
+
831
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
832
+ query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
833
+ key_states = key_states.view(*proj_shape)
834
+ value_states = value_states.view(*proj_shape)
835
+
836
+ source_len = key_states.size(1)
837
+
838
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
839
+
840
+ if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
841
+ raise ValueError(
842
+ f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
843
+ f" {attn_weights.size()}"
844
+ )
845
+
846
+ # expand attention_mask
847
+ if attention_mask is not None:
848
+ # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
849
+ attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
850
+
851
+ if attention_mask is not None:
852
+ if attention_mask.size() != (batch_size, 1, target_len, source_len):
853
+ raise ValueError(
854
+ f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
855
+ f" {attention_mask.size()}"
856
+ )
857
+ if attention_mask.dtype == torch.bool:
858
+ attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
859
+ attention_mask, -torch.inf
860
+ )
861
+ attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
862
+ attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
863
+
864
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
865
+
866
+ if output_attentions:
867
+ # this operation is a bit awkward, but it's required to
868
+ # make sure that attn_weights keeps its gradient.
869
+ # In order to do so, attn_weights have to reshaped
870
+ # twice and have to be reused in the following
871
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
872
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
873
+ else:
874
+ attn_weights_reshaped = None
875
+
876
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
877
+
878
+ attn_output = torch.bmm(attn_probs, value_states)
879
+
880
+ if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
881
+ raise ValueError(
882
+ f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
883
+ f" {attn_output.size()}"
884
+ )
885
+
886
+ attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
887
+ attn_output = attn_output.transpose(1, 2)
888
+ attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
889
+
890
+ attn_output = self.out_proj(attn_output)
891
+
892
+ return attn_output, attn_weights_reshaped
893
+
894
+
895
+ class RTDetrDecoderLayer(nn.Module):
896
+ def __init__(self, config: RTDetrConfig):
897
+ super().__init__()
898
+ # self-attention
899
+ self.self_attn = RTDetrMultiheadAttention(
900
+ embed_dim=config.d_model,
901
+ num_heads=config.decoder_attention_heads,
902
+ dropout=config.attention_dropout,
903
+ )
904
+ self.dropout = config.dropout
905
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
906
+ self.activation_dropout = config.activation_dropout
907
+
908
+ self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
909
+ # cross-attention
910
+ self.encoder_attn = RTDetrMultiscaleDeformableAttention(
911
+ config,
912
+ num_heads=config.decoder_attention_heads,
913
+ n_points=config.decoder_n_points,
914
+ )
915
+ self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
916
+ # feedforward neural networks
917
+ self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
918
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
919
+ self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
920
+
921
+ def forward(
922
+ self,
923
+ hidden_states: torch.Tensor,
924
+ position_embeddings: Optional[torch.Tensor] = None,
925
+ reference_points=None,
926
+ spatial_shapes=None,
927
+ spatial_shapes_list=None,
928
+ level_start_index=None,
929
+ encoder_hidden_states: Optional[torch.Tensor] = None,
930
+ encoder_attention_mask: Optional[torch.Tensor] = None,
931
+ output_attentions: Optional[bool] = False,
932
+ ):
933
+ """
934
+ Args:
935
+ hidden_states (`torch.FloatTensor`):
936
+ Input to the layer of shape `(seq_len, batch, embed_dim)`.
937
+ position_embeddings (`torch.FloatTensor`, *optional*):
938
+ Position embeddings that are added to the queries and keys in the self-attention layer.
939
+ reference_points (`torch.FloatTensor`, *optional*):
940
+ Reference points.
941
+ spatial_shapes (`torch.LongTensor`, *optional*):
942
+ Spatial shapes.
943
+ level_start_index (`torch.LongTensor`, *optional*):
944
+ Level start index.
945
+ encoder_hidden_states (`torch.FloatTensor`):
946
+ cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
947
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
948
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
949
+ values.
950
+ output_attentions (`bool`, *optional*):
951
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
952
+ returned tensors for more detail.
953
+ """
954
+ residual = hidden_states
955
+
956
+ # Self Attention
957
+ hidden_states, self_attn_weights = self.self_attn(
958
+ hidden_states=hidden_states,
959
+ attention_mask=encoder_attention_mask,
960
+ position_embeddings=position_embeddings,
961
+ output_attentions=output_attentions,
962
+ )
963
+
964
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
965
+ hidden_states = residual + hidden_states
966
+ hidden_states = self.self_attn_layer_norm(hidden_states)
967
+
968
+ second_residual = hidden_states
969
+
970
+ # Cross-Attention
971
+ cross_attn_weights = None
972
+ hidden_states, cross_attn_weights = self.encoder_attn(
973
+ hidden_states=hidden_states,
974
+ encoder_hidden_states=encoder_hidden_states,
975
+ position_embeddings=position_embeddings,
976
+ reference_points=reference_points,
977
+ spatial_shapes=spatial_shapes,
978
+ spatial_shapes_list=spatial_shapes_list,
979
+ level_start_index=level_start_index,
980
+ output_attentions=output_attentions,
981
+ )
982
+
983
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
984
+ hidden_states = second_residual + hidden_states
985
+
986
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
987
+
988
+ # Fully Connected
989
+ residual = hidden_states
990
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
991
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
992
+ hidden_states = self.fc2(hidden_states)
993
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
994
+ hidden_states = residual + hidden_states
995
+ hidden_states = self.final_layer_norm(hidden_states)
996
+
997
+ outputs = (hidden_states,)
998
+
999
+ if output_attentions:
1000
+ outputs += (self_attn_weights, cross_attn_weights)
1001
+
1002
+ return outputs
1003
+
1004
+
1005
+ @auto_docstring
1006
+ class RTDetrPreTrainedModel(PreTrainedModel):
1007
+ config: RTDetrConfig
1008
+ base_model_prefix = "rt_detr"
1009
+ main_input_name = "pixel_values"
1010
+ _no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"]
1011
+
1012
+ def _init_weights(self, module):
1013
+ """Initialize the weights"""
1014
+ if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)):
1015
+ if module.class_embed is not None:
1016
+ for layer in module.class_embed:
1017
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
1018
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
1019
+ nn.init.xavier_uniform_(layer.weight)
1020
+ nn.init.constant_(layer.bias, bias)
1021
+
1022
+ if module.bbox_embed is not None:
1023
+ for layer in module.bbox_embed:
1024
+ nn.init.constant_(layer.layers[-1].weight, 0)
1025
+ nn.init.constant_(layer.layers[-1].bias, 0)
1026
+
1027
+ elif isinstance(module, RTDetrMultiscaleDeformableAttention):
1028
+ nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
1029
+ default_dtype = torch.get_default_dtype()
1030
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
1031
+ 2.0 * math.pi / module.n_heads
1032
+ )
1033
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
1034
+ grid_init = (
1035
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
1036
+ .view(module.n_heads, 1, 1, 2)
1037
+ .repeat(1, module.n_levels, module.n_points, 1)
1038
+ )
1039
+ for i in range(module.n_points):
1040
+ grid_init[:, :, i, :] *= i + 1
1041
+ with torch.no_grad():
1042
+ module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
1043
+ nn.init.constant_(module.attention_weights.weight.data, 0.0)
1044
+ nn.init.constant_(module.attention_weights.bias.data, 0.0)
1045
+ nn.init.xavier_uniform_(module.value_proj.weight.data)
1046
+ nn.init.constant_(module.value_proj.bias.data, 0.0)
1047
+ nn.init.xavier_uniform_(module.output_proj.weight.data)
1048
+ nn.init.constant_(module.output_proj.bias.data, 0.0)
1049
+
1050
+ elif isinstance(module, RTDetrModel):
1051
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
1052
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
1053
+ nn.init.xavier_uniform_(module.enc_score_head.weight)
1054
+ nn.init.constant_(module.enc_score_head.bias, bias)
1055
+
1056
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
1057
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1058
+ if module.bias is not None:
1059
+ module.bias.data.zero_()
1060
+
1061
+ elif isinstance(module, nn.LayerNorm):
1062
+ module.weight.data.fill_(1.0)
1063
+ module.bias.data.zero_()
1064
+
1065
+ if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
1066
+ nn.init.xavier_uniform_(module.weight_embedding.weight)
1067
+ if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
1068
+ nn.init.xavier_uniform_(module.denoising_class_embed.weight)
1069
+
1070
+
1071
+ class RTDetrEncoder(nn.Module):
1072
+ def __init__(self, config: RTDetrConfig):
1073
+ super().__init__()
1074
+
1075
+ self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
1076
+
1077
+ def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
1078
+ hidden_states = src
1079
+ for layer in self.layers:
1080
+ hidden_states = layer(
1081
+ hidden_states,
1082
+ attention_mask=src_mask,
1083
+ position_embeddings=pos_embed,
1084
+ output_attentions=output_attentions,
1085
+ )
1086
+ return hidden_states
1087
+
1088
+
1089
+ class RTDetrHybridEncoder(nn.Module):
1090
+ """
1091
+ Decoder consisting of a projection layer, a set of `RTDetrEncoder`, a top-down Feature Pyramid Network
1092
+ (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
1093
+
1094
+ Args:
1095
+ config: RTDetrConfig
1096
+ """
1097
+
1098
+ def __init__(self, config: RTDetrConfig):
1099
+ super().__init__()
1100
+ self.config = config
1101
+ self.in_channels = config.encoder_in_channels
1102
+ self.feat_strides = config.feat_strides
1103
+ self.encoder_hidden_dim = config.encoder_hidden_dim
1104
+ self.encode_proj_layers = config.encode_proj_layers
1105
+ self.positional_encoding_temperature = config.positional_encoding_temperature
1106
+ self.eval_size = config.eval_size
1107
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
1108
+ self.out_strides = self.feat_strides
1109
+ self.num_fpn_stages = len(self.in_channels) - 1
1110
+ self.num_pan_stages = len(self.in_channels) - 1
1111
+ activation = config.activation_function
1112
+
1113
+ # encoder transformer
1114
+ self.encoder = nn.ModuleList([RTDetrEncoder(config) for _ in range(len(self.encode_proj_layers))])
1115
+
1116
+ # top-down FPN
1117
+ self.lateral_convs = nn.ModuleList()
1118
+ self.fpn_blocks = nn.ModuleList()
1119
+ for _ in range(self.num_fpn_stages):
1120
+ lateral_conv = RTDetrConvNormLayer(
1121
+ config,
1122
+ in_channels=self.encoder_hidden_dim,
1123
+ out_channels=self.encoder_hidden_dim,
1124
+ kernel_size=1,
1125
+ stride=1,
1126
+ activation=activation,
1127
+ )
1128
+ fpn_block = RTDetrCSPRepLayer(config)
1129
+ self.lateral_convs.append(lateral_conv)
1130
+ self.fpn_blocks.append(fpn_block)
1131
+
1132
+ # bottom-up PAN
1133
+ self.downsample_convs = nn.ModuleList()
1134
+ self.pan_blocks = nn.ModuleList()
1135
+ for _ in range(self.num_pan_stages):
1136
+ downsample_conv = RTDetrConvNormLayer(
1137
+ config,
1138
+ in_channels=self.encoder_hidden_dim,
1139
+ out_channels=self.encoder_hidden_dim,
1140
+ kernel_size=3,
1141
+ stride=2,
1142
+ activation=activation,
1143
+ )
1144
+ pan_block = RTDetrCSPRepLayer(config)
1145
+ self.downsample_convs.append(downsample_conv)
1146
+ self.pan_blocks.append(pan_block)
1147
+
1148
+ @staticmethod
1149
+ def build_2d_sincos_position_embedding(
1150
+ width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
1151
+ ):
1152
+ grid_w = torch.arange(torch_int(width), device=device).to(dtype)
1153
+ grid_h = torch.arange(torch_int(height), device=device).to(dtype)
1154
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
1155
+ if embed_dim % 4 != 0:
1156
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
1157
+ pos_dim = embed_dim // 4
1158
+ omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
1159
+ omega = 1.0 / (temperature**omega)
1160
+
1161
+ out_w = grid_w.flatten()[..., None] @ omega[None]
1162
+ out_h = grid_h.flatten()[..., None] @ omega[None]
1163
+
1164
+ return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
1165
+
1166
+ def forward(
1167
+ self,
1168
+ inputs_embeds=None,
1169
+ attention_mask=None,
1170
+ position_embeddings=None,
1171
+ spatial_shapes=None,
1172
+ level_start_index=None,
1173
+ valid_ratios=None,
1174
+ output_attentions=None,
1175
+ output_hidden_states=None,
1176
+ return_dict=None,
1177
+ ):
1178
+ r"""
1179
+ Args:
1180
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1181
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
1182
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1183
+ Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
1184
+ - 1 for pixel features that are real (i.e. **not masked**),
1185
+ - 0 for pixel features that are padding (i.e. **masked**).
1186
+ [What are attention masks?](../glossary#attention-mask)
1187
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1188
+ Position embeddings that are added to the queries and keys in each self-attention layer.
1189
+ spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
1190
+ Spatial shapes of each feature map.
1191
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
1192
+ Starting index of each feature map.
1193
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
1194
+ Ratio of valid area in each feature level.
1195
+ output_attentions (`bool`, *optional*):
1196
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1197
+ returned tensors for more detail.
1198
+ output_hidden_states (`bool`, *optional*):
1199
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1200
+ for more detail.
1201
+ return_dict (`bool`, *optional*):
1202
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1203
+ """
1204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1205
+ output_hidden_states = (
1206
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1207
+ )
1208
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1209
+
1210
+ hidden_states = inputs_embeds
1211
+
1212
+ encoder_states = () if output_hidden_states else None
1213
+ all_attentions = () if output_attentions else None
1214
+
1215
+ # encoder
1216
+ if self.config.encoder_layers > 0:
1217
+ for i, enc_ind in enumerate(self.encode_proj_layers):
1218
+ if output_hidden_states:
1219
+ encoder_states = encoder_states + (hidden_states[enc_ind],)
1220
+ height, width = hidden_states[enc_ind].shape[2:]
1221
+ # flatten [batch, channel, height, width] to [batch, height*width, channel]
1222
+ src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
1223
+ if self.training or self.eval_size is None:
1224
+ pos_embed = self.build_2d_sincos_position_embedding(
1225
+ width,
1226
+ height,
1227
+ self.encoder_hidden_dim,
1228
+ self.positional_encoding_temperature,
1229
+ device=src_flatten.device,
1230
+ dtype=src_flatten.dtype,
1231
+ )
1232
+ else:
1233
+ pos_embed = None
1234
+
1235
+ layer_outputs = self.encoder[i](
1236
+ src_flatten,
1237
+ pos_embed=pos_embed,
1238
+ output_attentions=output_attentions,
1239
+ )
1240
+ hidden_states[enc_ind] = (
1241
+ layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
1242
+ )
1243
+
1244
+ if output_attentions:
1245
+ all_attentions = all_attentions + (layer_outputs[1],)
1246
+
1247
+ if output_hidden_states:
1248
+ encoder_states = encoder_states + (hidden_states[enc_ind],)
1249
+
1250
+ # top-down FPN
1251
+ fpn_feature_maps = [hidden_states[-1]]
1252
+ for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
1253
+ backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
1254
+ top_fpn_feature_map = fpn_feature_maps[-1]
1255
+ # apply lateral block
1256
+ top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
1257
+ fpn_feature_maps[-1] = top_fpn_feature_map
1258
+ # apply fpn block
1259
+ top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
1260
+ fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
1261
+ new_fpn_feature_map = fpn_block(fused_feature_map)
1262
+ fpn_feature_maps.append(new_fpn_feature_map)
1263
+
1264
+ fpn_feature_maps = fpn_feature_maps[::-1]
1265
+
1266
+ # bottom-up PAN
1267
+ pan_feature_maps = [fpn_feature_maps[0]]
1268
+ for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
1269
+ top_pan_feature_map = pan_feature_maps[-1]
1270
+ fpn_feature_map = fpn_feature_maps[idx + 1]
1271
+ downsampled_feature_map = downsample_conv(top_pan_feature_map)
1272
+ fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
1273
+ new_pan_feature_map = pan_block(fused_feature_map)
1274
+ pan_feature_maps.append(new_pan_feature_map)
1275
+
1276
+ if not return_dict:
1277
+ return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
1278
+ return BaseModelOutput(
1279
+ last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
1280
+ )
1281
+
1282
+
1283
+ class RTDetrDecoder(RTDetrPreTrainedModel):
1284
+ def __init__(self, config: RTDetrConfig):
1285
+ super().__init__(config)
1286
+
1287
+ self.dropout = config.dropout
1288
+ self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
1289
+ self.query_pos_head = RTDetrMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
1290
+
1291
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
1292
+ self.bbox_embed = None
1293
+ self.class_embed = None
1294
+
1295
+ # Initialize weights and apply final processing
1296
+ self.post_init()
1297
+
1298
+ def forward(
1299
+ self,
1300
+ inputs_embeds=None,
1301
+ encoder_hidden_states=None,
1302
+ encoder_attention_mask=None,
1303
+ position_embeddings=None,
1304
+ reference_points=None,
1305
+ spatial_shapes=None,
1306
+ spatial_shapes_list=None,
1307
+ level_start_index=None,
1308
+ valid_ratios=None,
1309
+ output_attentions=None,
1310
+ output_hidden_states=None,
1311
+ return_dict=None,
1312
+ ):
1313
+ r"""
1314
+ Args:
1315
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1316
+ The query embeddings that are passed into the decoder.
1317
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1318
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1319
+ of the decoder.
1320
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1321
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
1322
+ in `[0, 1]`:
1323
+ - 1 for pixels that are real (i.e. **not masked**),
1324
+ - 0 for pixels that are padding (i.e. **masked**).
1325
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1326
+ Position embeddings that are added to the queries and keys in each self-attention layer.
1327
+ reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
1328
+ Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
1329
+ spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
1330
+ Spatial shapes of the feature maps.
1331
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
1332
+ Indexes for the start of each feature level. In range `[0, sequence_length]`.
1333
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
1334
+ Ratio of valid area in each feature level.
1335
+
1336
+ output_attentions (`bool`, *optional*):
1337
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1338
+ returned tensors for more detail.
1339
+ output_hidden_states (`bool`, *optional*):
1340
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1341
+ for more detail.
1342
+ return_dict (`bool`, *optional*):
1343
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1344
+ """
1345
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1346
+ output_hidden_states = (
1347
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1348
+ )
1349
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1350
+
1351
+ if inputs_embeds is not None:
1352
+ hidden_states = inputs_embeds
1353
+
1354
+ # decoder layers
1355
+ all_hidden_states = () if output_hidden_states else None
1356
+ all_self_attns = () if output_attentions else None
1357
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1358
+ intermediate = ()
1359
+ intermediate_reference_points = ()
1360
+ intermediate_logits = ()
1361
+
1362
+ reference_points = F.sigmoid(reference_points)
1363
+
1364
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252
1365
+ for idx, decoder_layer in enumerate(self.layers):
1366
+ reference_points_input = reference_points.unsqueeze(2)
1367
+ position_embeddings = self.query_pos_head(reference_points)
1368
+
1369
+ if output_hidden_states:
1370
+ all_hidden_states += (hidden_states,)
1371
+
1372
+ layer_outputs = decoder_layer(
1373
+ hidden_states,
1374
+ position_embeddings=position_embeddings,
1375
+ encoder_hidden_states=encoder_hidden_states,
1376
+ reference_points=reference_points_input,
1377
+ spatial_shapes=spatial_shapes,
1378
+ spatial_shapes_list=spatial_shapes_list,
1379
+ level_start_index=level_start_index,
1380
+ encoder_attention_mask=encoder_attention_mask,
1381
+ output_attentions=output_attentions,
1382
+ )
1383
+
1384
+ hidden_states = layer_outputs[0]
1385
+
1386
+ # hack implementation for iterative bounding box refinement
1387
+ if self.bbox_embed is not None:
1388
+ predicted_corners = self.bbox_embed[idx](hidden_states)
1389
+ new_reference_points = F.sigmoid(predicted_corners + inverse_sigmoid(reference_points))
1390
+ reference_points = new_reference_points.detach()
1391
+
1392
+ intermediate += (hidden_states,)
1393
+ intermediate_reference_points += (
1394
+ (new_reference_points,) if self.bbox_embed is not None else (reference_points,)
1395
+ )
1396
+
1397
+ if self.class_embed is not None:
1398
+ logits = self.class_embed[idx](hidden_states)
1399
+ intermediate_logits += (logits,)
1400
+
1401
+ if output_attentions:
1402
+ all_self_attns += (layer_outputs[1],)
1403
+
1404
+ if encoder_hidden_states is not None:
1405
+ all_cross_attentions += (layer_outputs[2],)
1406
+
1407
+ # Keep batch_size as first dimension
1408
+ intermediate = torch.stack(intermediate, dim=1)
1409
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
1410
+ if self.class_embed is not None:
1411
+ intermediate_logits = torch.stack(intermediate_logits, dim=1)
1412
+
1413
+ # add hidden states from the last decoder layer
1414
+ if output_hidden_states:
1415
+ all_hidden_states += (hidden_states,)
1416
+
1417
+ if not return_dict:
1418
+ return tuple(
1419
+ v
1420
+ for v in [
1421
+ hidden_states,
1422
+ intermediate,
1423
+ intermediate_logits,
1424
+ intermediate_reference_points,
1425
+ all_hidden_states,
1426
+ all_self_attns,
1427
+ all_cross_attentions,
1428
+ ]
1429
+ if v is not None
1430
+ )
1431
+ return RTDetrDecoderOutput(
1432
+ last_hidden_state=hidden_states,
1433
+ intermediate_hidden_states=intermediate,
1434
+ intermediate_logits=intermediate_logits,
1435
+ intermediate_reference_points=intermediate_reference_points,
1436
+ hidden_states=all_hidden_states,
1437
+ attentions=all_self_attns,
1438
+ cross_attentions=all_cross_attentions,
1439
+ )
1440
+
1441
+
1442
+ # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1443
+ class RTDetrMLPPredictionHead(nn.Module):
1444
+ """
1445
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1446
+ height and width of a bounding box w.r.t. an image.
1447
+
1448
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1449
+ Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453
1450
+
1451
+ """
1452
+
1453
+ def __init__(self, config, input_dim, d_model, output_dim, num_layers):
1454
+ super().__init__()
1455
+ self.num_layers = num_layers
1456
+ h = [d_model] * (num_layers - 1)
1457
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1458
+
1459
+ def forward(self, x):
1460
+ for i, layer in enumerate(self.layers):
1461
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1462
+ return x
1463
+
1464
+
1465
+ @auto_docstring(
1466
+ custom_intro="""
1467
+ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
1468
+ """
1469
+ )
1470
+ class RTDetrModel(RTDetrPreTrainedModel):
1471
+ def __init__(self, config: RTDetrConfig):
1472
+ super().__init__(config)
1473
+
1474
+ # Create backbone
1475
+ self.backbone = RTDetrConvEncoder(config)
1476
+ intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
1477
+
1478
+ # Create encoder input projection layers
1479
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212
1480
+ num_backbone_outs = len(intermediate_channel_sizes)
1481
+ encoder_input_proj_list = []
1482
+ for _ in range(num_backbone_outs):
1483
+ in_channels = intermediate_channel_sizes[_]
1484
+ encoder_input_proj_list.append(
1485
+ nn.Sequential(
1486
+ nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
1487
+ nn.BatchNorm2d(config.encoder_hidden_dim),
1488
+ )
1489
+ )
1490
+ self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list)
1491
+
1492
+ # Create encoder
1493
+ self.encoder = RTDetrHybridEncoder(config)
1494
+
1495
+ # denoising part
1496
+ if config.num_denoising > 0:
1497
+ self.denoising_class_embed = nn.Embedding(
1498
+ config.num_labels + 1, config.d_model, padding_idx=config.num_labels
1499
+ )
1500
+
1501
+ # decoder embedding
1502
+ if config.learn_initial_query:
1503
+ self.weight_embedding = nn.Embedding(config.num_queries, config.d_model)
1504
+
1505
+ # encoder head
1506
+ self.enc_output = nn.Sequential(
1507
+ nn.Linear(config.d_model, config.d_model),
1508
+ nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
1509
+ )
1510
+ self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
1511
+ self.enc_bbox_head = RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
1512
+
1513
+ # init encoder output anchors and valid_mask
1514
+ if config.anchor_image_size:
1515
+ self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
1516
+
1517
+ # Create decoder input projection layers
1518
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
1519
+ num_backbone_outs = len(config.decoder_in_channels)
1520
+ decoder_input_proj_list = []
1521
+ for _ in range(num_backbone_outs):
1522
+ in_channels = config.decoder_in_channels[_]
1523
+ decoder_input_proj_list.append(
1524
+ nn.Sequential(
1525
+ nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
1526
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
1527
+ )
1528
+ )
1529
+ for _ in range(config.num_feature_levels - num_backbone_outs):
1530
+ decoder_input_proj_list.append(
1531
+ nn.Sequential(
1532
+ nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False),
1533
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
1534
+ )
1535
+ )
1536
+ in_channels = config.d_model
1537
+ self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list)
1538
+
1539
+ # decoder
1540
+ self.decoder = RTDetrDecoder(config)
1541
+
1542
+ self.post_init()
1543
+
1544
+ def get_encoder(self):
1545
+ return self.encoder
1546
+
1547
+ def freeze_backbone(self):
1548
+ for param in self.backbone.parameters():
1549
+ param.requires_grad_(False)
1550
+
1551
+ def unfreeze_backbone(self):
1552
+ for param in self.backbone.parameters():
1553
+ param.requires_grad_(True)
1554
+
1555
+ @compile_compatible_method_lru_cache(maxsize=32)
1556
+ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
1557
+ if spatial_shapes is None:
1558
+ spatial_shapes = [
1559
+ [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
1560
+ for s in self.config.feat_strides
1561
+ ]
1562
+ anchors = []
1563
+ for level, (height, width) in enumerate(spatial_shapes):
1564
+ grid_y, grid_x = torch.meshgrid(
1565
+ torch.arange(end=height, device=device).to(dtype),
1566
+ torch.arange(end=width, device=device).to(dtype),
1567
+ indexing="ij",
1568
+ )
1569
+ grid_xy = torch.stack([grid_x, grid_y], -1)
1570
+ grid_xy = grid_xy.unsqueeze(0) + 0.5
1571
+ grid_xy[..., 0] /= width
1572
+ grid_xy[..., 1] /= height
1573
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
1574
+ anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
1575
+ # define the valid range for anchor coordinates
1576
+ eps = 1e-2
1577
+ anchors = torch.concat(anchors, 1)
1578
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
1579
+ anchors = torch.log(anchors / (1 - anchors))
1580
+ anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
1581
+
1582
+ return anchors, valid_mask
1583
+
1584
+ @auto_docstring
1585
+ def forward(
1586
+ self,
1587
+ pixel_values: torch.FloatTensor,
1588
+ pixel_mask: Optional[torch.LongTensor] = None,
1589
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1590
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1591
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1592
+ labels: Optional[list[dict]] = None,
1593
+ output_attentions: Optional[bool] = None,
1594
+ output_hidden_states: Optional[bool] = None,
1595
+ return_dict: Optional[bool] = None,
1596
+ ) -> Union[tuple[torch.FloatTensor], RTDetrModelOutput]:
1597
+ r"""
1598
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1599
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1600
+ can choose to directly pass a flattened representation of an image.
1601
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1602
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1603
+ embedded representation.
1604
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1605
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1606
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1607
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1608
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1609
+
1610
+ Examples:
1611
+
1612
+ ```python
1613
+ >>> from transformers import AutoImageProcessor, RTDetrModel
1614
+ >>> from PIL import Image
1615
+ >>> import requests
1616
+
1617
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1618
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1619
+
1620
+ >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
1621
+ >>> model = RTDetrModel.from_pretrained("PekingU/rtdetr_r50vd")
1622
+
1623
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1624
+
1625
+ >>> outputs = model(**inputs)
1626
+
1627
+ >>> last_hidden_states = outputs.last_hidden_state
1628
+ >>> list(last_hidden_states.shape)
1629
+ [1, 300, 256]
1630
+ ```"""
1631
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1632
+ output_hidden_states = (
1633
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1634
+ )
1635
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1636
+
1637
+ batch_size, num_channels, height, width = pixel_values.shape
1638
+ device = pixel_values.device
1639
+
1640
+ if pixel_mask is None:
1641
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1642
+
1643
+ features = self.backbone(pixel_values, pixel_mask)
1644
+
1645
+ proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1646
+
1647
+ if encoder_outputs is None:
1648
+ encoder_outputs = self.encoder(
1649
+ proj_feats,
1650
+ output_attentions=output_attentions,
1651
+ output_hidden_states=output_hidden_states,
1652
+ return_dict=return_dict,
1653
+ )
1654
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1655
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1656
+ encoder_outputs = BaseModelOutput(
1657
+ last_hidden_state=encoder_outputs[0],
1658
+ hidden_states=encoder_outputs[1] if output_hidden_states else None,
1659
+ attentions=encoder_outputs[2]
1660
+ if len(encoder_outputs) > 2
1661
+ else encoder_outputs[1]
1662
+ if output_attentions
1663
+ else None,
1664
+ )
1665
+
1666
+ # Equivalent to def _get_encoder_input
1667
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
1668
+ sources = []
1669
+ for level, source in enumerate(encoder_outputs[0]):
1670
+ sources.append(self.decoder_input_proj[level](source))
1671
+
1672
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
1673
+ if self.config.num_feature_levels > len(sources):
1674
+ _len_sources = len(sources)
1675
+ sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
1676
+ for i in range(_len_sources + 1, self.config.num_feature_levels):
1677
+ sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
1678
+
1679
+ # Prepare encoder inputs (by flattening)
1680
+ source_flatten = []
1681
+ spatial_shapes_list = []
1682
+ spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long)
1683
+ for level, source in enumerate(sources):
1684
+ height, width = source.shape[-2:]
1685
+ spatial_shapes[level, 0] = height
1686
+ spatial_shapes[level, 1] = width
1687
+ spatial_shapes_list.append((height, width))
1688
+ source = source.flatten(2).transpose(1, 2)
1689
+ source_flatten.append(source)
1690
+ source_flatten = torch.cat(source_flatten, 1)
1691
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
1692
+
1693
+ # prepare denoising training
1694
+ if self.training and self.config.num_denoising > 0 and labels is not None:
1695
+ (
1696
+ denoising_class,
1697
+ denoising_bbox_unact,
1698
+ attention_mask,
1699
+ denoising_meta_values,
1700
+ ) = get_contrastive_denoising_training_group(
1701
+ targets=labels,
1702
+ num_classes=self.config.num_labels,
1703
+ num_queries=self.config.num_queries,
1704
+ class_embed=self.denoising_class_embed,
1705
+ num_denoising_queries=self.config.num_denoising,
1706
+ label_noise_ratio=self.config.label_noise_ratio,
1707
+ box_noise_scale=self.config.box_noise_scale,
1708
+ )
1709
+ else:
1710
+ denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None
1711
+
1712
+ batch_size = len(source_flatten)
1713
+ device = source_flatten.device
1714
+ dtype = source_flatten.dtype
1715
+
1716
+ # prepare input for decoder
1717
+ if self.training or self.config.anchor_image_size is None:
1718
+ # Pass spatial_shapes as tuple to make it hashable and make sure
1719
+ # lru_cache is working for generate_anchors()
1720
+ spatial_shapes_tuple = tuple(spatial_shapes_list)
1721
+ anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
1722
+ else:
1723
+ anchors, valid_mask = self.anchors, self.valid_mask
1724
+ anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
1725
+
1726
+ # use the valid_mask to selectively retain values in the feature map where the mask is `True`
1727
+ memory = valid_mask.to(source_flatten.dtype) * source_flatten
1728
+
1729
+ output_memory = self.enc_output(memory)
1730
+
1731
+ enc_outputs_class = self.enc_score_head(output_memory)
1732
+ enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors
1733
+
1734
+ _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1)
1735
+
1736
+ reference_points_unact = enc_outputs_coord_logits.gather(
1737
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1])
1738
+ )
1739
+
1740
+ enc_topk_bboxes = F.sigmoid(reference_points_unact)
1741
+ if denoising_bbox_unact is not None:
1742
+ reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
1743
+
1744
+ enc_topk_logits = enc_outputs_class.gather(
1745
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
1746
+ )
1747
+
1748
+ # extract region features
1749
+ if self.config.learn_initial_query:
1750
+ target = self.weight_embedding.tile([batch_size, 1, 1])
1751
+ else:
1752
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
1753
+ target = target.detach()
1754
+
1755
+ if denoising_class is not None:
1756
+ target = torch.concat([denoising_class, target], 1)
1757
+
1758
+ init_reference_points = reference_points_unact.detach()
1759
+
1760
+ # decoder
1761
+ decoder_outputs = self.decoder(
1762
+ inputs_embeds=target,
1763
+ encoder_hidden_states=source_flatten,
1764
+ encoder_attention_mask=attention_mask,
1765
+ reference_points=init_reference_points,
1766
+ spatial_shapes=spatial_shapes,
1767
+ spatial_shapes_list=spatial_shapes_list,
1768
+ level_start_index=level_start_index,
1769
+ output_attentions=output_attentions,
1770
+ output_hidden_states=output_hidden_states,
1771
+ return_dict=return_dict,
1772
+ )
1773
+
1774
+ if not return_dict:
1775
+ enc_outputs = tuple(
1776
+ value
1777
+ for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
1778
+ if value is not None
1779
+ )
1780
+ dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
1781
+ tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
1782
+
1783
+ return tuple_outputs
1784
+
1785
+ return RTDetrModelOutput(
1786
+ last_hidden_state=decoder_outputs.last_hidden_state,
1787
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
1788
+ intermediate_logits=decoder_outputs.intermediate_logits,
1789
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
1790
+ intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners,
1791
+ initial_reference_points=decoder_outputs.initial_reference_points,
1792
+ decoder_hidden_states=decoder_outputs.hidden_states,
1793
+ decoder_attentions=decoder_outputs.attentions,
1794
+ cross_attentions=decoder_outputs.cross_attentions,
1795
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1796
+ encoder_hidden_states=encoder_outputs.hidden_states,
1797
+ encoder_attentions=encoder_outputs.attentions,
1798
+ init_reference_points=init_reference_points,
1799
+ enc_topk_logits=enc_topk_logits,
1800
+ enc_topk_bboxes=enc_topk_bboxes,
1801
+ enc_outputs_class=enc_outputs_class,
1802
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
1803
+ denoising_meta_values=denoising_meta_values,
1804
+ )
1805
+
1806
+
1807
+ @auto_docstring(
1808
+ custom_intro="""
1809
+ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further
1810
+ decoded into scores and classes.
1811
+ """
1812
+ )
1813
+ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1814
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
1815
+ _tied_weights_keys = ["bbox_embed", "class_embed"]
1816
+ # We can't initialize the model on meta device as some weights are modified during the initialization
1817
+ _no_split_modules = None
1818
+
1819
+ def __init__(self, config: RTDetrConfig):
1820
+ super().__init__(config)
1821
+
1822
+ # RTDETR encoder-decoder model
1823
+ self.model = RTDetrModel(config)
1824
+
1825
+ # Detection heads on top
1826
+ self.class_embed = partial(nn.Linear, config.d_model, config.num_labels)
1827
+ self.bbox_embed = partial(RTDetrMLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3)
1828
+
1829
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
1830
+ num_pred = config.decoder_layers
1831
+ if config.with_box_refine:
1832
+ self.class_embed = _get_clones(self.class_embed, num_pred)
1833
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
1834
+ else:
1835
+ self.class_embed = nn.ModuleList([self.class_embed() for _ in range(num_pred)])
1836
+ self.bbox_embed = nn.ModuleList([self.bbox_embed() for _ in range(num_pred)])
1837
+
1838
+ # hack implementation for iterative bounding box refinement
1839
+ self.model.decoder.class_embed = self.class_embed
1840
+ self.model.decoder.bbox_embed = self.bbox_embed
1841
+
1842
+ # Initialize weights and apply final processing
1843
+ self.post_init()
1844
+
1845
+ @torch.jit.unused
1846
+ def _set_aux_loss(self, outputs_class, outputs_coord):
1847
+ # this is a workaround to make torchscript happy, as torchscript
1848
+ # doesn't support dictionary with non-homogeneous values, such
1849
+ # as a dict having both a Tensor and a list.
1850
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
1851
+
1852
+ @auto_docstring
1853
+ def forward(
1854
+ self,
1855
+ pixel_values: torch.FloatTensor,
1856
+ pixel_mask: Optional[torch.LongTensor] = None,
1857
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1858
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1859
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1860
+ labels: Optional[list[dict]] = None,
1861
+ output_attentions: Optional[bool] = None,
1862
+ output_hidden_states: Optional[bool] = None,
1863
+ return_dict: Optional[bool] = None,
1864
+ **kwargs,
1865
+ ) -> Union[tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
1866
+ r"""
1867
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1868
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1869
+ can choose to directly pass a flattened representation of an image.
1870
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1871
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1872
+ embedded representation.
1873
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1874
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1875
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1876
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1877
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1878
+
1879
+ Examples:
1880
+
1881
+ ```python
1882
+ >>> from transformers import RTDetrImageProcessor, RTDetrForObjectDetection
1883
+ >>> from PIL import Image
1884
+ >>> import requests
1885
+ >>> import torch
1886
+
1887
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1888
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1889
+
1890
+ >>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
1891
+ >>> model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
1892
+
1893
+ >>> # prepare image for the model
1894
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1895
+
1896
+ >>> # forward pass
1897
+ >>> outputs = model(**inputs)
1898
+
1899
+ >>> logits = outputs.logits
1900
+ >>> list(logits.shape)
1901
+ [1, 300, 80]
1902
+
1903
+ >>> boxes = outputs.pred_boxes
1904
+ >>> list(boxes.shape)
1905
+ [1, 300, 4]
1906
+
1907
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
1908
+ >>> target_sizes = torch.tensor([image.size[::-1]])
1909
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
1910
+ ... 0
1911
+ ... ]
1912
+
1913
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
1914
+ ... box = [round(i, 2) for i in box.tolist()]
1915
+ ... print(
1916
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
1917
+ ... f"{round(score.item(), 3)} at location {box}"
1918
+ ... )
1919
+ Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21]
1920
+ Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5]
1921
+ Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22]
1922
+ Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
1923
+ Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
1924
+ ```"""
1925
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1926
+ output_hidden_states = (
1927
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1928
+ )
1929
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1930
+
1931
+ outputs = self.model(
1932
+ pixel_values,
1933
+ pixel_mask=pixel_mask,
1934
+ encoder_outputs=encoder_outputs,
1935
+ inputs_embeds=inputs_embeds,
1936
+ decoder_inputs_embeds=decoder_inputs_embeds,
1937
+ labels=labels,
1938
+ output_attentions=output_attentions,
1939
+ output_hidden_states=output_hidden_states,
1940
+ return_dict=return_dict,
1941
+ )
1942
+
1943
+ denoising_meta_values = (
1944
+ outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
1945
+ )
1946
+
1947
+ outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
1948
+ outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
1949
+ predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
1950
+ initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
1951
+
1952
+ logits = outputs_class[:, -1]
1953
+ pred_boxes = outputs_coord[:, -1]
1954
+
1955
+ loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
1956
+ if labels is not None:
1957
+ enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
1958
+ enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
1959
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1960
+ logits,
1961
+ labels,
1962
+ self.device,
1963
+ pred_boxes,
1964
+ self.config,
1965
+ outputs_class,
1966
+ outputs_coord,
1967
+ enc_topk_logits=enc_topk_logits,
1968
+ enc_topk_bboxes=enc_topk_bboxes,
1969
+ denoising_meta_values=denoising_meta_values,
1970
+ predicted_corners=predicted_corners,
1971
+ initial_reference_points=initial_reference_points,
1972
+ **kwargs,
1973
+ )
1974
+
1975
+ if not return_dict:
1976
+ if auxiliary_outputs is not None:
1977
+ output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
1978
+ else:
1979
+ output = (logits, pred_boxes) + outputs
1980
+ return ((loss, loss_dict) + output) if loss is not None else output
1981
+
1982
+ return RTDetrObjectDetectionOutput(
1983
+ loss=loss,
1984
+ loss_dict=loss_dict,
1985
+ logits=logits,
1986
+ pred_boxes=pred_boxes,
1987
+ auxiliary_outputs=auxiliary_outputs,
1988
+ last_hidden_state=outputs.last_hidden_state,
1989
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
1990
+ intermediate_logits=outputs.intermediate_logits,
1991
+ intermediate_reference_points=outputs.intermediate_reference_points,
1992
+ intermediate_predicted_corners=outputs.intermediate_predicted_corners,
1993
+ initial_reference_points=outputs.initial_reference_points,
1994
+ decoder_hidden_states=outputs.decoder_hidden_states,
1995
+ decoder_attentions=outputs.decoder_attentions,
1996
+ cross_attentions=outputs.cross_attentions,
1997
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1998
+ encoder_hidden_states=outputs.encoder_hidden_states,
1999
+ encoder_attentions=outputs.encoder_attentions,
2000
+ init_reference_points=outputs.init_reference_points,
2001
+ enc_topk_logits=outputs.enc_topk_logits,
2002
+ enc_topk_bboxes=outputs.enc_topk_bboxes,
2003
+ enc_outputs_class=outputs.enc_outputs_class,
2004
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
2005
+ denoising_meta_values=outputs.denoising_meta_values,
2006
+ )
2007
+
2008
+
2009
+ __all__ = [
2010
+ "RTDetrForObjectDetection",
2011
+ "RTDetrModel",
2012
+ "RTDetrPreTrainedModel",
2013
+ ]
phivenv/Lib/site-packages/transformers/models/rt_detr/modeling_rt_detr_resnet.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ PyTorch RTDetr specific ResNet model. The main difference between hugginface ResNet model is that this RTDetrResNet model forces to use shortcut at the first layer in the resnet-18/34 models.
17
+ See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
18
+ """
19
+
20
+ import math
21
+ from typing import Optional
22
+
23
+ from torch import Tensor, nn
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...utils import auto_docstring, logging
29
+ from ...utils.backbone_utils import BackboneMixin
30
+ from .configuration_rt_detr_resnet import RTDetrResNetConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ # Copied from transformers.models.resnet.modeling_resnet.ResNetConvLayer -> RTDetrResNetConvLayer
37
+ class RTDetrResNetConvLayer(nn.Module):
38
+ def __init__(
39
+ self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
40
+ ):
41
+ super().__init__()
42
+ self.convolution = nn.Conv2d(
43
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
44
+ )
45
+ self.normalization = nn.BatchNorm2d(out_channels)
46
+ self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
47
+
48
+ def forward(self, input: Tensor) -> Tensor:
49
+ hidden_state = self.convolution(input)
50
+ hidden_state = self.normalization(hidden_state)
51
+ hidden_state = self.activation(hidden_state)
52
+ return hidden_state
53
+
54
+
55
+ class RTDetrResNetEmbeddings(nn.Module):
56
+ """
57
+ ResNet Embeddings (stem) composed of a deep aggressive convolution.
58
+ """
59
+
60
+ def __init__(self, config: RTDetrResNetConfig):
61
+ super().__init__()
62
+ self.embedder = nn.Sequential(
63
+ *[
64
+ RTDetrResNetConvLayer(
65
+ config.num_channels,
66
+ config.embedding_size // 2,
67
+ kernel_size=3,
68
+ stride=2,
69
+ activation=config.hidden_act,
70
+ ),
71
+ RTDetrResNetConvLayer(
72
+ config.embedding_size // 2,
73
+ config.embedding_size // 2,
74
+ kernel_size=3,
75
+ stride=1,
76
+ activation=config.hidden_act,
77
+ ),
78
+ RTDetrResNetConvLayer(
79
+ config.embedding_size // 2,
80
+ config.embedding_size,
81
+ kernel_size=3,
82
+ stride=1,
83
+ activation=config.hidden_act,
84
+ ),
85
+ ]
86
+ )
87
+ self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
88
+ self.num_channels = config.num_channels
89
+
90
+ def forward(self, pixel_values: Tensor) -> Tensor:
91
+ num_channels = pixel_values.shape[1]
92
+ if num_channels != self.num_channels:
93
+ raise ValueError(
94
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
95
+ )
96
+ embedding = self.embedder(pixel_values)
97
+ embedding = self.pooler(embedding)
98
+ return embedding
99
+
100
+
101
+ # Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut -> RTDetrResNetChortCut
102
+ class RTDetrResNetShortCut(nn.Module):
103
+ """
104
+ ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
105
+ downsample the input using `stride=2`.
106
+ """
107
+
108
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
109
+ super().__init__()
110
+ self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
111
+ self.normalization = nn.BatchNorm2d(out_channels)
112
+
113
+ def forward(self, input: Tensor) -> Tensor:
114
+ hidden_state = self.convolution(input)
115
+ hidden_state = self.normalization(hidden_state)
116
+ return hidden_state
117
+
118
+
119
+ class RTDetrResNetBasicLayer(nn.Module):
120
+ """
121
+ A classic ResNet's residual layer composed by two `3x3` convolutions.
122
+ See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L34.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ config: RTDetrResNetConfig,
128
+ in_channels: int,
129
+ out_channels: int,
130
+ stride: int = 1,
131
+ should_apply_shortcut: bool = False,
132
+ ):
133
+ super().__init__()
134
+ if in_channels != out_channels:
135
+ self.shortcut = (
136
+ nn.Sequential(
137
+ *[nn.AvgPool2d(2, 2, 0, ceil_mode=True), RTDetrResNetShortCut(in_channels, out_channels, stride=1)]
138
+ )
139
+ if should_apply_shortcut
140
+ else nn.Identity()
141
+ )
142
+ else:
143
+ self.shortcut = (
144
+ RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
145
+ if should_apply_shortcut
146
+ else nn.Identity()
147
+ )
148
+ self.layer = nn.Sequential(
149
+ RTDetrResNetConvLayer(in_channels, out_channels, stride=stride),
150
+ RTDetrResNetConvLayer(out_channels, out_channels, activation=None),
151
+ )
152
+ self.activation = ACT2FN[config.hidden_act]
153
+
154
+ def forward(self, hidden_state):
155
+ residual = hidden_state
156
+ hidden_state = self.layer(hidden_state)
157
+ residual = self.shortcut(residual)
158
+ hidden_state += residual
159
+ hidden_state = self.activation(hidden_state)
160
+ return hidden_state
161
+
162
+
163
+ class RTDetrResNetBottleNeckLayer(nn.Module):
164
+ """
165
+ A classic RTDetrResNet's bottleneck layer composed by three `3x3` convolutions.
166
+
167
+ The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
168
+ convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
169
+ `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ config: RTDetrResNetConfig,
175
+ in_channels: int,
176
+ out_channels: int,
177
+ stride: int = 1,
178
+ ):
179
+ super().__init__()
180
+ reduction = 4
181
+ should_apply_shortcut = in_channels != out_channels or stride != 1
182
+ reduces_channels = out_channels // reduction
183
+ if stride == 2:
184
+ self.shortcut = nn.Sequential(
185
+ *[
186
+ nn.AvgPool2d(2, 2, 0, ceil_mode=True),
187
+ RTDetrResNetShortCut(in_channels, out_channels, stride=1)
188
+ if should_apply_shortcut
189
+ else nn.Identity(),
190
+ ]
191
+ )
192
+ else:
193
+ self.shortcut = (
194
+ RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
195
+ if should_apply_shortcut
196
+ else nn.Identity()
197
+ )
198
+ self.layer = nn.Sequential(
199
+ RTDetrResNetConvLayer(
200
+ in_channels, reduces_channels, kernel_size=1, stride=stride if config.downsample_in_bottleneck else 1
201
+ ),
202
+ RTDetrResNetConvLayer(
203
+ reduces_channels, reduces_channels, stride=stride if not config.downsample_in_bottleneck else 1
204
+ ),
205
+ RTDetrResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
206
+ )
207
+ self.activation = ACT2FN[config.hidden_act]
208
+
209
+ def forward(self, hidden_state):
210
+ residual = hidden_state
211
+ hidden_state = self.layer(hidden_state)
212
+ residual = self.shortcut(residual)
213
+ hidden_state += residual
214
+ hidden_state = self.activation(hidden_state)
215
+ return hidden_state
216
+
217
+
218
+ class RTDetrResNetStage(nn.Module):
219
+ """
220
+ A RTDetrResNet stage composed by stacked layers.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ config: RTDetrResNetConfig,
226
+ in_channels: int,
227
+ out_channels: int,
228
+ stride: int = 2,
229
+ depth: int = 2,
230
+ ):
231
+ super().__init__()
232
+
233
+ layer = RTDetrResNetBottleNeckLayer if config.layer_type == "bottleneck" else RTDetrResNetBasicLayer
234
+
235
+ if config.layer_type == "bottleneck":
236
+ first_layer = layer(
237
+ config,
238
+ in_channels,
239
+ out_channels,
240
+ stride=stride,
241
+ )
242
+ else:
243
+ first_layer = layer(config, in_channels, out_channels, stride=stride, should_apply_shortcut=True)
244
+ self.layers = nn.Sequential(
245
+ first_layer, *[layer(config, out_channels, out_channels) for _ in range(depth - 1)]
246
+ )
247
+
248
+ def forward(self, input: Tensor) -> Tensor:
249
+ hidden_state = input
250
+ for layer in self.layers:
251
+ hidden_state = layer(hidden_state)
252
+ return hidden_state
253
+
254
+
255
+ # Copied from transformers.models.resnet.modeling_resnet.ResNetEncoder with ResNet->RTDetrResNet
256
+ class RTDetrResNetEncoder(nn.Module):
257
+ def __init__(self, config: RTDetrResNetConfig):
258
+ super().__init__()
259
+ self.stages = nn.ModuleList([])
260
+ # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
261
+ self.stages.append(
262
+ RTDetrResNetStage(
263
+ config,
264
+ config.embedding_size,
265
+ config.hidden_sizes[0],
266
+ stride=2 if config.downsample_in_first_stage else 1,
267
+ depth=config.depths[0],
268
+ )
269
+ )
270
+ in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
271
+ for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
272
+ self.stages.append(RTDetrResNetStage(config, in_channels, out_channels, depth=depth))
273
+
274
+ def forward(
275
+ self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
276
+ ) -> BaseModelOutputWithNoAttention:
277
+ hidden_states = () if output_hidden_states else None
278
+
279
+ for stage_module in self.stages:
280
+ if output_hidden_states:
281
+ hidden_states = hidden_states + (hidden_state,)
282
+
283
+ hidden_state = stage_module(hidden_state)
284
+
285
+ if output_hidden_states:
286
+ hidden_states = hidden_states + (hidden_state,)
287
+
288
+ if not return_dict:
289
+ return tuple(v for v in [hidden_state, hidden_states] if v is not None)
290
+
291
+ return BaseModelOutputWithNoAttention(
292
+ last_hidden_state=hidden_state,
293
+ hidden_states=hidden_states,
294
+ )
295
+
296
+
297
+ @auto_docstring
298
+ # Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RTDetrResNet
299
+ class RTDetrResNetPreTrainedModel(PreTrainedModel):
300
+ config: RTDetrResNetConfig
301
+ base_model_prefix = "resnet"
302
+ main_input_name = "pixel_values"
303
+ _no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"]
304
+
305
+ def _init_weights(self, module):
306
+ if isinstance(module, nn.Conv2d):
307
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
308
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
309
+ elif isinstance(module, nn.Linear):
310
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
311
+ if module.bias is not None:
312
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
313
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
314
+ nn.init.uniform_(module.bias, -bound, bound)
315
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
316
+ nn.init.constant_(module.weight, 1)
317
+ nn.init.constant_(module.bias, 0)
318
+
319
+
320
+ @auto_docstring(
321
+ custom_intro="""
322
+ ResNet backbone, to be used with frameworks like RTDETR.
323
+ """
324
+ )
325
+ class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin):
326
+ has_attentions = False
327
+
328
+ def __init__(self, config):
329
+ super().__init__(config)
330
+ super()._init_backbone(config)
331
+
332
+ self.num_features = [config.embedding_size] + config.hidden_sizes
333
+ self.embedder = RTDetrResNetEmbeddings(config)
334
+ self.encoder = RTDetrResNetEncoder(config)
335
+
336
+ # initialize weights and apply final processing
337
+ self.post_init()
338
+
339
+ @auto_docstring
340
+ def forward(
341
+ self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
342
+ ) -> BackboneOutput:
343
+ r"""
344
+ Examples:
345
+
346
+ ```python
347
+ >>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone
348
+ >>> import torch
349
+ from ...utils.deprecation import deprecate_kwarg
350
+ from ...utils.deprecation import deprecate_kwarg
351
+ from ...utils.deprecation import deprecate_kwarg
352
+ from ...utils.deprecation import deprecate_kwarg
353
+ from ...utils.deprecation import deprecate_kwarg
354
+
355
+ >>> config = RTDetrResNetConfig()
356
+ >>> model = RTDetrResNetBackbone(config)
357
+
358
+ >>> pixel_values = torch.randn(1, 3, 224, 224)
359
+
360
+ >>> with torch.no_grad():
361
+ ... outputs = model(pixel_values)
362
+
363
+ >>> feature_maps = outputs.feature_maps
364
+ >>> list(feature_maps[-1].shape)
365
+ [1, 2048, 7, 7]
366
+ ```"""
367
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
368
+ output_hidden_states = (
369
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
370
+ )
371
+
372
+ embedding_output = self.embedder(pixel_values)
373
+
374
+ outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
375
+
376
+ hidden_states = outputs.hidden_states
377
+
378
+ feature_maps = ()
379
+ for idx, stage in enumerate(self.stage_names):
380
+ if stage in self.out_features:
381
+ feature_maps += (hidden_states[idx],)
382
+
383
+ if not return_dict:
384
+ output = (feature_maps,)
385
+ if output_hidden_states:
386
+ output += (outputs.hidden_states,)
387
+ return output
388
+
389
+ return BackboneOutput(
390
+ feature_maps=feature_maps,
391
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
392
+ attentions=None,
393
+ )
394
+
395
+
396
+ __all__ = [
397
+ "RTDetrResNetBackbone",
398
+ "RTDetrResNetPreTrainedModel",
399
+ ]
phivenv/Lib/site-packages/transformers/models/rt_detr/modular_rt_detr.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Optional, Union
3
+
4
+ from transformers.models.detr.image_processing_detr_fast import DetrFastImageProcessorKwargs, DetrImageProcessorFast
5
+
6
+ from ...image_processing_utils import BatchFeature
7
+ from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, get_max_height_width
8
+ from ...image_transforms import center_to_corners_format
9
+ from ...image_utils import (
10
+ IMAGENET_DEFAULT_MEAN,
11
+ IMAGENET_DEFAULT_STD,
12
+ AnnotationFormat,
13
+ AnnotationType,
14
+ ChannelDimension,
15
+ ImageInput,
16
+ PILImageResampling,
17
+ get_image_size,
18
+ validate_annotations,
19
+ )
20
+ from ...processing_utils import Unpack
21
+ from ...utils import (
22
+ TensorType,
23
+ is_torch_available,
24
+ is_torchvision_available,
25
+ is_torchvision_v2_available,
26
+ logging,
27
+ requires_backends,
28
+ )
29
+
30
+
31
+ if is_torch_available():
32
+ import torch
33
+
34
+
35
+ if is_torchvision_v2_available():
36
+ from torchvision.transforms.v2 import functional as F
37
+ elif is_torchvision_available():
38
+ from torchvision.transforms import functional as F
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
44
+
45
+
46
+ def prepare_coco_detection_annotation(
47
+ image,
48
+ target,
49
+ return_segmentation_masks: bool = False,
50
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
51
+ ):
52
+ """
53
+ Convert the target in COCO format into the format expected by RT-DETR.
54
+ """
55
+ image_height, image_width = image.size()[-2:]
56
+
57
+ image_id = target["image_id"]
58
+ image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
59
+
60
+ # Get all COCO annotations for the given image.
61
+ annotations = target["annotations"]
62
+ classes = []
63
+ area = []
64
+ boxes = []
65
+ keypoints = []
66
+ for obj in annotations:
67
+ if "iscrowd" not in obj or obj["iscrowd"] == 0:
68
+ classes.append(obj["category_id"])
69
+ area.append(obj["area"])
70
+ boxes.append(obj["bbox"])
71
+ if "keypoints" in obj:
72
+ keypoints.append(obj["keypoints"])
73
+
74
+ classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
75
+ area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
76
+ iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
77
+ # guard against no boxes via resizing
78
+ boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
79
+ boxes[:, 2:] += boxes[:, :2]
80
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
81
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
82
+
83
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
84
+
85
+ new_target = {
86
+ "image_id": image_id,
87
+ "class_labels": classes[keep],
88
+ "boxes": boxes[keep],
89
+ "area": area[keep],
90
+ "iscrowd": iscrowd[keep],
91
+ "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
92
+ }
93
+
94
+ if keypoints:
95
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
96
+ # Apply the keep mask here to filter the relevant annotations
97
+ keypoints = keypoints[keep]
98
+ num_keypoints = keypoints.shape[0]
99
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
100
+ new_target["keypoints"] = keypoints
101
+
102
+ return new_target
103
+
104
+
105
+ class RTDetrFastImageProcessorKwargs(DetrFastImageProcessorKwargs):
106
+ pass
107
+
108
+
109
+ class RTDetrImageProcessorFast(DetrImageProcessorFast):
110
+ resample = PILImageResampling.BILINEAR
111
+ image_mean = IMAGENET_DEFAULT_MEAN
112
+ image_std = IMAGENET_DEFAULT_STD
113
+ format = AnnotationFormat.COCO_DETECTION
114
+ do_convert_annotations = True
115
+ do_resize = True
116
+ do_rescale = True
117
+ do_normalize = False
118
+ do_pad = False
119
+ size = {"height": 640, "width": 640}
120
+ default_to_square = False
121
+ model_input_names = ["pixel_values", "pixel_mask"]
122
+ valid_kwargs = RTDetrFastImageProcessorKwargs
123
+
124
+ def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None:
125
+ # Backwards compatibility
126
+ do_convert_annotations = kwargs.get("do_convert_annotations")
127
+ do_normalize = kwargs.get("do_normalize")
128
+ if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
129
+ self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
130
+
131
+ BaseImageProcessorFast.__init__(self, **kwargs)
132
+
133
+ def preprocess(
134
+ self,
135
+ images: ImageInput,
136
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
137
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
138
+ **kwargs: Unpack[RTDetrFastImageProcessorKwargs],
139
+ ) -> BatchFeature:
140
+ return BaseImageProcessorFast.preprocess(self, images, annotations, masks_path, **kwargs)
141
+
142
+ def prepare_annotation(
143
+ self,
144
+ image: torch.Tensor,
145
+ target: dict,
146
+ format: Optional[AnnotationFormat] = None,
147
+ return_segmentation_masks: Optional[bool] = None,
148
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
149
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
150
+ ) -> dict:
151
+ format = format if format is not None else self.format
152
+
153
+ if format == AnnotationFormat.COCO_DETECTION:
154
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
155
+ target = prepare_coco_detection_annotation(
156
+ image, target, return_segmentation_masks, input_data_format=input_data_format
157
+ )
158
+ else:
159
+ raise ValueError(f"Format {format} is not supported.")
160
+ return target
161
+
162
+ def _preprocess(
163
+ self,
164
+ images: list["torch.Tensor"],
165
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
166
+ masks_path: Optional[Union[str, pathlib.Path]],
167
+ return_segmentation_masks: bool,
168
+ do_resize: bool,
169
+ size: SizeDict,
170
+ interpolation: Optional["F.InterpolationMode"],
171
+ do_rescale: bool,
172
+ rescale_factor: float,
173
+ do_normalize: bool,
174
+ do_convert_annotations: bool,
175
+ image_mean: Optional[Union[float, list[float]]],
176
+ image_std: Optional[Union[float, list[float]]],
177
+ do_pad: bool,
178
+ pad_size: Optional[dict[str, int]],
179
+ format: Optional[Union[str, AnnotationFormat]],
180
+ return_tensors: Optional[Union[str, TensorType]],
181
+ **kwargs,
182
+ ) -> BatchFeature:
183
+ """
184
+ Preprocess an image or a batch of images so that it can be used by the model.
185
+ """
186
+
187
+ if annotations is not None and isinstance(annotations, dict):
188
+ annotations = [annotations]
189
+
190
+ if annotations is not None and len(images) != len(annotations):
191
+ raise ValueError(
192
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
193
+ )
194
+
195
+ format = AnnotationFormat(format)
196
+ if annotations is not None:
197
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
198
+
199
+ data = {}
200
+ processed_images = []
201
+ processed_annotations = []
202
+ pixel_masks = [] # Initialize pixel_masks here
203
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
204
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
205
+ if annotations is not None:
206
+ annotation = self.prepare_annotation(
207
+ image,
208
+ annotation,
209
+ format,
210
+ return_segmentation_masks=return_segmentation_masks,
211
+ masks_path=masks_path,
212
+ input_data_format=ChannelDimension.FIRST,
213
+ )
214
+
215
+ if do_resize:
216
+ resized_image = self.resize(image, size=size, interpolation=interpolation)
217
+ if annotations is not None:
218
+ annotation = self.resize_annotation(
219
+ annotation,
220
+ orig_size=image.size()[-2:],
221
+ target_size=resized_image.size()[-2:],
222
+ )
223
+ image = resized_image
224
+ # Fused rescale and normalize
225
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
226
+ if do_convert_annotations and annotations is not None:
227
+ annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
228
+
229
+ processed_images.append(image)
230
+ processed_annotations.append(annotation)
231
+ images = processed_images
232
+ annotations = processed_annotations if annotations is not None else None
233
+
234
+ if do_pad:
235
+ # depends on all resized image shapes so we need another loop
236
+ if pad_size is not None:
237
+ padded_size = (pad_size["height"], pad_size["width"])
238
+ else:
239
+ padded_size = get_max_height_width(images)
240
+
241
+ padded_images = []
242
+ padded_annotations = []
243
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
244
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
245
+ if padded_size == image.size()[-2:]:
246
+ padded_images.append(image)
247
+ pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
248
+ padded_annotations.append(annotation)
249
+ continue
250
+ image, pixel_mask, annotation = self.pad(
251
+ image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
252
+ )
253
+ padded_images.append(image)
254
+ padded_annotations.append(annotation)
255
+ pixel_masks.append(pixel_mask)
256
+ images = padded_images
257
+ annotations = padded_annotations if annotations is not None else None
258
+ data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
259
+
260
+ data.update({"pixel_values": torch.stack(images, dim=0)})
261
+ encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
262
+ if annotations is not None:
263
+ encoded_inputs["labels"] = [
264
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
265
+ ]
266
+ return encoded_inputs
267
+
268
+ def post_process_object_detection(
269
+ self,
270
+ outputs,
271
+ threshold: float = 0.5,
272
+ target_sizes: Union[TensorType, list[tuple]] = None,
273
+ use_focal_loss: bool = True,
274
+ ):
275
+ """
276
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
277
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
278
+
279
+ Args:
280
+ outputs ([`DetrObjectDetectionOutput`]):
281
+ Raw outputs of the model.
282
+ threshold (`float`, *optional*, defaults to 0.5):
283
+ Score threshold to keep object detection predictions.
284
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
285
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
286
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
287
+ use_focal_loss (`bool` defaults to `True`):
288
+ Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
289
+ to compute the scores of each detection, otherwise, a softmax function is used.
290
+
291
+ Returns:
292
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
293
+ in the batch as predicted by the model.
294
+ """
295
+ requires_backends(self, ["torch"])
296
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
297
+ # convert from relative cxcywh to absolute xyxy
298
+ boxes = center_to_corners_format(out_bbox)
299
+ if target_sizes is not None:
300
+ if len(out_logits) != len(target_sizes):
301
+ raise ValueError(
302
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
303
+ )
304
+ if isinstance(target_sizes, list):
305
+ img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
306
+ else:
307
+ img_h, img_w = target_sizes.unbind(1)
308
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
309
+ boxes = boxes * scale_fct[:, None, :]
310
+
311
+ num_top_queries = out_logits.shape[1]
312
+ num_classes = out_logits.shape[2]
313
+
314
+ if use_focal_loss:
315
+ scores = torch.nn.functional.sigmoid(out_logits)
316
+ scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
317
+ labels = index % num_classes
318
+ index = index // num_classes
319
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
320
+ else:
321
+ scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
322
+ scores, labels = scores.max(dim=-1)
323
+ if scores.shape[1] > num_top_queries:
324
+ scores, index = torch.topk(scores, num_top_queries, dim=-1)
325
+ labels = torch.gather(labels, dim=1, index=index)
326
+ boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
327
+
328
+ results = []
329
+ for score, label, box in zip(scores, labels, boxes):
330
+ results.append(
331
+ {
332
+ "scores": score[score > threshold],
333
+ "labels": label[score > threshold],
334
+ "boxes": box[score > threshold],
335
+ }
336
+ )
337
+
338
+ return results
339
+
340
+ def from_dict():
341
+ raise NotImplementedError("No need to override this method for RT-DETR yet.")
342
+
343
+ def post_process():
344
+ raise NotImplementedError("Post-processing is not implemented for RT-DETR yet.")
345
+
346
+ def post_process_segmentation():
347
+ raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.")
348
+
349
+ def post_process_instance():
350
+ raise NotImplementedError("Instance post-processing is not implemented for RT-DETR yet.")
351
+
352
+ def post_process_panoptic():
353
+ raise NotImplementedError("Panoptic post-processing is not implemented for RT-DETR yet.")
354
+
355
+ def post_process_instance_segmentation():
356
+ raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.")
357
+
358
+ def post_process_semantic_segmentation():
359
+ raise NotImplementedError("Semantic segmentation post-processing is not implemented for RT-DETR yet.")
360
+
361
+ def post_process_panoptic_segmentation():
362
+ raise NotImplementedError("Panoptic segmentation post-processing is not implemented for RT-DETR yet.")
363
+
364
+
365
+ __all__ = ["RTDetrImageProcessorFast"]
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ from ...utils import _LazyModule
19
+ from ...utils.import_utils import define_import_structure
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from .configuration_rt_detr_v2 import *
24
+ from .modeling_rt_detr_v2 import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (540 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/configuration_rt_detr_v2.cpython-39.pyc ADDED
Binary file (15.2 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/modeling_rt_detr_v2.cpython-39.pyc ADDED
Binary file (64.3 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/__pycache__/modular_rt_detr_v2.cpython-39.pyc ADDED
Binary file (22.9 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_rt_detr_v2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...utils import logging
24
+ from ...utils.backbone_utils import verify_backbone_config_arguments
25
+ from ..auto import CONFIG_MAPPING
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RTDetrV2Config(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a
34
+ RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
35
+ with the defaults will yield a similar configuration to that of the RT-DETR architecture.
36
+
37
+ e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd)
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ initializer_range (`float`, *optional*, defaults to 0.01):
44
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
45
+ initializer_bias_prior_prob (`float`, *optional*):
46
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
47
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
48
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
49
+ The epsilon used by the layer normalization layers.
50
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
51
+ The epsilon used by the batch normalization layers.
52
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
53
+ The configuration of the backbone model.
54
+ backbone (`str`, *optional*):
55
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
56
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
57
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
58
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
59
+ Whether to use pretrained weights for the backbone.
60
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
61
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
62
+ library.
63
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
64
+ Whether to freeze the batch normalization layers in the backbone.
65
+ backbone_kwargs (`dict`, *optional*):
66
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
67
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
68
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
69
+ Dimension of the layers in hybrid encoder.
70
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
71
+ Multi level features input for encoder.
72
+ feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
73
+ Strides used in each feature map.
74
+ encoder_layers (`int`, *optional*, defaults to 1):
75
+ Total of layers to be used by the encoder.
76
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
77
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
78
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
79
+ Number of attention heads for each attention layer in the Transformer encoder.
80
+ dropout (`float`, *optional*, defaults to 0.0):
81
+ The ratio for all dropout layers.
82
+ activation_dropout (`float`, *optional*, defaults to 0.0):
83
+ The dropout ratio for activations inside the fully connected layer.
84
+ encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
85
+ Indexes of the projected layers to be used in the encoder.
86
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
87
+ The temperature parameter used to create the positional encodings.
88
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
89
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
90
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
91
+ activation_function (`str`, *optional*, defaults to `"silu"`):
92
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
93
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
94
+ eval_size (`tuple[int, int]`, *optional*):
95
+ Height and width used to compute the effective height and width of the position embeddings after taking
96
+ into account the stride.
97
+ normalize_before (`bool`, *optional*, defaults to `False`):
98
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
99
+ feed-forward modules.
100
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
101
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
102
+ d_model (`int`, *optional*, defaults to 256):
103
+ Dimension of the layers exclude hybrid encoder.
104
+ num_queries (`int`, *optional*, defaults to 300):
105
+ Number of object queries.
106
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
107
+ Multi level features dimension for decoder
108
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
109
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
110
+ num_feature_levels (`int`, *optional*, defaults to 3):
111
+ The number of input feature levels.
112
+ decoder_n_points (`int`, *optional*, defaults to 4):
113
+ The number of sampled keys in each feature level for each attention head in the decoder.
114
+ decoder_layers (`int`, *optional*, defaults to 6):
115
+ Number of decoder layers.
116
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
117
+ Number of attention heads for each attention layer in the Transformer decoder.
118
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
119
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
120
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
121
+ attention_dropout (`float`, *optional*, defaults to 0.0):
122
+ The dropout ratio for the attention probabilities.
123
+ num_denoising (`int`, *optional*, defaults to 100):
124
+ The total number of denoising tasks or queries to be used for contrastive denoising.
125
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
126
+ The fraction of denoising labels to which random noise should be added.
127
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
128
+ Scale or magnitude of noise to be added to the bounding boxes.
129
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
130
+ Indicates whether the initial query embeddings for the decoder should be learned during training
131
+ anchor_image_size (`tuple[int, int]`, *optional*):
132
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
133
+ with_box_refine (`bool`, *optional*, defaults to `True`):
134
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
135
+ based on the predictions from the previous layer.
136
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
137
+ Whether the architecture has an encoder decoder structure.
138
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
139
+ Parameter alpha used by the Hungarian Matcher.
140
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
141
+ Parameter gamma used by the Hungarian Matcher.
142
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
143
+ The relative weight of the class loss used by the Hungarian Matcher.
144
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
145
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
146
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
147
+ The relative weight of the giou loss of used by the Hungarian Matcher.
148
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
149
+ Parameter informing if focal loss should be used.
150
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
151
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
152
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
153
+ Parameter alpha used to compute the focal loss.
154
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
155
+ Parameter gamma used to compute the focal loss.
156
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
157
+ Relative weight of the varifocal loss in the object detection loss.
158
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
159
+ Relative weight of the L1 bounding box loss in the object detection loss.
160
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
161
+ Relative weight of the generalized IoU loss in the object detection loss.
162
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
163
+ Relative classification weight of the 'no-object' class in the object detection loss.
164
+ decoder_n_levels (`int`, *optional*, defaults to 3):
165
+ The number of feature levels used by the decoder.
166
+ decoder_offset_scale (`float`, *optional*, defaults to 0.5):
167
+ Scaling factor applied to the attention offsets in the decoder.
168
+ decoder_method (`str`, *optional*, defaults to `"default"`):
169
+ The method to use for the decoder: `"default"` or `"discrete"`.
170
+
171
+ Examples:
172
+
173
+ ```python
174
+ >>> from transformers import RTDetrV2Config, RTDetrV2Model
175
+
176
+ >>> # Initializing a RT-DETR configuration
177
+ >>> configuration = RTDetrV2Config()
178
+
179
+ >>> # Initializing a model (with random weights) from the configuration
180
+ >>> model = RTDetrV2Model(configuration)
181
+
182
+ >>> # Accessing the model configuration
183
+ >>> configuration = model.config
184
+ ```
185
+ """
186
+
187
+ model_type = "rt_detr_v2"
188
+ layer_types = ["basic", "bottleneck"]
189
+ attribute_map = {
190
+ "hidden_size": "d_model",
191
+ "num_attention_heads": "encoder_attention_heads",
192
+ }
193
+
194
+ def __init__(
195
+ self,
196
+ initializer_range=0.01,
197
+ initializer_bias_prior_prob=None,
198
+ layer_norm_eps=1e-5,
199
+ batch_norm_eps=1e-5,
200
+ # backbone
201
+ backbone_config=None,
202
+ backbone=None,
203
+ use_pretrained_backbone=False,
204
+ use_timm_backbone=False,
205
+ freeze_backbone_batch_norms=True,
206
+ backbone_kwargs=None,
207
+ # encoder HybridEncoder
208
+ encoder_hidden_dim=256,
209
+ encoder_in_channels=[512, 1024, 2048],
210
+ feat_strides=[8, 16, 32],
211
+ encoder_layers=1,
212
+ encoder_ffn_dim=1024,
213
+ encoder_attention_heads=8,
214
+ dropout=0.0,
215
+ activation_dropout=0.0,
216
+ encode_proj_layers=[2],
217
+ positional_encoding_temperature=10000,
218
+ encoder_activation_function="gelu",
219
+ activation_function="silu",
220
+ eval_size=None,
221
+ normalize_before=False,
222
+ hidden_expansion=1.0,
223
+ # decoder RTDetrV2Transformer
224
+ d_model=256,
225
+ num_queries=300,
226
+ decoder_in_channels=[256, 256, 256],
227
+ decoder_ffn_dim=1024,
228
+ num_feature_levels=3,
229
+ decoder_n_points=4,
230
+ decoder_layers=6,
231
+ decoder_attention_heads=8,
232
+ decoder_activation_function="relu",
233
+ attention_dropout=0.0,
234
+ num_denoising=100,
235
+ label_noise_ratio=0.5,
236
+ box_noise_scale=1.0,
237
+ learn_initial_query=False,
238
+ anchor_image_size=None,
239
+ with_box_refine=True,
240
+ is_encoder_decoder=True,
241
+ # Loss
242
+ matcher_alpha=0.25,
243
+ matcher_gamma=2.0,
244
+ matcher_class_cost=2.0,
245
+ matcher_bbox_cost=5.0,
246
+ matcher_giou_cost=2.0,
247
+ use_focal_loss=True,
248
+ auxiliary_loss=True,
249
+ focal_loss_alpha=0.75,
250
+ focal_loss_gamma=2.0,
251
+ weight_loss_vfl=1.0,
252
+ weight_loss_bbox=5.0,
253
+ weight_loss_giou=2.0,
254
+ eos_coefficient=1e-4,
255
+ decoder_n_levels=3, # default value
256
+ decoder_offset_scale=0.5, # default value
257
+ decoder_method="default",
258
+ **kwargs,
259
+ ):
260
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
261
+ self.initializer_range = initializer_range
262
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
263
+ self.layer_norm_eps = layer_norm_eps
264
+ self.batch_norm_eps = batch_norm_eps
265
+ # backbone
266
+ if backbone_config is None and backbone is None:
267
+ logger.info(
268
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone."
269
+ )
270
+ backbone_model_type = "rt_detr_resnet"
271
+ config_class = CONFIG_MAPPING[backbone_model_type]
272
+ # this will map it to RTDetrResNetConfig
273
+ # note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1
274
+ # and we would need to create RTDetrV2ResNetModel
275
+ backbone_config = config_class(
276
+ num_channels=3,
277
+ embedding_size=64,
278
+ hidden_sizes=[256, 512, 1024, 2048],
279
+ depths=[3, 4, 6, 3],
280
+ layer_type="bottleneck",
281
+ hidden_act="relu",
282
+ downsample_in_first_stage=False,
283
+ downsample_in_bottleneck=False,
284
+ out_features=None,
285
+ out_indices=[2, 3, 4],
286
+ )
287
+ elif isinstance(backbone_config, dict):
288
+ backbone_model_type = backbone_config.pop("model_type")
289
+ config_class = CONFIG_MAPPING[backbone_model_type]
290
+ backbone_config = config_class.from_dict(backbone_config)
291
+
292
+ verify_backbone_config_arguments(
293
+ use_timm_backbone=use_timm_backbone,
294
+ use_pretrained_backbone=use_pretrained_backbone,
295
+ backbone=backbone,
296
+ backbone_config=backbone_config,
297
+ backbone_kwargs=backbone_kwargs,
298
+ )
299
+
300
+ self.backbone_config = backbone_config
301
+ self.backbone = backbone
302
+ self.use_pretrained_backbone = use_pretrained_backbone
303
+ self.use_timm_backbone = use_timm_backbone
304
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
305
+ self.backbone_kwargs = backbone_kwargs
306
+ # encoder
307
+ self.encoder_hidden_dim = encoder_hidden_dim
308
+ self.encoder_in_channels = encoder_in_channels
309
+ self.feat_strides = feat_strides
310
+ self.encoder_ffn_dim = encoder_ffn_dim
311
+ self.dropout = dropout
312
+ self.activation_dropout = activation_dropout
313
+ self.encode_proj_layers = encode_proj_layers
314
+ self.encoder_layers = encoder_layers
315
+ self.positional_encoding_temperature = positional_encoding_temperature
316
+ self.eval_size = eval_size
317
+ self.normalize_before = normalize_before
318
+ self.encoder_activation_function = encoder_activation_function
319
+ self.activation_function = activation_function
320
+ self.hidden_expansion = hidden_expansion
321
+ self.num_queries = num_queries
322
+ self.decoder_ffn_dim = decoder_ffn_dim
323
+ self.decoder_in_channels = decoder_in_channels
324
+ self.num_feature_levels = num_feature_levels
325
+ self.decoder_n_points = decoder_n_points
326
+ self.decoder_layers = decoder_layers
327
+ self.decoder_attention_heads = decoder_attention_heads
328
+ self.decoder_activation_function = decoder_activation_function
329
+ self.attention_dropout = attention_dropout
330
+ self.num_denoising = num_denoising
331
+ self.label_noise_ratio = label_noise_ratio
332
+ self.box_noise_scale = box_noise_scale
333
+ self.learn_initial_query = learn_initial_query
334
+ self.anchor_image_size = anchor_image_size
335
+ self.auxiliary_loss = auxiliary_loss
336
+ self.with_box_refine = with_box_refine
337
+ # Loss
338
+ self.matcher_alpha = matcher_alpha
339
+ self.matcher_gamma = matcher_gamma
340
+ self.matcher_class_cost = matcher_class_cost
341
+ self.matcher_bbox_cost = matcher_bbox_cost
342
+ self.matcher_giou_cost = matcher_giou_cost
343
+ self.use_focal_loss = use_focal_loss
344
+ self.focal_loss_alpha = focal_loss_alpha
345
+ self.focal_loss_gamma = focal_loss_gamma
346
+ self.weight_loss_vfl = weight_loss_vfl
347
+ self.weight_loss_bbox = weight_loss_bbox
348
+ self.weight_loss_giou = weight_loss_giou
349
+ self.eos_coefficient = eos_coefficient
350
+
351
+ if not hasattr(self, "d_model"):
352
+ self.d_model = d_model
353
+
354
+ if not hasattr(self, "encoder_attention_heads"):
355
+ self.encoder_attention_heads = encoder_attention_heads
356
+ # add the new attributes with the given values or defaults
357
+ self.decoder_n_levels = decoder_n_levels
358
+ self.decoder_offset_scale = decoder_offset_scale
359
+ self.decoder_method = decoder_method
360
+
361
+ @property
362
+ def sub_configs(self):
363
+ return (
364
+ {"backbone_config": type(self.backbone_config)}
365
+ if getattr(self, "backbone_config", None) is not None
366
+ else {}
367
+ )
368
+
369
+ @classmethod
370
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
371
+ """Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model
372
+ configuration.
373
+
374
+ Args:
375
+ backbone_config ([`PretrainedConfig`]):
376
+ The backbone configuration.
377
+
378
+ Returns:
379
+ [`RTDetrV2Config`]: An instance of a configuration object
380
+ """
381
+ return cls(
382
+ backbone_config=backbone_config,
383
+ **kwargs,
384
+ )
385
+
386
+
387
+ __all__ = ["RTDetrV2Config"]
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py ADDED
@@ -0,0 +1,1998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_rt_detr_v2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ import math
22
+ import warnings
23
+ from dataclasses import dataclass
24
+ from functools import partial
25
+ from typing import Optional, Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from torch import Tensor, nn
30
+
31
+ from ...activations import ACT2CLS, ACT2FN
32
+ from ...image_transforms import center_to_corners_format, corners_to_center_format
33
+ from ...modeling_outputs import BaseModelOutput
34
+ from ...modeling_utils import PreTrainedModel
35
+ from ...pytorch_utils import compile_compatible_method_lru_cache
36
+ from ...utils import ModelOutput, auto_docstring, is_torchdynamo_compiling, torch_int
37
+ from ...utils.backbone_utils import load_backbone
38
+ from .configuration_rt_detr_v2 import RTDetrV2Config
39
+
40
+
41
+ def multi_scale_deformable_attention_v2(
42
+ value: Tensor,
43
+ value_spatial_shapes: Tensor,
44
+ sampling_locations: Tensor,
45
+ attention_weights: Tensor,
46
+ num_points_list: list[int],
47
+ method="default",
48
+ ) -> Tensor:
49
+ batch_size, _, num_heads, hidden_dim = value.shape
50
+ _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape
51
+ value_list = (
52
+ value.permute(0, 2, 3, 1)
53
+ .flatten(0, 1)
54
+ .split([height * width for height, width in value_spatial_shapes], dim=-1)
55
+ )
56
+ # sampling_offsets [8, 480, 8, 12, 2]
57
+ if method == "default":
58
+ sampling_grids = 2 * sampling_locations - 1
59
+ elif method == "discrete":
60
+ sampling_grids = sampling_locations
61
+ sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
62
+ sampling_grids = sampling_grids.split(num_points_list, dim=-2)
63
+ sampling_value_list = []
64
+ for level_id, (height, width) in enumerate(value_spatial_shapes):
65
+ # batch_size, height*width, num_heads, hidden_dim
66
+ # -> batch_size, height*width, num_heads*hidden_dim
67
+ # -> batch_size, num_heads*hidden_dim, height*width
68
+ # -> batch_size*num_heads, hidden_dim, height, width
69
+ value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width)
70
+ # batch_size, num_queries, num_heads, num_points, 2
71
+ # -> batch_size, num_heads, num_queries, num_points, 2
72
+ # -> batch_size*num_heads, num_queries, num_points, 2
73
+ sampling_grid_l_ = sampling_grids[level_id]
74
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
75
+ if method == "default":
76
+ sampling_value_l_ = nn.functional.grid_sample(
77
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
78
+ )
79
+ elif method == "discrete":
80
+ sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to(
81
+ torch.int64
82
+ )
83
+
84
+ # Separate clamping for x and y coordinates
85
+ sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1)
86
+ sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1)
87
+
88
+ # Combine the clamped coordinates
89
+ sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1)
90
+ sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2)
91
+ sampling_idx = (
92
+ torch.arange(sampling_coord.shape[0], device=value.device)
93
+ .unsqueeze(-1)
94
+ .repeat(1, sampling_coord.shape[1])
95
+ )
96
+ sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]]
97
+ sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape(
98
+ batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id]
99
+ )
100
+ sampling_value_list.append(sampling_value_l_)
101
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
102
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
103
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
104
+ attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(
105
+ batch_size * num_heads, 1, num_queries, sum(num_points_list)
106
+ )
107
+ output = (
108
+ (torch.concat(sampling_value_list, dim=-1) * attention_weights)
109
+ .sum(-1)
110
+ .view(batch_size, num_heads * hidden_dim, num_queries)
111
+ )
112
+ return output.transpose(1, 2).contiguous()
113
+
114
+
115
+ # the main change
116
+ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
117
+ """
118
+ RTDetrV2 version of multiscale deformable attention, extending the base implementation
119
+ with improved offset handling and initialization.
120
+ """
121
+
122
+ def __init__(self, config: RTDetrV2Config):
123
+ super().__init__()
124
+ num_heads = config.decoder_attention_heads
125
+ n_points = config.decoder_n_points
126
+
127
+ if config.d_model % num_heads != 0:
128
+ raise ValueError(
129
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
130
+ )
131
+ dim_per_head = config.d_model // num_heads
132
+ # check if dim_per_head is power of 2
133
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
134
+ warnings.warn(
135
+ "You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the"
136
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
137
+ " implementation."
138
+ )
139
+
140
+ self.im2col_step = 64
141
+
142
+ self.d_model = config.d_model
143
+
144
+ # V2-specific attributes
145
+ self.n_levels = config.decoder_n_levels
146
+ self.n_heads = num_heads
147
+ self.n_points = n_points
148
+
149
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
150
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
151
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
152
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
153
+
154
+ self.offset_scale = config.decoder_offset_scale
155
+ self.method = config.decoder_method
156
+
157
+ # Initialize n_points list and scale
158
+ n_points_list = [self.n_points for _ in range(self.n_levels)]
159
+ self.n_points_list = n_points_list
160
+ n_points_scale = [1 / n for n in n_points_list for _ in range(n)]
161
+ self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32))
162
+
163
+ def forward(
164
+ self,
165
+ hidden_states: torch.Tensor,
166
+ attention_mask: Optional[torch.Tensor] = None,
167
+ encoder_hidden_states=None,
168
+ encoder_attention_mask=None,
169
+ position_embeddings: Optional[torch.Tensor] = None,
170
+ reference_points=None,
171
+ spatial_shapes=None,
172
+ spatial_shapes_list=None,
173
+ level_start_index=None,
174
+ output_attentions: bool = False,
175
+ ):
176
+ # Process inputs up to sampling locations calculation using parent class logic
177
+ if position_embeddings is not None:
178
+ hidden_states = hidden_states + position_embeddings
179
+
180
+ batch_size, num_queries, _ = hidden_states.shape
181
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
182
+ if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
183
+ raise ValueError(
184
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
185
+ )
186
+
187
+ value = self.value_proj(encoder_hidden_states)
188
+ if attention_mask is not None:
189
+ value = value.masked_fill(~attention_mask[..., None], float(0))
190
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
191
+
192
+ # V2-specific sampling offsets shape
193
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
194
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2
195
+ )
196
+
197
+ attention_weights = self.attention_weights(hidden_states).view(
198
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
199
+ )
200
+ attention_weights = F.softmax(attention_weights, -1)
201
+
202
+ # V2-specific sampling locations calculation
203
+ if reference_points.shape[-1] == 2:
204
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
205
+ sampling_locations = (
206
+ reference_points[:, :, None, :, None, :]
207
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
208
+ )
209
+ elif reference_points.shape[-1] == 4:
210
+ n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
211
+ offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
212
+ sampling_locations = reference_points[:, :, None, :, :2] + offset
213
+ else:
214
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
215
+
216
+ # V2-specific attention implementation choice
217
+ output = multi_scale_deformable_attention_v2(
218
+ value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
219
+ )
220
+
221
+ output = self.output_proj(output)
222
+ return output, attention_weights
223
+
224
+
225
+ class RTDetrV2MultiheadAttention(nn.Module):
226
+ """
227
+ Multi-headed attention from 'Attention Is All You Need' paper.
228
+
229
+ Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ embed_dim: int,
235
+ num_heads: int,
236
+ dropout: float = 0.0,
237
+ bias: bool = True,
238
+ ):
239
+ super().__init__()
240
+ self.embed_dim = embed_dim
241
+ self.num_heads = num_heads
242
+ self.dropout = dropout
243
+ self.head_dim = embed_dim // num_heads
244
+ if self.head_dim * num_heads != self.embed_dim:
245
+ raise ValueError(
246
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
247
+ f" {num_heads})."
248
+ )
249
+ self.scaling = self.head_dim**-0.5
250
+
251
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
252
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
253
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
254
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
255
+
256
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
257
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
258
+
259
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
260
+ return tensor if position_embeddings is None else tensor + position_embeddings
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.Tensor,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ position_embeddings: Optional[torch.Tensor] = None,
267
+ output_attentions: bool = False,
268
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
269
+ """Input shape: Batch x Time x Channel"""
270
+
271
+ batch_size, target_len, embed_dim = hidden_states.size()
272
+ # add position embeddings to the hidden states before projecting to queries and keys
273
+ if position_embeddings is not None:
274
+ hidden_states_original = hidden_states
275
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
276
+
277
+ # get queries, keys and values
278
+ query_states = self.q_proj(hidden_states) * self.scaling
279
+ key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
280
+ value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
281
+
282
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
283
+ query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
284
+ key_states = key_states.view(*proj_shape)
285
+ value_states = value_states.view(*proj_shape)
286
+
287
+ source_len = key_states.size(1)
288
+
289
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
290
+
291
+ if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
292
+ raise ValueError(
293
+ f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
294
+ f" {attn_weights.size()}"
295
+ )
296
+
297
+ # expand attention_mask
298
+ if attention_mask is not None:
299
+ # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
300
+ attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
301
+
302
+ if attention_mask is not None:
303
+ if attention_mask.size() != (batch_size, 1, target_len, source_len):
304
+ raise ValueError(
305
+ f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
306
+ f" {attention_mask.size()}"
307
+ )
308
+ if attention_mask.dtype == torch.bool:
309
+ attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
310
+ attention_mask, -torch.inf
311
+ )
312
+ attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
313
+ attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
314
+
315
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
316
+
317
+ if output_attentions:
318
+ # this operation is a bit awkward, but it's required to
319
+ # make sure that attn_weights keeps its gradient.
320
+ # In order to do so, attn_weights have to reshaped
321
+ # twice and have to be reused in the following
322
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
323
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
324
+ else:
325
+ attn_weights_reshaped = None
326
+
327
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
328
+
329
+ attn_output = torch.bmm(attn_probs, value_states)
330
+
331
+ if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
332
+ raise ValueError(
333
+ f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
334
+ f" {attn_output.size()}"
335
+ )
336
+
337
+ attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
338
+ attn_output = attn_output.transpose(1, 2)
339
+ attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
340
+
341
+ attn_output = self.out_proj(attn_output)
342
+
343
+ return attn_output, attn_weights_reshaped
344
+
345
+
346
+ class RTDetrV2DecoderLayer(nn.Module):
347
+ def __init__(self, config: RTDetrV2Config):
348
+ super().__init__()
349
+ # self-attention
350
+ self.self_attn = RTDetrV2MultiheadAttention(
351
+ embed_dim=config.d_model,
352
+ num_heads=config.decoder_attention_heads,
353
+ dropout=config.attention_dropout,
354
+ )
355
+ self.dropout = config.dropout
356
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
357
+ self.activation_dropout = config.activation_dropout
358
+
359
+ self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
360
+ # override only the encoder attention module with v2 version
361
+ self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config)
362
+ self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
363
+ # feedforward neural networks
364
+ self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
365
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
366
+ self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
367
+
368
+ def forward(
369
+ self,
370
+ hidden_states: torch.Tensor,
371
+ position_embeddings: Optional[torch.Tensor] = None,
372
+ reference_points=None,
373
+ spatial_shapes=None,
374
+ spatial_shapes_list=None,
375
+ level_start_index=None,
376
+ encoder_hidden_states: Optional[torch.Tensor] = None,
377
+ encoder_attention_mask: Optional[torch.Tensor] = None,
378
+ output_attentions: Optional[bool] = False,
379
+ ):
380
+ """
381
+ Args:
382
+ hidden_states (`torch.FloatTensor`):
383
+ Input to the layer of shape `(seq_len, batch, embed_dim)`.
384
+ position_embeddings (`torch.FloatTensor`, *optional*):
385
+ Position embeddings that are added to the queries and keys in the self-attention layer.
386
+ reference_points (`torch.FloatTensor`, *optional*):
387
+ Reference points.
388
+ spatial_shapes (`torch.LongTensor`, *optional*):
389
+ Spatial shapes.
390
+ level_start_index (`torch.LongTensor`, *optional*):
391
+ Level start index.
392
+ encoder_hidden_states (`torch.FloatTensor`):
393
+ cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
394
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
395
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
396
+ values.
397
+ output_attentions (`bool`, *optional*):
398
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
399
+ returned tensors for more detail.
400
+ """
401
+ residual = hidden_states
402
+
403
+ # Self Attention
404
+ hidden_states, self_attn_weights = self.self_attn(
405
+ hidden_states=hidden_states,
406
+ attention_mask=encoder_attention_mask,
407
+ position_embeddings=position_embeddings,
408
+ output_attentions=output_attentions,
409
+ )
410
+
411
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
412
+ hidden_states = residual + hidden_states
413
+ hidden_states = self.self_attn_layer_norm(hidden_states)
414
+
415
+ second_residual = hidden_states
416
+
417
+ # Cross-Attention
418
+ cross_attn_weights = None
419
+ hidden_states, cross_attn_weights = self.encoder_attn(
420
+ hidden_states=hidden_states,
421
+ encoder_hidden_states=encoder_hidden_states,
422
+ position_embeddings=position_embeddings,
423
+ reference_points=reference_points,
424
+ spatial_shapes=spatial_shapes,
425
+ spatial_shapes_list=spatial_shapes_list,
426
+ level_start_index=level_start_index,
427
+ output_attentions=output_attentions,
428
+ )
429
+
430
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
431
+ hidden_states = second_residual + hidden_states
432
+
433
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
434
+
435
+ # Fully Connected
436
+ residual = hidden_states
437
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
438
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
439
+ hidden_states = self.fc2(hidden_states)
440
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
441
+ hidden_states = residual + hidden_states
442
+ hidden_states = self.final_layer_norm(hidden_states)
443
+
444
+ outputs = (hidden_states,)
445
+
446
+ if output_attentions:
447
+ outputs += (self_attn_weights, cross_attn_weights)
448
+
449
+ return outputs
450
+
451
+
452
+ @auto_docstring
453
+ class RTDetrV2PreTrainedModel(PreTrainedModel):
454
+ config: RTDetrV2Config
455
+ base_model_prefix = "rt_detr_v2"
456
+ main_input_name = "pixel_values"
457
+ _no_split_modules = [r"RTDetrV2HybridEncoder", r"RTDetrV2DecoderLayer"]
458
+
459
+ def _init_weights(self, module):
460
+ """Initialize the weights"""
461
+ if isinstance(module, (RTDetrV2ForObjectDetection, RTDetrV2Decoder)):
462
+ if module.class_embed is not None:
463
+ for layer in module.class_embed:
464
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
465
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
466
+ nn.init.xavier_uniform_(layer.weight)
467
+ nn.init.constant_(layer.bias, bias)
468
+
469
+ if module.bbox_embed is not None:
470
+ for layer in module.bbox_embed:
471
+ nn.init.constant_(layer.layers[-1].weight, 0)
472
+ nn.init.constant_(layer.layers[-1].bias, 0)
473
+
474
+ elif isinstance(module, RTDetrV2MultiscaleDeformableAttention):
475
+ nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
476
+ default_dtype = torch.get_default_dtype()
477
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
478
+ 2.0 * math.pi / module.n_heads
479
+ )
480
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
481
+ grid_init = (
482
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
483
+ .view(module.n_heads, 1, 1, 2)
484
+ .repeat(1, module.n_levels, module.n_points, 1)
485
+ )
486
+ for i in range(module.n_points):
487
+ grid_init[:, :, i, :] *= i + 1
488
+ with torch.no_grad():
489
+ module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
490
+ nn.init.constant_(module.attention_weights.weight.data, 0.0)
491
+ nn.init.constant_(module.attention_weights.bias.data, 0.0)
492
+ nn.init.xavier_uniform_(module.value_proj.weight.data)
493
+ nn.init.constant_(module.value_proj.bias.data, 0.0)
494
+ nn.init.xavier_uniform_(module.output_proj.weight.data)
495
+ nn.init.constant_(module.output_proj.bias.data, 0.0)
496
+
497
+ elif isinstance(module, RTDetrV2Model):
498
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
499
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
500
+ nn.init.xavier_uniform_(module.enc_score_head.weight)
501
+ nn.init.constant_(module.enc_score_head.bias, bias)
502
+
503
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
504
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
505
+ if module.bias is not None:
506
+ module.bias.data.zero_()
507
+
508
+ elif isinstance(module, nn.LayerNorm):
509
+ module.weight.data.fill_(1.0)
510
+ module.bias.data.zero_()
511
+
512
+ if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
513
+ nn.init.xavier_uniform_(module.weight_embedding.weight)
514
+ if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
515
+ nn.init.xavier_uniform_(module.denoising_class_embed.weight)
516
+
517
+
518
+ @dataclass
519
+ @auto_docstring(
520
+ custom_intro="""
521
+ Base class for outputs of the RTDetrV2Decoder. This class adds two attributes to
522
+ BaseModelOutputWithCrossAttentions, namely:
523
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
524
+ - a stacked tensor of intermediate reference points.
525
+ """
526
+ )
527
+ class RTDetrV2DecoderOutput(ModelOutput):
528
+ r"""
529
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
530
+ Stacked intermediate hidden states (output of each layer of the decoder).
531
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
532
+ Stacked intermediate logits (logits of each layer of the decoder).
533
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
534
+ Stacked intermediate reference points (reference points of each layer of the decoder).
535
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
536
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
537
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
538
+ Stacked initial reference points (initial reference points of each layer of the decoder).
539
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
540
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
541
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
542
+ used to compute the weighted average in the cross-attention heads.
543
+ """
544
+
545
+ last_hidden_state: Optional[torch.FloatTensor] = None
546
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
547
+ intermediate_logits: Optional[torch.FloatTensor] = None
548
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
549
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
550
+ initial_reference_points: Optional[torch.FloatTensor] = None
551
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
552
+ attentions: Optional[tuple[torch.FloatTensor]] = None
553
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
554
+
555
+
556
+ def inverse_sigmoid(x, eps=1e-5):
557
+ x = x.clamp(min=0, max=1)
558
+ x1 = x.clamp(min=eps)
559
+ x2 = (1 - x).clamp(min=eps)
560
+ return torch.log(x1 / x2)
561
+
562
+
563
+ class RTDetrV2Decoder(RTDetrV2PreTrainedModel):
564
+ def __init__(self, config: RTDetrV2Config):
565
+ super().__init__(config)
566
+
567
+ self.dropout = config.dropout
568
+ self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)])
569
+ self.query_pos_head = RTDetrV2MLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
570
+
571
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
572
+ self.bbox_embed = None
573
+ self.class_embed = None
574
+
575
+ # Initialize weights and apply final processing
576
+ self.post_init()
577
+
578
+ def forward(
579
+ self,
580
+ inputs_embeds=None,
581
+ encoder_hidden_states=None,
582
+ encoder_attention_mask=None,
583
+ position_embeddings=None,
584
+ reference_points=None,
585
+ spatial_shapes=None,
586
+ spatial_shapes_list=None,
587
+ level_start_index=None,
588
+ valid_ratios=None,
589
+ output_attentions=None,
590
+ output_hidden_states=None,
591
+ return_dict=None,
592
+ ):
593
+ r"""
594
+ Args:
595
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
596
+ The query embeddings that are passed into the decoder.
597
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
598
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
599
+ of the decoder.
600
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
601
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
602
+ in `[0, 1]`:
603
+ - 1 for pixels that are real (i.e. **not masked**),
604
+ - 0 for pixels that are padding (i.e. **masked**).
605
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
606
+ Position embeddings that are added to the queries and keys in each self-attention layer.
607
+ reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
608
+ Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
609
+ spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
610
+ Spatial shapes of the feature maps.
611
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
612
+ Indexes for the start of each feature level. In range `[0, sequence_length]`.
613
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
614
+ Ratio of valid area in each feature level.
615
+
616
+ output_attentions (`bool`, *optional*):
617
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
618
+ returned tensors for more detail.
619
+ output_hidden_states (`bool`, *optional*):
620
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
621
+ for more detail.
622
+ return_dict (`bool`, *optional*):
623
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
624
+ """
625
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
626
+ output_hidden_states = (
627
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
628
+ )
629
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
630
+
631
+ if inputs_embeds is not None:
632
+ hidden_states = inputs_embeds
633
+
634
+ # decoder layers
635
+ all_hidden_states = () if output_hidden_states else None
636
+ all_self_attns = () if output_attentions else None
637
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
638
+ intermediate = ()
639
+ intermediate_reference_points = ()
640
+ intermediate_logits = ()
641
+
642
+ reference_points = F.sigmoid(reference_points)
643
+
644
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L252
645
+ for idx, decoder_layer in enumerate(self.layers):
646
+ reference_points_input = reference_points.unsqueeze(2)
647
+ position_embeddings = self.query_pos_head(reference_points)
648
+
649
+ if output_hidden_states:
650
+ all_hidden_states += (hidden_states,)
651
+
652
+ layer_outputs = decoder_layer(
653
+ hidden_states,
654
+ position_embeddings=position_embeddings,
655
+ encoder_hidden_states=encoder_hidden_states,
656
+ reference_points=reference_points_input,
657
+ spatial_shapes=spatial_shapes,
658
+ spatial_shapes_list=spatial_shapes_list,
659
+ level_start_index=level_start_index,
660
+ encoder_attention_mask=encoder_attention_mask,
661
+ output_attentions=output_attentions,
662
+ )
663
+
664
+ hidden_states = layer_outputs[0]
665
+
666
+ # hack implementation for iterative bounding box refinement
667
+ if self.bbox_embed is not None:
668
+ predicted_corners = self.bbox_embed[idx](hidden_states)
669
+ new_reference_points = F.sigmoid(predicted_corners + inverse_sigmoid(reference_points))
670
+ reference_points = new_reference_points.detach()
671
+
672
+ intermediate += (hidden_states,)
673
+ intermediate_reference_points += (
674
+ (new_reference_points,) if self.bbox_embed is not None else (reference_points,)
675
+ )
676
+
677
+ if self.class_embed is not None:
678
+ logits = self.class_embed[idx](hidden_states)
679
+ intermediate_logits += (logits,)
680
+
681
+ if output_attentions:
682
+ all_self_attns += (layer_outputs[1],)
683
+
684
+ if encoder_hidden_states is not None:
685
+ all_cross_attentions += (layer_outputs[2],)
686
+
687
+ # Keep batch_size as first dimension
688
+ intermediate = torch.stack(intermediate, dim=1)
689
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
690
+ if self.class_embed is not None:
691
+ intermediate_logits = torch.stack(intermediate_logits, dim=1)
692
+
693
+ # add hidden states from the last decoder layer
694
+ if output_hidden_states:
695
+ all_hidden_states += (hidden_states,)
696
+
697
+ if not return_dict:
698
+ return tuple(
699
+ v
700
+ for v in [
701
+ hidden_states,
702
+ intermediate,
703
+ intermediate_logits,
704
+ intermediate_reference_points,
705
+ all_hidden_states,
706
+ all_self_attns,
707
+ all_cross_attentions,
708
+ ]
709
+ if v is not None
710
+ )
711
+ return RTDetrV2DecoderOutput(
712
+ last_hidden_state=hidden_states,
713
+ intermediate_hidden_states=intermediate,
714
+ intermediate_logits=intermediate_logits,
715
+ intermediate_reference_points=intermediate_reference_points,
716
+ hidden_states=all_hidden_states,
717
+ attentions=all_self_attns,
718
+ cross_attentions=all_cross_attentions,
719
+ )
720
+
721
+
722
+ @dataclass
723
+ @auto_docstring(
724
+ custom_intro="""
725
+ Base class for outputs of the RT-DETR encoder-decoder model.
726
+ """
727
+ )
728
+ class RTDetrV2ModelOutput(ModelOutput):
729
+ r"""
730
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
731
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
732
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
733
+ Stacked intermediate hidden states (output of each layer of the decoder).
734
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
735
+ Stacked intermediate logits (logits of each layer of the decoder).
736
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
737
+ Stacked intermediate reference points (reference points of each layer of the decoder).
738
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
739
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
740
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
741
+ Initial reference points used for the first decoder layer.
742
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
743
+ Initial reference points sent through the Transformer decoder.
744
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
745
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
746
+ picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e.
747
+ foreground and background).
748
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`):
749
+ Logits of predicted bounding boxes coordinates in the encoder stage.
750
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
751
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
752
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
753
+ foreground and background).
754
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
755
+ Logits of predicted bounding boxes coordinates in the first stage.
756
+ denoising_meta_values (`dict`):
757
+ Extra dictionary for the denoising related values.
758
+ """
759
+
760
+ last_hidden_state: Optional[torch.FloatTensor] = None
761
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
762
+ intermediate_logits: Optional[torch.FloatTensor] = None
763
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
764
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
765
+ initial_reference_points: Optional[torch.FloatTensor] = None
766
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
767
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
768
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
769
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
770
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
771
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
772
+ init_reference_points: Optional[torch.FloatTensor] = None
773
+ enc_topk_logits: Optional[torch.FloatTensor] = None
774
+ enc_topk_bboxes: Optional[torch.FloatTensor] = None
775
+ enc_outputs_class: Optional[torch.FloatTensor] = None
776
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
777
+ denoising_meta_values: Optional[dict] = None
778
+
779
+
780
+ class RTDetrV2FrozenBatchNorm2d(nn.Module):
781
+ """
782
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
783
+
784
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
785
+ torchvision.models.resnet[18,34,50,101] produce nans.
786
+ """
787
+
788
+ def __init__(self, n):
789
+ super().__init__()
790
+ self.register_buffer("weight", torch.ones(n))
791
+ self.register_buffer("bias", torch.zeros(n))
792
+ self.register_buffer("running_mean", torch.zeros(n))
793
+ self.register_buffer("running_var", torch.ones(n))
794
+
795
+ def _load_from_state_dict(
796
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
797
+ ):
798
+ num_batches_tracked_key = prefix + "num_batches_tracked"
799
+ if num_batches_tracked_key in state_dict:
800
+ del state_dict[num_batches_tracked_key]
801
+
802
+ super()._load_from_state_dict(
803
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
804
+ )
805
+
806
+ def forward(self, x):
807
+ # move reshapes to the beginning
808
+ # to make it user-friendly
809
+ weight = self.weight.reshape(1, -1, 1, 1)
810
+ bias = self.bias.reshape(1, -1, 1, 1)
811
+ running_var = self.running_var.reshape(1, -1, 1, 1)
812
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
813
+ epsilon = 1e-5
814
+ scale = weight * (running_var + epsilon).rsqrt()
815
+ bias = bias - running_mean * scale
816
+ return x * scale + bias
817
+
818
+
819
+ def replace_batch_norm(model):
820
+ r"""
821
+ Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrV2FrozenBatchNorm2d`.
822
+
823
+ Args:
824
+ model (torch.nn.Module):
825
+ input model
826
+ """
827
+ for name, module in model.named_children():
828
+ if isinstance(module, nn.BatchNorm2d):
829
+ new_module = RTDetrV2FrozenBatchNorm2d(module.num_features)
830
+
831
+ if module.weight.device != torch.device("meta"):
832
+ new_module.weight.data.copy_(module.weight)
833
+ new_module.bias.data.copy_(module.bias)
834
+ new_module.running_mean.data.copy_(module.running_mean)
835
+ new_module.running_var.data.copy_(module.running_var)
836
+
837
+ model._modules[name] = new_module
838
+
839
+ if len(list(module.children())) > 0:
840
+ replace_batch_norm(module)
841
+
842
+
843
+ class RTDetrV2ConvEncoder(nn.Module):
844
+ """
845
+ Convolutional backbone using the modeling_rt_detr_v2_resnet.py.
846
+
847
+ nn.BatchNorm2d layers are replaced by RTDetrV2FrozenBatchNorm2d as defined above.
848
+ https://github.com/lyuwenyu/RT-DETR/blob/main/RTDetrV2_pytorch/src/nn/backbone/presnet.py#L142
849
+ """
850
+
851
+ def __init__(self, config):
852
+ super().__init__()
853
+
854
+ backbone = load_backbone(config)
855
+
856
+ if config.freeze_backbone_batch_norms:
857
+ # replace batch norm by frozen batch norm
858
+ with torch.no_grad():
859
+ replace_batch_norm(backbone)
860
+ self.model = backbone
861
+ self.intermediate_channel_sizes = self.model.channels
862
+
863
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
864
+ # send pixel_values through the model to get list of feature maps
865
+ features = self.model(pixel_values).feature_maps
866
+
867
+ out = []
868
+ for feature_map in features:
869
+ # downsample pixel_mask to match shape of corresponding feature_map
870
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
871
+ out.append((feature_map, mask))
872
+ return out
873
+
874
+
875
+ class RTDetrV2ConvNormLayer(nn.Module):
876
+ def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
877
+ super().__init__()
878
+ self.conv = nn.Conv2d(
879
+ in_channels,
880
+ out_channels,
881
+ kernel_size,
882
+ stride,
883
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
884
+ bias=False,
885
+ )
886
+ self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
887
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
888
+
889
+ def forward(self, hidden_state):
890
+ hidden_state = self.conv(hidden_state)
891
+ hidden_state = self.norm(hidden_state)
892
+ hidden_state = self.activation(hidden_state)
893
+ return hidden_state
894
+
895
+
896
+ class RTDetrV2EncoderLayer(nn.Module):
897
+ def __init__(self, config: RTDetrV2Config):
898
+ super().__init__()
899
+ self.normalize_before = config.normalize_before
900
+
901
+ # self-attention
902
+ self.self_attn = RTDetrV2MultiheadAttention(
903
+ embed_dim=config.encoder_hidden_dim,
904
+ num_heads=config.num_attention_heads,
905
+ dropout=config.dropout,
906
+ )
907
+ self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
908
+ self.dropout = config.dropout
909
+ self.activation_fn = ACT2FN[config.encoder_activation_function]
910
+ self.activation_dropout = config.activation_dropout
911
+ self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
912
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
913
+ self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
914
+
915
+ def forward(
916
+ self,
917
+ hidden_states: torch.Tensor,
918
+ attention_mask: torch.Tensor,
919
+ position_embeddings: Optional[torch.Tensor] = None,
920
+ output_attentions: bool = False,
921
+ **kwargs,
922
+ ):
923
+ """
924
+ Args:
925
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
926
+ attention_mask (`torch.FloatTensor`): attention mask of size
927
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
928
+ values.
929
+ position_embeddings (`torch.FloatTensor`, *optional*):
930
+ Object queries (also called content embeddings), to be added to the hidden states.
931
+ output_attentions (`bool`, *optional*):
932
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
933
+ returned tensors for more detail.
934
+ """
935
+ residual = hidden_states
936
+ if self.normalize_before:
937
+ hidden_states = self.self_attn_layer_norm(hidden_states)
938
+
939
+ hidden_states, attn_weights = self.self_attn(
940
+ hidden_states=hidden_states,
941
+ attention_mask=attention_mask,
942
+ position_embeddings=position_embeddings,
943
+ output_attentions=output_attentions,
944
+ )
945
+
946
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
947
+ hidden_states = residual + hidden_states
948
+ if not self.normalize_before:
949
+ hidden_states = self.self_attn_layer_norm(hidden_states)
950
+
951
+ if self.normalize_before:
952
+ hidden_states = self.final_layer_norm(hidden_states)
953
+ residual = hidden_states
954
+
955
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
956
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
957
+
958
+ hidden_states = self.fc2(hidden_states)
959
+
960
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
961
+
962
+ hidden_states = residual + hidden_states
963
+ if not self.normalize_before:
964
+ hidden_states = self.final_layer_norm(hidden_states)
965
+
966
+ if self.training:
967
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
968
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
969
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
970
+
971
+ outputs = (hidden_states,)
972
+
973
+ if output_attentions:
974
+ outputs += (attn_weights,)
975
+
976
+ return outputs
977
+
978
+
979
+ class RTDetrV2RepVggBlock(nn.Module):
980
+ """
981
+ RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
982
+ """
983
+
984
+ def __init__(self, config: RTDetrV2Config):
985
+ super().__init__()
986
+
987
+ activation = config.activation_function
988
+ hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
989
+ self.conv1 = RTDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
990
+ self.conv2 = RTDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
991
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
992
+
993
+ def forward(self, x):
994
+ y = self.conv1(x) + self.conv2(x)
995
+ return self.activation(y)
996
+
997
+
998
+ class RTDetrV2CSPRepLayer(nn.Module):
999
+ """
1000
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
1001
+ """
1002
+
1003
+ def __init__(self, config: RTDetrV2Config):
1004
+ super().__init__()
1005
+
1006
+ in_channels = config.encoder_hidden_dim * 2
1007
+ out_channels = config.encoder_hidden_dim
1008
+ num_blocks = 3
1009
+ activation = config.activation_function
1010
+
1011
+ hidden_channels = int(out_channels * config.hidden_expansion)
1012
+ self.conv1 = RTDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
1013
+ self.conv2 = RTDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
1014
+ self.bottlenecks = nn.Sequential(*[RTDetrV2RepVggBlock(config) for _ in range(num_blocks)])
1015
+ if hidden_channels != out_channels:
1016
+ self.conv3 = RTDetrV2ConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
1017
+ else:
1018
+ self.conv3 = nn.Identity()
1019
+
1020
+ def forward(self, hidden_state):
1021
+ hidden_state_1 = self.conv1(hidden_state)
1022
+ hidden_state_1 = self.bottlenecks(hidden_state_1)
1023
+ hidden_state_2 = self.conv2(hidden_state)
1024
+ return self.conv3(hidden_state_1 + hidden_state_2)
1025
+
1026
+
1027
+ class RTDetrV2Encoder(nn.Module):
1028
+ def __init__(self, config: RTDetrV2Config):
1029
+ super().__init__()
1030
+
1031
+ self.layers = nn.ModuleList([RTDetrV2EncoderLayer(config) for _ in range(config.encoder_layers)])
1032
+
1033
+ def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
1034
+ hidden_states = src
1035
+ for layer in self.layers:
1036
+ hidden_states = layer(
1037
+ hidden_states,
1038
+ attention_mask=src_mask,
1039
+ position_embeddings=pos_embed,
1040
+ output_attentions=output_attentions,
1041
+ )
1042
+ return hidden_states
1043
+
1044
+
1045
+ class RTDetrV2HybridEncoder(nn.Module):
1046
+ """
1047
+ Decoder consisting of a projection layer, a set of `RTDetrV2Encoder`, a top-down Feature Pyramid Network
1048
+ (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
1049
+
1050
+ Args:
1051
+ config: RTDetrV2Config
1052
+ """
1053
+
1054
+ def __init__(self, config: RTDetrV2Config):
1055
+ super().__init__()
1056
+ self.config = config
1057
+ self.in_channels = config.encoder_in_channels
1058
+ self.feat_strides = config.feat_strides
1059
+ self.encoder_hidden_dim = config.encoder_hidden_dim
1060
+ self.encode_proj_layers = config.encode_proj_layers
1061
+ self.positional_encoding_temperature = config.positional_encoding_temperature
1062
+ self.eval_size = config.eval_size
1063
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
1064
+ self.out_strides = self.feat_strides
1065
+ self.num_fpn_stages = len(self.in_channels) - 1
1066
+ self.num_pan_stages = len(self.in_channels) - 1
1067
+ activation = config.activation_function
1068
+
1069
+ # encoder transformer
1070
+ self.encoder = nn.ModuleList([RTDetrV2Encoder(config) for _ in range(len(self.encode_proj_layers))])
1071
+
1072
+ # top-down FPN
1073
+ self.lateral_convs = nn.ModuleList()
1074
+ self.fpn_blocks = nn.ModuleList()
1075
+ for _ in range(self.num_fpn_stages):
1076
+ lateral_conv = RTDetrV2ConvNormLayer(
1077
+ config,
1078
+ in_channels=self.encoder_hidden_dim,
1079
+ out_channels=self.encoder_hidden_dim,
1080
+ kernel_size=1,
1081
+ stride=1,
1082
+ activation=activation,
1083
+ )
1084
+ fpn_block = RTDetrV2CSPRepLayer(config)
1085
+ self.lateral_convs.append(lateral_conv)
1086
+ self.fpn_blocks.append(fpn_block)
1087
+
1088
+ # bottom-up PAN
1089
+ self.downsample_convs = nn.ModuleList()
1090
+ self.pan_blocks = nn.ModuleList()
1091
+ for _ in range(self.num_pan_stages):
1092
+ downsample_conv = RTDetrV2ConvNormLayer(
1093
+ config,
1094
+ in_channels=self.encoder_hidden_dim,
1095
+ out_channels=self.encoder_hidden_dim,
1096
+ kernel_size=3,
1097
+ stride=2,
1098
+ activation=activation,
1099
+ )
1100
+ pan_block = RTDetrV2CSPRepLayer(config)
1101
+ self.downsample_convs.append(downsample_conv)
1102
+ self.pan_blocks.append(pan_block)
1103
+
1104
+ @staticmethod
1105
+ def build_2d_sincos_position_embedding(
1106
+ width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
1107
+ ):
1108
+ grid_w = torch.arange(torch_int(width), device=device).to(dtype)
1109
+ grid_h = torch.arange(torch_int(height), device=device).to(dtype)
1110
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
1111
+ if embed_dim % 4 != 0:
1112
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
1113
+ pos_dim = embed_dim // 4
1114
+ omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
1115
+ omega = 1.0 / (temperature**omega)
1116
+
1117
+ out_w = grid_w.flatten()[..., None] @ omega[None]
1118
+ out_h = grid_h.flatten()[..., None] @ omega[None]
1119
+
1120
+ return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
1121
+
1122
+ def forward(
1123
+ self,
1124
+ inputs_embeds=None,
1125
+ attention_mask=None,
1126
+ position_embeddings=None,
1127
+ spatial_shapes=None,
1128
+ level_start_index=None,
1129
+ valid_ratios=None,
1130
+ output_attentions=None,
1131
+ output_hidden_states=None,
1132
+ return_dict=None,
1133
+ ):
1134
+ r"""
1135
+ Args:
1136
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1137
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
1138
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1139
+ Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
1140
+ - 1 for pixel features that are real (i.e. **not masked**),
1141
+ - 0 for pixel features that are padding (i.e. **masked**).
1142
+ [What are attention masks?](../glossary#attention-mask)
1143
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1144
+ Position embeddings that are added to the queries and keys in each self-attention layer.
1145
+ spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
1146
+ Spatial shapes of each feature map.
1147
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
1148
+ Starting index of each feature map.
1149
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
1150
+ Ratio of valid area in each feature level.
1151
+ output_attentions (`bool`, *optional*):
1152
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1153
+ returned tensors for more detail.
1154
+ output_hidden_states (`bool`, *optional*):
1155
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1156
+ for more detail.
1157
+ return_dict (`bool`, *optional*):
1158
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1159
+ """
1160
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1161
+ output_hidden_states = (
1162
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1163
+ )
1164
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1165
+
1166
+ hidden_states = inputs_embeds
1167
+
1168
+ encoder_states = () if output_hidden_states else None
1169
+ all_attentions = () if output_attentions else None
1170
+
1171
+ # encoder
1172
+ if self.config.encoder_layers > 0:
1173
+ for i, enc_ind in enumerate(self.encode_proj_layers):
1174
+ if output_hidden_states:
1175
+ encoder_states = encoder_states + (hidden_states[enc_ind],)
1176
+ height, width = hidden_states[enc_ind].shape[2:]
1177
+ # flatten [batch, channel, height, width] to [batch, height*width, channel]
1178
+ src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
1179
+ if self.training or self.eval_size is None:
1180
+ pos_embed = self.build_2d_sincos_position_embedding(
1181
+ width,
1182
+ height,
1183
+ self.encoder_hidden_dim,
1184
+ self.positional_encoding_temperature,
1185
+ device=src_flatten.device,
1186
+ dtype=src_flatten.dtype,
1187
+ )
1188
+ else:
1189
+ pos_embed = None
1190
+
1191
+ layer_outputs = self.encoder[i](
1192
+ src_flatten,
1193
+ pos_embed=pos_embed,
1194
+ output_attentions=output_attentions,
1195
+ )
1196
+ hidden_states[enc_ind] = (
1197
+ layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
1198
+ )
1199
+
1200
+ if output_attentions:
1201
+ all_attentions = all_attentions + (layer_outputs[1],)
1202
+
1203
+ if output_hidden_states:
1204
+ encoder_states = encoder_states + (hidden_states[enc_ind],)
1205
+
1206
+ # top-down FPN
1207
+ fpn_feature_maps = [hidden_states[-1]]
1208
+ for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
1209
+ backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
1210
+ top_fpn_feature_map = fpn_feature_maps[-1]
1211
+ # apply lateral block
1212
+ top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
1213
+ fpn_feature_maps[-1] = top_fpn_feature_map
1214
+ # apply fpn block
1215
+ top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
1216
+ fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
1217
+ new_fpn_feature_map = fpn_block(fused_feature_map)
1218
+ fpn_feature_maps.append(new_fpn_feature_map)
1219
+
1220
+ fpn_feature_maps = fpn_feature_maps[::-1]
1221
+
1222
+ # bottom-up PAN
1223
+ pan_feature_maps = [fpn_feature_maps[0]]
1224
+ for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
1225
+ top_pan_feature_map = pan_feature_maps[-1]
1226
+ fpn_feature_map = fpn_feature_maps[idx + 1]
1227
+ downsampled_feature_map = downsample_conv(top_pan_feature_map)
1228
+ fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
1229
+ new_pan_feature_map = pan_block(fused_feature_map)
1230
+ pan_feature_maps.append(new_pan_feature_map)
1231
+
1232
+ if not return_dict:
1233
+ return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
1234
+ return BaseModelOutput(
1235
+ last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
1236
+ )
1237
+
1238
+
1239
+ def get_contrastive_denoising_training_group(
1240
+ targets,
1241
+ num_classes,
1242
+ num_queries,
1243
+ class_embed,
1244
+ num_denoising_queries=100,
1245
+ label_noise_ratio=0.5,
1246
+ box_noise_scale=1.0,
1247
+ ):
1248
+ """
1249
+ Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
1250
+
1251
+ Args:
1252
+ targets (`list[dict]`):
1253
+ The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
1254
+ num_classes (`int`):
1255
+ Total number of classes in the dataset.
1256
+ num_queries (`int`):
1257
+ Number of query slots in the transformer.
1258
+ class_embed (`callable`):
1259
+ A function or a model layer to embed class labels.
1260
+ num_denoising_queries (`int`, *optional*, defaults to 100):
1261
+ Number of denoising queries.
1262
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
1263
+ Ratio of noise applied to labels.
1264
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
1265
+ Scale of noise applied to bounding boxes.
1266
+ Returns:
1267
+ `tuple` comprising various elements:
1268
+ - **input_query_class** (`torch.FloatTensor`) --
1269
+ Class queries with applied label noise.
1270
+ - **input_query_bbox** (`torch.FloatTensor`) --
1271
+ Bounding box queries with applied box noise.
1272
+ - **attn_mask** (`torch.FloatTensor`) --
1273
+ Attention mask for separating denoising and reconstruction queries.
1274
+ - **denoising_meta_values** (`dict`) --
1275
+ Metadata including denoising positive indices, number of groups, and split sizes.
1276
+ """
1277
+
1278
+ if num_denoising_queries <= 0:
1279
+ return None, None, None, None
1280
+
1281
+ num_ground_truths = [len(t["class_labels"]) for t in targets]
1282
+ device = targets[0]["class_labels"].device
1283
+
1284
+ max_gt_num = max(num_ground_truths)
1285
+ if max_gt_num == 0:
1286
+ return None, None, None, None
1287
+
1288
+ num_groups_denoising_queries = num_denoising_queries // max_gt_num
1289
+ num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
1290
+ # pad gt to max_num of a batch
1291
+ batch_size = len(num_ground_truths)
1292
+
1293
+ input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
1294
+ input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
1295
+ pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
1296
+
1297
+ for i in range(batch_size):
1298
+ num_gt = num_ground_truths[i]
1299
+ if num_gt > 0:
1300
+ input_query_class[i, :num_gt] = targets[i]["class_labels"]
1301
+ input_query_bbox[i, :num_gt] = targets[i]["boxes"]
1302
+ pad_gt_mask[i, :num_gt] = 1
1303
+ # each group has positive and negative queries.
1304
+ input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
1305
+ input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
1306
+ pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
1307
+ # positive and negative mask
1308
+ negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
1309
+ negative_gt_mask[:, max_gt_num:] = 1
1310
+ negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
1311
+ positive_gt_mask = 1 - negative_gt_mask
1312
+ # contrastive denoising training positive index
1313
+ positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
1314
+ denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
1315
+ denoise_positive_idx = torch.split(
1316
+ denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
1317
+ )
1318
+ # total denoising queries
1319
+ num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
1320
+
1321
+ if label_noise_ratio > 0:
1322
+ mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
1323
+ # randomly put a new one here
1324
+ new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
1325
+ input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
1326
+
1327
+ if box_noise_scale > 0:
1328
+ known_bbox = center_to_corners_format(input_query_bbox)
1329
+ diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
1330
+ rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
1331
+ rand_part = torch.rand_like(input_query_bbox)
1332
+ rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
1333
+ rand_part *= rand_sign
1334
+ known_bbox += rand_part * diff
1335
+ known_bbox.clip_(min=0.0, max=1.0)
1336
+ input_query_bbox = corners_to_center_format(known_bbox)
1337
+ input_query_bbox = inverse_sigmoid(input_query_bbox)
1338
+
1339
+ input_query_class = class_embed(input_query_class)
1340
+
1341
+ target_size = num_denoising_queries + num_queries
1342
+ attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
1343
+ # match query cannot see the reconstruction
1344
+ attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
1345
+
1346
+ # reconstructions cannot see each other
1347
+ for i in range(num_groups_denoising_queries):
1348
+ idx_block_start = max_gt_num * 2 * i
1349
+ idx_block_end = max_gt_num * 2 * (i + 1)
1350
+ attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
1351
+ attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
1352
+
1353
+ denoising_meta_values = {
1354
+ "dn_positive_idx": denoise_positive_idx,
1355
+ "dn_num_group": num_groups_denoising_queries,
1356
+ "dn_num_split": [num_denoising_queries, num_queries],
1357
+ }
1358
+
1359
+ return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
1360
+
1361
+
1362
+ @auto_docstring(
1363
+ custom_intro="""
1364
+ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
1365
+ """
1366
+ )
1367
+ class RTDetrV2Model(RTDetrV2PreTrainedModel):
1368
+ def __init__(self, config: RTDetrV2Config):
1369
+ super().__init__(config)
1370
+
1371
+ # Create backbone
1372
+ self.backbone = RTDetrV2ConvEncoder(config)
1373
+ intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
1374
+
1375
+ # Create encoder input projection layers
1376
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/hybrid_encoder.py#L212
1377
+ num_backbone_outs = len(intermediate_channel_sizes)
1378
+ encoder_input_proj_list = []
1379
+ for _ in range(num_backbone_outs):
1380
+ in_channels = intermediate_channel_sizes[_]
1381
+ encoder_input_proj_list.append(
1382
+ nn.Sequential(
1383
+ nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
1384
+ nn.BatchNorm2d(config.encoder_hidden_dim),
1385
+ )
1386
+ )
1387
+ self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list)
1388
+
1389
+ # Create encoder
1390
+ self.encoder = RTDetrV2HybridEncoder(config)
1391
+
1392
+ # denoising part
1393
+ if config.num_denoising > 0:
1394
+ self.denoising_class_embed = nn.Embedding(
1395
+ config.num_labels + 1, config.d_model, padding_idx=config.num_labels
1396
+ )
1397
+
1398
+ # decoder embedding
1399
+ if config.learn_initial_query:
1400
+ self.weight_embedding = nn.Embedding(config.num_queries, config.d_model)
1401
+
1402
+ # encoder head
1403
+ self.enc_output = nn.Sequential(
1404
+ nn.Linear(config.d_model, config.d_model),
1405
+ nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
1406
+ )
1407
+ self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
1408
+ self.enc_bbox_head = RTDetrV2MLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
1409
+
1410
+ # init encoder output anchors and valid_mask
1411
+ if config.anchor_image_size:
1412
+ self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
1413
+
1414
+ # Create decoder input projection layers
1415
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412
1416
+ num_backbone_outs = len(config.decoder_in_channels)
1417
+ decoder_input_proj_list = []
1418
+ for _ in range(num_backbone_outs):
1419
+ in_channels = config.decoder_in_channels[_]
1420
+ decoder_input_proj_list.append(
1421
+ nn.Sequential(
1422
+ nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
1423
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
1424
+ )
1425
+ )
1426
+ for _ in range(config.num_feature_levels - num_backbone_outs):
1427
+ decoder_input_proj_list.append(
1428
+ nn.Sequential(
1429
+ nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False),
1430
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
1431
+ )
1432
+ )
1433
+ in_channels = config.d_model
1434
+ self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list)
1435
+ # decoder
1436
+ self.decoder = RTDetrV2Decoder(config)
1437
+
1438
+ self.post_init()
1439
+
1440
+ def get_encoder(self):
1441
+ return self.encoder
1442
+
1443
+ def freeze_backbone(self):
1444
+ for param in self.backbone.parameters():
1445
+ param.requires_grad_(False)
1446
+
1447
+ def unfreeze_backbone(self):
1448
+ for param in self.backbone.parameters():
1449
+ param.requires_grad_(True)
1450
+
1451
+ @compile_compatible_method_lru_cache(maxsize=32)
1452
+ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
1453
+ if spatial_shapes is None:
1454
+ spatial_shapes = [
1455
+ [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
1456
+ for s in self.config.feat_strides
1457
+ ]
1458
+ anchors = []
1459
+ for level, (height, width) in enumerate(spatial_shapes):
1460
+ grid_y, grid_x = torch.meshgrid(
1461
+ torch.arange(end=height, device=device).to(dtype),
1462
+ torch.arange(end=width, device=device).to(dtype),
1463
+ indexing="ij",
1464
+ )
1465
+ grid_xy = torch.stack([grid_x, grid_y], -1)
1466
+ grid_xy = grid_xy.unsqueeze(0) + 0.5
1467
+ grid_xy[..., 0] /= width
1468
+ grid_xy[..., 1] /= height
1469
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
1470
+ anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
1471
+ # define the valid range for anchor coordinates
1472
+ eps = 1e-2
1473
+ anchors = torch.concat(anchors, 1)
1474
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
1475
+ anchors = torch.log(anchors / (1 - anchors))
1476
+ anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
1477
+
1478
+ return anchors, valid_mask
1479
+
1480
+ @auto_docstring
1481
+ def forward(
1482
+ self,
1483
+ pixel_values: torch.FloatTensor,
1484
+ pixel_mask: Optional[torch.LongTensor] = None,
1485
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1486
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1487
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1488
+ labels: Optional[list[dict]] = None,
1489
+ output_attentions: Optional[bool] = None,
1490
+ output_hidden_states: Optional[bool] = None,
1491
+ return_dict: Optional[bool] = None,
1492
+ ) -> Union[tuple[torch.FloatTensor], RTDetrV2ModelOutput]:
1493
+ r"""
1494
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1495
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1496
+ can choose to directly pass a flattened representation of an image.
1497
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1498
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1499
+ embedded representation.
1500
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1501
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1502
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1503
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1504
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1505
+
1506
+ Examples:
1507
+
1508
+ ```python
1509
+ >>> from transformers import AutoImageProcessor, RTDetrV2Model
1510
+ >>> from PIL import Image
1511
+ >>> import requests
1512
+
1513
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1514
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1515
+
1516
+ >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/RTDetrV2_r50vd")
1517
+ >>> model = RTDetrV2Model.from_pretrained("PekingU/RTDetrV2_r50vd")
1518
+
1519
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1520
+
1521
+ >>> outputs = model(**inputs)
1522
+
1523
+ >>> last_hidden_states = outputs.last_hidden_state
1524
+ >>> list(last_hidden_states.shape)
1525
+ [1, 300, 256]
1526
+ ```"""
1527
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1528
+ output_hidden_states = (
1529
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1530
+ )
1531
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1532
+
1533
+ batch_size, num_channels, height, width = pixel_values.shape
1534
+ device = pixel_values.device
1535
+
1536
+ if pixel_mask is None:
1537
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1538
+
1539
+ features = self.backbone(pixel_values, pixel_mask)
1540
+
1541
+ proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1542
+
1543
+ if encoder_outputs is None:
1544
+ encoder_outputs = self.encoder(
1545
+ proj_feats,
1546
+ output_attentions=output_attentions,
1547
+ output_hidden_states=output_hidden_states,
1548
+ return_dict=return_dict,
1549
+ )
1550
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1551
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1552
+ encoder_outputs = BaseModelOutput(
1553
+ last_hidden_state=encoder_outputs[0],
1554
+ hidden_states=encoder_outputs[1] if output_hidden_states else None,
1555
+ attentions=encoder_outputs[2]
1556
+ if len(encoder_outputs) > 2
1557
+ else encoder_outputs[1]
1558
+ if output_attentions
1559
+ else None,
1560
+ )
1561
+
1562
+ # Equivalent to def _get_encoder_input
1563
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412
1564
+ sources = []
1565
+ for level, source in enumerate(encoder_outputs[0]):
1566
+ sources.append(self.decoder_input_proj[level](source))
1567
+
1568
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
1569
+ if self.config.num_feature_levels > len(sources):
1570
+ _len_sources = len(sources)
1571
+ sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
1572
+ for i in range(_len_sources + 1, self.config.num_feature_levels):
1573
+ sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
1574
+
1575
+ # Prepare encoder inputs (by flattening)
1576
+ source_flatten = []
1577
+ spatial_shapes_list = []
1578
+ spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long)
1579
+ for level, source in enumerate(sources):
1580
+ height, width = source.shape[-2:]
1581
+ spatial_shapes[level, 0] = height
1582
+ spatial_shapes[level, 1] = width
1583
+ spatial_shapes_list.append((height, width))
1584
+ source = source.flatten(2).transpose(1, 2)
1585
+ source_flatten.append(source)
1586
+ source_flatten = torch.cat(source_flatten, 1)
1587
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
1588
+
1589
+ # prepare denoising training
1590
+ if self.training and self.config.num_denoising > 0 and labels is not None:
1591
+ (
1592
+ denoising_class,
1593
+ denoising_bbox_unact,
1594
+ attention_mask,
1595
+ denoising_meta_values,
1596
+ ) = get_contrastive_denoising_training_group(
1597
+ targets=labels,
1598
+ num_classes=self.config.num_labels,
1599
+ num_queries=self.config.num_queries,
1600
+ class_embed=self.denoising_class_embed,
1601
+ num_denoising_queries=self.config.num_denoising,
1602
+ label_noise_ratio=self.config.label_noise_ratio,
1603
+ box_noise_scale=self.config.box_noise_scale,
1604
+ )
1605
+ else:
1606
+ denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None
1607
+
1608
+ batch_size = len(source_flatten)
1609
+ device = source_flatten.device
1610
+ dtype = source_flatten.dtype
1611
+
1612
+ # prepare input for decoder
1613
+ if self.training or self.config.anchor_image_size is None:
1614
+ # Pass spatial_shapes as tuple to make it hashable and make sure
1615
+ # lru_cache is working for generate_anchors()
1616
+ spatial_shapes_tuple = tuple(spatial_shapes_list)
1617
+ anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
1618
+ else:
1619
+ anchors, valid_mask = self.anchors, self.valid_mask
1620
+ anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
1621
+
1622
+ # use the valid_mask to selectively retain values in the feature map where the mask is `True`
1623
+ memory = valid_mask.to(source_flatten.dtype) * source_flatten
1624
+
1625
+ output_memory = self.enc_output(memory)
1626
+
1627
+ enc_outputs_class = self.enc_score_head(output_memory)
1628
+ enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors
1629
+
1630
+ _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1)
1631
+
1632
+ reference_points_unact = enc_outputs_coord_logits.gather(
1633
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1])
1634
+ )
1635
+
1636
+ enc_topk_bboxes = F.sigmoid(reference_points_unact)
1637
+ if denoising_bbox_unact is not None:
1638
+ reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
1639
+
1640
+ enc_topk_logits = enc_outputs_class.gather(
1641
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
1642
+ )
1643
+
1644
+ # extract region features
1645
+ if self.config.learn_initial_query:
1646
+ target = self.weight_embedding.tile([batch_size, 1, 1])
1647
+ else:
1648
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
1649
+ target = target.detach()
1650
+
1651
+ if denoising_class is not None:
1652
+ target = torch.concat([denoising_class, target], 1)
1653
+
1654
+ init_reference_points = reference_points_unact.detach()
1655
+
1656
+ # decoder
1657
+ decoder_outputs = self.decoder(
1658
+ inputs_embeds=target,
1659
+ encoder_hidden_states=source_flatten,
1660
+ encoder_attention_mask=attention_mask,
1661
+ reference_points=init_reference_points,
1662
+ spatial_shapes=spatial_shapes,
1663
+ spatial_shapes_list=spatial_shapes_list,
1664
+ level_start_index=level_start_index,
1665
+ output_attentions=output_attentions,
1666
+ output_hidden_states=output_hidden_states,
1667
+ return_dict=return_dict,
1668
+ )
1669
+
1670
+ if not return_dict:
1671
+ enc_outputs = tuple(
1672
+ value
1673
+ for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
1674
+ if value is not None
1675
+ )
1676
+ dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
1677
+ tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
1678
+
1679
+ return tuple_outputs
1680
+
1681
+ return RTDetrV2ModelOutput(
1682
+ last_hidden_state=decoder_outputs.last_hidden_state,
1683
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
1684
+ intermediate_logits=decoder_outputs.intermediate_logits,
1685
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
1686
+ intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners,
1687
+ initial_reference_points=decoder_outputs.initial_reference_points,
1688
+ decoder_hidden_states=decoder_outputs.hidden_states,
1689
+ decoder_attentions=decoder_outputs.attentions,
1690
+ cross_attentions=decoder_outputs.cross_attentions,
1691
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1692
+ encoder_hidden_states=encoder_outputs.hidden_states,
1693
+ encoder_attentions=encoder_outputs.attentions,
1694
+ init_reference_points=init_reference_points,
1695
+ enc_topk_logits=enc_topk_logits,
1696
+ enc_topk_bboxes=enc_topk_bboxes,
1697
+ enc_outputs_class=enc_outputs_class,
1698
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
1699
+ denoising_meta_values=denoising_meta_values,
1700
+ )
1701
+
1702
+
1703
+ # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1704
+ class RTDetrV2MLPPredictionHead(nn.Module):
1705
+ """
1706
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1707
+ height and width of a bounding box w.r.t. an image.
1708
+
1709
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1710
+ Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_paddle/ppdet/modeling/transformers/utils.py#L453
1711
+
1712
+ """
1713
+
1714
+ def __init__(self, config, input_dim, d_model, output_dim, num_layers):
1715
+ super().__init__()
1716
+ self.num_layers = num_layers
1717
+ h = [d_model] * (num_layers - 1)
1718
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1719
+
1720
+ def forward(self, x):
1721
+ for i, layer in enumerate(self.layers):
1722
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1723
+ return x
1724
+
1725
+
1726
+ @dataclass
1727
+ @auto_docstring(
1728
+ custom_intro="""
1729
+ Output type of [`RTDetrV2ForObjectDetection`].
1730
+ """
1731
+ )
1732
+ class RTDetrV2ObjectDetectionOutput(ModelOutput):
1733
+ r"""
1734
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
1735
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
1736
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
1737
+ scale-invariant IoU loss.
1738
+ loss_dict (`Dict`, *optional*):
1739
+ A dictionary containing the individual losses. Useful for logging.
1740
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
1741
+ Classification logits (including no-object) for all queries.
1742
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1743
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
1744
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
1745
+ possible padding). You can use [`~RTDetrV2ImageProcessor.post_process_object_detection`] to retrieve the
1746
+ unnormalized (absolute) bounding boxes.
1747
+ auxiliary_outputs (`list[Dict]`, *optional*):
1748
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
1749
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
1750
+ `pred_boxes`) for each decoder layer.
1751
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1752
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
1753
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1754
+ Stacked intermediate hidden states (output of each layer of the decoder).
1755
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`):
1756
+ Stacked intermediate logits (logits of each layer of the decoder).
1757
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1758
+ Stacked intermediate reference points (reference points of each layer of the decoder).
1759
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1760
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
1761
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1762
+ Stacked initial reference points (initial reference points of each layer of the decoder).
1763
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1764
+ Initial reference points sent through the Transformer decoder.
1765
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1766
+ Logits of predicted bounding boxes coordinates in the encoder.
1767
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1768
+ Logits of predicted bounding boxes coordinates in the encoder.
1769
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1770
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
1771
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
1772
+ foreground and background).
1773
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1774
+ Logits of predicted bounding boxes coordinates in the first stage.
1775
+ denoising_meta_values (`dict`):
1776
+ Extra dictionary for the denoising related values
1777
+ """
1778
+
1779
+ loss: Optional[torch.FloatTensor] = None
1780
+ loss_dict: Optional[dict] = None
1781
+ logits: Optional[torch.FloatTensor] = None
1782
+ pred_boxes: Optional[torch.FloatTensor] = None
1783
+ auxiliary_outputs: Optional[list[dict]] = None
1784
+ last_hidden_state: Optional[torch.FloatTensor] = None
1785
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
1786
+ intermediate_logits: Optional[torch.FloatTensor] = None
1787
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
1788
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
1789
+ initial_reference_points: Optional[torch.FloatTensor] = None
1790
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
1791
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
1792
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
1793
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
1794
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
1795
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
1796
+ init_reference_points: Optional[tuple[torch.FloatTensor]] = None
1797
+ enc_topk_logits: Optional[torch.FloatTensor] = None
1798
+ enc_topk_bboxes: Optional[torch.FloatTensor] = None
1799
+ enc_outputs_class: Optional[torch.FloatTensor] = None
1800
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
1801
+ denoising_meta_values: Optional[dict] = None
1802
+
1803
+
1804
+ @auto_docstring(
1805
+ custom_intro="""
1806
+ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further
1807
+ decoded into scores and classes.
1808
+ """
1809
+ )
1810
+ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
1811
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
1812
+ _tied_weights_keys = ["bbox_embed", "class_embed"]
1813
+ # We can't initialize the model on meta device as some weights are modified during the initialization
1814
+ _no_split_modules = None
1815
+
1816
+ def __init__(self, config: RTDetrV2Config):
1817
+ super().__init__(config)
1818
+ # RTDETR encoder-decoder model
1819
+ self.model = RTDetrV2Model(config)
1820
+
1821
+ # Detection heads on top
1822
+ class_embed = partial(nn.Linear, config.d_model, config.num_labels)
1823
+ bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3)
1824
+
1825
+ self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)])
1826
+ self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)])
1827
+
1828
+ self.model.decoder.class_embed = self.class_embed
1829
+ self.model.decoder.bbox_embed = self.bbox_embed
1830
+
1831
+ # Initialize weights and apply final processing
1832
+ self.post_init()
1833
+
1834
+ @torch.jit.unused
1835
+ def _set_aux_loss(self, outputs_class, outputs_coord):
1836
+ # this is a workaround to make torchscript happy, as torchscript
1837
+ # doesn't support dictionary with non-homogeneous values, such
1838
+ # as a dict having both a Tensor and a list.
1839
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
1840
+
1841
+ @auto_docstring
1842
+ def forward(
1843
+ self,
1844
+ pixel_values: torch.FloatTensor,
1845
+ pixel_mask: Optional[torch.LongTensor] = None,
1846
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1847
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1848
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1849
+ labels: Optional[list[dict]] = None,
1850
+ output_attentions: Optional[bool] = None,
1851
+ output_hidden_states: Optional[bool] = None,
1852
+ return_dict: Optional[bool] = None,
1853
+ **kwargs,
1854
+ ) -> Union[tuple[torch.FloatTensor], RTDetrV2ObjectDetectionOutput]:
1855
+ r"""
1856
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1857
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1858
+ can choose to directly pass a flattened representation of an image.
1859
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1860
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1861
+ embedded representation.
1862
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1863
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1864
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1865
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1866
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1867
+
1868
+ Examples:
1869
+
1870
+ ```python
1871
+ >>> from transformers import RTDetrV2ImageProcessor, RTDetrV2ForObjectDetection
1872
+ >>> from PIL import Image
1873
+ >>> import requests
1874
+ >>> import torch
1875
+
1876
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1877
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1878
+
1879
+ >>> image_processor = RTDetrV2ImageProcessor.from_pretrained("PekingU/RTDetrV2_r50vd")
1880
+ >>> model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/RTDetrV2_r50vd")
1881
+
1882
+ >>> # prepare image for the model
1883
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1884
+
1885
+ >>> # forward pass
1886
+ >>> outputs = model(**inputs)
1887
+
1888
+ >>> logits = outputs.logits
1889
+ >>> list(logits.shape)
1890
+ [1, 300, 80]
1891
+
1892
+ >>> boxes = outputs.pred_boxes
1893
+ >>> list(boxes.shape)
1894
+ [1, 300, 4]
1895
+
1896
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
1897
+ >>> target_sizes = torch.tensor([image.size[::-1]])
1898
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
1899
+ ... 0
1900
+ ... ]
1901
+
1902
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
1903
+ ... box = [round(i, 2) for i in box.tolist()]
1904
+ ... print(
1905
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
1906
+ ... f"{round(score.item(), 3)} at location {box}"
1907
+ ... )
1908
+ Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21]
1909
+ Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5]
1910
+ Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22]
1911
+ Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
1912
+ Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
1913
+ ```"""
1914
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1915
+ output_hidden_states = (
1916
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1917
+ )
1918
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1919
+
1920
+ outputs = self.model(
1921
+ pixel_values,
1922
+ pixel_mask=pixel_mask,
1923
+ encoder_outputs=encoder_outputs,
1924
+ inputs_embeds=inputs_embeds,
1925
+ decoder_inputs_embeds=decoder_inputs_embeds,
1926
+ labels=labels,
1927
+ output_attentions=output_attentions,
1928
+ output_hidden_states=output_hidden_states,
1929
+ return_dict=return_dict,
1930
+ )
1931
+
1932
+ denoising_meta_values = (
1933
+ outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
1934
+ )
1935
+
1936
+ outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
1937
+ outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
1938
+ predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
1939
+ initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
1940
+
1941
+ logits = outputs_class[:, -1]
1942
+ pred_boxes = outputs_coord[:, -1]
1943
+
1944
+ loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
1945
+ if labels is not None:
1946
+ enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
1947
+ enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
1948
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1949
+ logits,
1950
+ labels,
1951
+ self.device,
1952
+ pred_boxes,
1953
+ self.config,
1954
+ outputs_class,
1955
+ outputs_coord,
1956
+ enc_topk_logits=enc_topk_logits,
1957
+ enc_topk_bboxes=enc_topk_bboxes,
1958
+ denoising_meta_values=denoising_meta_values,
1959
+ predicted_corners=predicted_corners,
1960
+ initial_reference_points=initial_reference_points,
1961
+ **kwargs,
1962
+ )
1963
+
1964
+ if not return_dict:
1965
+ if auxiliary_outputs is not None:
1966
+ output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
1967
+ else:
1968
+ output = (logits, pred_boxes) + outputs
1969
+ return ((loss, loss_dict) + output) if loss is not None else output
1970
+
1971
+ return RTDetrV2ObjectDetectionOutput(
1972
+ loss=loss,
1973
+ loss_dict=loss_dict,
1974
+ logits=logits,
1975
+ pred_boxes=pred_boxes,
1976
+ auxiliary_outputs=auxiliary_outputs,
1977
+ last_hidden_state=outputs.last_hidden_state,
1978
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
1979
+ intermediate_logits=outputs.intermediate_logits,
1980
+ intermediate_reference_points=outputs.intermediate_reference_points,
1981
+ intermediate_predicted_corners=outputs.intermediate_predicted_corners,
1982
+ initial_reference_points=outputs.initial_reference_points,
1983
+ decoder_hidden_states=outputs.decoder_hidden_states,
1984
+ decoder_attentions=outputs.decoder_attentions,
1985
+ cross_attentions=outputs.cross_attentions,
1986
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1987
+ encoder_hidden_states=outputs.encoder_hidden_states,
1988
+ encoder_attentions=outputs.encoder_attentions,
1989
+ init_reference_points=outputs.init_reference_points,
1990
+ enc_topk_logits=outputs.enc_topk_logits,
1991
+ enc_topk_bboxes=outputs.enc_topk_bboxes,
1992
+ enc_outputs_class=outputs.enc_outputs_class,
1993
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
1994
+ denoising_meta_values=outputs.denoising_meta_values,
1995
+ )
1996
+
1997
+
1998
+ __all__ = ["RTDetrV2Model", "RTDetrV2PreTrainedModel", "RTDetrV2ForObjectDetection"]
phivenv/Lib/site-packages/transformers/models/rt_detr_v2/modular_rt_detr_v2.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import warnings
16
+ from functools import partial
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import Tensor, nn
22
+
23
+ from ...configuration_utils import PretrainedConfig
24
+ from ...utils import is_torchdynamo_compiling, logging
25
+ from ...utils.backbone_utils import (
26
+ verify_backbone_config_arguments,
27
+ )
28
+ from ..auto import CONFIG_MAPPING
29
+ from ..rt_detr.modeling_rt_detr import (
30
+ RTDetrDecoder,
31
+ RTDetrDecoderLayer,
32
+ RTDetrForObjectDetection,
33
+ RTDetrMLPPredictionHead,
34
+ RTDetrModel,
35
+ RTDetrPreTrainedModel,
36
+ )
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class RTDetrV2Config(PretrainedConfig):
43
+ r"""
44
+ This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a
45
+ RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
46
+ with the defaults will yield a similar configuration to that of the RT-DETR architecture.
47
+
48
+ e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd)
49
+
50
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
51
+ documentation from [`PretrainedConfig`] for more information.
52
+
53
+ Args:
54
+ initializer_range (`float`, *optional*, defaults to 0.01):
55
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
56
+ initializer_bias_prior_prob (`float`, *optional*):
57
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
58
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
60
+ The epsilon used by the layer normalization layers.
61
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
62
+ The epsilon used by the batch normalization layers.
63
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
64
+ The configuration of the backbone model.
65
+ backbone (`str`, *optional*):
66
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
67
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
68
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
69
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
70
+ Whether to use pretrained weights for the backbone.
71
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
72
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
73
+ library.
74
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
75
+ Whether to freeze the batch normalization layers in the backbone.
76
+ backbone_kwargs (`dict`, *optional*):
77
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
78
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
79
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
80
+ Dimension of the layers in hybrid encoder.
81
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
82
+ Multi level features input for encoder.
83
+ feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
84
+ Strides used in each feature map.
85
+ encoder_layers (`int`, *optional*, defaults to 1):
86
+ Total of layers to be used by the encoder.
87
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
88
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
89
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
90
+ Number of attention heads for each attention layer in the Transformer encoder.
91
+ dropout (`float`, *optional*, defaults to 0.0):
92
+ The ratio for all dropout layers.
93
+ activation_dropout (`float`, *optional*, defaults to 0.0):
94
+ The dropout ratio for activations inside the fully connected layer.
95
+ encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
96
+ Indexes of the projected layers to be used in the encoder.
97
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
98
+ The temperature parameter used to create the positional encodings.
99
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
100
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
101
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
102
+ activation_function (`str`, *optional*, defaults to `"silu"`):
103
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
104
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
105
+ eval_size (`tuple[int, int]`, *optional*):
106
+ Height and width used to compute the effective height and width of the position embeddings after taking
107
+ into account the stride.
108
+ normalize_before (`bool`, *optional*, defaults to `False`):
109
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
110
+ feed-forward modules.
111
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
112
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
113
+ d_model (`int`, *optional*, defaults to 256):
114
+ Dimension of the layers exclude hybrid encoder.
115
+ num_queries (`int`, *optional*, defaults to 300):
116
+ Number of object queries.
117
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
118
+ Multi level features dimension for decoder
119
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
120
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
121
+ num_feature_levels (`int`, *optional*, defaults to 3):
122
+ The number of input feature levels.
123
+ decoder_n_points (`int`, *optional*, defaults to 4):
124
+ The number of sampled keys in each feature level for each attention head in the decoder.
125
+ decoder_layers (`int`, *optional*, defaults to 6):
126
+ Number of decoder layers.
127
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
128
+ Number of attention heads for each attention layer in the Transformer decoder.
129
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
130
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
131
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
132
+ attention_dropout (`float`, *optional*, defaults to 0.0):
133
+ The dropout ratio for the attention probabilities.
134
+ num_denoising (`int`, *optional*, defaults to 100):
135
+ The total number of denoising tasks or queries to be used for contrastive denoising.
136
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
137
+ The fraction of denoising labels to which random noise should be added.
138
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
139
+ Scale or magnitude of noise to be added to the bounding boxes.
140
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
141
+ Indicates whether the initial query embeddings for the decoder should be learned during training
142
+ anchor_image_size (`tuple[int, int]`, *optional*):
143
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
144
+ with_box_refine (`bool`, *optional*, defaults to `True`):
145
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
146
+ based on the predictions from the previous layer.
147
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
148
+ Whether the architecture has an encoder decoder structure.
149
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
150
+ Parameter alpha used by the Hungarian Matcher.
151
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
152
+ Parameter gamma used by the Hungarian Matcher.
153
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
154
+ The relative weight of the class loss used by the Hungarian Matcher.
155
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
156
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
157
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
158
+ The relative weight of the giou loss of used by the Hungarian Matcher.
159
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
160
+ Parameter informing if focal loss should be used.
161
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
162
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
163
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
164
+ Parameter alpha used to compute the focal loss.
165
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
166
+ Parameter gamma used to compute the focal loss.
167
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
168
+ Relative weight of the varifocal loss in the object detection loss.
169
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
170
+ Relative weight of the L1 bounding box loss in the object detection loss.
171
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
172
+ Relative weight of the generalized IoU loss in the object detection loss.
173
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
174
+ Relative classification weight of the 'no-object' class in the object detection loss.
175
+ decoder_n_levels (`int`, *optional*, defaults to 3):
176
+ The number of feature levels used by the decoder.
177
+ decoder_offset_scale (`float`, *optional*, defaults to 0.5):
178
+ Scaling factor applied to the attention offsets in the decoder.
179
+ decoder_method (`str`, *optional*, defaults to `"default"`):
180
+ The method to use for the decoder: `"default"` or `"discrete"`.
181
+
182
+ Examples:
183
+
184
+ ```python
185
+ >>> from transformers import RTDetrV2Config, RTDetrV2Model
186
+
187
+ >>> # Initializing a RT-DETR configuration
188
+ >>> configuration = RTDetrV2Config()
189
+
190
+ >>> # Initializing a model (with random weights) from the configuration
191
+ >>> model = RTDetrV2Model(configuration)
192
+
193
+ >>> # Accessing the model configuration
194
+ >>> configuration = model.config
195
+ ```
196
+ """
197
+
198
+ model_type = "rt_detr_v2"
199
+ layer_types = ["basic", "bottleneck"]
200
+ attribute_map = {
201
+ "hidden_size": "d_model",
202
+ "num_attention_heads": "encoder_attention_heads",
203
+ }
204
+
205
+ def __init__(
206
+ self,
207
+ initializer_range=0.01,
208
+ initializer_bias_prior_prob=None,
209
+ layer_norm_eps=1e-5,
210
+ batch_norm_eps=1e-5,
211
+ # backbone
212
+ backbone_config=None,
213
+ backbone=None,
214
+ use_pretrained_backbone=False,
215
+ use_timm_backbone=False,
216
+ freeze_backbone_batch_norms=True,
217
+ backbone_kwargs=None,
218
+ # encoder HybridEncoder
219
+ encoder_hidden_dim=256,
220
+ encoder_in_channels=[512, 1024, 2048],
221
+ feat_strides=[8, 16, 32],
222
+ encoder_layers=1,
223
+ encoder_ffn_dim=1024,
224
+ encoder_attention_heads=8,
225
+ dropout=0.0,
226
+ activation_dropout=0.0,
227
+ encode_proj_layers=[2],
228
+ positional_encoding_temperature=10000,
229
+ encoder_activation_function="gelu",
230
+ activation_function="silu",
231
+ eval_size=None,
232
+ normalize_before=False,
233
+ hidden_expansion=1.0,
234
+ # decoder RTDetrV2Transformer
235
+ d_model=256,
236
+ num_queries=300,
237
+ decoder_in_channels=[256, 256, 256],
238
+ decoder_ffn_dim=1024,
239
+ num_feature_levels=3,
240
+ decoder_n_points=4,
241
+ decoder_layers=6,
242
+ decoder_attention_heads=8,
243
+ decoder_activation_function="relu",
244
+ attention_dropout=0.0,
245
+ num_denoising=100,
246
+ label_noise_ratio=0.5,
247
+ box_noise_scale=1.0,
248
+ learn_initial_query=False,
249
+ anchor_image_size=None,
250
+ with_box_refine=True,
251
+ is_encoder_decoder=True,
252
+ # Loss
253
+ matcher_alpha=0.25,
254
+ matcher_gamma=2.0,
255
+ matcher_class_cost=2.0,
256
+ matcher_bbox_cost=5.0,
257
+ matcher_giou_cost=2.0,
258
+ use_focal_loss=True,
259
+ auxiliary_loss=True,
260
+ focal_loss_alpha=0.75,
261
+ focal_loss_gamma=2.0,
262
+ weight_loss_vfl=1.0,
263
+ weight_loss_bbox=5.0,
264
+ weight_loss_giou=2.0,
265
+ eos_coefficient=1e-4,
266
+ decoder_n_levels=3, # default value
267
+ decoder_offset_scale=0.5, # default value
268
+ decoder_method="default",
269
+ **kwargs,
270
+ ):
271
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
272
+ self.initializer_range = initializer_range
273
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
274
+ self.layer_norm_eps = layer_norm_eps
275
+ self.batch_norm_eps = batch_norm_eps
276
+ # backbone
277
+ if backbone_config is None and backbone is None:
278
+ logger.info(
279
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone."
280
+ )
281
+ backbone_model_type = "rt_detr_resnet"
282
+ config_class = CONFIG_MAPPING[backbone_model_type]
283
+ # this will map it to RTDetrResNetConfig
284
+ # note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1
285
+ # and we would need to create RTDetrV2ResNetModel
286
+ backbone_config = config_class(
287
+ num_channels=3,
288
+ embedding_size=64,
289
+ hidden_sizes=[256, 512, 1024, 2048],
290
+ depths=[3, 4, 6, 3],
291
+ layer_type="bottleneck",
292
+ hidden_act="relu",
293
+ downsample_in_first_stage=False,
294
+ downsample_in_bottleneck=False,
295
+ out_features=None,
296
+ out_indices=[2, 3, 4],
297
+ )
298
+ elif isinstance(backbone_config, dict):
299
+ backbone_model_type = backbone_config.pop("model_type")
300
+ config_class = CONFIG_MAPPING[backbone_model_type]
301
+ backbone_config = config_class.from_dict(backbone_config)
302
+
303
+ verify_backbone_config_arguments(
304
+ use_timm_backbone=use_timm_backbone,
305
+ use_pretrained_backbone=use_pretrained_backbone,
306
+ backbone=backbone,
307
+ backbone_config=backbone_config,
308
+ backbone_kwargs=backbone_kwargs,
309
+ )
310
+
311
+ self.backbone_config = backbone_config
312
+ self.backbone = backbone
313
+ self.use_pretrained_backbone = use_pretrained_backbone
314
+ self.use_timm_backbone = use_timm_backbone
315
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
316
+ self.backbone_kwargs = backbone_kwargs
317
+ # encoder
318
+ self.encoder_hidden_dim = encoder_hidden_dim
319
+ self.encoder_in_channels = encoder_in_channels
320
+ self.feat_strides = feat_strides
321
+ self.encoder_ffn_dim = encoder_ffn_dim
322
+ self.dropout = dropout
323
+ self.activation_dropout = activation_dropout
324
+ self.encode_proj_layers = encode_proj_layers
325
+ self.encoder_layers = encoder_layers
326
+ self.positional_encoding_temperature = positional_encoding_temperature
327
+ self.eval_size = eval_size
328
+ self.normalize_before = normalize_before
329
+ self.encoder_activation_function = encoder_activation_function
330
+ self.activation_function = activation_function
331
+ self.hidden_expansion = hidden_expansion
332
+ self.num_queries = num_queries
333
+ self.decoder_ffn_dim = decoder_ffn_dim
334
+ self.decoder_in_channels = decoder_in_channels
335
+ self.num_feature_levels = num_feature_levels
336
+ self.decoder_n_points = decoder_n_points
337
+ self.decoder_layers = decoder_layers
338
+ self.decoder_attention_heads = decoder_attention_heads
339
+ self.decoder_activation_function = decoder_activation_function
340
+ self.attention_dropout = attention_dropout
341
+ self.num_denoising = num_denoising
342
+ self.label_noise_ratio = label_noise_ratio
343
+ self.box_noise_scale = box_noise_scale
344
+ self.learn_initial_query = learn_initial_query
345
+ self.anchor_image_size = anchor_image_size
346
+ self.auxiliary_loss = auxiliary_loss
347
+ self.with_box_refine = with_box_refine
348
+ # Loss
349
+ self.matcher_alpha = matcher_alpha
350
+ self.matcher_gamma = matcher_gamma
351
+ self.matcher_class_cost = matcher_class_cost
352
+ self.matcher_bbox_cost = matcher_bbox_cost
353
+ self.matcher_giou_cost = matcher_giou_cost
354
+ self.use_focal_loss = use_focal_loss
355
+ self.focal_loss_alpha = focal_loss_alpha
356
+ self.focal_loss_gamma = focal_loss_gamma
357
+ self.weight_loss_vfl = weight_loss_vfl
358
+ self.weight_loss_bbox = weight_loss_bbox
359
+ self.weight_loss_giou = weight_loss_giou
360
+ self.eos_coefficient = eos_coefficient
361
+
362
+ if not hasattr(self, "d_model"):
363
+ self.d_model = d_model
364
+
365
+ if not hasattr(self, "encoder_attention_heads"):
366
+ self.encoder_attention_heads = encoder_attention_heads
367
+ # add the new attributes with the given values or defaults
368
+ self.decoder_n_levels = decoder_n_levels
369
+ self.decoder_offset_scale = decoder_offset_scale
370
+ self.decoder_method = decoder_method
371
+
372
+ @property
373
+ def sub_configs(self):
374
+ return (
375
+ {"backbone_config": type(self.backbone_config)}
376
+ if getattr(self, "backbone_config", None) is not None
377
+ else {}
378
+ )
379
+
380
+ @classmethod
381
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
382
+ """Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model
383
+ configuration.
384
+
385
+ Args:
386
+ backbone_config ([`PretrainedConfig`]):
387
+ The backbone configuration.
388
+
389
+ Returns:
390
+ [`RTDetrV2Config`]: An instance of a configuration object
391
+ """
392
+ return cls(
393
+ backbone_config=backbone_config,
394
+ **kwargs,
395
+ )
396
+
397
+
398
+ def multi_scale_deformable_attention_v2(
399
+ value: Tensor,
400
+ value_spatial_shapes: Tensor,
401
+ sampling_locations: Tensor,
402
+ attention_weights: Tensor,
403
+ num_points_list: list[int],
404
+ method="default",
405
+ ) -> Tensor:
406
+ batch_size, _, num_heads, hidden_dim = value.shape
407
+ _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape
408
+ value_list = (
409
+ value.permute(0, 2, 3, 1)
410
+ .flatten(0, 1)
411
+ .split([height * width for height, width in value_spatial_shapes], dim=-1)
412
+ )
413
+ # sampling_offsets [8, 480, 8, 12, 2]
414
+ if method == "default":
415
+ sampling_grids = 2 * sampling_locations - 1
416
+ elif method == "discrete":
417
+ sampling_grids = sampling_locations
418
+ sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
419
+ sampling_grids = sampling_grids.split(num_points_list, dim=-2)
420
+ sampling_value_list = []
421
+ for level_id, (height, width) in enumerate(value_spatial_shapes):
422
+ # batch_size, height*width, num_heads, hidden_dim
423
+ # -> batch_size, height*width, num_heads*hidden_dim
424
+ # -> batch_size, num_heads*hidden_dim, height*width
425
+ # -> batch_size*num_heads, hidden_dim, height, width
426
+ value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width)
427
+ # batch_size, num_queries, num_heads, num_points, 2
428
+ # -> batch_size, num_heads, num_queries, num_points, 2
429
+ # -> batch_size*num_heads, num_queries, num_points, 2
430
+ sampling_grid_l_ = sampling_grids[level_id]
431
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
432
+ if method == "default":
433
+ sampling_value_l_ = nn.functional.grid_sample(
434
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
435
+ )
436
+ elif method == "discrete":
437
+ sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to(
438
+ torch.int64
439
+ )
440
+
441
+ # Separate clamping for x and y coordinates
442
+ sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1)
443
+ sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1)
444
+
445
+ # Combine the clamped coordinates
446
+ sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1)
447
+ sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2)
448
+ sampling_idx = (
449
+ torch.arange(sampling_coord.shape[0], device=value.device)
450
+ .unsqueeze(-1)
451
+ .repeat(1, sampling_coord.shape[1])
452
+ )
453
+ sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]]
454
+ sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape(
455
+ batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id]
456
+ )
457
+ sampling_value_list.append(sampling_value_l_)
458
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
459
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
460
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
461
+ attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(
462
+ batch_size * num_heads, 1, num_queries, sum(num_points_list)
463
+ )
464
+ output = (
465
+ (torch.concat(sampling_value_list, dim=-1) * attention_weights)
466
+ .sum(-1)
467
+ .view(batch_size, num_heads * hidden_dim, num_queries)
468
+ )
469
+ return output.transpose(1, 2).contiguous()
470
+
471
+
472
+ # the main change
473
+ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
474
+ """
475
+ RTDetrV2 version of multiscale deformable attention, extending the base implementation
476
+ with improved offset handling and initialization.
477
+ """
478
+
479
+ def __init__(self, config: RTDetrV2Config):
480
+ super().__init__()
481
+ num_heads = config.decoder_attention_heads
482
+ n_points = config.decoder_n_points
483
+
484
+ if config.d_model % num_heads != 0:
485
+ raise ValueError(
486
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
487
+ )
488
+ dim_per_head = config.d_model // num_heads
489
+ # check if dim_per_head is power of 2
490
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
491
+ warnings.warn(
492
+ "You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the"
493
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
494
+ " implementation."
495
+ )
496
+
497
+ self.im2col_step = 64
498
+
499
+ self.d_model = config.d_model
500
+
501
+ # V2-specific attributes
502
+ self.n_levels = config.decoder_n_levels
503
+ self.n_heads = num_heads
504
+ self.n_points = n_points
505
+
506
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
507
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
508
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
509
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
510
+
511
+ self.offset_scale = config.decoder_offset_scale
512
+ self.method = config.decoder_method
513
+
514
+ # Initialize n_points list and scale
515
+ n_points_list = [self.n_points for _ in range(self.n_levels)]
516
+ self.n_points_list = n_points_list
517
+ n_points_scale = [1 / n for n in n_points_list for _ in range(n)]
518
+ self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32))
519
+
520
+ def forward(
521
+ self,
522
+ hidden_states: torch.Tensor,
523
+ attention_mask: Optional[torch.Tensor] = None,
524
+ encoder_hidden_states=None,
525
+ encoder_attention_mask=None,
526
+ position_embeddings: Optional[torch.Tensor] = None,
527
+ reference_points=None,
528
+ spatial_shapes=None,
529
+ spatial_shapes_list=None,
530
+ level_start_index=None,
531
+ output_attentions: bool = False,
532
+ ):
533
+ # Process inputs up to sampling locations calculation using parent class logic
534
+ if position_embeddings is not None:
535
+ hidden_states = hidden_states + position_embeddings
536
+
537
+ batch_size, num_queries, _ = hidden_states.shape
538
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
539
+ if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
540
+ raise ValueError(
541
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
542
+ )
543
+
544
+ value = self.value_proj(encoder_hidden_states)
545
+ if attention_mask is not None:
546
+ value = value.masked_fill(~attention_mask[..., None], float(0))
547
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
548
+
549
+ # V2-specific sampling offsets shape
550
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
551
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2
552
+ )
553
+
554
+ attention_weights = self.attention_weights(hidden_states).view(
555
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
556
+ )
557
+ attention_weights = F.softmax(attention_weights, -1)
558
+
559
+ # V2-specific sampling locations calculation
560
+ if reference_points.shape[-1] == 2:
561
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
562
+ sampling_locations = (
563
+ reference_points[:, :, None, :, None, :]
564
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
565
+ )
566
+ elif reference_points.shape[-1] == 4:
567
+ n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
568
+ offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
569
+ sampling_locations = reference_points[:, :, None, :, :2] + offset
570
+ else:
571
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
572
+
573
+ # V2-specific attention implementation choice
574
+ output = multi_scale_deformable_attention_v2(
575
+ value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
576
+ )
577
+
578
+ output = self.output_proj(output)
579
+ return output, attention_weights
580
+
581
+
582
+ class RTDetrV2DecoderLayer(RTDetrDecoderLayer):
583
+ def __init__(self, config: RTDetrV2Config):
584
+ # initialize parent class
585
+ super().__init__(config)
586
+ # override only the encoder attention module with v2 version
587
+ self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config)
588
+
589
+
590
+ class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel):
591
+ pass
592
+
593
+
594
+ class RTDetrV2Decoder(RTDetrDecoder):
595
+ def __init__(self, config: RTDetrV2Config):
596
+ super().__init__(config)
597
+ self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)])
598
+
599
+
600
+ class RTDetrV2Model(RTDetrModel):
601
+ def __init__(self, config: RTDetrV2Config):
602
+ super().__init__(config)
603
+ # decoder
604
+ self.decoder = RTDetrV2Decoder(config)
605
+
606
+
607
+ class RTDetrV2MLPPredictionHead(RTDetrMLPPredictionHead):
608
+ pass
609
+
610
+
611
+ class RTDetrV2ForObjectDetection(RTDetrForObjectDetection, RTDetrV2PreTrainedModel):
612
+ def __init__(self, config: RTDetrV2Config):
613
+ RTDetrV2PreTrainedModel.__init__(self, config)
614
+ # RTDETR encoder-decoder model
615
+ self.model = RTDetrV2Model(config)
616
+
617
+ # Detection heads on top
618
+ class_embed = partial(nn.Linear, config.d_model, config.num_labels)
619
+ bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3)
620
+
621
+ self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)])
622
+ self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)])
623
+
624
+ self.model.decoder.class_embed = self.class_embed
625
+ self.model.decoder.bbox_embed = self.bbox_embed
626
+
627
+ # Initialize weights and apply final processing
628
+ self.post_init()
629
+
630
+
631
+ __all__ = [
632
+ "RTDetrV2Config",
633
+ "RTDetrV2Model",
634
+ "RTDetrV2PreTrainedModel",
635
+ "RTDetrV2ForObjectDetection",
636
+ ]
phivenv/Lib/site-packages/transformers/models/rwkv/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_rwkv import *
22
+ from .modeling_rwkv import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (522 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/configuration_rwkv.cpython-39.pyc ADDED
Binary file (4.36 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rwkv/__pycache__/modeling_rwkv.cpython-39.pyc ADDED
Binary file (23.5 kB). View file
 
phivenv/Lib/site-packages/transformers/models/rwkv/configuration_rwkv.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """RWKV configuration"""
17
+
18
+ from ...configuration_utils import PretrainedConfig
19
+ from ...utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class RwkvConfig(PretrainedConfig):
26
+ """
27
+ This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV
28
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the RWVK-4
30
+ [RWKV/rwkv-4-169m-pile](https://huggingface.co/RWKV/rwkv-4-169m-pile) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 50277):
38
+ Vocabulary size of the RWKV model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`RwkvModel`].
40
+ context_length (`int`, *optional*, defaults to 1024):
41
+ The maximum sequence length that this model can be used with in a single forward (using it in RNN mode
42
+ lets use any sequence length).
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimensionality of the embeddings and hidden states.
45
+ num_hidden_layers (`int`, *optional*, defaults to 32):
46
+ Number of hidden layers in the model.
47
+ attention_hidden_size (`int`, *optional*):
48
+ Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
49
+ intermediate_size (`int`, *optional*):
50
+ Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
51
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
52
+ The epsilon to use in the layer normalization layers.
53
+ bos_token_id (`int`, *optional*, defaults to 0):
54
+ The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer
55
+ as GPTNeoX.
56
+ eos_token_id (`int`, *optional*, defaults to 0):
57
+ The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as
58
+ GPTNeoX.
59
+ rescale_every (`int`, *optional*, defaults to 6):
60
+ At inference, the hidden states (and weights of the corresponding output layers) are divided by 2 every
61
+ `rescale_every` layer. If set to 0 or a negative number, no rescale is done.
62
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
63
+ Whether or not to tie the word embeddings with the input token embeddings.
64
+ use_cache (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should return the last state.
66
+
67
+
68
+ Example:
69
+
70
+ ```python
71
+ >>> from transformers import RwkvConfig, RwkvModel
72
+
73
+ >>> # Initializing a Rwkv configuration
74
+ >>> configuration = RwkvConfig()
75
+
76
+ >>> # Initializing a model (with random weights) from the configuration
77
+ >>> model = RwkvModel(configuration)
78
+
79
+ >>> # Accessing the model configuration
80
+ >>> configuration = model.config
81
+ ```"""
82
+
83
+ model_type = "rwkv"
84
+ attribute_map = {"max_position_embeddings": "context_length"}
85
+
86
+ def __init__(
87
+ self,
88
+ vocab_size=50277,
89
+ context_length=1024,
90
+ hidden_size=4096,
91
+ num_hidden_layers=32,
92
+ attention_hidden_size=None,
93
+ intermediate_size=None,
94
+ layer_norm_epsilon=1e-5,
95
+ bos_token_id=0,
96
+ eos_token_id=0,
97
+ rescale_every=6,
98
+ tie_word_embeddings=False,
99
+ use_cache=True,
100
+ **kwargs,
101
+ ):
102
+ self.vocab_size = vocab_size
103
+ self.context_length = context_length
104
+ self.hidden_size = hidden_size
105
+ self.num_hidden_layers = num_hidden_layers
106
+ self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
107
+ self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size
108
+ self.layer_norm_epsilon = layer_norm_epsilon
109
+ self.rescale_every = rescale_every
110
+ self.use_cache = use_cache
111
+
112
+ self.bos_token_id = bos_token_id
113
+ self.eos_token_id = eos_token_id
114
+
115
+ super().__init__(
116
+ tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
117
+ )
118
+
119
+
120
+ __all__ = ["RwkvConfig"]
phivenv/Lib/site-packages/transformers/models/rwkv/modeling_rwkv.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Bo Peng and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RWKV model."""
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from typing import Optional, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+
27
+ from ...generation import GenerationMixin
28
+ from ...modeling_layers import GradientCheckpointingLayer
29
+ from ...modeling_utils import PreTrainedModel
30
+ from ...utils import (
31
+ ModelOutput,
32
+ auto_docstring,
33
+ is_bitsandbytes_available,
34
+ is_ninja_available,
35
+ is_torch_cuda_available,
36
+ logging,
37
+ )
38
+ from .configuration_rwkv import RwkvConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ rwkv_cuda_kernel = None
45
+
46
+
47
+ def load_wkv_cuda_kernel(context_length):
48
+ from torch.utils.cpp_extension import load as load_kernel
49
+
50
+ global rwkv_cuda_kernel
51
+
52
+ kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv"
53
+ cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]]
54
+
55
+ # Only load the kernel if it's not been loaded yet or if we changed the context length
56
+ if rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == context_length:
57
+ return
58
+
59
+ logger.info(f"Loading CUDA kernel for RWKV at context length of {context_length}.")
60
+
61
+ flags = [
62
+ "-res-usage",
63
+ "--maxrregcount 60",
64
+ "--use_fast_math",
65
+ "-O3",
66
+ "-Xptxas -O3",
67
+ "--extra-device-vectorization",
68
+ f"-DTmax={context_length}",
69
+ ]
70
+ rwkv_cuda_kernel = load_kernel(
71
+ name=f"wkv_{context_length}",
72
+ sources=cuda_kernel_files,
73
+ verbose=(logging.get_verbosity() == logging.DEBUG),
74
+ extra_cuda_cflags=flags,
75
+ )
76
+ rwkv_cuda_kernel.max_seq_length = context_length
77
+
78
+
79
+ class RwkvLinearAttention(torch.autograd.Function):
80
+ @staticmethod
81
+ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
82
+ batch_size, seq_len, hidden_size = key.size()
83
+ if seq_len > rwkv_cuda_kernel.max_seq_length:
84
+ raise ValueError(
85
+ f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
86
+ f"{rwkv_cuda_kernel.max_seq_length} with this model."
87
+ )
88
+ if batch_size * hidden_size % min(hidden_size, 32) != 0:
89
+ raise ValueError(
90
+ f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
91
+ f"multiple of {min(hidden_size, 32)}."
92
+ )
93
+
94
+ ctx.input_dtype = key.dtype
95
+
96
+ if (
97
+ time_decay.device.type != "cuda"
98
+ or time_first.device.type != "cuda"
99
+ or key.device.type != "cuda"
100
+ or value.device.type != "cuda"
101
+ ):
102
+ raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
103
+
104
+ time_decay = -torch.exp(time_decay.float().contiguous())
105
+ if key.dtype == torch.float16:
106
+ time_first = time_first.float()
107
+ key = key.float()
108
+ value = value.float()
109
+ time_first = time_first.contiguous()
110
+ key = key.contiguous()
111
+ value = value.contiguous()
112
+ # The CUDA kernel will fill this tensor.
113
+ output = torch.empty_like(key, memory_format=torch.contiguous_format)
114
+ if return_state or state is not None:
115
+ if state is None:
116
+ state = torch.zeros(
117
+ batch_size,
118
+ hidden_size,
119
+ 3,
120
+ dtype=torch.float32,
121
+ device=key.device,
122
+ memory_format=torch.contiguous_format,
123
+ )
124
+ state[:, :, 2] -= 1e38
125
+ else:
126
+ state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()
127
+ if key.dtype == torch.bfloat16:
128
+ forward_func = rwkv_cuda_kernel.forward_with_state_bf16
129
+ else:
130
+ forward_func = rwkv_cuda_kernel.forward_with_state
131
+ forward_func(time_decay, time_first, key, value, output, state)
132
+ else:
133
+ forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
134
+ forward_func(time_decay, time_first, key, value, output)
135
+
136
+ ctx.save_for_backward(time_decay, time_first, key, value, output)
137
+
138
+ if state is not None:
139
+ state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]
140
+
141
+ return output.to(ctx.input_dtype), state
142
+
143
+ @staticmethod
144
+ # g stands for grad
145
+ def backward(ctx, g_output, g_state=None):
146
+ input_dtype = ctx.input_dtype
147
+
148
+ time_decay, time_first, key, value, output = ctx.saved_tensors
149
+ # The CUDA kernel will fill those tensors.
150
+ g_time_decay = torch.empty_like(
151
+ time_decay,
152
+ memory_format=torch.contiguous_format,
153
+ dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
154
+ )
155
+ g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)
156
+ g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
157
+ g_value = torch.empty_like(value, memory_format=torch.contiguous_format)
158
+
159
+ if input_dtype == torch.float16:
160
+ g_output = g_output.float()
161
+ backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
162
+ backward_func(
163
+ time_decay,
164
+ time_first,
165
+ key,
166
+ value,
167
+ output,
168
+ g_output.contiguous(),
169
+ g_time_decay,
170
+ g_time_first,
171
+ g_key,
172
+ g_value,
173
+ )
174
+
175
+ return (
176
+ g_time_decay.to(input_dtype),
177
+ g_time_first.to(input_dtype),
178
+ g_key.to(input_dtype),
179
+ g_value.to(input_dtype),
180
+ None,
181
+ None,
182
+ )
183
+
184
+
185
+ def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False):
186
+ # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
187
+ # within a torch.no_grad.
188
+ _, seq_length, _ = key.size()
189
+ output = torch.zeros_like(key)
190
+
191
+ if state is None:
192
+ num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
193
+ den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
194
+ max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
195
+ else:
196
+ num_state, den_state, max_state = state
197
+ # For numerical stability
198
+ # real_numerator_state = num_state * torch.exp(max_state)
199
+ # real_denominator_state = den_state * torch.exp(max_state)
200
+
201
+ time_decay = -torch.exp(time_decay)
202
+
203
+ for current_index in range(seq_length):
204
+ current_key = key[:, current_index].float()
205
+ current_value = value[:, current_index]
206
+
207
+ # wkv computation at time t
208
+ max_for_output = torch.maximum(max_state, current_key + time_first)
209
+ e1 = torch.exp(max_state - max_for_output)
210
+ e2 = torch.exp(current_key + time_first - max_for_output)
211
+ numerator = e1 * num_state + e2 * current_value
212
+ denominator = e1 * den_state + e2
213
+ output[:, current_index] = (numerator / denominator).to(output.dtype)
214
+
215
+ # Update state for next iteration
216
+ max_for_state = torch.maximum(max_state + time_decay, current_key)
217
+ e1 = torch.exp(max_state + time_decay - max_for_state)
218
+ e2 = torch.exp(current_key - max_for_state)
219
+ num_state = e1 * num_state + e2 * current_value
220
+ den_state = e1 * den_state + e2
221
+ max_state = max_for_state
222
+
223
+ if return_state or state is not None:
224
+ state = [num_state, den_state, max_state]
225
+
226
+ return output, state
227
+
228
+
229
+ def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False):
230
+ no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
231
+ # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
232
+ # in this case).
233
+ one_token = key.size(1) == 1
234
+ if rwkv_cuda_kernel is None or no_cuda or one_token:
235
+ return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
236
+ else:
237
+ return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)
238
+
239
+
240
+ class RwkvSelfAttention(nn.Module):
241
+ def __init__(self, config, layer_id=0):
242
+ super().__init__()
243
+ self.config = config
244
+ kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length
245
+ if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
246
+ try:
247
+ load_wkv_cuda_kernel(config.context_length)
248
+ except Exception:
249
+ logger.info("Could not load the custom CUDA kernel for RWKV attention.")
250
+ self.layer_id = layer_id
251
+ hidden_size = config.hidden_size
252
+ attention_hidden_size = (
253
+ config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
254
+ )
255
+ self.attention_hidden_size = attention_hidden_size
256
+
257
+ self.time_decay = nn.Parameter(torch.empty(attention_hidden_size))
258
+ self.time_first = nn.Parameter(torch.empty(attention_hidden_size))
259
+
260
+ self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
261
+ self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
262
+ self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
263
+
264
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
265
+ self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
266
+ self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
267
+ self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
268
+ self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
269
+
270
+ # TODO: maybe jit, otherwise move inside forward
271
+ def extract_key_value(self, hidden, state=None):
272
+ # Mix hidden with the previous timestep to produce key, value, receptance
273
+ if hidden.size(1) == 1 and state is not None:
274
+ shifted = state[1][:, :, self.layer_id]
275
+ else:
276
+ shifted = self.time_shift(hidden)
277
+ if state is not None:
278
+ shifted[:, 0] = state[1][:, :, self.layer_id]
279
+ key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
280
+ value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
281
+ receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
282
+
283
+ key = self.key(key)
284
+ value = self.value(value)
285
+ receptance = torch.sigmoid(self.receptance(receptance))
286
+ if state is not None:
287
+ state[1][:, :, self.layer_id] = hidden[:, -1]
288
+ return receptance, key, value, state
289
+
290
+ def forward(self, hidden, state=None, use_cache=False):
291
+ receptance, key, value, state = self.extract_key_value(hidden, state=state)
292
+ layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
293
+ rwkv, layer_state = rwkv_linear_attention(
294
+ self.time_decay,
295
+ self.time_first,
296
+ key,
297
+ value,
298
+ state=layer_state,
299
+ return_state=use_cache,
300
+ )
301
+
302
+ if layer_state is not None:
303
+ state[2][:, :, self.layer_id] = layer_state[0]
304
+ state[3][:, :, self.layer_id] = layer_state[1]
305
+ state[4][:, :, self.layer_id] = layer_state[2]
306
+
307
+ return self.output(receptance * rwkv), state
308
+
309
+
310
+ class RwkvFeedForward(nn.Module):
311
+ def __init__(self, config, layer_id=0):
312
+ super().__init__()
313
+ self.config = config
314
+ self.layer_id = layer_id
315
+ hidden_size = config.hidden_size
316
+ intermediate_size = (
317
+ config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
318
+ )
319
+
320
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
321
+ self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
322
+ self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
323
+
324
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
325
+ self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
326
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
327
+
328
+ def forward(self, hidden, state=None):
329
+ if hidden.size(1) == 1 and state is not None:
330
+ shifted = state[0][:, :, self.layer_id]
331
+ else:
332
+ shifted = self.time_shift(hidden)
333
+ if state is not None:
334
+ shifted[:, 0] = state[0][:, :, self.layer_id]
335
+ key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
336
+ receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
337
+
338
+ key = torch.square(torch.relu(self.key(key)))
339
+ value = self.value(key)
340
+ receptance = torch.sigmoid(self.receptance(receptance))
341
+
342
+ if state is not None:
343
+ state[0][:, :, self.layer_id] = hidden[:, -1]
344
+
345
+ return receptance * value, state
346
+
347
+
348
+ class RwkvBlock(GradientCheckpointingLayer):
349
+ def __init__(self, config, layer_id):
350
+ super().__init__()
351
+ self.config = config
352
+ self.layer_id = layer_id
353
+
354
+ if layer_id == 0:
355
+ self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
356
+
357
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
358
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
359
+
360
+ self.attention = RwkvSelfAttention(config, layer_id)
361
+ self.feed_forward = RwkvFeedForward(config, layer_id)
362
+
363
+ def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
364
+ if self.layer_id == 0:
365
+ hidden = self.pre_ln(hidden)
366
+
367
+ attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
368
+ hidden = hidden + attention
369
+
370
+ feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
371
+ hidden = hidden + feed_forward
372
+
373
+ outputs = (hidden, state)
374
+ if output_attentions:
375
+ outputs += (attention,)
376
+ else:
377
+ outputs += (None,)
378
+
379
+ return outputs
380
+
381
+
382
+ @auto_docstring
383
+ class RwkvPreTrainedModel(PreTrainedModel):
384
+ config: RwkvConfig
385
+ base_model_prefix = "rwkv"
386
+ _no_split_modules = ["RwkvBlock"]
387
+ _keep_in_fp32_modules = ["time_decay", "time_first"]
388
+ supports_gradient_checkpointing = True
389
+ _is_stateful = True
390
+
391
+ def _init_weights(self, module: nn.Module):
392
+ """Initialize the weights."""
393
+ if isinstance(module, RwkvSelfAttention):
394
+ layer_id = module.layer_id
395
+ num_hidden_layers = module.config.num_hidden_layers
396
+ hidden_size = module.config.hidden_size
397
+ attention_hidden_size = module.attention_hidden_size
398
+
399
+ ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
400
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
401
+
402
+ time_weight = torch.tensor(
403
+ [i / hidden_size for i in range(hidden_size)],
404
+ dtype=module.time_mix_key.dtype,
405
+ device=module.time_mix_key.device,
406
+ )
407
+ time_weight = time_weight[None, None, :]
408
+
409
+ decay_speed = [
410
+ -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
411
+ for h in range(attention_hidden_size)
412
+ ]
413
+ decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
414
+ zigzag = (
415
+ torch.tensor(
416
+ [(i + 1) % 3 - 1 for i in range(attention_hidden_size)],
417
+ dtype=module.time_first.dtype,
418
+ device=module.time_first.device,
419
+ )
420
+ * 0.5
421
+ )
422
+
423
+ module.time_decay.data = decay_speed
424
+ module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag)
425
+
426
+ module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
427
+ module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
428
+ module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
429
+ elif isinstance(module, RwkvFeedForward):
430
+ layer_id = module.layer_id
431
+ num_hidden_layers = module.config.num_hidden_layers
432
+ hidden_size = module.config.hidden_size
433
+
434
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
435
+
436
+ time_weight = torch.tensor(
437
+ [i / hidden_size for i in range(hidden_size)],
438
+ dtype=module.time_mix_key.dtype,
439
+ device=module.time_mix_key.device,
440
+ )
441
+ time_weight = time_weight[None, None, :]
442
+
443
+ module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
444
+ module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
445
+ elif isinstance(module, nn.Linear):
446
+ shape = module.weight.data.shape
447
+ gain = 1.0
448
+ scale = 1.0 # extra scale for gain
449
+ if module.bias is not None:
450
+ module.bias.data.zero_()
451
+ if shape[0] > shape[1]:
452
+ gain = math.sqrt(shape[0] / shape[1])
453
+ if shape[0] == self.config.vocab_size and shape[1] == self.config.hidden_size: # final projection?
454
+ scale = 0.5
455
+
456
+ gain *= scale
457
+ nn.init.orthogonal_(module.weight, gain=gain)
458
+ elif isinstance(module, nn.Embedding):
459
+ shape = module.weight.data.shape
460
+ gain = 1e-4 * math.sqrt(max(shape[0], shape[1]))
461
+ nn.init.orthogonal_(module.weight, gain=gain)
462
+ elif isinstance(module, nn.LayerNorm):
463
+ module.weight.data.fill_(1.0)
464
+ module.bias.data.zero_()
465
+
466
+
467
+ @dataclass
468
+ @auto_docstring(
469
+ custom_intro="""
470
+ Class for the RWKV model outputs.
471
+ """
472
+ )
473
+ class RwkvOutput(ModelOutput):
474
+ r"""
475
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
476
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
477
+ avoid providing the old `input_ids`.
478
+ """
479
+
480
+ last_hidden_state: Optional[torch.FloatTensor] = None
481
+ state: Optional[list[torch.FloatTensor]] = None
482
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
483
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
484
+
485
+
486
+ @dataclass
487
+ @auto_docstring(
488
+ custom_intro="""
489
+ Base class for causal language model (or autoregressive) outputs.
490
+ """
491
+ )
492
+ class RwkvCausalLMOutput(ModelOutput):
493
+ r"""
494
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
495
+ Language modeling loss (for next-token prediction).
496
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
497
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
498
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
499
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
500
+ avoid providing the old `input_ids`.
501
+ """
502
+
503
+ loss: Optional[torch.FloatTensor] = None
504
+ logits: Optional[torch.FloatTensor] = None
505
+ state: Optional[list[torch.FloatTensor]] = None
506
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
507
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
508
+
509
+
510
+ @auto_docstring
511
+ class RwkvModel(RwkvPreTrainedModel):
512
+ def __init__(self, config):
513
+ super().__init__(config)
514
+
515
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
516
+ self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
517
+ self.ln_out = nn.LayerNorm(config.hidden_size)
518
+
519
+ self.layers_are_rescaled = False
520
+
521
+ self.gradient_checkpointing = False
522
+
523
+ # Initialize weights and apply final processing
524
+ self.post_init()
525
+
526
+ def get_input_embeddings(self):
527
+ return self.embeddings
528
+
529
+ def set_input_embeddings(self, new_embeddings):
530
+ self.embeddings = new_embeddings
531
+
532
+ @auto_docstring
533
+ def forward(
534
+ self,
535
+ input_ids: Optional[torch.LongTensor] = None,
536
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
537
+ inputs_embeds: Optional[torch.FloatTensor] = None,
538
+ state: Optional[list[torch.FloatTensor]] = None,
539
+ use_cache: Optional[bool] = None,
540
+ output_attentions: Optional[bool] = None,
541
+ output_hidden_states: Optional[bool] = None,
542
+ return_dict: Optional[bool] = None,
543
+ ) -> Union[tuple, RwkvOutput]:
544
+ r"""
545
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
546
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
547
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
548
+ sequence tokens in the vocabulary.
549
+
550
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
551
+ `input_ids`.
552
+
553
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
554
+ [`PreTrainedTokenizer.__call__`] for details.
555
+
556
+ [What are input IDs?](../glossary#input-ids)
557
+ state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
558
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
559
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
560
+ use_cache (`bool`, *optional*):
561
+ If set to `True`, the last state is returned and can be used to quickly generate the next logits.
562
+ """
563
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
564
+ output_hidden_states = (
565
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
566
+ )
567
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
568
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
569
+
570
+ if attention_mask is not None:
571
+ logger.warning_once("`attention_mask` was passed, but it is unused in this model.")
572
+
573
+ if self.training == self.layers_are_rescaled:
574
+ self._rescale_layers()
575
+
576
+ if input_ids is not None and inputs_embeds is not None:
577
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
578
+ elif input_ids is None and inputs_embeds is None:
579
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
580
+
581
+ if inputs_embeds is None:
582
+ inputs_embeds = self.embeddings(input_ids)
583
+
584
+ if use_cache and state is None:
585
+ shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
586
+ state = [
587
+ torch.zeros(
588
+ *shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device
589
+ )
590
+ for i in range(5)
591
+ ]
592
+ state[4] -= 1e30
593
+
594
+ if self.gradient_checkpointing and self.training:
595
+ if use_cache:
596
+ logger.warning_once(
597
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
598
+ )
599
+ use_cache = False
600
+
601
+ hidden_states = inputs_embeds
602
+
603
+ all_self_attentions = () if output_attentions else None
604
+ all_hidden_states = () if output_hidden_states else None
605
+ for idx, block in enumerate(self.blocks):
606
+ hidden_states, state, attentions = block(
607
+ hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
608
+ )
609
+
610
+ if (
611
+ self.layers_are_rescaled
612
+ and self.config.rescale_every > 0
613
+ and (idx + 1) % self.config.rescale_every == 0
614
+ ):
615
+ hidden_states = hidden_states / 2
616
+
617
+ if output_hidden_states:
618
+ all_hidden_states = all_hidden_states + (hidden_states,)
619
+
620
+ if output_attentions:
621
+ all_self_attentions = all_self_attentions + (attentions,)
622
+
623
+ hidden_states = self.ln_out(hidden_states)
624
+
625
+ if output_hidden_states:
626
+ all_hidden_states = all_hidden_states + (hidden_states,)
627
+
628
+ if not return_dict:
629
+ return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None)
630
+
631
+ return RwkvOutput(
632
+ last_hidden_state=hidden_states,
633
+ state=state,
634
+ hidden_states=all_hidden_states,
635
+ attentions=all_self_attentions,
636
+ )
637
+
638
+ def _rescale_layers(self):
639
+ # Layers should be rescaled for inference only.
640
+ if self.layers_are_rescaled == (not self.training):
641
+ return
642
+ if self.config.rescale_every > 0:
643
+ with torch.no_grad():
644
+ for block_id, block in enumerate(self.blocks):
645
+ if self.training:
646
+ block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
647
+ block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
648
+ else:
649
+ # Deal with quantization statistics
650
+ if hasattr(block.attention.output.weight, "SCB"):
651
+ block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
652
+ block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
653
+ elif hasattr(block.attention.output.weight, "quant_state"):
654
+ self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
655
+ self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
656
+ else:
657
+ block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
658
+ block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
659
+
660
+ self.layers_are_rescaled = not self.training
661
+
662
+ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
663
+ r"""
664
+ Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
665
+ be quantized again.
666
+ """
667
+ if not is_bitsandbytes_available():
668
+ raise ImportError("Please install bitsandbytes to use this method.")
669
+ import bitsandbytes as bnb
670
+
671
+ dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
672
+
673
+ dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
674
+
675
+ # re-quantize the model:
676
+ # we need to put it first on CPU then back to the device
677
+ # this will create an overhead :/
678
+ # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
679
+ # bugs with bnb
680
+ quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
681
+ setattr(target_layer, "weight", quant_weight)
682
+
683
+
684
+ @auto_docstring(
685
+ custom_intro="""
686
+ The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
687
+ embeddings).
688
+ """
689
+ )
690
+ class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin):
691
+ _tied_weights_keys = ["head.weight"]
692
+
693
+ def __init__(self, config):
694
+ super().__init__(config)
695
+ self.rwkv = RwkvModel(config)
696
+ self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
697
+
698
+ # Initialize weights and apply final processing
699
+ self.post_init()
700
+
701
+ def get_output_embeddings(self):
702
+ return self.head
703
+
704
+ def set_output_embeddings(self, new_embeddings):
705
+ self.head = new_embeddings
706
+
707
+ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs):
708
+ # Overwritten -- this model uses `state`, but doesn't have a cache (`past_key_values`)
709
+
710
+ # only last token for inputs_ids if the state is passed along.
711
+ if state is not None:
712
+ input_ids = input_ids[:, -1].unsqueeze(-1)
713
+
714
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
715
+ if inputs_embeds is not None and state is None:
716
+ model_inputs = {"inputs_embeds": inputs_embeds}
717
+ else:
718
+ model_inputs = {"input_ids": input_ids}
719
+
720
+ model_inputs["state"] = state
721
+ model_inputs["use_cache"] = use_cache
722
+ return model_inputs
723
+
724
+ @auto_docstring
725
+ def forward(
726
+ self,
727
+ input_ids: Optional[torch.LongTensor] = None,
728
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
729
+ inputs_embeds: Optional[torch.FloatTensor] = None,
730
+ state: Optional[list[torch.FloatTensor]] = None,
731
+ labels: Optional[torch.LongTensor] = None,
732
+ use_cache: Optional[bool] = None,
733
+ output_attentions: Optional[bool] = None,
734
+ output_hidden_states: Optional[bool] = None,
735
+ return_dict: Optional[bool] = None,
736
+ **kwargs,
737
+ ) -> Union[tuple, RwkvCausalLMOutput]:
738
+ r"""
739
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
740
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
741
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
742
+ sequence tokens in the vocabulary.
743
+
744
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
745
+ `input_ids`.
746
+
747
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
748
+ [`PreTrainedTokenizer.__call__`] for details.
749
+
750
+ [What are input IDs?](../glossary#input-ids)
751
+ state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
752
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
753
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
754
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
755
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
756
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
757
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
758
+ use_cache (`bool`, *optional*):
759
+ If set to `True`, the last state is returned and can be used to quickly generate the next logits.
760
+ """
761
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
762
+
763
+ rwkv_outputs = self.rwkv(
764
+ input_ids,
765
+ inputs_embeds=inputs_embeds,
766
+ state=state,
767
+ use_cache=use_cache,
768
+ output_attentions=output_attentions,
769
+ output_hidden_states=output_hidden_states,
770
+ return_dict=return_dict,
771
+ )
772
+ hidden_states = rwkv_outputs[0]
773
+
774
+ logits = self.head(hidden_states)
775
+
776
+ loss = None
777
+ if labels is not None:
778
+ loss = self.loss_function(
779
+ logits,
780
+ labels,
781
+ vocab_size=self.config.vocab_size,
782
+ **kwargs,
783
+ )
784
+
785
+ if not return_dict:
786
+ output = (logits,) + rwkv_outputs[1:]
787
+ return ((loss,) + output) if loss is not None else output
788
+
789
+ return RwkvCausalLMOutput(
790
+ loss=loss,
791
+ logits=logits,
792
+ state=rwkv_outputs.state,
793
+ hidden_states=rwkv_outputs.hidden_states,
794
+ attentions=rwkv_outputs.attentions,
795
+ )
796
+
797
+
798
+ __all__ = ["RwkvForCausalLM", "RwkvModel", "RwkvPreTrainedModel"]
phivenv/Lib/site-packages/transformers/models/sam/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_sam import *
22
+ from .image_processing_sam import *
23
+ from .image_processing_sam_fast import *
24
+ from .modeling_sam import *
25
+ from .modeling_tf_sam import *
26
+ from .processing_sam import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (641 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/configuration_sam.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/image_processing_sam.cpython-39.pyc ADDED
Binary file (49.9 kB). View file
 
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/image_processing_sam_fast.cpython-39.pyc ADDED
Binary file (26.4 kB). View file
 
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/modeling_sam.cpython-39.pyc ADDED
Binary file (46.7 kB). View file
 
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/modeling_tf_sam.cpython-39.pyc ADDED
Binary file (55.4 kB). View file
 
phivenv/Lib/site-packages/transformers/models/sam/__pycache__/processing_sam.cpython-39.pyc ADDED
Binary file (8.46 kB). View file
 
phivenv/Lib/site-packages/transformers/models/sam/configuration_sam.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SAM model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class SamPromptEncoderConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`SamPromptEncoder`]. The [`SamPromptEncoder`]
27
+ module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield
28
+ a similar configuration to that of the SAM-vit-h
29
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ hidden_size (`int`, *optional*, defaults to 256):
36
+ Dimensionality of the hidden states.
37
+ image_size (`int`, *optional*, defaults to 1024):
38
+ The expected output resolution of the image.
39
+ patch_size (`int`, *optional*, defaults to 16):
40
+ The size (resolution) of each patch.
41
+ mask_input_channels (`int`, *optional*, defaults to 16):
42
+ The number of channels to be fed to the `MaskDecoder` module.
43
+ num_point_embeddings (`int`, *optional*, defaults to 4):
44
+ The number of point embeddings to be used.
45
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
46
+ The non-linear activation function in the encoder and pooler.
47
+ """
48
+
49
+ base_config_key = "prompt_encoder_config"
50
+
51
+ def __init__(
52
+ self,
53
+ hidden_size=256,
54
+ image_size=1024,
55
+ patch_size=16,
56
+ mask_input_channels=16,
57
+ num_point_embeddings=4,
58
+ hidden_act="gelu",
59
+ layer_norm_eps=1e-6,
60
+ **kwargs,
61
+ ):
62
+ super().__init__(**kwargs)
63
+ self.hidden_size = hidden_size
64
+ self.image_size = image_size
65
+ self.patch_size = patch_size
66
+ self.image_embedding_size = image_size // patch_size
67
+ self.mask_input_channels = mask_input_channels
68
+ self.num_point_embeddings = num_point_embeddings
69
+ self.hidden_act = hidden_act
70
+ self.layer_norm_eps = layer_norm_eps
71
+
72
+
73
+ class SamMaskDecoderConfig(PretrainedConfig):
74
+ r"""
75
+ This is the configuration class to store the configuration of a [`SamMaskDecoder`]. It is used to instantiate a SAM
76
+ mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults
77
+ will yield a similar configuration to that of the SAM-vit-h
78
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
79
+
80
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
81
+ documentation from [`PretrainedConfig`] for more information.
82
+
83
+ Args:
84
+ hidden_size (`int`, *optional*, defaults to 256):
85
+ Dimensionality of the hidden states.
86
+ hidden_act (`str`, *optional*, defaults to `"relu"`):
87
+ The non-linear activation function used inside the `SamMaskDecoder` module.
88
+ mlp_dim (`int`, *optional*, defaults to 2048):
89
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
90
+ num_hidden_layers (`int`, *optional*, defaults to 2):
91
+ Number of hidden layers in the Transformer encoder.
92
+ num_attention_heads (`int`, *optional*, defaults to 8):
93
+ Number of attention heads for each attention layer in the Transformer encoder.
94
+ attention_downsample_rate (`int`, *optional*, defaults to 2):
95
+ The downsampling rate of the attention layer.
96
+ num_multimask_outputs (`int`, *optional*, defaults to 3):
97
+ The number of outputs from the `SamMaskDecoder` module. In the Segment Anything paper, this is set to 3.
98
+ iou_head_depth (`int`, *optional*, defaults to 3):
99
+ The number of layers in the IoU head module.
100
+ iou_head_hidden_dim (`int`, *optional*, defaults to 256):
101
+ The dimensionality of the hidden states in the IoU head module.
102
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
103
+ The epsilon used by the layer normalization layers.
104
+
105
+ """
106
+
107
+ base_config_key = "mask_decoder_config"
108
+
109
+ def __init__(
110
+ self,
111
+ hidden_size=256,
112
+ hidden_act="relu",
113
+ mlp_dim=2048,
114
+ num_hidden_layers=2,
115
+ num_attention_heads=8,
116
+ attention_downsample_rate=2,
117
+ num_multimask_outputs=3,
118
+ iou_head_depth=3,
119
+ iou_head_hidden_dim=256,
120
+ layer_norm_eps=1e-6,
121
+ **kwargs,
122
+ ):
123
+ super().__init__(**kwargs)
124
+ self.hidden_size = hidden_size
125
+ self.hidden_act = hidden_act
126
+ self.mlp_dim = mlp_dim
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.num_attention_heads = num_attention_heads
129
+ self.attention_downsample_rate = attention_downsample_rate
130
+ self.num_multimask_outputs = num_multimask_outputs
131
+ self.iou_head_depth = iou_head_depth
132
+ self.iou_head_hidden_dim = iou_head_hidden_dim
133
+ self.layer_norm_eps = layer_norm_eps
134
+
135
+
136
+ class SamVisionConfig(PretrainedConfig):
137
+ r"""
138
+ This is the configuration class to store the configuration of a [`SamVisionModel`]. It is used to instantiate a SAM
139
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
140
+ defaults will yield a similar configuration to that of the SAM ViT-h
141
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
142
+
143
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
144
+ documentation from [`PretrainedConfig`] for more information.
145
+
146
+ Args:
147
+ hidden_size (`int`, *optional*, defaults to 768):
148
+ Dimensionality of the encoder layers and the pooler layer.
149
+ output_channels (`int`, *optional*, defaults to 256):
150
+ Dimensionality of the output channels in the Patch Encoder.
151
+ num_hidden_layers (`int`, *optional*, defaults to 12):
152
+ Number of hidden layers in the Transformer encoder.
153
+ num_attention_heads (`int`, *optional*, defaults to 12):
154
+ Number of attention heads for each attention layer in the Transformer encoder.
155
+ num_channels (`int`, *optional*, defaults to 3):
156
+ Number of channels in the input image.
157
+ image_size (`int`, *optional*, defaults to 1024):
158
+ Expected resolution. Target size of the resized input image.
159
+ patch_size (`int`, *optional*, defaults to 16):
160
+ Size of the patches to be extracted from the input image.
161
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
162
+ The non-linear activation function (function or string)
163
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
164
+ The epsilon used by the layer normalization layers.
165
+ attention_dropout (`float`, *optional*, defaults to 0.0):
166
+ The dropout ratio for the attention probabilities.
167
+ initializer_range (`float`, *optional*, defaults to 1e-10):
168
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
169
+ qkv_bias (`bool`, *optional*, defaults to `True`):
170
+ Whether to add a bias to query, key, value projections.
171
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
172
+ Ratio of mlp hidden dim to embedding dim.
173
+ use_abs_pos (`bool`, *optional*, defaults to `True`):
174
+ Whether to use absolute position embedding.
175
+ use_rel_pos (`bool`, *optional*, defaults to `True`):
176
+ Whether to use relative position embedding.
177
+ window_size (`int`, *optional*, defaults to 14):
178
+ Window size for relative position.
179
+ global_attn_indexes (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
180
+ The indexes of the global attention layers.
181
+ num_pos_feats (`int`, *optional*, defaults to 128):
182
+ The dimensionality of the position embedding.
183
+ mlp_dim (`int`, *optional*):
184
+ The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
185
+ hidden_size`.
186
+
187
+ Example:
188
+
189
+ ```python
190
+ >>> from transformers import (
191
+ ... SamVisionConfig,
192
+ ... SamVisionModel,
193
+ ... )
194
+
195
+ >>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
196
+ >>> configuration = SamVisionConfig()
197
+
198
+ >>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
199
+ >>> model = SamVisionModel(configuration)
200
+
201
+ >>> # Accessing the model configuration
202
+ >>> configuration = model.config
203
+ ```"""
204
+
205
+ base_config_key = "vision_config"
206
+ model_type = "sam_vision_model"
207
+
208
+ def __init__(
209
+ self,
210
+ hidden_size=768,
211
+ output_channels=256,
212
+ num_hidden_layers=12,
213
+ num_attention_heads=12,
214
+ num_channels=3,
215
+ image_size=1024,
216
+ patch_size=16,
217
+ hidden_act="gelu",
218
+ layer_norm_eps=1e-06,
219
+ attention_dropout=0.0,
220
+ initializer_range=1e-10,
221
+ qkv_bias=True,
222
+ mlp_ratio=4.0,
223
+ use_abs_pos=True,
224
+ use_rel_pos=True,
225
+ window_size=14,
226
+ global_attn_indexes=[2, 5, 8, 11],
227
+ num_pos_feats=128,
228
+ mlp_dim=None,
229
+ **kwargs,
230
+ ):
231
+ super().__init__(**kwargs)
232
+
233
+ self.hidden_size = hidden_size
234
+ self.output_channels = output_channels
235
+ self.num_hidden_layers = num_hidden_layers
236
+ self.num_attention_heads = num_attention_heads
237
+ self.num_channels = num_channels
238
+ self.image_size = image_size
239
+ self.patch_size = patch_size
240
+ self.hidden_act = hidden_act
241
+ self.layer_norm_eps = layer_norm_eps
242
+ self.attention_dropout = attention_dropout
243
+ self.initializer_range = initializer_range
244
+ self.qkv_bias = qkv_bias
245
+ self.mlp_ratio = mlp_ratio
246
+ self.use_abs_pos = use_abs_pos
247
+ self.use_rel_pos = use_rel_pos
248
+ self.window_size = window_size
249
+ self.global_attn_indexes = global_attn_indexes
250
+ self.num_pos_feats = num_pos_feats
251
+ self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
252
+
253
+
254
+ class SamConfig(PretrainedConfig):
255
+ r"""
256
+ [`SamConfig`] is the configuration class to store the configuration of a [`SamModel`]. It is used to instantiate a
257
+ SAM model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder
258
+ configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
259
+ SAM-ViT-H [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
260
+
261
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
262
+ documentation from [`PretrainedConfig`] for more information.
263
+
264
+ Args:
265
+ vision_config (Union[`dict`, `SamVisionConfig`], *optional*):
266
+ Dictionary of configuration options used to initialize [`SamVisionConfig`].
267
+ prompt_encoder_config (Union[`dict`, `SamPromptEncoderConfig`], *optional*):
268
+ Dictionary of configuration options used to initialize [`SamPromptEncoderConfig`].
269
+ mask_decoder_config (Union[`dict`, `SamMaskDecoderConfig`], *optional*):
270
+ Dictionary of configuration options used to initialize [`SamMaskDecoderConfig`].
271
+
272
+ kwargs (*optional*):
273
+ Dictionary of keyword arguments.
274
+
275
+ Example:
276
+
277
+ ```python
278
+ >>> from transformers import (
279
+ ... SamVisionConfig,
280
+ ... SamPromptEncoderConfig,
281
+ ... SamMaskDecoderConfig,
282
+ ... SamModel,
283
+ ... )
284
+
285
+ >>> # Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
286
+ >>> configuration = SamConfig()
287
+
288
+ >>> # Initializing a SamModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
289
+ >>> model = SamModel(configuration)
290
+
291
+ >>> # Accessing the model configuration
292
+ >>> configuration = model.config
293
+
294
+ >>> # We can also initialize a SamConfig from a SamVisionConfig, SamPromptEncoderConfig, and SamMaskDecoderConfig
295
+
296
+ >>> # Initializing SAM vision, SAM Q-Former and language model configurations
297
+ >>> vision_config = SamVisionConfig()
298
+ >>> prompt_encoder_config = SamPromptEncoderConfig()
299
+ >>> mask_decoder_config = SamMaskDecoderConfig()
300
+
301
+ >>> config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
302
+ ```"""
303
+
304
+ model_type = "sam"
305
+ sub_configs = {
306
+ "prompt_encoder_config": SamPromptEncoderConfig,
307
+ "mask_decoder_config": SamMaskDecoderConfig,
308
+ "vision_config": SamVisionConfig,
309
+ }
310
+
311
+ def __init__(
312
+ self,
313
+ vision_config=None,
314
+ prompt_encoder_config=None,
315
+ mask_decoder_config=None,
316
+ initializer_range=0.02,
317
+ **kwargs,
318
+ ):
319
+ super().__init__(**kwargs)
320
+ vision_config = vision_config if vision_config is not None else {}
321
+ prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
322
+ mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
323
+
324
+ if isinstance(vision_config, SamVisionConfig):
325
+ vision_config = vision_config.to_dict()
326
+ if isinstance(prompt_encoder_config, SamPromptEncoderConfig):
327
+ prompt_encoder_config = prompt_encoder_config.to_dict()
328
+ if isinstance(mask_decoder_config, SamMaskDecoderConfig):
329
+ mask_decoder_config = mask_decoder_config.to_dict()
330
+
331
+ self.vision_config = SamVisionConfig(**vision_config)
332
+ self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
333
+ self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config)
334
+ self.initializer_range = initializer_range
335
+
336
+
337
+ __all__ = ["SamConfig", "SamMaskDecoderConfig", "SamPromptEncoderConfig", "SamVisionConfig"]
phivenv/Lib/site-packages/transformers/models/sam/image_processing_sam.py ADDED
@@ -0,0 +1,1499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for SAM."""
16
+
17
+ import math
18
+ from copy import deepcopy
19
+ from itertools import product
20
+ from typing import Any, Optional, Union
21
+
22
+ import numpy as np
23
+
24
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
25
+ from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
26
+ from ...image_utils import (
27
+ IMAGENET_DEFAULT_MEAN,
28
+ IMAGENET_DEFAULT_STD,
29
+ ChannelDimension,
30
+ ImageInput,
31
+ PILImageResampling,
32
+ get_image_size,
33
+ infer_channel_dimension_format,
34
+ is_scaled_image,
35
+ make_list_of_images,
36
+ to_numpy_array,
37
+ valid_images,
38
+ validate_preprocess_arguments,
39
+ )
40
+ from ...utils import (
41
+ TensorType,
42
+ filter_out_non_signature_kwargs,
43
+ is_tf_available,
44
+ is_torch_available,
45
+ is_torchvision_available,
46
+ logging,
47
+ requires_backends,
48
+ )
49
+
50
+
51
+ if is_torch_available():
52
+ import torch
53
+ import torch.nn.functional as F
54
+
55
+ if is_torchvision_available():
56
+ from torchvision.ops.boxes import batched_nms
57
+
58
+ if is_tf_available():
59
+ import tensorflow as tf
60
+ from tensorflow.experimental import numpy as tnp
61
+
62
+ from ...tf_utils import flatten, shape_list
63
+
64
+ logger = logging.get_logger(__name__)
65
+
66
+
67
+ class SamImageProcessor(BaseImageProcessor):
68
+ r"""
69
+ Constructs a SAM image processor.
70
+
71
+ Args:
72
+ do_resize (`bool`, *optional*, defaults to `True`):
73
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
74
+ `do_resize` parameter in the `preprocess` method.
75
+ size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
76
+ Size of the output image after resizing. Resizes the longest edge of the image to match
77
+ `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `size` parameter in the
78
+ `preprocess` method.
79
+ mask_size (`dict`, *optional*, defaults to `{"longest_edge": 256}`):
80
+ Size of the output segmentation map after resizing. Resizes the longest edge of the image to match
81
+ `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `mask_size` parameter
82
+ in the `preprocess` method.
83
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
84
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
85
+ `preprocess` method.
86
+ do_rescale (`bool`, *optional*, defaults to `True`):
87
+ Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
88
+ `do_rescale` parameter in the `preprocess` method.
89
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
90
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
91
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
92
+ do_normalize (`bool`, *optional*, defaults to `True`):
93
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
94
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
95
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
96
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
97
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
98
+ overridden by the `image_mean` parameter in the `preprocess` method.
99
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
100
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
101
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
102
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
103
+ do_pad (`bool`, *optional*, defaults to `True`):
104
+ Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
105
+ `preprocess` method.
106
+ pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
107
+ Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
108
+ method.
109
+ mask_pad_size (`dict`, *optional*, defaults to `{"height": 256, "width": 256}`):
110
+ Size of the output segmentation map after padding. Can be overridden by the `mask_pad_size` parameter in
111
+ the `preprocess` method.
112
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
113
+ Whether to convert the image to RGB.
114
+ """
115
+
116
+ model_input_names = ["pixel_values"]
117
+
118
+ def __init__(
119
+ self,
120
+ do_resize: bool = True,
121
+ size: Optional[dict[str, int]] = None,
122
+ mask_size: Optional[dict[str, int]] = None,
123
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
124
+ do_rescale: bool = True,
125
+ rescale_factor: Union[int, float] = 1 / 255,
126
+ do_normalize: bool = True,
127
+ image_mean: Optional[Union[float, list[float]]] = None,
128
+ image_std: Optional[Union[float, list[float]]] = None,
129
+ do_pad: bool = True,
130
+ pad_size: Optional[int] = None,
131
+ mask_pad_size: Optional[int] = None,
132
+ do_convert_rgb: bool = True,
133
+ **kwargs,
134
+ ) -> None:
135
+ super().__init__(**kwargs)
136
+ size = size if size is not None else {"longest_edge": 1024}
137
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
138
+
139
+ pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024}
140
+ pad_size = get_size_dict(pad_size, default_to_square=True)
141
+
142
+ mask_size = mask_size if mask_size is not None else {"longest_edge": 256}
143
+ mask_size = (
144
+ get_size_dict(max_size=mask_size, default_to_square=False)
145
+ if not isinstance(mask_size, dict)
146
+ else mask_size
147
+ )
148
+
149
+ mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256}
150
+ mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
151
+
152
+ self.do_resize = do_resize
153
+ self.size = size
154
+ self.mask_size = mask_size
155
+ self.resample = resample
156
+ self.do_rescale = do_rescale
157
+ self.rescale_factor = rescale_factor
158
+ self.do_normalize = do_normalize
159
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
160
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
161
+ self.do_pad = do_pad
162
+ self.pad_size = pad_size
163
+ self.mask_pad_size = mask_pad_size
164
+ self.do_convert_rgb = do_convert_rgb
165
+
166
+ def pad_image(
167
+ self,
168
+ image: np.ndarray,
169
+ pad_size: dict[str, int],
170
+ data_format: Optional[Union[str, ChannelDimension]] = None,
171
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
172
+ **kwargs,
173
+ ) -> np.ndarray:
174
+ """
175
+ Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom.
176
+
177
+ Args:
178
+ image (`np.ndarray`):
179
+ Image to pad.
180
+ pad_size (`dict[str, int]`):
181
+ Size of the output image after padding.
182
+ data_format (`str` or `ChannelDimension`, *optional*):
183
+ The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
184
+ `data_format` of the `image` will be used.
185
+ input_data_format (`str` or `ChannelDimension`, *optional*):
186
+ The channel dimension format of the input image. If not provided, it will be inferred.
187
+ """
188
+ output_height, output_width = pad_size["height"], pad_size["width"]
189
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
190
+
191
+ pad_width = output_width - input_width
192
+ pad_height = output_height - input_height
193
+
194
+ padded_image = pad(
195
+ image,
196
+ ((0, pad_height), (0, pad_width)),
197
+ data_format=data_format,
198
+ input_data_format=input_data_format,
199
+ **kwargs,
200
+ )
201
+ return padded_image
202
+
203
+ def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int):
204
+ """
205
+ Compute the output size given input size and target long side length.
206
+ """
207
+ oldh, oldw = old_shape
208
+ scale = longest_edge * 1.0 / max(oldh, oldw)
209
+ newh, neww = oldh * scale, oldw * scale
210
+ newh = int(newh + 0.5)
211
+ neww = int(neww + 0.5)
212
+ return (newh, neww)
213
+
214
+ def resize(
215
+ self,
216
+ image: np.ndarray,
217
+ size: dict[str, int],
218
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
219
+ data_format: Optional[Union[str, ChannelDimension]] = None,
220
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
221
+ **kwargs,
222
+ ) -> np.ndarray:
223
+ """
224
+ Resize an image to `(size["height"], size["width"])`.
225
+
226
+ Args:
227
+ image (`np.ndarray`):
228
+ Image to resize.
229
+ size (`dict[str, int]`):
230
+ Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
231
+ edge of the image will be resized to the specified size, while the other edge will be resized to
232
+ maintain the aspect ratio.
233
+ resample:
234
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
235
+ data_format (`ChannelDimension` or `str`, *optional*):
236
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
237
+ image is used. Can be one of:
238
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
239
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
240
+ input_data_format (`ChannelDimension` or `str`, *optional*):
241
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
242
+ from the input image. Can be one of:
243
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
244
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
245
+
246
+ Returns:
247
+ `np.ndarray`: The resized image.
248
+ """
249
+ size = get_size_dict(size)
250
+ if "longest_edge" not in size:
251
+ raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
252
+ input_size = get_image_size(image, channel_dim=input_data_format)
253
+ output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
254
+ return resize(
255
+ image,
256
+ size=(output_height, output_width),
257
+ resample=resample,
258
+ data_format=data_format,
259
+ input_data_format=input_data_format,
260
+ **kwargs,
261
+ )
262
+
263
+ def _preprocess(
264
+ self,
265
+ image: ImageInput,
266
+ do_resize: bool,
267
+ do_rescale: bool,
268
+ do_normalize: bool,
269
+ size: Optional[dict[str, int]] = None,
270
+ resample: PILImageResampling = None,
271
+ rescale_factor: Optional[float] = None,
272
+ image_mean: Optional[Union[float, list[float]]] = None,
273
+ image_std: Optional[Union[float, list[float]]] = None,
274
+ do_pad: Optional[bool] = None,
275
+ pad_size: Optional[dict[str, int]] = None,
276
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
277
+ ):
278
+ if do_resize:
279
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
280
+ reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
281
+
282
+ if do_rescale:
283
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
284
+
285
+ if do_normalize:
286
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
287
+
288
+ if do_pad:
289
+ image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
290
+
291
+ return image, reshaped_input_size
292
+
293
+ def _preprocess_image(
294
+ self,
295
+ image: ImageInput,
296
+ do_resize: Optional[bool] = None,
297
+ size: Optional[dict[str, int]] = None,
298
+ resample: PILImageResampling = None,
299
+ do_rescale: Optional[bool] = None,
300
+ rescale_factor: Optional[float] = None,
301
+ do_normalize: Optional[bool] = None,
302
+ image_mean: Optional[Union[float, list[float]]] = None,
303
+ image_std: Optional[Union[float, list[float]]] = None,
304
+ do_pad: Optional[bool] = None,
305
+ pad_size: Optional[dict[str, int]] = None,
306
+ do_convert_rgb: Optional[bool] = None,
307
+ data_format: Optional[Union[str, ChannelDimension]] = None,
308
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
309
+ ) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
310
+ # PIL RGBA images are converted to RGB
311
+ if do_convert_rgb:
312
+ image = convert_to_rgb(image)
313
+
314
+ # All transformations expect numpy arrays.
315
+ image = to_numpy_array(image)
316
+
317
+ if do_rescale and is_scaled_image(image):
318
+ logger.warning_once(
319
+ "It looks like you are trying to rescale already rescaled images. If the input"
320
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
321
+ )
322
+
323
+ if input_data_format is None:
324
+ input_data_format = infer_channel_dimension_format(image)
325
+
326
+ original_size = get_image_size(image, channel_dim=input_data_format)
327
+
328
+ image, reshaped_input_size = self._preprocess(
329
+ image=image,
330
+ do_resize=do_resize,
331
+ size=size,
332
+ resample=resample,
333
+ do_rescale=do_rescale,
334
+ rescale_factor=rescale_factor,
335
+ do_normalize=do_normalize,
336
+ image_mean=image_mean,
337
+ image_std=image_std,
338
+ do_pad=do_pad,
339
+ pad_size=pad_size,
340
+ input_data_format=input_data_format,
341
+ )
342
+
343
+ if data_format is not None:
344
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
345
+
346
+ return image, original_size, reshaped_input_size
347
+
348
+ def _preprocess_mask(
349
+ self,
350
+ segmentation_map: ImageInput,
351
+ do_resize: Optional[bool] = None,
352
+ mask_size: Optional[dict[str, int]] = None,
353
+ do_pad: Optional[bool] = None,
354
+ mask_pad_size: Optional[dict[str, int]] = None,
355
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
356
+ ) -> np.ndarray:
357
+ segmentation_map = to_numpy_array(segmentation_map)
358
+
359
+ # Add channel dimension if missing - needed for certain transformations
360
+ if segmentation_map.ndim == 2:
361
+ added_channel_dim = True
362
+ segmentation_map = segmentation_map[None, ...]
363
+ input_data_format = ChannelDimension.FIRST
364
+ else:
365
+ added_channel_dim = False
366
+ if input_data_format is None:
367
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
368
+
369
+ original_size = get_image_size(segmentation_map, channel_dim=input_data_format)
370
+
371
+ segmentation_map, _ = self._preprocess(
372
+ image=segmentation_map,
373
+ do_resize=do_resize,
374
+ size=mask_size,
375
+ resample=PILImageResampling.NEAREST,
376
+ do_rescale=False,
377
+ do_normalize=False,
378
+ do_pad=do_pad,
379
+ pad_size=mask_pad_size,
380
+ input_data_format=input_data_format,
381
+ )
382
+
383
+ # Remove extra channel dimension if added for processing
384
+ if added_channel_dim:
385
+ segmentation_map = segmentation_map.squeeze(0)
386
+ segmentation_map = segmentation_map.astype(np.int64)
387
+
388
+ return segmentation_map, original_size
389
+
390
+ def __call__(self, images, segmentation_maps=None, **kwargs):
391
+ # Overrides the `__call__` method of the `BaseImageProcessor` class such that the images and segmentation maps can both
392
+ # be passed in as positional arguments.
393
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
394
+
395
+ @filter_out_non_signature_kwargs()
396
+ def preprocess(
397
+ self,
398
+ images: ImageInput,
399
+ segmentation_maps: Optional[ImageInput] = None,
400
+ do_resize: Optional[bool] = None,
401
+ size: Optional[dict[str, int]] = None,
402
+ mask_size: Optional[dict[str, int]] = None,
403
+ resample: Optional["PILImageResampling"] = None,
404
+ do_rescale: Optional[bool] = None,
405
+ rescale_factor: Optional[Union[int, float]] = None,
406
+ do_normalize: Optional[bool] = None,
407
+ image_mean: Optional[Union[float, list[float]]] = None,
408
+ image_std: Optional[Union[float, list[float]]] = None,
409
+ do_pad: Optional[bool] = None,
410
+ pad_size: Optional[dict[str, int]] = None,
411
+ mask_pad_size: Optional[dict[str, int]] = None,
412
+ do_convert_rgb: Optional[bool] = None,
413
+ return_tensors: Optional[Union[str, TensorType]] = None,
414
+ data_format: ChannelDimension = ChannelDimension.FIRST,
415
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
416
+ ):
417
+ """
418
+ Preprocess an image or batch of images.
419
+
420
+ Args:
421
+ images (`ImageInput`):
422
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
423
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
424
+ segmentation_maps (`ImageInput`, *optional*):
425
+ Segmentation map to preprocess.
426
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
427
+ Whether to resize the image.
428
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
429
+ Controls the size of the image after `resize`. The longest edge of the image is resized to
430
+ `size["longest_edge"]` whilst preserving the aspect ratio.
431
+ mask_size (`dict[str, int]`, *optional*, defaults to `self.mask_size`):
432
+ Controls the size of the segmentation map after `resize`. The longest edge of the image is resized to
433
+ `size["longest_edge"]` whilst preserving the aspect ratio.
434
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
435
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
436
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
437
+ Whether to rescale the image pixel values by rescaling factor.
438
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
439
+ Rescale factor to apply to the image pixel values.
440
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
441
+ Whether to normalize the image.
442
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
443
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
444
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
445
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
446
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
447
+ Whether to pad the image.
448
+ pad_size (`dict[str, int]`, *optional*, defaults to `self.pad_size`):
449
+ Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
450
+ `pad_size["width"]` if `do_pad` is set to `True`.
451
+ mask_pad_size (`dict[str, int]`, *optional*, defaults to `self.mask_pad_size`):
452
+ Controls the size of the padding applied to the segmentation map. The image is padded to
453
+ `mask_pad_size["height"]` and `mask_pad_size["width"]` if `do_pad` is set to `True`.
454
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
455
+ Whether to convert the image to RGB.
456
+ return_tensors (`str` or `TensorType`, *optional*):
457
+ The type of tensors to return. Can be one of:
458
+ - Unset: Return a list of `np.ndarray`.
459
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
460
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
461
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
462
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
463
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
464
+ The channel dimension format for the output image. Can be one of:
465
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
466
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
467
+ - Unset: Use the channel dimension format of the input image.
468
+ input_data_format (`ChannelDimension` or `str`, *optional*):
469
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
470
+ from the input image. Can be one of:
471
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
472
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
473
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
474
+ """
475
+ do_resize = do_resize if do_resize is not None else self.do_resize
476
+ size = size if size is not None else self.size
477
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
478
+ mask_size = mask_size if mask_size is not None else self.mask_size
479
+ mask_size = (
480
+ get_size_dict(max_size=mask_size, default_to_square=False)
481
+ if not isinstance(mask_size, dict)
482
+ else mask_size
483
+ )
484
+ resample = resample if resample is not None else self.resample
485
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
486
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
487
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
488
+ image_mean = image_mean if image_mean is not None else self.image_mean
489
+ image_std = image_std if image_std is not None else self.image_std
490
+ do_pad = do_pad if do_pad is not None else self.do_pad
491
+ pad_size = pad_size if pad_size is not None else self.pad_size
492
+ pad_size = get_size_dict(pad_size, default_to_square=True)
493
+ mask_pad_size = mask_pad_size if mask_pad_size is not None else self.mask_pad_size
494
+ mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
495
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
496
+
497
+ images = make_list_of_images(images)
498
+
499
+ if not valid_images(images):
500
+ raise ValueError(
501
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
502
+ "torch.Tensor, tf.Tensor or jax.ndarray."
503
+ )
504
+
505
+ if segmentation_maps is not None:
506
+ segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
507
+
508
+ if not valid_images(segmentation_maps):
509
+ raise ValueError(
510
+ "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
511
+ "torch.Tensor, tf.Tensor or jax.ndarray."
512
+ )
513
+ validate_preprocess_arguments(
514
+ do_rescale=do_rescale,
515
+ rescale_factor=rescale_factor,
516
+ do_normalize=do_normalize,
517
+ image_mean=image_mean,
518
+ image_std=image_std,
519
+ do_pad=do_pad,
520
+ size_divisibility=pad_size, # Here _preprocess needs do_pad and pad_size.
521
+ do_resize=do_resize,
522
+ size=size,
523
+ resample=resample,
524
+ )
525
+
526
+ images, original_sizes, reshaped_input_sizes = zip(
527
+ *(
528
+ self._preprocess_image(
529
+ image=img,
530
+ do_resize=do_resize,
531
+ size=size,
532
+ resample=resample,
533
+ do_rescale=do_rescale,
534
+ rescale_factor=rescale_factor,
535
+ do_normalize=do_normalize,
536
+ image_mean=image_mean,
537
+ image_std=image_std,
538
+ do_pad=do_pad,
539
+ pad_size=pad_size,
540
+ do_convert_rgb=do_convert_rgb,
541
+ data_format=data_format,
542
+ input_data_format=input_data_format,
543
+ )
544
+ for img in images
545
+ )
546
+ )
547
+
548
+ data = {
549
+ "pixel_values": images,
550
+ "original_sizes": original_sizes,
551
+ "reshaped_input_sizes": reshaped_input_sizes,
552
+ }
553
+
554
+ if segmentation_maps is not None:
555
+ segmentation_maps, original_mask_sizes = zip(
556
+ *(
557
+ self._preprocess_mask(
558
+ segmentation_map=mask,
559
+ do_resize=do_resize,
560
+ mask_size=mask_size,
561
+ do_pad=do_pad,
562
+ mask_pad_size=mask_pad_size,
563
+ input_data_format=input_data_format,
564
+ )
565
+ for mask in segmentation_maps
566
+ )
567
+ )
568
+
569
+ # masks should start out the same size as input images
570
+ assert all(
571
+ original_im_size == original_mask_size
572
+ for original_im_size, original_mask_size in zip(original_sizes, original_mask_sizes)
573
+ ), "Segmentation maps should be the same size as input images."
574
+
575
+ data["labels"] = segmentation_maps
576
+
577
+ return BatchFeature(data=data, tensor_type=return_tensors)
578
+
579
+ def post_process_masks(
580
+ self,
581
+ masks,
582
+ original_sizes,
583
+ reshaped_input_sizes,
584
+ mask_threshold=0.0,
585
+ binarize=True,
586
+ pad_size=None,
587
+ return_tensors="pt",
588
+ ):
589
+ """
590
+ Remove padding and upscale masks to the original image size.
591
+
592
+ Args:
593
+ masks (`Union[list[torch.Tensor], list[np.ndarray], list[tf.Tensor]]`):
594
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
595
+ original_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`):
596
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
597
+ width) format.
598
+ reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, list[tuple[int,int]]]`):
599
+ The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
600
+ mask_threshold (`float`, *optional*, defaults to 0.0):
601
+ The threshold to use for binarizing the masks.
602
+ binarize (`bool`, *optional*, defaults to `True`):
603
+ Whether to binarize the masks.
604
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
605
+ The target size the images were padded to before being passed to the model. If None, the target size is
606
+ assumed to be the processor's `pad_size`.
607
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
608
+ If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
609
+ Returns:
610
+ (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
611
+ (height, width) is given by original_size.
612
+ """
613
+ if return_tensors == "pt":
614
+ return self._post_process_masks_pt(
615
+ masks=masks,
616
+ original_sizes=original_sizes,
617
+ reshaped_input_sizes=reshaped_input_sizes,
618
+ mask_threshold=mask_threshold,
619
+ binarize=binarize,
620
+ pad_size=pad_size,
621
+ )
622
+ elif return_tensors == "tf":
623
+ return self._post_process_masks_tf(
624
+ masks=masks,
625
+ original_sizes=original_sizes,
626
+ reshaped_input_sizes=reshaped_input_sizes,
627
+ mask_threshold=mask_threshold,
628
+ binarize=binarize,
629
+ pad_size=pad_size,
630
+ )
631
+ else:
632
+ raise ValueError("return_tensors must be either 'pt' or 'tf'")
633
+
634
+ def _post_process_masks_pt(
635
+ self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
636
+ ):
637
+ """
638
+ Remove padding and upscale masks to the original image size.
639
+
640
+ Args:
641
+ masks (`Union[list[torch.Tensor], list[np.ndarray]]`):
642
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
643
+ original_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`):
644
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
645
+ width) format.
646
+ reshaped_input_sizes (`Union[torch.Tensor, list[tuple[int,int]]]`):
647
+ The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
648
+ mask_threshold (`float`, *optional*, defaults to 0.0):
649
+ The threshold to use for binarizing the masks.
650
+ binarize (`bool`, *optional*, defaults to `True`):
651
+ Whether to binarize the masks.
652
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
653
+ The target size the images were padded to before being passed to the model. If None, the target size is
654
+ assumed to be the processor's `pad_size`.
655
+ Returns:
656
+ (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
657
+ is given by original_size.
658
+ """
659
+ requires_backends(self, ["torch"])
660
+ pad_size = self.pad_size if pad_size is None else pad_size
661
+ target_image_size = (pad_size["height"], pad_size["width"])
662
+ if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
663
+ original_sizes = original_sizes.tolist()
664
+ if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
665
+ reshaped_input_sizes = reshaped_input_sizes.tolist()
666
+ output_masks = []
667
+ for i, original_size in enumerate(original_sizes):
668
+ if isinstance(masks[i], np.ndarray):
669
+ masks[i] = torch.from_numpy(masks[i])
670
+ elif not isinstance(masks[i], torch.Tensor):
671
+ raise TypeError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
672
+ interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
673
+ interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
674
+ interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
675
+ if binarize:
676
+ interpolated_mask = interpolated_mask > mask_threshold
677
+ output_masks.append(interpolated_mask)
678
+
679
+ return output_masks
680
+
681
+ def _post_process_masks_tf(
682
+ self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
683
+ ):
684
+ """
685
+ Remove padding and upscale masks to the original image size.
686
+
687
+ Args:
688
+ masks (`tf.Tensor`):
689
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
690
+ original_sizes (`tf.Tensor`):
691
+ The original size of the images before resizing for input to the model, in (height, width) format.
692
+ reshaped_input_sizes (`tf.Tensor`):
693
+ The size of the image input to the model, in (height, width) format. Used to remove padding.
694
+ mask_threshold (`float`, *optional*, defaults to 0.0):
695
+ The threshold to use for binarizing the masks.
696
+ binarize (`bool`, *optional*, defaults to `True`):
697
+ Whether to binarize the masks.
698
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
699
+ The target size the images were padded to before being passed to the model. If None, the target size is
700
+ assumed to be the processor's `pad_size`.
701
+ Returns:
702
+ (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
703
+ given by original_size.
704
+ """
705
+ requires_backends(self, ["tf"])
706
+ pad_size = self.pad_size if pad_size is None else pad_size
707
+ target_image_size = (pad_size["height"], pad_size["width"])
708
+
709
+ output_masks = []
710
+ for i, original_size in enumerate(original_sizes):
711
+ # tf.image expects NHWC, we transpose the NCHW inputs for it
712
+ mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
713
+ interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
714
+ interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
715
+ interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
716
+ if binarize:
717
+ interpolated_mask = interpolated_mask > mask_threshold
718
+ # And then we transpose them back at the end
719
+ output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
720
+
721
+ return output_masks
722
+
723
+ def post_process_for_mask_generation(
724
+ self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
725
+ ):
726
+ """
727
+ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
728
+
729
+ Args:
730
+ all_masks (`Union[list[torch.Tensor], list[tf.Tensor]]`):
731
+ List of all predicted segmentation masks
732
+ all_scores (`Union[list[torch.Tensor], list[tf.Tensor]]`):
733
+ List of all predicted iou scores
734
+ all_boxes (`Union[list[torch.Tensor], list[tf.Tensor]]`):
735
+ List of all bounding boxes of the predicted masks
736
+ crops_nms_thresh (`float`):
737
+ Threshold for NMS (Non Maximum Suppression) algorithm.
738
+ return_tensors (`str`, *optional*, defaults to `pt`):
739
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
740
+ """
741
+ if return_tensors == "pt":
742
+ return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
743
+ elif return_tensors == "tf":
744
+ return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
745
+
746
+ def generate_crop_boxes(
747
+ self,
748
+ image,
749
+ target_size,
750
+ crop_n_layers: int = 0,
751
+ overlap_ratio: float = 512 / 1500,
752
+ points_per_crop: Optional[int] = 32,
753
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
754
+ device: Optional["torch.device"] = None,
755
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
756
+ return_tensors: str = "pt",
757
+ ):
758
+ """
759
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
760
+
761
+ Args:
762
+ image (`np.array`):
763
+ Input original image
764
+ target_size (`int`):
765
+ Target size of the resized image
766
+ crop_n_layers (`int`, *optional*, defaults to 0):
767
+ If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
768
+ each layer has 2**i_layer number of image crops.
769
+ overlap_ratio (`float`, *optional*, defaults to 512/1500):
770
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
771
+ the image length. Later layers with more crops scale down this overlap.
772
+ points_per_crop (`int`, *optional*, defaults to 32):
773
+ Number of points to sample from each crop.
774
+ crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
775
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
776
+ device (`torch.device`, *optional*, defaults to None):
777
+ Device to use for the computation. If None, cpu will be used.
778
+ input_data_format (`str` or `ChannelDimension`, *optional*):
779
+ The channel dimension format of the input image. If not provided, it will be inferred.
780
+ return_tensors (`str`, *optional*, defaults to `pt`):
781
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
782
+ """
783
+ crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
784
+ image,
785
+ target_size,
786
+ crop_n_layers,
787
+ overlap_ratio,
788
+ points_per_crop,
789
+ crop_n_points_downscale_factor,
790
+ input_data_format,
791
+ )
792
+ if return_tensors == "pt":
793
+ if device is None:
794
+ device = torch.device("cpu")
795
+ crop_boxes = torch.tensor(crop_boxes, device=device)
796
+ points_per_crop = torch.tensor(points_per_crop, device=device)
797
+ # cropped_images stays as np
798
+ input_labels = torch.tensor(input_labels, device=device)
799
+
800
+ elif return_tensors == "tf":
801
+ if device is not None:
802
+ raise ValueError("device is not a supported argument when return_tensors is tf!")
803
+ crop_boxes = tf.convert_to_tensor(crop_boxes)
804
+ points_per_crop = tf.convert_to_tensor(points_per_crop)
805
+ # cropped_images stays as np
806
+ input_labels = tf.convert_to_tensor(input_labels)
807
+ else:
808
+ raise ValueError("return_tensors must be either 'pt' or 'tf'.")
809
+ return crop_boxes, points_per_crop, cropped_images, input_labels
810
+
811
+ def filter_masks(
812
+ self,
813
+ masks,
814
+ iou_scores,
815
+ original_size,
816
+ cropped_box_image,
817
+ pred_iou_thresh=0.88,
818
+ stability_score_thresh=0.95,
819
+ mask_threshold=0,
820
+ stability_score_offset=1,
821
+ return_tensors="pt",
822
+ ):
823
+ """
824
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
825
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
826
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
827
+ bounding boxes and pad the predicted masks if necessary.
828
+
829
+ Args:
830
+ masks (`Union[torch.Tensor, tf.Tensor]`):
831
+ Input masks.
832
+ iou_scores (`Union[torch.Tensor, tf.Tensor]`):
833
+ List of IoU scores.
834
+ original_size (`tuple[int,int]`):
835
+ Size of the original image.
836
+ cropped_box_image (`np.array`):
837
+ The cropped image.
838
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
839
+ The threshold for the iou scores.
840
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
841
+ The threshold for the stability score.
842
+ mask_threshold (`float`, *optional*, defaults to 0):
843
+ The threshold for the predicted masks.
844
+ stability_score_offset (`float`, *optional*, defaults to 1):
845
+ The offset for the stability score used in the `_compute_stability_score` method.
846
+ return_tensors (`str`, *optional*, defaults to `pt`):
847
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
848
+ """
849
+ if return_tensors == "pt":
850
+ return self._filter_masks_pt(
851
+ masks=masks,
852
+ iou_scores=iou_scores,
853
+ original_size=original_size,
854
+ cropped_box_image=cropped_box_image,
855
+ pred_iou_thresh=pred_iou_thresh,
856
+ stability_score_thresh=stability_score_thresh,
857
+ mask_threshold=mask_threshold,
858
+ stability_score_offset=stability_score_offset,
859
+ )
860
+ elif return_tensors == "tf":
861
+ return self._filter_masks_tf(
862
+ masks=masks,
863
+ iou_scores=iou_scores,
864
+ original_size=original_size,
865
+ cropped_box_image=cropped_box_image,
866
+ pred_iou_thresh=pred_iou_thresh,
867
+ stability_score_thresh=stability_score_thresh,
868
+ mask_threshold=mask_threshold,
869
+ stability_score_offset=stability_score_offset,
870
+ )
871
+
872
+ def _filter_masks_pt(
873
+ self,
874
+ masks,
875
+ iou_scores,
876
+ original_size,
877
+ cropped_box_image,
878
+ pred_iou_thresh=0.88,
879
+ stability_score_thresh=0.95,
880
+ mask_threshold=0,
881
+ stability_score_offset=1,
882
+ ):
883
+ """
884
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
885
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
886
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
887
+ bounding boxes and pad the predicted masks if necessary.
888
+
889
+ Args:
890
+ masks (`torch.Tensor`):
891
+ Input masks.
892
+ iou_scores (`torch.Tensor`):
893
+ List of IoU scores.
894
+ original_size (`tuple[int,int]`):
895
+ Size of the original image.
896
+ cropped_box_image (`np.array`):
897
+ The cropped image.
898
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
899
+ The threshold for the iou scores.
900
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
901
+ The threshold for the stability score.
902
+ mask_threshold (`float`, *optional*, defaults to 0):
903
+ The threshold for the predicted masks.
904
+ stability_score_offset (`float`, *optional*, defaults to 1):
905
+ The offset for the stability score used in the `_compute_stability_score` method.
906
+
907
+ """
908
+ requires_backends(self, ["torch"])
909
+ original_height, original_width = original_size
910
+ iou_scores = iou_scores.flatten(0, 1)
911
+ masks = masks.flatten(0, 1)
912
+
913
+ if masks.shape[0] != iou_scores.shape[0]:
914
+ raise ValueError("masks and iou_scores must have the same batch size.")
915
+
916
+ if masks.device != iou_scores.device:
917
+ iou_scores = iou_scores.to(masks.device)
918
+
919
+ batch_size = masks.shape[0]
920
+
921
+ keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
922
+
923
+ if pred_iou_thresh > 0.0:
924
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
925
+
926
+ # compute stability score
927
+ if stability_score_thresh > 0.0:
928
+ stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
929
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
930
+
931
+ scores = iou_scores[keep_mask]
932
+ masks = masks[keep_mask]
933
+
934
+ # binarize masks
935
+ masks = masks > mask_threshold
936
+ converted_boxes = _batched_mask_to_box(masks)
937
+
938
+ keep_mask = ~_is_box_near_crop_edge(
939
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
940
+ )
941
+
942
+ scores = scores[keep_mask]
943
+ masks = masks[keep_mask]
944
+ converted_boxes = converted_boxes[keep_mask]
945
+
946
+ masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
947
+ # conversion to rle is necessary to run non-maximum suppression
948
+ masks = _mask_to_rle_pytorch(masks)
949
+
950
+ return masks, scores, converted_boxes
951
+
952
+ def _filter_masks_tf(
953
+ self,
954
+ masks,
955
+ iou_scores,
956
+ original_size,
957
+ cropped_box_image,
958
+ pred_iou_thresh=0.88,
959
+ stability_score_thresh=0.95,
960
+ mask_threshold=0,
961
+ stability_score_offset=1,
962
+ ):
963
+ """
964
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
965
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
966
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
967
+ bounding boxes and pad the predicted masks if necessary.
968
+
969
+ Args:
970
+ masks (`tf.Tensor`):
971
+ Input masks.
972
+ iou_scores (`tf.Tensor`):
973
+ List of IoU scores.
974
+ original_size (`tuple[int,int]`):
975
+ Size of the original image.
976
+ cropped_box_image (`np.array`):
977
+ The cropped image.
978
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
979
+ The threshold for the iou scores.
980
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
981
+ The threshold for the stability score.
982
+ mask_threshold (`float`, *optional*, defaults to 0):
983
+ The threshold for the predicted masks.
984
+ stability_score_offset (`float`, *optional*, defaults to 1):
985
+ The offset for the stability score used in the `_compute_stability_score` method.
986
+
987
+ """
988
+ requires_backends(self, ["tf"])
989
+ original_height, original_width = original_size
990
+ iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
991
+ masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
992
+
993
+ if masks.shape[0] != iou_scores.shape[0]:
994
+ raise ValueError("masks and iou_scores must have the same batch size.")
995
+
996
+ batch_size = masks.shape[0]
997
+
998
+ keep_mask = tf.ones(batch_size, dtype=tf.bool)
999
+
1000
+ if pred_iou_thresh > 0.0:
1001
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
1002
+
1003
+ # compute stability score
1004
+ if stability_score_thresh > 0.0:
1005
+ stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
1006
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
1007
+
1008
+ scores = iou_scores[keep_mask]
1009
+ masks = masks[keep_mask]
1010
+
1011
+ # binarize masks
1012
+ masks = masks > mask_threshold
1013
+ converted_boxes = _batched_mask_to_box_tf(masks)
1014
+
1015
+ keep_mask = ~_is_box_near_crop_edge_tf(
1016
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
1017
+ )
1018
+
1019
+ scores = scores[keep_mask]
1020
+ masks = masks[keep_mask]
1021
+ converted_boxes = converted_boxes[keep_mask]
1022
+
1023
+ masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
1024
+ # conversion to rle is necessary to run non-maximum suppression
1025
+ masks = _mask_to_rle_tf(masks)
1026
+
1027
+ return masks, scores, converted_boxes
1028
+
1029
+
1030
+ def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
1031
+ # One mask is always contained inside the other.
1032
+ # Save memory by preventing unnecessary cast to torch.int64
1033
+ intersections = (
1034
+ (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
1035
+ )
1036
+ unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
1037
+ stability_scores = intersections / unions
1038
+ return stability_scores
1039
+
1040
+
1041
+ def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
1042
+ # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
1043
+ # we get the right division results.
1044
+ intersections = tf.count_nonzero(
1045
+ masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
1046
+ )
1047
+ unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
1048
+ stability_scores = intersections / unions
1049
+ return stability_scores
1050
+
1051
+
1052
+ def _build_point_grid(n_per_side: int) -> np.ndarray:
1053
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
1054
+ offset = 1 / (2 * n_per_side)
1055
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
1056
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
1057
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
1058
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
1059
+ return points
1060
+
1061
+
1062
+ def _normalize_coordinates(
1063
+ target_size: int, coords: np.ndarray, original_size: tuple[int, int], is_bounding_box=False
1064
+ ) -> np.ndarray:
1065
+ """
1066
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
1067
+ format.
1068
+ """
1069
+ old_height, old_width = original_size
1070
+
1071
+ scale = target_size * 1.0 / max(old_height, old_width)
1072
+ new_height, new_width = old_height * scale, old_width * scale
1073
+ new_width = int(new_width + 0.5)
1074
+ new_height = int(new_height + 0.5)
1075
+
1076
+ coords = deepcopy(coords).astype(float)
1077
+
1078
+ if is_bounding_box:
1079
+ coords = coords.reshape(-1, 2, 2)
1080
+
1081
+ coords[..., 0] = coords[..., 0] * (new_width / old_width)
1082
+ coords[..., 1] = coords[..., 1] * (new_height / old_height)
1083
+
1084
+ if is_bounding_box:
1085
+ coords = coords.reshape(-1, 4)
1086
+
1087
+ return coords
1088
+
1089
+
1090
+ def _generate_crop_boxes(
1091
+ image,
1092
+ target_size: int, # Is it tuple here?
1093
+ crop_n_layers: int = 0,
1094
+ overlap_ratio: float = 512 / 1500,
1095
+ points_per_crop: Optional[int] = 32,
1096
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
1097
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1098
+ ) -> tuple[list[list[int]], list[int]]:
1099
+ """
1100
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
1101
+
1102
+ Args:
1103
+ image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
1104
+ Image to generate crops for.
1105
+ target_size (`int`):
1106
+ Size of the smallest crop.
1107
+ crop_n_layers (`int`, *optional*):
1108
+ If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
1109
+ to run, where each layer has 2**i_layer number of image crops.
1110
+ overlap_ratio (`int`, *optional*):
1111
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
1112
+ image length. Later layers with more crops scale down this overlap.
1113
+ points_per_crop (`int`, *optional*):
1114
+ Number of points to sample per crop.
1115
+ crop_n_points_downscale_factor (`int`, *optional*):
1116
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
1117
+ input_data_format (`str` or `ChannelDimension`, *optional*):
1118
+ The channel dimension format of the input image. If not provided, it will be inferred.
1119
+ """
1120
+
1121
+ if isinstance(image, list):
1122
+ raise TypeError("Only one image is allowed for crop generation.")
1123
+ image = to_numpy_array(image)
1124
+ original_size = get_image_size(image, input_data_format)
1125
+
1126
+ points_grid = []
1127
+ for i in range(crop_n_layers + 1):
1128
+ n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
1129
+ points_grid.append(_build_point_grid(n_points))
1130
+
1131
+ crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
1132
+
1133
+ cropped_images, point_grid_per_crop = _generate_crop_images(
1134
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
1135
+ )
1136
+ crop_boxes = np.array(crop_boxes)
1137
+ crop_boxes = crop_boxes.astype(np.float32)
1138
+ points_per_crop = np.array([point_grid_per_crop])
1139
+ points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
1140
+
1141
+ input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
1142
+
1143
+ return crop_boxes, points_per_crop, cropped_images, input_labels
1144
+
1145
+
1146
+ def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
1147
+ """
1148
+ Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
1149
+ consists of the following required indices:
1150
+ - X: X coordinate of the top left of the bounding box
1151
+ - Y: Y coordinate of the top left of the bounding box
1152
+ - W: width of the bounding box
1153
+ - H: height of the bounding box
1154
+ """
1155
+ crop_boxes, layer_idxs = [], []
1156
+ im_height, im_width = original_size
1157
+ short_side = min(im_height, im_width)
1158
+
1159
+ # Original image
1160
+ crop_boxes.append([0, 0, im_width, im_height])
1161
+ layer_idxs.append(0)
1162
+ for i_layer in range(crop_n_layers):
1163
+ n_crops_per_side = 2 ** (i_layer + 1)
1164
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
1165
+
1166
+ crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
1167
+ crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
1168
+
1169
+ crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
1170
+ crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
1171
+
1172
+ for left, top in product(crop_box_x0, crop_box_y0):
1173
+ box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
1174
+ crop_boxes.append(box)
1175
+ layer_idxs.append(i_layer + 1)
1176
+
1177
+ return crop_boxes, layer_idxs
1178
+
1179
+
1180
+ def _generate_crop_images(
1181
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
1182
+ ):
1183
+ """
1184
+ Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
1185
+ also passed.
1186
+ """
1187
+ cropped_images = []
1188
+ total_points_per_crop = []
1189
+ for i, crop_box in enumerate(crop_boxes):
1190
+ left, top, right, bottom = crop_box
1191
+
1192
+ channel_dim = infer_channel_dimension_format(image, input_data_format)
1193
+ if channel_dim == ChannelDimension.LAST:
1194
+ cropped_im = image[top:bottom, left:right, :]
1195
+ else:
1196
+ cropped_im = image[:, top:bottom, left:right]
1197
+
1198
+ cropped_images.append(cropped_im)
1199
+
1200
+ cropped_im_size = get_image_size(cropped_im, channel_dim)
1201
+ points_scale = np.array(cropped_im_size)[None, ::-1]
1202
+
1203
+ points = points_grid[layer_idxs[i]] * points_scale
1204
+ normalized_points = _normalize_coordinates(target_size, points, original_size)
1205
+ total_points_per_crop.append(normalized_points)
1206
+
1207
+ return cropped_images, total_points_per_crop
1208
+
1209
+
1210
+ def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
1211
+ left, top, right, bottom = crop_box
1212
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
1213
+ return masks
1214
+ # Coordinate transform masks
1215
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
1216
+ pad = (left, pad_x - left, top, pad_y - top)
1217
+ return torch.nn.functional.pad(masks, pad, value=0)
1218
+
1219
+
1220
+ def _pad_masks_tf(masks, crop_box: list[int], orig_height: int, orig_width: int):
1221
+ left, top, right, bottom = crop_box
1222
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
1223
+ return masks
1224
+ # Coordinate transform masks
1225
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
1226
+ pad = (left, pad_x - left, top, pad_y - top)
1227
+ return tf.pad(masks, pad, constant_values=0)
1228
+
1229
+
1230
+ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
1231
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
1232
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
1233
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
1234
+
1235
+ left, top, _, _ = crop_box
1236
+ offset = torch.tensor([[left, top, left, top]], device=boxes.device)
1237
+ # Check if boxes has a channel dimension
1238
+ if len(boxes.shape) == 3:
1239
+ offset = offset.unsqueeze(1)
1240
+ boxes = (boxes + offset).float()
1241
+
1242
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
1243
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
1244
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
1245
+ return torch.any(near_crop_edge, dim=1)
1246
+
1247
+
1248
+ def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
1249
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
1250
+ crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
1251
+ orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
1252
+
1253
+ left, top, _, _ = crop_box
1254
+ offset = tf.convert_to_tensor([[left, top, left, top]])
1255
+ # Check if boxes has a channel dimension
1256
+ if len(boxes.shape) == 3:
1257
+ offset = tf.expand_dims(offset, 1)
1258
+ boxes = tf.cast(boxes + offset, tf.float32)
1259
+
1260
+ near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
1261
+ near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
1262
+ near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
1263
+ return tf.reduce_any(near_crop_edge, axis=1)
1264
+
1265
+
1266
+ def _batched_mask_to_box(masks: "torch.Tensor"):
1267
+ """
1268
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
1269
+ corresponds the following required indices:
1270
+ - LEFT: left hand side of the bounding box
1271
+ - TOP: top of the bounding box
1272
+ - RIGHT: right of the bounding box
1273
+ - BOTTOM: bottom of the bounding box
1274
+
1275
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
1276
+ is channel_1 x channel_2 x ... x 4.
1277
+
1278
+ Args:
1279
+ - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
1280
+ """
1281
+ # torch.max below raises an error on empty inputs, just skip in this case
1282
+
1283
+ if torch.numel(masks) == 0:
1284
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
1285
+
1286
+ # Normalize shape to Cxheightxwidth
1287
+ shape = masks.shape
1288
+ height, width = shape[-2:]
1289
+
1290
+ # Get top and bottom edges
1291
+ in_height, _ = torch.max(masks, dim=-1)
1292
+ in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
1293
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
1294
+ in_height_coords = in_height_coords + height * (~in_height)
1295
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
1296
+
1297
+ # Get left and right edges
1298
+ in_width, _ = torch.max(masks, dim=-2)
1299
+ in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
1300
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
1301
+ in_width_coords = in_width_coords + width * (~in_width)
1302
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
1303
+
1304
+ # If the mask is empty the right edge will be to the left of the left edge.
1305
+ # Replace these boxes with [0, 0, 0, 0]
1306
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
1307
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
1308
+ out = out * (~empty_filter).unsqueeze(-1)
1309
+
1310
+ # Return to original shape
1311
+ out = out.reshape(*shape[:-2], 4)
1312
+ return out
1313
+
1314
+
1315
+ def _batched_mask_to_box_tf(masks: "tf.Tensor"):
1316
+ """
1317
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
1318
+ corresponds the following required indices:
1319
+ - LEFT: left hand side of the bounding box
1320
+ - TOP: top of the bounding box
1321
+ - RIGHT: right of the bounding box
1322
+ - BOTTOM: bottom of the bounding box
1323
+
1324
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
1325
+ is channel_1 x channel_2 x ... x 4.
1326
+
1327
+ Args:
1328
+ - masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
1329
+ """
1330
+
1331
+ if tf.size(masks) == 0:
1332
+ return tf.zeros([*masks.shape[:-2], 4])
1333
+
1334
+ # Normalize shape to Cxheightxwidth
1335
+ shape = shape_list(masks)
1336
+ height, width = shape[-2:]
1337
+
1338
+ # Get top and bottom edges
1339
+ in_height = tf.reduce_max(masks, axis=-1)
1340
+ in_height_coords = in_height * tf.range(height)[None, :]
1341
+ bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
1342
+ in_height_coords = in_height_coords + height * (~in_height)
1343
+ top_edges = tf.reduce_min(in_height_coords, axis=-1)
1344
+
1345
+ # Get left and right edges
1346
+ in_width, _ = tf.reduce_max(masks, axis=-2)
1347
+ in_width_coords = in_width * tf.range(width)[None, :]
1348
+ right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
1349
+ in_width_coords = in_width_coords + width * (~in_width)
1350
+ left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
1351
+
1352
+ # If the mask is empty the right edge will be to the left of the left edge.
1353
+ # Replace these boxes with [0, 0, 0, 0]
1354
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
1355
+ out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
1356
+ out = out * tf.expand_dims(~empty_filter, -1)
1357
+
1358
+ # Return to original shape
1359
+ out = tf.reshape(out, *shape[:-2], 4)
1360
+ return out
1361
+
1362
+
1363
+ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
1364
+ """
1365
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
1366
+ """
1367
+ # Put in fortran order and flatten height and width
1368
+ batch_size, height, width = input_mask.shape
1369
+ input_mask = input_mask.permute(0, 2, 1).flatten(1)
1370
+
1371
+ # Compute change indices
1372
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
1373
+ change_indices = diff.nonzero()
1374
+
1375
+ # Encode run length
1376
+ out = []
1377
+ for i in range(batch_size):
1378
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
1379
+ if len(cur_idxs) == 0:
1380
+ # No changes => either all 0 or all 1
1381
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
1382
+ if input_mask[i, 0] == 0:
1383
+ out.append({"size": [height, width], "counts": [height * width]})
1384
+ else:
1385
+ out.append({"size": [height, width], "counts": [0, height * width]})
1386
+ continue
1387
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
1388
+ counts = [] if input_mask[i, 0] == 0 else [0]
1389
+ counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
1390
+ out.append({"size": [height, width], "counts": counts})
1391
+ return out
1392
+
1393
+
1394
+ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
1395
+ """
1396
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
1397
+ """
1398
+ # Put in fortran order and flatten height and width
1399
+ batch_size, height, width = input_mask.shape
1400
+ input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
1401
+
1402
+ # Compute change indices
1403
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
1404
+ change_indices = tf.where(diff)
1405
+
1406
+ # Encode run length
1407
+ out = []
1408
+ for i in range(batch_size):
1409
+ cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
1410
+ if len(cur_idxs) == 0:
1411
+ # No changes => either all 0 or all 1
1412
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
1413
+ if input_mask[i, 0] == 0:
1414
+ out.append({"size": [height, width], "counts": [height * width]})
1415
+ else:
1416
+ out.append({"size": [height, width], "counts": [0, height * width]})
1417
+ continue
1418
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
1419
+ counts = [] if input_mask[i, 0] == 0 else [0]
1420
+ counts += (
1421
+ [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
1422
+ )
1423
+ out.append({"size": [height, width], "counts": counts})
1424
+ return out
1425
+
1426
+
1427
+ def _rle_to_mask(rle: dict[str, Any]) -> np.ndarray:
1428
+ """Compute a binary mask from an uncompressed RLE."""
1429
+ height, width = rle["size"]
1430
+ mask = np.empty(height * width, dtype=bool)
1431
+ idx = 0
1432
+ parity = False
1433
+ for count in rle["counts"]:
1434
+ mask[idx : idx + count] = parity
1435
+ idx += count
1436
+ parity = not parity
1437
+ mask = mask.reshape(width, height)
1438
+ return mask.transpose() # Reshape to original shape
1439
+
1440
+
1441
+ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
1442
+ """
1443
+ Perform NMS (Non Maximum Suppression) on the outputs.
1444
+
1445
+ Args:
1446
+ rle_masks (`torch.Tensor`):
1447
+ binary masks in the RLE format
1448
+ iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
1449
+ iou_scores predicted by the model
1450
+ mask_boxes (`torch.Tensor`):
1451
+ The bounding boxes corresponding to segmentation masks
1452
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
1453
+ NMS threshold.
1454
+ """
1455
+ keep_by_nms = batched_nms(
1456
+ boxes=mask_boxes.float(),
1457
+ scores=iou_scores,
1458
+ idxs=torch.zeros(mask_boxes.shape[0]),
1459
+ iou_threshold=amg_crops_nms_thresh,
1460
+ )
1461
+
1462
+ iou_scores = iou_scores[keep_by_nms]
1463
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
1464
+ mask_boxes = mask_boxes[keep_by_nms]
1465
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
1466
+
1467
+ return masks, iou_scores, rle_masks, mask_boxes
1468
+
1469
+
1470
+ def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
1471
+ """
1472
+ Perform NMS (Non Maximum Suppression) on the outputs.
1473
+
1474
+ Args:
1475
+ rle_masks (`tf.Tensor`):
1476
+ binary masks in the RLE format
1477
+ iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
1478
+ iou_scores predicted by the model
1479
+ mask_boxes (`tf.Tensor`):
1480
+ The bounding boxes corresponding to segmentation masks
1481
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
1482
+ NMS threshold.
1483
+ """
1484
+ keep_by_nms = tf.image.combined_non_max_suppression(
1485
+ boxes=mask_boxes.float(),
1486
+ scores=iou_scores,
1487
+ idxs=torch.zeros(mask_boxes.shape[0]),
1488
+ iou_threshold=amg_crops_nms_thresh,
1489
+ )
1490
+
1491
+ iou_scores = iou_scores[keep_by_nms]
1492
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
1493
+ mask_boxes = mask_boxes[keep_by_nms]
1494
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
1495
+
1496
+ return masks, iou_scores, rle_masks, mask_boxes
1497
+
1498
+
1499
+ __all__ = ["SamImageProcessor"]
phivenv/Lib/site-packages/transformers/models/sam/image_processing_sam_fast.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for SAM."""
16
+
17
+ import math
18
+ from copy import deepcopy
19
+ from itertools import product
20
+ from typing import Any, Optional, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from ...image_processing_utils import BatchFeature, get_size_dict
26
+ from ...image_processing_utils_fast import (
27
+ BaseImageProcessorFast,
28
+ DefaultFastImageProcessorKwargs,
29
+ group_images_by_shape,
30
+ reorder_images,
31
+ )
32
+ from ...image_utils import (
33
+ IMAGENET_DEFAULT_MEAN,
34
+ IMAGENET_DEFAULT_STD,
35
+ ChannelDimension,
36
+ ImageInput,
37
+ PILImageResampling,
38
+ SizeDict,
39
+ pil_torch_interpolation_mapping,
40
+ )
41
+ from ...processing_utils import Unpack
42
+ from ...utils import (
43
+ TensorType,
44
+ auto_docstring,
45
+ is_torch_available,
46
+ is_torchvision_available,
47
+ is_torchvision_v2_available,
48
+ )
49
+
50
+
51
+ if is_torch_available():
52
+ import torch
53
+ from torch.nn import functional as F
54
+
55
+ if is_torchvision_v2_available():
56
+ from torchvision.ops.boxes import batched_nms
57
+ from torchvision.transforms.v2 import functional as F_t
58
+ elif is_torchvision_available():
59
+ from torchvision.ops.boxes import batched_nms
60
+ from torchvision.transforms import functional as F_t
61
+
62
+
63
+ class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
64
+ r"""
65
+ do_pad (`bool`, *optional*, defaults to `True`):
66
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
67
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
68
+ pad_size (`dict[str, int]`, *optional*):
69
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
70
+ provided for preprocessing.
71
+ mask_size (`dict[str, int]`, *optional*):
72
+ The size `{"longest_edge": int}` to resize the segmentation maps to.
73
+ mask_pad_size (`dict[str, int]`, *optional*):
74
+ The size `{"height": int, "width": int}` to pad the segmentation maps to. Must be larger than any segmentation
75
+ map size provided for preprocessing.
76
+ """
77
+
78
+ mask_size: Optional[dict[str, int]]
79
+ do_pad: Optional[bool]
80
+ pad_size: Optional[dict[str, int]]
81
+ mask_pad_size: Optional[dict[str, int]]
82
+
83
+
84
+ @auto_docstring
85
+ class SamImageProcessorFast(BaseImageProcessorFast):
86
+ resample = PILImageResampling.BILINEAR
87
+ image_mean = IMAGENET_DEFAULT_MEAN
88
+ image_std = IMAGENET_DEFAULT_STD
89
+ size = {"longest_edge": 1024}
90
+ mask_size = {"longest_edge": 256}
91
+ do_resize = True
92
+ do_rescale = True
93
+ do_normalize = True
94
+ do_convert_rgb = True
95
+
96
+ valid_kwargs = SamFastImageProcessorKwargs
97
+
98
+ do_pad = True
99
+ pad_size = {"height": 1024, "width": 1024}
100
+ mask_pad_size = {"height": 256, "width": 256}
101
+
102
+ def __init__(self, **kwargs: Unpack[SamFastImageProcessorKwargs]):
103
+ super().__init__(**kwargs)
104
+
105
+ def pad_image(self, images: "torch.Tensor", pad_size: SizeDict):
106
+ """Pad images to the specified size."""
107
+ output_height, output_width = pad_size.height, pad_size.width
108
+ input_height, input_width = images.shape[-2:]
109
+ pad_width = output_width - input_width
110
+ pad_height = output_height - input_height
111
+ padding = (0, 0, pad_width, pad_height)
112
+ return F_t.pad(images, padding)
113
+
114
+ def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int):
115
+ """
116
+ Compute the output size given input size and target long side length.
117
+ """
118
+ oldh, oldw = old_shape
119
+ scale = longest_edge * 1.0 / max(oldh, oldw)
120
+ newh, neww = oldh * scale, oldw * scale
121
+ newh = int(newh + 0.5)
122
+ neww = int(neww + 0.5)
123
+ return (newh, neww)
124
+
125
+ def resize(
126
+ self, image: "torch.Tensor", size: SizeDict, interpolation: Optional["F_t.InterpolationMode"], **kwargs
127
+ ) -> "torch.Tensor":
128
+ """
129
+ Resize an image to `(size["height"], size["width"])`.
130
+
131
+ Args:
132
+ image (`np.ndarray`):
133
+ Image to resize.
134
+ size (`dict[str, int]`):
135
+ Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
136
+ edge of the image will be resized to the specified size, while the other edge will be resized to
137
+ maintain the aspect ratio.
138
+ interpolation:
139
+ `F_t.InterpolationMode` filter to use when resizing the image e.g. `F_t.InterpolationMode.BICUBIC`.
140
+
141
+ Returns:
142
+ `torch.Tensor`: The resized image.
143
+ """
144
+ if not size.longest_edge:
145
+ raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
146
+ input_size = image.shape[-2:]
147
+ output_height, output_width = self._get_preprocess_shape(input_size, size.longest_edge)
148
+ return super().resize(
149
+ image, size=SizeDict(height=output_height, width=output_width), interpolation=interpolation, **kwargs
150
+ )
151
+
152
+ def _further_process_kwargs(
153
+ self,
154
+ size: Optional[SizeDict] = None,
155
+ pad_size: Optional[SizeDict] = None,
156
+ mask_size: Optional[SizeDict] = None,
157
+ mask_pad_size: Optional[SizeDict] = None,
158
+ default_to_square: Optional[bool] = None,
159
+ image_mean: Optional[Union[float, list[float]]] = None,
160
+ image_std: Optional[Union[float, list[float]]] = None,
161
+ data_format: Optional[ChannelDimension] = None,
162
+ **kwargs,
163
+ ) -> dict:
164
+ """
165
+ Update kwargs that need further processing before being validated
166
+ Can be overridden by subclasses to customize the processing of kwargs.
167
+ """
168
+ if kwargs is None:
169
+ kwargs = {}
170
+ if size is not None:
171
+ size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
172
+ if pad_size is not None:
173
+ pad_size = SizeDict(**get_size_dict(pad_size, param_name="pad_size"))
174
+ if mask_size is not None:
175
+ mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
176
+ if mask_pad_size is not None:
177
+ mask_pad_size = SizeDict(**get_size_dict(mask_pad_size, param_name="mask_pad_size"))
178
+ if isinstance(image_mean, list):
179
+ image_mean = tuple(image_mean)
180
+ if isinstance(image_std, list):
181
+ image_std = tuple(image_std)
182
+ if data_format is None:
183
+ data_format = ChannelDimension.FIRST
184
+
185
+ kwargs["size"] = size
186
+ kwargs["pad_size"] = pad_size
187
+ kwargs["mask_size"] = mask_size
188
+ kwargs["mask_pad_size"] = mask_pad_size
189
+ kwargs["image_mean"] = image_mean
190
+ kwargs["image_std"] = image_std
191
+ kwargs["data_format"] = data_format
192
+
193
+ # torch resize uses interpolation instead of resample
194
+ # Check if resample is an int before checking if it's an instance of PILImageResampling
195
+ # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
196
+ # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
197
+ resample = kwargs.pop("resample")
198
+ kwargs["interpolation"] = (
199
+ pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
200
+ )
201
+
202
+ return kwargs
203
+
204
+ @auto_docstring
205
+ def preprocess(
206
+ self,
207
+ images: ImageInput,
208
+ segmentation_maps: Optional[ImageInput] = None,
209
+ **kwargs: Unpack[SamFastImageProcessorKwargs],
210
+ ) -> BatchFeature:
211
+ r"""
212
+ segmentation_maps (`ImageInput`, *optional*):
213
+ The segmentation maps to preprocess.
214
+ """
215
+ return super().preprocess(images, segmentation_maps, **kwargs)
216
+
217
+ def _preprocess_image_like_inputs(
218
+ self,
219
+ images: ImageInput,
220
+ segmentation_maps: Optional[ImageInput],
221
+ do_convert_rgb: bool,
222
+ input_data_format: ChannelDimension,
223
+ device: Optional[Union[str, "torch.device"]] = None,
224
+ **kwargs: Unpack[SamFastImageProcessorKwargs],
225
+ ) -> BatchFeature:
226
+ """
227
+ Preprocess image-like inputs.
228
+ """
229
+ images = self._prepare_image_like_inputs(
230
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
231
+ )
232
+ original_sizes = [image.shape[-2:] for image in images]
233
+ images_kwargs = kwargs.copy()
234
+ pixel_values = self._preprocess(images, **images_kwargs)
235
+ reshaped_input_sizes = [image.shape[-2:] for image in images]
236
+ data = {
237
+ "pixel_values": pixel_values,
238
+ "original_sizes": original_sizes,
239
+ "reshaped_input_sizes": reshaped_input_sizes,
240
+ }
241
+
242
+ if segmentation_maps is not None:
243
+ processed_segmentation_maps = self._prepare_image_like_inputs(
244
+ images=segmentation_maps,
245
+ expected_ndims=2,
246
+ do_convert_rgb=False,
247
+ input_data_format=ChannelDimension.FIRST,
248
+ )
249
+
250
+ segmentation_maps_kwargs = kwargs.copy()
251
+ segmentation_maps_kwargs.update(
252
+ {
253
+ "do_normalize": False,
254
+ "do_rescale": False,
255
+ "interpolation": F_t.InterpolationMode.NEAREST_EXACT
256
+ if is_torchvision_v2_available()
257
+ else F_t.InterpolationMode.NEAREST,
258
+ "size": segmentation_maps_kwargs.pop("mask_size"),
259
+ "pad_size": segmentation_maps_kwargs.pop("mask_pad_size"),
260
+ }
261
+ )
262
+ processed_segmentation_maps = self._preprocess(
263
+ images=processed_segmentation_maps, **segmentation_maps_kwargs
264
+ )
265
+ data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
266
+
267
+ return BatchFeature(data=data, tensor_type=kwargs["return_tensors"])
268
+
269
+ def _preprocess(
270
+ self,
271
+ images: list["torch.Tensor"],
272
+ do_resize: bool,
273
+ size: SizeDict,
274
+ interpolation: Optional["F_t.InterpolationMode"],
275
+ do_rescale: bool,
276
+ rescale_factor: float,
277
+ do_normalize: bool,
278
+ image_mean: Optional[Union[float, list[float]]],
279
+ image_std: Optional[Union[float, list[float]]],
280
+ do_pad: bool,
281
+ pad_size: SizeDict,
282
+ disable_grouping: Optional[bool],
283
+ return_tensors: Optional[Union[str, TensorType]],
284
+ **kwargs,
285
+ ) -> Union["torch.Tensor", list["torch.Tensor"]]:
286
+ # Group images by size for batched resizing
287
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
288
+ resized_images_grouped = {}
289
+ for shape, stacked_images in grouped_images.items():
290
+ if do_resize:
291
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
292
+ resized_images_grouped[shape] = stacked_images
293
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
294
+
295
+ # Group images by size for further processing
296
+ # Needed in case do_resize is False, or resize returns images with different sizes
297
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
298
+ processed_images_grouped = {}
299
+ for shape, stacked_images in grouped_images.items():
300
+ # Fused rescale and normalize
301
+ stacked_images = self.rescale_and_normalize(
302
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
303
+ )
304
+ if do_pad:
305
+ stacked_images = self.pad_image(stacked_images, pad_size)
306
+ processed_images_grouped[shape] = stacked_images
307
+
308
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
309
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
310
+
311
+ return processed_images
312
+
313
+ def generate_crop_boxes(
314
+ self,
315
+ image: "torch.Tensor",
316
+ target_size,
317
+ crop_n_layers: int = 0,
318
+ overlap_ratio: float = 512 / 1500,
319
+ points_per_crop: Optional[int] = 32,
320
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
321
+ device: Optional["torch.device"] = None,
322
+ ):
323
+ """
324
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
325
+
326
+ Args:
327
+ image (`torch.Tensor`):
328
+ Input original image
329
+ target_size (`int`):
330
+ Target size of the resized image
331
+ crop_n_layers (`int`, *optional*, defaults to 0):
332
+ If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
333
+ each layer has 2**i_layer number of image crops.
334
+ overlap_ratio (`float`, *optional*, defaults to 512/1500):
335
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
336
+ the image length. Later layers with more crops scale down this overlap.
337
+ points_per_crop (`int`, *optional*, defaults to 32):
338
+ Number of points to sample from each crop.
339
+ crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
340
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
341
+ device (`torch.device`, *optional*, defaults to None):
342
+ Device to use for the computation. If None, cpu will be used.
343
+ input_data_format (`str` or `ChannelDimension`, *optional*):
344
+ The channel dimension format of the input image. If not provided, it will be inferred.
345
+ return_tensors (`str`, *optional*, defaults to `pt`):
346
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
347
+ """
348
+ image = self._process_image(image)
349
+ crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
350
+ image,
351
+ target_size,
352
+ crop_n_layers,
353
+ overlap_ratio,
354
+ points_per_crop,
355
+ crop_n_points_downscale_factor,
356
+ )
357
+ if device is None:
358
+ device = torch.device("cpu")
359
+ crop_boxes = crop_boxes.to(device)
360
+ points_per_crop = points_per_crop.to(device)
361
+ # cropped_images stays as torch.Tensor
362
+ input_labels = input_labels.to(device)
363
+
364
+ return crop_boxes, points_per_crop, cropped_images, input_labels
365
+
366
+ def filter_masks(
367
+ self,
368
+ masks,
369
+ iou_scores,
370
+ original_size,
371
+ cropped_box_image,
372
+ pred_iou_thresh=0.88,
373
+ stability_score_thresh=0.95,
374
+ mask_threshold=0,
375
+ stability_score_offset=1,
376
+ ):
377
+ """
378
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
379
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
380
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
381
+ bounding boxes and pad the predicted masks if necessary.
382
+
383
+ Args:
384
+ masks (`torch.Tensor`):
385
+ Input masks.
386
+ iou_scores (`torch.Tensor`):
387
+ List of IoU scores.
388
+ original_size (`tuple[int,int]`):
389
+ Size of the original image.
390
+ cropped_box_image (`torch.Tensor`):
391
+ The cropped image.
392
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
393
+ The threshold for the iou scores.
394
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
395
+ The threshold for the stability score.
396
+ mask_threshold (`float`, *optional*, defaults to 0):
397
+ The threshold for the predicted masks.
398
+ stability_score_offset (`float`, *optional*, defaults to 1):
399
+ The offset for the stability score used in the `_compute_stability_score` method.
400
+
401
+ """
402
+ original_height, original_width = original_size
403
+ iou_scores = iou_scores.flatten(0, 1)
404
+ masks = masks.flatten(0, 1)
405
+
406
+ if masks.shape[0] != iou_scores.shape[0]:
407
+ raise ValueError("masks and iou_scores must have the same batch size.")
408
+
409
+ if masks.device != iou_scores.device:
410
+ iou_scores = iou_scores.to(masks.device)
411
+
412
+ batch_size = masks.shape[0]
413
+
414
+ keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
415
+
416
+ if pred_iou_thresh > 0.0:
417
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
418
+
419
+ # compute stability score
420
+ if stability_score_thresh > 0.0:
421
+ stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
422
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
423
+
424
+ scores = iou_scores[keep_mask]
425
+ masks = masks[keep_mask]
426
+
427
+ # binarize masks
428
+ masks = masks > mask_threshold
429
+ converted_boxes = _batched_mask_to_box(masks)
430
+
431
+ keep_mask = ~_is_box_near_crop_edge(
432
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
433
+ )
434
+
435
+ scores = scores[keep_mask]
436
+ masks = masks[keep_mask]
437
+ converted_boxes = converted_boxes[keep_mask]
438
+
439
+ masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
440
+ # conversion to rle is necessary to run non-maximum suppression
441
+ masks = _mask_to_rle(masks)
442
+
443
+ return masks, scores, converted_boxes
444
+
445
+ def post_process_masks(
446
+ self,
447
+ masks,
448
+ original_sizes,
449
+ reshaped_input_sizes,
450
+ mask_threshold=0.0,
451
+ binarize=True,
452
+ pad_size=None,
453
+ ):
454
+ """
455
+ Remove padding and upscale masks to the original image size.
456
+
457
+ Args:
458
+ masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
459
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
460
+ original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
461
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
462
+ width) format.
463
+ reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
464
+ The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
465
+ mask_threshold (`float`, *optional*, defaults to 0.0):
466
+ The threshold to use for binarizing the masks.
467
+ binarize (`bool`, *optional*, defaults to `True`):
468
+ Whether to binarize the masks.
469
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
470
+ The target size the images were padded to before being passed to the model. If None, the target size is
471
+ assumed to be the processor's `pad_size`.
472
+ Returns:
473
+ (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
474
+ is given by original_size.
475
+ """
476
+ pad_size = self.size if pad_size is None else pad_size
477
+ target_image_size = (pad_size["height"], pad_size["width"])
478
+ if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
479
+ original_sizes = original_sizes.tolist()
480
+ if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
481
+ reshaped_input_sizes = reshaped_input_sizes.tolist()
482
+
483
+ output_masks = []
484
+ for i, original_size in enumerate(original_sizes):
485
+ if isinstance(masks[i], np.ndarray):
486
+ masks[i] = torch.from_numpy(masks[i])
487
+ elif not isinstance(masks[i], torch.Tensor):
488
+ raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
489
+ interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
490
+ interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
491
+ interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
492
+ if binarize:
493
+ interpolated_mask = interpolated_mask > mask_threshold
494
+ output_masks.append(interpolated_mask)
495
+
496
+ return output_masks
497
+
498
+ def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
499
+ """
500
+ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
501
+
502
+ Args:
503
+ all_masks (`torch.Tensor`):
504
+ List of all predicted segmentation masks
505
+ all_scores (`torch.Tensor`):
506
+ List of all predicted iou scores
507
+ all_boxes (`torch.Tensor`):
508
+ List of all bounding boxes of the predicted masks
509
+ crops_nms_thresh (`float`):
510
+ Threshold for NMS (Non Maximum Suppression) algorithm.
511
+ """
512
+ return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh)
513
+
514
+
515
+ def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
516
+ # One mask is always contained inside the other.
517
+ # Save memory by preventing unnecessary cast to torch.int64
518
+ intersections = (
519
+ (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
520
+ )
521
+ unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
522
+ stability_scores = intersections / unions
523
+ return stability_scores
524
+
525
+
526
+ def _mask_to_rle(input_mask: "torch.Tensor"):
527
+ """
528
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
529
+ """
530
+ # Put in fortran order and flatten height and width
531
+ batch_size, height, width = input_mask.shape
532
+ input_mask = input_mask.permute(0, 2, 1).flatten(1)
533
+
534
+ # Compute change indices
535
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
536
+ change_indices = diff.nonzero()
537
+
538
+ # Encode run length
539
+ out = []
540
+ for i in range(batch_size):
541
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
542
+ if len(cur_idxs) == 0:
543
+ # No changes => either all 0 or all 1
544
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
545
+ if input_mask[i, 0] == 0:
546
+ out.append({"size": [height, width], "counts": [height * width]})
547
+ else:
548
+ out.append({"size": [height, width], "counts": [0, height * width]})
549
+ continue
550
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
551
+ counts = [] if input_mask[i, 0] == 0 else [0]
552
+ counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
553
+ out.append({"size": [height, width], "counts": counts})
554
+ return out
555
+
556
+
557
+ def _batched_mask_to_box(masks: "torch.Tensor"):
558
+ """
559
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
560
+ corresponds the following required indices:
561
+ - LEFT: left hand side of the bounding box
562
+ - TOP: top of the bounding box
563
+ - RIGHT: right of the bounding box
564
+ - BOTTOM: bottom of the bounding box
565
+
566
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
567
+ is channel_1 x channel_2 x ... x 4.
568
+
569
+ Args:
570
+ - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
571
+ """
572
+ # torch.max below raises an error on empty inputs, just skip in this case
573
+
574
+ if torch.numel(masks) == 0:
575
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
576
+
577
+ # Normalize shape to Cxheightxwidth
578
+ shape = masks.shape
579
+ height, width = shape[-2:]
580
+
581
+ # Get top and bottom edges
582
+ in_height, _ = torch.max(masks, dim=-1)
583
+ in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
584
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
585
+ in_height_coords = in_height_coords + height * (~in_height)
586
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
587
+
588
+ # Get left and right edges
589
+ in_width, _ = torch.max(masks, dim=-2)
590
+ in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
591
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
592
+ in_width_coords = in_width_coords + width * (~in_width)
593
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
594
+
595
+ # If the mask is empty the right edge will be to the left of the left edge.
596
+ # Replace these boxes with [0, 0, 0, 0]
597
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
598
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
599
+ out = out * (~empty_filter).unsqueeze(-1)
600
+
601
+ # Return to original shape
602
+ out = out.reshape(*shape[:-2], 4)
603
+ return out
604
+
605
+
606
+ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
607
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
608
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
609
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
610
+
611
+ left, top, _, _ = crop_box
612
+ offset = torch.tensor([[left, top, left, top]], device=boxes.device)
613
+ # Check if boxes has a channel dimension
614
+ if len(boxes.shape) == 3:
615
+ offset = offset.unsqueeze(1)
616
+ boxes = (boxes + offset).float()
617
+
618
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
619
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
620
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
621
+ return torch.any(near_crop_edge, dim=1)
622
+
623
+
624
+ def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
625
+ left, top, right, bottom = crop_box
626
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
627
+ return masks
628
+ # Coordinate transform masks
629
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
630
+ pad = (left, pad_x - left, top, pad_y - top)
631
+ return torch.nn.functional.pad(masks, pad, value=0)
632
+
633
+
634
+ def _generate_crop_boxes(
635
+ image,
636
+ target_size: int, # Is it tuple here?
637
+ crop_n_layers: int = 0,
638
+ overlap_ratio: float = 512 / 1500,
639
+ points_per_crop: Optional[int] = 32,
640
+ crop_n_points_downscale_factor: Optional[list[int]] = 1,
641
+ ) -> tuple[list[list[int]], list[int]]:
642
+ """
643
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
644
+
645
+ Args:
646
+ image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
647
+ Image to generate crops for.
648
+ target_size (`int`):
649
+ Size of the smallest crop.
650
+ crop_n_layers (`int`, *optional*):
651
+ If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
652
+ to run, where each layer has 2**i_layer number of image crops.
653
+ overlap_ratio (`int`, *optional*):
654
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
655
+ image length. Later layers with more crops scale down this overlap.
656
+ points_per_crop (`int`, *optional*):
657
+ Number of points to sample per crop.
658
+ crop_n_points_downscale_factor (`int`, *optional*):
659
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
660
+ input_data_format (`str` or `ChannelDimension`, *optional*):
661
+ The channel dimension format of the input image. If not provided, it will be inferred.
662
+ """
663
+
664
+ if isinstance(image, list):
665
+ raise ValueError("Only one image is allowed for crop generation.")
666
+ original_size = image.shape[-2:]
667
+
668
+ points_grid = []
669
+ for i in range(crop_n_layers + 1):
670
+ n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
671
+ points_grid.append(_build_point_grid(n_points))
672
+
673
+ crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
674
+
675
+ cropped_images, point_grid_per_crop = _generate_crop_images(
676
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size
677
+ )
678
+ crop_boxes = torch.tensor(crop_boxes)
679
+ crop_boxes = crop_boxes.float()
680
+ points_per_crop = torch.stack(point_grid_per_crop)
681
+ points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3)
682
+ cropped_images = torch.stack(cropped_images)
683
+
684
+ input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64)
685
+
686
+ return crop_boxes, points_per_crop, cropped_images, input_labels
687
+
688
+
689
+ def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
690
+ """
691
+ Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
692
+ consists of the following required indices:
693
+ - X: X coordinate of the top left of the bounding box
694
+ - Y: Y coordinate of the top left of the bounding box
695
+ - W: width of the bounding box
696
+ - H: height of the bounding box
697
+ """
698
+ crop_boxes, layer_idxs = [], []
699
+ im_height, im_width = original_size
700
+ short_side = min(im_height, im_width)
701
+
702
+ # Original image
703
+ crop_boxes.append([0, 0, im_width, im_height])
704
+ layer_idxs.append(0)
705
+ for i_layer in range(crop_n_layers):
706
+ n_crops_per_side = 2 ** (i_layer + 1)
707
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
708
+
709
+ crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
710
+ crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
711
+
712
+ crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
713
+ crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
714
+
715
+ for left, top in product(crop_box_x0, crop_box_y0):
716
+ box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
717
+ crop_boxes.append(box)
718
+ layer_idxs.append(i_layer + 1)
719
+
720
+ return crop_boxes, layer_idxs
721
+
722
+
723
+ def _build_point_grid(n_per_side: int) -> torch.Tensor:
724
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
725
+ offset = 1 / (2 * n_per_side)
726
+ points_one_side = torch.linspace(offset, 1 - offset, n_per_side)
727
+ points_x = torch.tile(points_one_side[None, :], (n_per_side, 1))
728
+ points_y = torch.tile(points_one_side[:, None], (1, n_per_side))
729
+ points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2)
730
+ return points
731
+
732
+
733
+ def _generate_crop_images(
734
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
735
+ ):
736
+ """
737
+ Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
738
+ also passed.
739
+ """
740
+ cropped_images = []
741
+ total_points_per_crop = []
742
+ for i, crop_box in enumerate(crop_boxes):
743
+ left, top, right, bottom = crop_box
744
+ cropped_im = image[:, top:bottom, left:right]
745
+
746
+ cropped_images.append(cropped_im)
747
+
748
+ cropped_im_size = cropped_im.shape[-2:]
749
+ points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0)
750
+
751
+ points = points_grid[layer_idxs[i]] * points_scale
752
+ normalized_points = _normalize_coordinates(target_size, points, original_size)
753
+ total_points_per_crop.append(normalized_points)
754
+
755
+ return cropped_images, total_points_per_crop
756
+
757
+
758
+ def _normalize_coordinates(
759
+ target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False
760
+ ) -> torch.Tensor:
761
+ """
762
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
763
+ format.
764
+ """
765
+ old_height, old_width = original_size
766
+
767
+ scale = target_size * 1.0 / max(old_height, old_width)
768
+ new_height, new_width = old_height * scale, old_width * scale
769
+ new_width = int(new_width + 0.5)
770
+ new_height = int(new_height + 0.5)
771
+
772
+ coords = deepcopy(coords).float()
773
+
774
+ if is_bounding_box:
775
+ coords = coords.reshape(-1, 2, 2)
776
+
777
+ coords[..., 0] = coords[..., 0] * (new_width / old_width)
778
+ coords[..., 1] = coords[..., 1] * (new_height / old_height)
779
+
780
+ if is_bounding_box:
781
+ coords = coords.reshape(-1, 4)
782
+
783
+ return coords
784
+
785
+
786
+ def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor:
787
+ """Compute a binary mask from an uncompressed RLE."""
788
+ height, width = rle["size"]
789
+ mask = torch.empty(height * width, dtype=bool)
790
+ idx = 0
791
+ parity = False
792
+ for count in rle["counts"]:
793
+ mask[idx : idx + count] = parity
794
+ idx += count
795
+ parity = not parity
796
+ mask = mask.reshape(width, height)
797
+ return mask.transpose(0, 1) # Reshape to original shape
798
+
799
+
800
+ def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
801
+ """
802
+ Perform NMS (Non Maximum Suppression) on the outputs.
803
+
804
+ Args:
805
+ rle_masks (`torch.Tensor`):
806
+ binary masks in the RLE format
807
+ iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
808
+ iou_scores predicted by the model
809
+ mask_boxes (`torch.Tensor`):
810
+ The bounding boxes corresponding to segmentation masks
811
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
812
+ NMS threshold.
813
+ """
814
+ keep_by_nms = batched_nms(
815
+ boxes=mask_boxes.float(),
816
+ scores=iou_scores,
817
+ idxs=torch.zeros(mask_boxes.shape[0]),
818
+ iou_threshold=amg_crops_nms_thresh,
819
+ )
820
+
821
+ iou_scores = iou_scores[keep_by_nms]
822
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
823
+ mask_boxes = mask_boxes[keep_by_nms]
824
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
825
+
826
+ return masks, iou_scores, rle_masks, mask_boxes
827
+
828
+
829
+ __all__ = ["SamImageProcessorFast"]
phivenv/Lib/site-packages/transformers/models/sam/modeling_sam.py ADDED
@@ -0,0 +1,1368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch SAM model."""
16
+
17
+ import collections
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import Tensor, nn
25
+
26
+ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
27
+
28
+ from ...activations import ACT2FN
29
+ from ...modeling_layers import GradientCheckpointingLayer
30
+ from ...modeling_outputs import BaseModelOutput
31
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
32
+ from ...processing_utils import Unpack
33
+ from ...utils import (
34
+ ModelOutput,
35
+ auto_docstring,
36
+ logging,
37
+ )
38
+ from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ @dataclass
45
+ @auto_docstring(
46
+ custom_intro="""
47
+ Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
48
+ layer to the pooler_output.
49
+ """
50
+ )
51
+ class SamVisionEncoderOutput(ModelOutput):
52
+ r"""
53
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
54
+ The image embeddings obtained by applying the projection layer to the pooler_output.
55
+ """
56
+
57
+ image_embeds: Optional[torch.FloatTensor] = None
58
+ last_hidden_state: Optional[torch.FloatTensor] = None
59
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
60
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
61
+
62
+
63
+ @dataclass
64
+ @auto_docstring(
65
+ custom_intro="""
66
+ Base class for Segment-Anything model's output
67
+ """
68
+ )
69
+ class SamImageSegmentationOutput(ModelOutput):
70
+ r"""
71
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
72
+ The iou scores of the predicted masks.
73
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
74
+ The predicted low resolutions masks. Needs to be post-processed by the processor
75
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
76
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
77
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
78
+
79
+ Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
80
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
81
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
82
+ sequence_length)`.
83
+
84
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
85
+ heads.
86
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
87
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
88
+ sequence_length)`.
89
+
90
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
91
+ heads.
92
+ """
93
+
94
+ iou_scores: Optional[torch.FloatTensor] = None
95
+ pred_masks: Optional[torch.FloatTensor] = None
96
+ vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
97
+ vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
98
+ mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
99
+
100
+
101
+ class SamPatchEmbeddings(nn.Module):
102
+ """
103
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
104
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
105
+ Transformer.
106
+ """
107
+
108
+ def __init__(self, config):
109
+ super().__init__()
110
+ image_size, patch_size = config.image_size, config.patch_size
111
+ num_channels, hidden_size = config.num_channels, config.hidden_size
112
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
113
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
114
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
115
+ self.image_size = image_size
116
+ self.patch_size = patch_size
117
+ self.num_channels = num_channels
118
+ self.num_patches = num_patches
119
+
120
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
121
+
122
+ def forward(self, pixel_values):
123
+ batch_size, num_channels, height, width = pixel_values.shape
124
+ if num_channels != self.num_channels:
125
+ raise ValueError(
126
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
127
+ )
128
+ if height != self.image_size[0] or width != self.image_size[1]:
129
+ raise ValueError(
130
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
131
+ )
132
+ embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
133
+ return embeddings
134
+
135
+
136
+ class SamMLPBlock(nn.Module):
137
+ def __init__(self, config):
138
+ super().__init__()
139
+ self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
140
+ self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
141
+ self.act = ACT2FN[config.hidden_act]
142
+
143
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
144
+ hidden_states = self.lin1(hidden_states)
145
+ hidden_states = self.act(hidden_states)
146
+ hidden_states = self.lin2(hidden_states)
147
+ return hidden_states
148
+
149
+
150
+ # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
151
+ class SamLayerNorm(nn.LayerNorm):
152
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
153
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
154
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
155
+ """
156
+
157
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
158
+ super().__init__(normalized_shape, eps=eps, **kwargs)
159
+ if data_format not in ["channels_last", "channels_first"]:
160
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
161
+ self.data_format = data_format
162
+
163
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
164
+ """
165
+ Args:
166
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
167
+ """
168
+ if self.data_format == "channels_first":
169
+ features = features.permute(0, 2, 3, 1)
170
+ features = super().forward(features)
171
+ features = features.permute(0, 3, 1, 2)
172
+ else:
173
+ features = super().forward(features)
174
+ return features
175
+
176
+
177
+ def eager_attention_forward(
178
+ module: nn.Module,
179
+ query: torch.Tensor,
180
+ key: torch.Tensor,
181
+ value: torch.Tensor,
182
+ attention_mask: Optional[torch.Tensor],
183
+ scaling: float,
184
+ dropout: float = 0.0,
185
+ **kwargs,
186
+ ):
187
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
188
+ if attention_mask is not None:
189
+ attn_weights = attn_weights + attention_mask
190
+
191
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
192
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
193
+ attn_output = torch.matmul(attn_weights, value)
194
+ attn_output = attn_output.transpose(1, 2).contiguous()
195
+
196
+ return attn_output, attn_weights
197
+
198
+
199
+ class SamAttention(nn.Module):
200
+ """
201
+ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
202
+ values.
203
+ """
204
+
205
+ def __init__(self, config, downsample_rate=None):
206
+ super().__init__()
207
+ self.config = config
208
+ self.hidden_size = config.hidden_size
209
+
210
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
211
+
212
+ self.internal_dim = config.hidden_size // downsample_rate
213
+ self.num_attention_heads = config.num_attention_heads
214
+ if self.internal_dim % config.num_attention_heads != 0:
215
+ raise ValueError("num_attention_heads must divide hidden_size.")
216
+ self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
217
+
218
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
219
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
220
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
221
+ self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
222
+
223
+ self.is_causal = False
224
+
225
+ def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
226
+ batch, point_batch_size, n_tokens, channel = hidden_states.shape
227
+ c_per_head = channel // num_attention_heads
228
+ hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
229
+ return hidden_states.transpose(1, 2)
230
+
231
+ def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
232
+ batch, n_tokens, n_heads, c_per_head = hidden_states.shape
233
+ return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
234
+
235
+ def forward(
236
+ self,
237
+ query: Tensor,
238
+ key: Tensor,
239
+ value: Tensor,
240
+ attention_similarity: Optional[Tensor] = None,
241
+ **kwargs: Unpack[TransformersKwargs],
242
+ ) -> Tensor:
243
+ # Input projections
244
+ query = self.q_proj(query)
245
+ key = self.k_proj(key)
246
+ value = self.v_proj(value)
247
+
248
+ point_batch_size = query.shape[1]
249
+ # Separate into heads
250
+ query = self._separate_heads(query, self.num_attention_heads)
251
+ key = self._separate_heads(key, self.num_attention_heads)
252
+ value = self._separate_heads(value, self.num_attention_heads)
253
+
254
+ # SamAttention
255
+ attention_interface: Callable = eager_attention_forward
256
+ if self.config._attn_implementation != "eager":
257
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
258
+
259
+ attn_output, attn_weights = attention_interface(
260
+ self,
261
+ query,
262
+ key,
263
+ value,
264
+ attention_mask=attention_similarity,
265
+ dropout=0.0,
266
+ scaling=self.scaling,
267
+ is_causal=self.is_causal,
268
+ **kwargs,
269
+ )
270
+
271
+ attn_output = self._recombine_heads(attn_output, point_batch_size)
272
+ attn_output = self.out_proj(attn_output)
273
+
274
+ return attn_output, attn_weights
275
+
276
+
277
+ class SamTwoWayAttentionBlock(nn.Module):
278
+ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
279
+ """
280
+ A transformer block with four layers:
281
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
282
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
283
+
284
+ Arguments:
285
+ config (`SamMaskDecoderConfig`):
286
+ The configuration file used to instantiate the block
287
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
288
+ The downsample ratio of the block used to reduce the inner dim of the attention.
289
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
290
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
291
+ """
292
+ super().__init__()
293
+
294
+ self.hidden_size = config.hidden_size
295
+ self.layer_norm_eps = config.layer_norm_eps
296
+
297
+ self.self_attn = SamAttention(config, downsample_rate=1)
298
+ self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
299
+
300
+ self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
301
+ self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
302
+
303
+ self.mlp = SamMLPBlock(config)
304
+ self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
305
+
306
+ self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
307
+ self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
308
+ self.skip_first_layer_pe = skip_first_layer_pe
309
+
310
+ def forward(
311
+ self,
312
+ queries: Tensor,
313
+ keys: Tensor,
314
+ query_point_embedding: Tensor,
315
+ key_point_embedding: Tensor,
316
+ attention_similarity: Tensor,
317
+ **kwargs: Unpack[TransformersKwargs],
318
+ ):
319
+ # Self attention block
320
+ if self.skip_first_layer_pe:
321
+ queries, _ = self.self_attn(query=queries, key=queries, value=queries)
322
+ else:
323
+ query = queries + query_point_embedding
324
+ attn_out, _ = self.self_attn(query=query, key=query, value=queries)
325
+ queries = queries + attn_out
326
+ queries = self.layer_norm1(queries)
327
+
328
+ # Cross attention block, tokens attending to image embedding
329
+ query = queries + query_point_embedding
330
+ key = keys + key_point_embedding
331
+
332
+ attn_out, _ = self.cross_attn_token_to_image(
333
+ query=query, key=key, value=keys, attention_similarity=attention_similarity
334
+ )
335
+ queries = queries + attn_out
336
+
337
+ queries = self.layer_norm2(queries)
338
+
339
+ # MLP block
340
+ mlp_out = self.mlp(queries)
341
+ queries = queries + mlp_out
342
+ queries = self.layer_norm3(queries)
343
+
344
+ # Cross attention block, image embedding attending to tokens
345
+ query = queries + query_point_embedding
346
+ key = keys + key_point_embedding
347
+
348
+ attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
349
+ keys = keys + attn_out
350
+
351
+ keys = self.layer_norm4(keys)
352
+ return queries, keys, attn_out
353
+
354
+
355
+ class SamTwoWayTransformer(nn.Module):
356
+ def __init__(self, config: SamMaskDecoderConfig):
357
+ super().__init__()
358
+ self.config = config
359
+
360
+ self.num_hidden_layers = config.num_hidden_layers
361
+ self.layers = nn.ModuleList()
362
+
363
+ for i in range(self.num_hidden_layers):
364
+ self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
365
+
366
+ self.final_attn_token_to_image = SamAttention(config)
367
+ self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
368
+
369
+ def forward(
370
+ self,
371
+ point_embeddings: Tensor,
372
+ image_embeddings: Tensor,
373
+ image_positional_embeddings: Tensor,
374
+ attention_similarity: Tensor,
375
+ target_embedding=None,
376
+ **kwargs: Unpack[TransformersKwargs],
377
+ ) -> Union[tuple, BaseModelOutput]:
378
+ if image_embeddings is None:
379
+ raise ValueError("You have to specify an image_embedding")
380
+
381
+ image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
382
+ image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
383
+
384
+ # Prepare queries
385
+ queries = point_embeddings
386
+ keys = image_embeddings
387
+
388
+ # Apply transformer blocks and final layernorm
389
+ for layer in self.layers:
390
+ if target_embedding is not None:
391
+ queries += target_embedding
392
+
393
+ queries, keys, _ = layer(
394
+ queries=queries,
395
+ keys=keys,
396
+ query_point_embedding=point_embeddings,
397
+ key_point_embedding=image_positional_embeddings,
398
+ attention_similarity=attention_similarity,
399
+ **kwargs,
400
+ )
401
+ # Apply the final attention layer from the points to the image
402
+ query = queries + point_embeddings
403
+ key = keys + image_positional_embeddings
404
+
405
+ attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
406
+
407
+ queries = queries + attn_out
408
+ queries = self.layer_norm_final_attn(queries)
409
+ return queries, keys
410
+
411
+
412
+ class SamFeedForward(nn.Module):
413
+ def __init__(
414
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
415
+ ):
416
+ super().__init__()
417
+ self.num_layers = num_layers
418
+ self.activation = nn.ReLU()
419
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
420
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
421
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
422
+ self.sigmoid_output = sigmoid_output
423
+
424
+ def forward(self, hidden_states):
425
+ hidden_states = self.proj_in(hidden_states)
426
+ hidden_states = self.activation(hidden_states)
427
+ for layer in self.layers:
428
+ hidden_states = self.activation(layer(hidden_states))
429
+
430
+ hidden_states = self.proj_out(hidden_states)
431
+ if self.sigmoid_output:
432
+ hidden_states = F.sigmoid(hidden_states)
433
+ return hidden_states
434
+
435
+
436
+ class SamMaskDecoder(nn.Module):
437
+ def __init__(self, config: SamMaskDecoderConfig):
438
+ super().__init__()
439
+ self.config = config
440
+ self.hidden_size = config.hidden_size
441
+
442
+ self.num_multimask_outputs = config.num_multimask_outputs
443
+ self.num_mask_tokens = config.num_multimask_outputs + 1
444
+
445
+ self.iou_token = nn.Embedding(1, self.hidden_size)
446
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
447
+
448
+ self.transformer = SamTwoWayTransformer(config)
449
+
450
+ # should we create a new class for this?
451
+ self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
452
+ self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
453
+ self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
454
+ self.activation = nn.GELU()
455
+
456
+ mlps_list = []
457
+ for _ in range(self.num_mask_tokens):
458
+ mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
459
+ self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
460
+
461
+ self.iou_prediction_head = SamFeedForward(
462
+ self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
463
+ )
464
+
465
+ def forward(
466
+ self,
467
+ image_embeddings: torch.Tensor,
468
+ image_positional_embeddings: torch.Tensor,
469
+ sparse_prompt_embeddings: torch.Tensor,
470
+ dense_prompt_embeddings: torch.Tensor,
471
+ multimask_output: bool,
472
+ attention_similarity: Optional[torch.Tensor] = None,
473
+ target_embedding: Optional[torch.Tensor] = None,
474
+ ) -> tuple[torch.Tensor, torch.Tensor]:
475
+ """
476
+ Predict masks given image and prompt embeddings.
477
+
478
+ Args:
479
+ image_embeddings (`torch.Tensor`):
480
+ the embeddings from the image encoder
481
+ image_positional_embedding (`torch.Tensor`):
482
+ positional encoding with the shape of image_embeddings
483
+ sparse_prompt_embeddings (`torch.Tensor`):
484
+ The embeddings of the points and boxes
485
+ dense_prompt_embeddings (`torch.Tensor`):
486
+ the embeddings of the mask inputs
487
+ multimask_output (bool):
488
+ Whether to return multiple masks or a single mask.
489
+ """
490
+ batch_size, num_channels, height, width = image_embeddings.shape
491
+ point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
492
+ # Concatenate output tokens
493
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
494
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
495
+
496
+ if sparse_prompt_embeddings is not None:
497
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
498
+ else:
499
+ tokens = output_tokens
500
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
501
+
502
+ # Expand per-image data in batch direction to be per-point
503
+ image_embeddings = image_embeddings + dense_prompt_embeddings
504
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
505
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
506
+
507
+ # Run the transformer, image_positional_embedding are consumed
508
+ point_embedding, image_embeddings = self.transformer(
509
+ point_embeddings=point_embeddings,
510
+ image_embeddings=image_embeddings,
511
+ image_positional_embeddings=image_positional_embeddings,
512
+ attention_similarity=attention_similarity,
513
+ target_embedding=target_embedding,
514
+ )
515
+ iou_token_out = point_embedding[:, :, 0, :]
516
+ mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
517
+
518
+ # Upscale mask embeddings and predict masks using the mask tokens
519
+ image_embeddings = image_embeddings.transpose(2, 3).reshape(
520
+ batch_size * point_batch_size, num_channels, height, width
521
+ )
522
+
523
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
524
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
525
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
526
+
527
+ hyper_in_list = []
528
+ for i in range(self.num_mask_tokens):
529
+ current_mlp = self.output_hypernetworks_mlps[i]
530
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
531
+ hyper_in = torch.stack(hyper_in_list, dim=2)
532
+
533
+ _, num_channels, height, width = upscaled_embedding.shape
534
+ upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
535
+ masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
536
+
537
+ # Generate mask quality predictions
538
+ iou_pred = self.iou_prediction_head(iou_token_out)
539
+
540
+ # Select the correct mask or masks for output
541
+ if multimask_output:
542
+ mask_slice = slice(1, None)
543
+ else:
544
+ mask_slice = slice(0, 1)
545
+ masks = masks[:, :, mask_slice, :, :]
546
+ iou_pred = iou_pred[:, :, mask_slice]
547
+ return masks, iou_pred
548
+
549
+
550
+ class SamPositionalEmbedding(nn.Module):
551
+ def __init__(self, config):
552
+ super().__init__()
553
+ self.scale = config.hidden_size // 2
554
+ self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
555
+
556
+ def forward(self, input_coords, input_shape=None):
557
+ """Positionally encode points that are normalized to [0,1]."""
558
+ coordinates = input_coords.clone()
559
+
560
+ if input_shape is not None:
561
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
562
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
563
+
564
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
565
+ coordinates = 2 * coordinates - 1
566
+ coordinates = coordinates.to(self.positional_embedding.dtype)
567
+ coordinates = coordinates @ self.positional_embedding
568
+ coordinates = 2 * np.pi * coordinates
569
+ # outputs d_1 x ... x d_n x channel shape
570
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
571
+
572
+
573
+ class SamMaskEmbedding(nn.Module):
574
+ def __init__(self, config: SamPromptEncoderConfig):
575
+ super().__init__()
576
+ self.mask_input_channels = config.mask_input_channels // 4
577
+ self.activation = ACT2FN[config.hidden_act]
578
+ self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
579
+ self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
580
+ self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
581
+ self.layer_norm1 = SamLayerNorm(
582
+ self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
583
+ )
584
+ self.layer_norm2 = SamLayerNorm(
585
+ self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
586
+ )
587
+
588
+ def forward(self, masks):
589
+ hidden_states = self.conv1(masks)
590
+ hidden_states = self.layer_norm1(hidden_states)
591
+ hidden_states = self.activation(hidden_states)
592
+
593
+ hidden_states = self.conv2(hidden_states)
594
+ hidden_states = self.layer_norm2(hidden_states)
595
+ hidden_states = self.activation(hidden_states)
596
+ dense_embeddings = self.conv3(hidden_states)
597
+ return dense_embeddings
598
+
599
+
600
+ class SamPromptEncoder(nn.Module):
601
+ def __init__(self, config: SamConfig):
602
+ super().__init__()
603
+ self.shared_embedding = SamPositionalEmbedding(config.vision_config)
604
+ config = config.prompt_encoder_config
605
+ self.mask_embed = SamMaskEmbedding(config)
606
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
607
+
608
+ self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
609
+ self.input_image_size = config.image_size
610
+
611
+ self.point_embed = nn.ModuleList(
612
+ [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
613
+ )
614
+ self.hidden_size = config.hidden_size
615
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
616
+
617
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
618
+ """Embeds point prompts."""
619
+ points = points + 0.5 # Shift to center of pixel
620
+ if pad:
621
+ target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
622
+ target_labels_shape = (points.shape[0], points.shape[1], 1)
623
+ padding_point = torch.zeros(target_point_shape, device=points.device)
624
+ padding_label = -torch.ones(target_labels_shape, device=labels.device)
625
+ points = torch.cat([points, padding_point], dim=2)
626
+ labels = torch.cat([labels, padding_label], dim=2)
627
+ input_shape = (self.input_image_size, self.input_image_size)
628
+ point_embedding = self.shared_embedding(points, input_shape)
629
+
630
+ # torch.where and expanding the labels tensor is required by the ONNX export
631
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
632
+
633
+ # This is required for the ONNX export. The dtype, device need to be explicitly
634
+ # specified as otherwise torch.onnx.export interprets as double
635
+ point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
636
+
637
+ point_embedding = torch.where(
638
+ (labels == 0)[:, :, :, None],
639
+ point_embedding + self.point_embed[0].weight[None, None, :, :],
640
+ point_embedding,
641
+ )
642
+
643
+ point_embedding = torch.where(
644
+ (labels == 1)[:, :, :, None],
645
+ point_embedding + self.point_embed[1].weight[None, None, :, :],
646
+ point_embedding,
647
+ )
648
+
649
+ return point_embedding
650
+
651
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
652
+ """Embeds box prompts."""
653
+ boxes = boxes + 0.5 # Shift to center of pixel
654
+ batch_size, nb_boxes = boxes.shape[:2]
655
+ coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
656
+ input_shape = (self.input_image_size, self.input_image_size)
657
+ corner_embedding = self.shared_embedding(coords, input_shape)
658
+ corner_embedding[:, :, 0, :] += self.point_embed[2].weight
659
+ corner_embedding[:, :, 1, :] += self.point_embed[3].weight
660
+ return corner_embedding
661
+
662
+ def forward(
663
+ self,
664
+ input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
665
+ input_labels: Optional[torch.Tensor],
666
+ input_boxes: Optional[torch.Tensor],
667
+ input_masks: Optional[torch.Tensor],
668
+ ) -> tuple[torch.Tensor, torch.Tensor]:
669
+ """
670
+ Embeds different types of prompts, returning both sparse and dense embeddings.
671
+
672
+ Args:
673
+ points (`torch.Tensor`, *optional*):
674
+ point coordinates and labels to embed.
675
+ boxes (`torch.Tensor`, *optional*):
676
+ boxes to embed
677
+ masks (`torch.Tensor`, *optional*):
678
+ masks to embed
679
+ """
680
+ sparse_embeddings = None
681
+ batch_size = 1
682
+ if input_points is not None:
683
+ batch_size = input_points.shape[0]
684
+ if input_labels is None:
685
+ raise ValueError("If points are provided, labels must also be provided.")
686
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
687
+ sparse_embeddings = point_embeddings
688
+ if input_boxes is not None:
689
+ batch_size = input_boxes.shape[0]
690
+ box_embeddings = self._embed_boxes(input_boxes)
691
+ if sparse_embeddings is None:
692
+ sparse_embeddings = box_embeddings
693
+ else:
694
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
695
+ if input_masks is not None:
696
+ dense_embeddings = self.mask_embed(input_masks)
697
+ else:
698
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
699
+ batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
700
+ )
701
+
702
+ return sparse_embeddings, dense_embeddings
703
+
704
+
705
+ class SamVisionAttention(nn.Module):
706
+ """Multi-head Attention block with relative position embeddings."""
707
+
708
+ def __init__(self, config, window_size):
709
+ super().__init__()
710
+ input_size = (
711
+ (config.image_size // config.patch_size, config.image_size // config.patch_size)
712
+ if window_size == 0
713
+ else (window_size, window_size)
714
+ )
715
+
716
+ self.num_attention_heads = config.num_attention_heads
717
+ head_dim = config.hidden_size // config.num_attention_heads
718
+ self.scale = head_dim**-0.5
719
+ self.dropout = config.attention_dropout
720
+
721
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
722
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
723
+
724
+ self.use_rel_pos = config.use_rel_pos
725
+ if self.use_rel_pos:
726
+ if input_size is None:
727
+ raise ValueError("Input size must be provided if using relative positional encoding.")
728
+
729
+ # initialize relative positional embeddings
730
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
731
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
732
+
733
+ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
734
+ """
735
+ Get relative positional embeddings according to the relative positions of
736
+ query and key sizes.
737
+
738
+ Args:
739
+ q_size (int):
740
+ size of the query.
741
+ k_size (int):
742
+ size of key k.
743
+ rel_pos (`torch.Tensor`):
744
+ relative position embeddings (L, channel).
745
+
746
+ Returns:
747
+ Extracted positional embeddings according to relative positions.
748
+ """
749
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
750
+ # Interpolate rel pos.
751
+ rel_pos_resized = F.interpolate(
752
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
753
+ size=max_rel_dist,
754
+ mode="linear",
755
+ )
756
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
757
+
758
+ # Scale the coords with short length if shapes for q and k are different.
759
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
760
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
761
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
762
+
763
+ return rel_pos_resized[relative_coords.long()]
764
+
765
+ def get_decomposed_rel_pos(
766
+ self,
767
+ query: torch.Tensor,
768
+ rel_pos_h: torch.Tensor,
769
+ rel_pos_w: torch.Tensor,
770
+ q_size: tuple[int, int],
771
+ k_size: tuple[int, int],
772
+ ) -> torch.Tensor:
773
+ """
774
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
775
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
776
+
777
+ Args:
778
+ query (`torch.Tensor`):
779
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
780
+ rel_pos_h (`torch.Tensor`):
781
+ relative position embeddings (Lh, channel) for height axis.
782
+ rel_pos_w (`torch.Tensor`):
783
+ relative position embeddings (Lw, channel) for width axis.
784
+ q_size (tuple):
785
+ spatial sequence size of query q with (query_height, query_width).
786
+ k_size (tuple):
787
+ spatial sequence size of key k with (key_height, key_width).
788
+
789
+ Returns:
790
+ decomposed_rel_pos (`torch.Tensor`):
791
+ decomposed relative position embeddings.
792
+ """
793
+ query_height, query_width = q_size
794
+ key_height, key_width = k_size
795
+ relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
796
+ relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
797
+
798
+ batch_size, _, dim = query.shape
799
+ reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
800
+ rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
801
+ rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
802
+
803
+ decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
804
+
805
+ return decomposed_rel_pos
806
+
807
+ def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]:
808
+ batch_size, height, width, _ = hidden_states.shape
809
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
810
+ qkv = (
811
+ self.qkv(hidden_states)
812
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
813
+ .permute(2, 0, 3, 1, 4)
814
+ )
815
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
816
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
817
+
818
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
819
+
820
+ if self.use_rel_pos:
821
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
822
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
823
+ )
824
+ decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
825
+ attn_weights = attn_weights + decomposed_rel_pos
826
+
827
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
828
+
829
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
830
+
831
+ attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
832
+ attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
833
+
834
+ attn_output = self.proj(attn_output)
835
+ return attn_output, attn_weights
836
+
837
+
838
+ class SamVisionSdpaAttention(SamVisionAttention):
839
+ """
840
+ Multi-head Attention block with relative position embeddings.
841
+ Using SDPA instead of the default attention.
842
+ """
843
+
844
+ def __init__(self, config, window_size):
845
+ super().__init__(config, window_size)
846
+
847
+ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
848
+ if output_attentions:
849
+ logger.warning_once(
850
+ "`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
851
+ "`output_attentions=True`. Falling back to the manual attention implementation, but "
852
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
853
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
854
+ )
855
+ return super().forward(
856
+ hidden_states=hidden_states,
857
+ output_attentions=output_attentions,
858
+ )
859
+
860
+ batch_size, height, width, _ = hidden_states.shape
861
+ # qkv with shape (3, B, nHead, H * W, C)
862
+ qkv = (
863
+ self.qkv(hidden_states)
864
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
865
+ .permute(2, 0, 3, 1, 4)
866
+ )
867
+ # q, k, v with shape (B * nHead, H * W, C)
868
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
869
+
870
+ attn_bias = None
871
+ if self.use_rel_pos:
872
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
873
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
874
+ )
875
+ decomposed_rel_pos = decomposed_rel_pos.reshape(
876
+ batch_size, self.num_attention_heads, height * width, height * width
877
+ )
878
+ attn_bias = decomposed_rel_pos
879
+
880
+ query = query.view(batch_size, self.num_attention_heads, height * width, -1)
881
+ key = key.view(batch_size, self.num_attention_heads, height * width, -1)
882
+ value = value.view(batch_size, self.num_attention_heads, height * width, -1)
883
+
884
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
885
+
886
+ attn_output = (
887
+ attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
888
+ .permute(0, 2, 3, 1, 4)
889
+ .reshape(batch_size, height, width, -1)
890
+ )
891
+
892
+ attn_output = self.proj(attn_output)
893
+ return attn_output, None
894
+
895
+
896
+ SAM_VISION_ATTENTION_CLASSES = {
897
+ "eager": SamVisionAttention,
898
+ "sdpa": SamVisionSdpaAttention,
899
+ }
900
+
901
+
902
+ class SamVisionLayer(GradientCheckpointingLayer):
903
+ def __init__(self, config, window_size):
904
+ super().__init__()
905
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
906
+ self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
907
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
908
+ self.mlp = SamMLPBlock(config)
909
+ self.window_size = window_size
910
+
911
+ def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
912
+ """
913
+ Args:
914
+ Partition into non-overlapping windows with padding if needed.
915
+ hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
916
+ size.
917
+
918
+ Returns:
919
+ windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
920
+ (pad_height, pad_width): padded height and width before partition
921
+ """
922
+ batch_size, height, width, channel = hidden_states.shape
923
+
924
+ pad_h = (window_size - height % window_size) % window_size
925
+ pad_w = (window_size - width % window_size) % window_size
926
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
927
+ pad_height, pad_width = height + pad_h, width + pad_w
928
+
929
+ hidden_states = hidden_states.reshape(
930
+ batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
931
+ )
932
+ windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
933
+ return windows, (pad_height, pad_width)
934
+
935
+ def window_unpartition(
936
+ self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int]
937
+ ) -> torch.Tensor:
938
+ """
939
+ Args:
940
+ Window unpartition into original sequences and removing padding.
941
+ hidden_states (tensor):
942
+ input tokens with [batch_size * num_windows, window_size, window_size, channel].
943
+ window_size (int):
944
+ window size.
945
+ padding_shape (Tuple):
946
+ padded height and width (pad_height, pad_width).
947
+ original_shape (Tuple): original height and width (height, width) before padding.
948
+
949
+ Returns:
950
+ hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
951
+ """
952
+ pad_height, pad_width = padding_shape
953
+ height, width = original_shape
954
+ batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
955
+ hidden_states = windows.reshape(
956
+ batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
957
+ )
958
+ hidden_states = (
959
+ hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
960
+ )
961
+
962
+ hidden_states = hidden_states[:, :height, :width, :].contiguous()
963
+ return hidden_states
964
+
965
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
966
+ residual = hidden_states
967
+ hidden_states = self.layer_norm1(hidden_states)
968
+ # Window partition
969
+ if self.window_size > 0:
970
+ height, width = hidden_states.shape[1], hidden_states.shape[2]
971
+ hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
972
+
973
+ hidden_states, attn_weights = self.attn(
974
+ hidden_states=hidden_states,
975
+ )
976
+ # Reverse window partition
977
+ if self.window_size > 0:
978
+ hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
979
+
980
+ hidden_states = residual + hidden_states
981
+ layernorm_output = self.layer_norm2(hidden_states)
982
+ hidden_states = hidden_states + self.mlp(layernorm_output)
983
+ return hidden_states
984
+
985
+
986
+ class SamVisionNeck(nn.Module):
987
+ def __init__(self, config: SamVisionConfig):
988
+ super().__init__()
989
+ self.config = config
990
+
991
+ self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
992
+ self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
993
+ self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
994
+ self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
995
+
996
+ def forward(self, hidden_states):
997
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
998
+ hidden_states = self.conv1(hidden_states)
999
+ hidden_states = self.layer_norm1(hidden_states)
1000
+
1001
+ hidden_states = self.conv2(hidden_states)
1002
+ hidden_states = self.layer_norm2(hidden_states)
1003
+ return hidden_states
1004
+
1005
+
1006
+ @auto_docstring
1007
+ class SamPreTrainedModel(PreTrainedModel):
1008
+ config: SamConfig
1009
+ base_model_prefix = "sam"
1010
+ main_input_name = "pixel_values"
1011
+ _no_split_modules = ["SamVisionAttention"]
1012
+ supports_gradient_checkpointing = True
1013
+ _supports_sdpa = True
1014
+
1015
+ def _init_weights(self, module: nn.Module):
1016
+ super()._init_weights(module)
1017
+ if isinstance(module, SamVisionAttention):
1018
+ if module.use_rel_pos:
1019
+ module.rel_pos_h.data.zero_()
1020
+ module.rel_pos_w.data.zero_()
1021
+ elif isinstance(module, SamVisionEncoder):
1022
+ if self.config.use_abs_pos:
1023
+ module.pos_embed.data.zero_()
1024
+
1025
+
1026
+ class SamVisionEncoder(SamPreTrainedModel):
1027
+ _can_record_outputs = {"hidden_states": SamVisionLayer, "attentions": SamVisionAttention}
1028
+
1029
+ def __init__(self, config: SamVisionConfig):
1030
+ super().__init__(config)
1031
+ self.config = config
1032
+ self.image_size = config.image_size
1033
+ self.patch_embed = SamPatchEmbeddings(config)
1034
+
1035
+ self.pos_embed = None
1036
+ if config.use_abs_pos:
1037
+ # Initialize absolute positional embedding with pretrain image size.
1038
+ self.pos_embed = nn.Parameter(
1039
+ torch.zeros(
1040
+ 1,
1041
+ config.image_size // config.patch_size,
1042
+ config.image_size // config.patch_size,
1043
+ config.hidden_size,
1044
+ )
1045
+ )
1046
+
1047
+ self.layers = nn.ModuleList()
1048
+ for i in range(config.num_hidden_layers):
1049
+ layer = SamVisionLayer(
1050
+ config,
1051
+ window_size=config.window_size if i not in config.global_attn_indexes else 0,
1052
+ )
1053
+ self.layers.append(layer)
1054
+
1055
+ self.neck = SamVisionNeck(config)
1056
+
1057
+ self.gradient_checkpointing = False
1058
+
1059
+ def get_input_embeddings(self):
1060
+ return self.patch_embed
1061
+
1062
+ @check_model_inputs
1063
+ def forward(
1064
+ self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
1065
+ ) -> SamVisionEncoderOutput:
1066
+ if pixel_values is None:
1067
+ raise ValueError("You have to specify pixel_values")
1068
+
1069
+ hidden_states = self.patch_embed(pixel_values)
1070
+ if self.pos_embed is not None:
1071
+ hidden_states = hidden_states + self.pos_embed
1072
+ for layer_module in self.layers:
1073
+ hidden_states = layer_module(hidden_states)
1074
+ hidden_states = self.neck(hidden_states)
1075
+ return SamVisionEncoderOutput(
1076
+ last_hidden_state=hidden_states,
1077
+ )
1078
+
1079
+
1080
+ @auto_docstring(
1081
+ custom_intro="""
1082
+ The vision model from Sam without any head or projection on top.
1083
+ """
1084
+ )
1085
+ class SamVisionModel(SamPreTrainedModel):
1086
+ config: SamVisionConfig
1087
+ main_input_name = "pixel_values"
1088
+
1089
+ def __init__(self, config: SamVisionConfig):
1090
+ super().__init__(config)
1091
+ self.vision_encoder = SamVisionEncoder(config)
1092
+ self.post_init()
1093
+
1094
+ def get_input_embeddings(self) -> nn.Module:
1095
+ return self.vision_encoder.patch_embed
1096
+
1097
+ @auto_docstring
1098
+ def forward(
1099
+ self,
1100
+ pixel_values: Optional[torch.FloatTensor] = None,
1101
+ **kwargs: Unpack[TransformersKwargs],
1102
+ ) -> Union[tuple, SamVisionEncoderOutput]:
1103
+ return self.vision_encoder(pixel_values, **kwargs)
1104
+
1105
+
1106
+ @auto_docstring(
1107
+ custom_intro="""
1108
+ Segment Anything Model (SAM) for generating segmentation masks, given an input image and
1109
+ input points and labels, boxes, or masks.
1110
+ """
1111
+ )
1112
+ class SamModel(SamPreTrainedModel):
1113
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
1114
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
1115
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
1116
+ _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)}
1117
+
1118
+ def __init__(self, config: SamConfig):
1119
+ super().__init__(config)
1120
+ self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
1121
+
1122
+ self.vision_encoder = SamVisionEncoder(config.vision_config)
1123
+ self.prompt_encoder = SamPromptEncoder(config)
1124
+ # The module using it is not a PreTrainedModel subclass so we need this
1125
+ config.mask_decoder_config._attn_implementation = config._attn_implementation
1126
+ self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
1127
+
1128
+ self.post_init()
1129
+
1130
+ def _tie_weights(self):
1131
+ self.prompt_encoder.shared_embedding.positional_embedding.data = (
1132
+ self.shared_image_embedding.positional_embedding.data
1133
+ )
1134
+
1135
+ def get_input_embeddings(self):
1136
+ return self.vision_encoder.get_input_embeddings()
1137
+
1138
+ def get_image_wide_positional_embeddings(self):
1139
+ size = self.config.prompt_encoder_config.image_embedding_size
1140
+ target_device = self.shared_image_embedding.positional_embedding.device
1141
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
1142
+ grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
1143
+ y_embed = grid.cumsum(dim=0) - 0.5
1144
+ x_embed = grid.cumsum(dim=1) - 0.5
1145
+ y_embed = y_embed / size
1146
+ x_embed = x_embed / size
1147
+
1148
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
1149
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
1150
+
1151
+ @torch.no_grad()
1152
+ def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
1153
+ r"""
1154
+ Returns the image embeddings by passing the pixel values through the vision encoder.
1155
+
1156
+ Args:
1157
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1158
+ Input pixel values
1159
+ """
1160
+ vision_output = self.vision_encoder(
1161
+ pixel_values,
1162
+ **kwargs,
1163
+ )
1164
+ image_embeddings = vision_output[0]
1165
+ return image_embeddings
1166
+
1167
+ @torch.no_grad()
1168
+ def get_prompt_embeddings(
1169
+ self,
1170
+ input_points: Optional[torch.FloatTensor] = None,
1171
+ input_labels: Optional[torch.LongTensor] = None,
1172
+ input_boxes: Optional[torch.FloatTensor] = None,
1173
+ input_masks: Optional[torch.LongTensor] = None,
1174
+ ):
1175
+ r"""
1176
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
1177
+
1178
+ Args:
1179
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
1180
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
1181
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
1182
+ point. The model will output `point_batch_size` times 3 masks in total.
1183
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
1184
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
1185
+ processor, or can be fed by the user.
1186
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
1187
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
1188
+ processor. users can also pass manually the input boxes.
1189
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
1190
+ Optional input masks for the prompt encoder.
1191
+ """
1192
+ prompt_output = self.prompt_encoder(
1193
+ input_points=input_points,
1194
+ input_labels=input_labels,
1195
+ input_boxes=input_boxes,
1196
+ input_masks=input_masks,
1197
+ )
1198
+ return prompt_output
1199
+
1200
+ @check_model_inputs
1201
+ @auto_docstring
1202
+ def forward(
1203
+ self,
1204
+ pixel_values: Optional[torch.FloatTensor] = None,
1205
+ input_points: Optional[torch.FloatTensor] = None,
1206
+ input_labels: Optional[torch.LongTensor] = None,
1207
+ input_boxes: Optional[torch.FloatTensor] = None,
1208
+ input_masks: Optional[torch.LongTensor] = None,
1209
+ image_embeddings: Optional[torch.FloatTensor] = None,
1210
+ multimask_output: bool = True,
1211
+ attention_similarity: Optional[torch.FloatTensor] = None,
1212
+ target_embedding: Optional[torch.FloatTensor] = None,
1213
+ **kwargs: Unpack[TransformersKwargs],
1214
+ ) -> SamImageSegmentationOutput:
1215
+ r"""
1216
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
1217
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
1218
+ better results. The points can be obtained by passing a list of list of list to the processor that will
1219
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
1220
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
1221
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
1222
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
1223
+ coordinates of the point. If a different number of points is passed either for each image, or for each
1224
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
1225
+ computation of the embedding will be skipped for these points using the labels.
1226
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
1227
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
1228
+ official implementation, there are 3 types of labels
1229
+
1230
+ - `1`: the point is a point that contains the object of interest
1231
+ - `0`: the point is a point that does not contain the object of interest
1232
+ - `-1`: the point corresponds to the background
1233
+
1234
+ We added the label:
1235
+
1236
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
1237
+
1238
+ The padding labels should be automatically done by the processor.
1239
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
1240
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
1241
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
1242
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
1243
+ size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
1244
+ In the order (`x1`, `y1`, `x2`, `y2`):
1245
+
1246
+ - `x1`: the x coordinate of the top left point of the input box
1247
+ - `y1`: the y coordinate of the top left point of the input box
1248
+ - `x2`: the x coordinate of the bottom right point of the input box
1249
+ - `y2`: the y coordinate of the bottom right point of the input box
1250
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
1251
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
1252
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
1253
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
1254
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
1255
+ Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
1256
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
1257
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
1258
+ multimask_output (`bool`, *optional*):
1259
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
1260
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
1261
+ "best" mask, by specifying `multimask_output=False`.
1262
+ attention_similarity (`torch.FloatTensor`, *optional*):
1263
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
1264
+ model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
1265
+ target_embedding (`torch.FloatTensor`, *optional*):
1266
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
1267
+ the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
1268
+
1269
+ Example:
1270
+
1271
+ ```python
1272
+ >>> from PIL import Image
1273
+ >>> import requests
1274
+ >>> from transformers import AutoModel, AutoProcessor
1275
+
1276
+ >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
1277
+ >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
1278
+
1279
+ >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
1280
+ >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
1281
+ >>> input_points = [[[400, 650]]] # 2D location of a window on the car
1282
+ >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
1283
+
1284
+ >>> # Get segmentation mask
1285
+ >>> outputs = model(**inputs)
1286
+
1287
+ >>> # Postprocess masks
1288
+ >>> masks = processor.post_process_masks(
1289
+ ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
1290
+ ... )
1291
+ ```
1292
+ """
1293
+ if pixel_values is None and image_embeddings is None:
1294
+ raise ValueError("Either pixel_values or image_embeddings must be provided.")
1295
+
1296
+ if pixel_values is not None and image_embeddings is not None:
1297
+ raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
1298
+
1299
+ if input_points is not None and len(input_points.shape) != 4:
1300
+ raise ValueError(
1301
+ "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
1302
+ f" got {input_points.shape}.",
1303
+ )
1304
+ if input_boxes is not None and len(input_boxes.shape) != 3:
1305
+ raise ValueError(
1306
+ "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
1307
+ f" got {input_boxes.shape}.",
1308
+ )
1309
+ if input_points is not None and input_boxes is not None:
1310
+ point_batch_size = input_points.shape[1]
1311
+ box_batch_size = input_boxes.shape[1]
1312
+ if point_batch_size != box_batch_size:
1313
+ raise ValueError(
1314
+ f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}."
1315
+ )
1316
+
1317
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
1318
+ # repeat with batch size
1319
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
1320
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
1321
+
1322
+ vision_attentions = None
1323
+ vision_hidden_states = None
1324
+
1325
+ if pixel_values is not None:
1326
+ vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
1327
+ image_embeddings = vision_outputs.last_hidden_state
1328
+ vision_hidden_states = vision_outputs.hidden_states
1329
+ vision_attentions = vision_outputs.attentions
1330
+
1331
+ if input_points is not None and input_labels is None:
1332
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
1333
+
1334
+ if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
1335
+ raise ValueError(
1336
+ "The batch size of the image embeddings and the input points must be the same. ",
1337
+ f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.",
1338
+ " if you want to pass multiple points for the same image, make sure that you passed ",
1339
+ " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
1340
+ " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
1341
+ )
1342
+
1343
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
1344
+ input_points=input_points,
1345
+ input_labels=input_labels,
1346
+ input_boxes=input_boxes,
1347
+ input_masks=input_masks,
1348
+ )
1349
+
1350
+ low_res_masks, iou_predictions = self.mask_decoder(
1351
+ image_embeddings=image_embeddings,
1352
+ image_positional_embeddings=image_positional_embeddings,
1353
+ sparse_prompt_embeddings=sparse_embeddings,
1354
+ dense_prompt_embeddings=dense_embeddings,
1355
+ multimask_output=multimask_output,
1356
+ attention_similarity=attention_similarity,
1357
+ target_embedding=target_embedding,
1358
+ )
1359
+
1360
+ return SamImageSegmentationOutput(
1361
+ iou_scores=iou_predictions,
1362
+ pred_masks=low_res_masks,
1363
+ vision_hidden_states=vision_hidden_states,
1364
+ vision_attentions=vision_attentions,
1365
+ )
1366
+
1367
+
1368
+ __all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]