Spaces:
Sleeping
Sleeping
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A function to build localization and classification losses from config.""" | |
| import functools | |
| from object_detection.core import balanced_positive_negative_sampler as sampler | |
| from object_detection.core import losses | |
| from object_detection.protos import losses_pb2 | |
| from object_detection.utils import ops | |
| def build(loss_config): | |
| """Build losses based on the config. | |
| Builds classification, localization losses and optionally a hard example miner | |
| based on the config. | |
| Args: | |
| loss_config: A losses_pb2.Loss object. | |
| Returns: | |
| classification_loss: Classification loss object. | |
| localization_loss: Localization loss object. | |
| classification_weight: Classification loss weight. | |
| localization_weight: Localization loss weight. | |
| hard_example_miner: Hard example miner object. | |
| random_example_sampler: BalancedPositiveNegativeSampler object. | |
| Raises: | |
| ValueError: If hard_example_miner is used with sigmoid_focal_loss. | |
| ValueError: If random_example_sampler is getting non-positive value as | |
| desired positive example fraction. | |
| """ | |
| classification_loss = _build_classification_loss( | |
| loss_config.classification_loss) | |
| localization_loss = _build_localization_loss( | |
| loss_config.localization_loss) | |
| classification_weight = loss_config.classification_weight | |
| localization_weight = loss_config.localization_weight | |
| hard_example_miner = None | |
| if loss_config.HasField('hard_example_miner'): | |
| if (loss_config.classification_loss.WhichOneof('classification_loss') == | |
| 'weighted_sigmoid_focal'): | |
| raise ValueError('HardExampleMiner should not be used with sigmoid focal ' | |
| 'loss') | |
| hard_example_miner = build_hard_example_miner( | |
| loss_config.hard_example_miner, | |
| classification_weight, | |
| localization_weight) | |
| random_example_sampler = None | |
| if loss_config.HasField('random_example_sampler'): | |
| if loss_config.random_example_sampler.positive_sample_fraction <= 0: | |
| raise ValueError('RandomExampleSampler should not use non-positive' | |
| 'value as positive sample fraction.') | |
| random_example_sampler = sampler.BalancedPositiveNegativeSampler( | |
| positive_fraction=loss_config.random_example_sampler. | |
| positive_sample_fraction) | |
| if loss_config.expected_loss_weights == loss_config.NONE: | |
| expected_loss_weights_fn = None | |
| elif loss_config.expected_loss_weights == loss_config.EXPECTED_SAMPLING: | |
| expected_loss_weights_fn = functools.partial( | |
| ops.expected_classification_loss_by_expected_sampling, | |
| min_num_negative_samples=loss_config.min_num_negative_samples, | |
| desired_negative_sampling_ratio=loss_config | |
| .desired_negative_sampling_ratio) | |
| elif (loss_config.expected_loss_weights == loss_config | |
| .REWEIGHTING_UNMATCHED_ANCHORS): | |
| expected_loss_weights_fn = functools.partial( | |
| ops.expected_classification_loss_by_reweighting_unmatched_anchors, | |
| min_num_negative_samples=loss_config.min_num_negative_samples, | |
| desired_negative_sampling_ratio=loss_config | |
| .desired_negative_sampling_ratio) | |
| else: | |
| raise ValueError('Not a valid value for expected_classification_loss.') | |
| return (classification_loss, localization_loss, classification_weight, | |
| localization_weight, hard_example_miner, random_example_sampler, | |
| expected_loss_weights_fn) | |
| def build_hard_example_miner(config, | |
| classification_weight, | |
| localization_weight): | |
| """Builds hard example miner based on the config. | |
| Args: | |
| config: A losses_pb2.HardExampleMiner object. | |
| classification_weight: Classification loss weight. | |
| localization_weight: Localization loss weight. | |
| Returns: | |
| Hard example miner. | |
| """ | |
| loss_type = None | |
| if config.loss_type == losses_pb2.HardExampleMiner.BOTH: | |
| loss_type = 'both' | |
| if config.loss_type == losses_pb2.HardExampleMiner.CLASSIFICATION: | |
| loss_type = 'cls' | |
| if config.loss_type == losses_pb2.HardExampleMiner.LOCALIZATION: | |
| loss_type = 'loc' | |
| max_negatives_per_positive = None | |
| num_hard_examples = None | |
| if config.max_negatives_per_positive > 0: | |
| max_negatives_per_positive = config.max_negatives_per_positive | |
| if config.num_hard_examples > 0: | |
| num_hard_examples = config.num_hard_examples | |
| hard_example_miner = losses.HardExampleMiner( | |
| num_hard_examples=num_hard_examples, | |
| iou_threshold=config.iou_threshold, | |
| loss_type=loss_type, | |
| cls_loss_weight=classification_weight, | |
| loc_loss_weight=localization_weight, | |
| max_negatives_per_positive=max_negatives_per_positive, | |
| min_negatives_per_image=config.min_negatives_per_image) | |
| return hard_example_miner | |
| def build_faster_rcnn_classification_loss(loss_config): | |
| """Builds a classification loss for Faster RCNN based on the loss config. | |
| Args: | |
| loss_config: A losses_pb2.ClassificationLoss object. | |
| Returns: | |
| Loss based on the config. | |
| Raises: | |
| ValueError: On invalid loss_config. | |
| """ | |
| if not isinstance(loss_config, losses_pb2.ClassificationLoss): | |
| raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.') | |
| loss_type = loss_config.WhichOneof('classification_loss') | |
| if loss_type == 'weighted_sigmoid': | |
| return losses.WeightedSigmoidClassificationLoss() | |
| if loss_type == 'weighted_softmax': | |
| config = loss_config.weighted_softmax | |
| return losses.WeightedSoftmaxClassificationLoss( | |
| logit_scale=config.logit_scale) | |
| if loss_type == 'weighted_logits_softmax': | |
| config = loss_config.weighted_logits_softmax | |
| return losses.WeightedSoftmaxClassificationAgainstLogitsLoss( | |
| logit_scale=config.logit_scale) | |
| if loss_type == 'weighted_sigmoid_focal': | |
| config = loss_config.weighted_sigmoid_focal | |
| alpha = None | |
| if config.HasField('alpha'): | |
| alpha = config.alpha | |
| return losses.SigmoidFocalClassificationLoss( | |
| gamma=config.gamma, | |
| alpha=alpha) | |
| # By default, Faster RCNN second stage classifier uses Softmax loss | |
| # with anchor-wise outputs. | |
| config = loss_config.weighted_softmax | |
| return losses.WeightedSoftmaxClassificationLoss( | |
| logit_scale=config.logit_scale) | |
| def _build_localization_loss(loss_config): | |
| """Builds a localization loss based on the loss config. | |
| Args: | |
| loss_config: A losses_pb2.LocalizationLoss object. | |
| Returns: | |
| Loss based on the config. | |
| Raises: | |
| ValueError: On invalid loss_config. | |
| """ | |
| if not isinstance(loss_config, losses_pb2.LocalizationLoss): | |
| raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.') | |
| loss_type = loss_config.WhichOneof('localization_loss') | |
| if loss_type == 'weighted_l2': | |
| return losses.WeightedL2LocalizationLoss() | |
| if loss_type == 'weighted_smooth_l1': | |
| return losses.WeightedSmoothL1LocalizationLoss( | |
| loss_config.weighted_smooth_l1.delta) | |
| if loss_type == 'weighted_iou': | |
| return losses.WeightedIOULocalizationLoss() | |
| if loss_type == 'l1_localization_loss': | |
| return losses.L1LocalizationLoss() | |
| raise ValueError('Empty loss config.') | |
| def _build_classification_loss(loss_config): | |
| """Builds a classification loss based on the loss config. | |
| Args: | |
| loss_config: A losses_pb2.ClassificationLoss object. | |
| Returns: | |
| Loss based on the config. | |
| Raises: | |
| ValueError: On invalid loss_config. | |
| """ | |
| if not isinstance(loss_config, losses_pb2.ClassificationLoss): | |
| raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.') | |
| loss_type = loss_config.WhichOneof('classification_loss') | |
| if loss_type == 'weighted_sigmoid': | |
| return losses.WeightedSigmoidClassificationLoss() | |
| if loss_type == 'weighted_sigmoid_focal': | |
| config = loss_config.weighted_sigmoid_focal | |
| alpha = None | |
| if config.HasField('alpha'): | |
| alpha = config.alpha | |
| return losses.SigmoidFocalClassificationLoss( | |
| gamma=config.gamma, | |
| alpha=alpha) | |
| if loss_type == 'weighted_softmax': | |
| config = loss_config.weighted_softmax | |
| return losses.WeightedSoftmaxClassificationLoss( | |
| logit_scale=config.logit_scale) | |
| if loss_type == 'weighted_logits_softmax': | |
| config = loss_config.weighted_logits_softmax | |
| return losses.WeightedSoftmaxClassificationAgainstLogitsLoss( | |
| logit_scale=config.logit_scale) | |
| if loss_type == 'bootstrapped_sigmoid': | |
| config = loss_config.bootstrapped_sigmoid | |
| return losses.BootstrappedSigmoidClassificationLoss( | |
| alpha=config.alpha, | |
| bootstrap_type=('hard' if config.hard_bootstrap else 'soft')) | |
| if loss_type == 'penalty_reduced_logistic_focal_loss': | |
| config = loss_config.penalty_reduced_logistic_focal_loss | |
| return losses.PenaltyReducedLogisticFocalLoss( | |
| alpha=config.alpha, beta=config.beta) | |
| raise ValueError('Empty loss config.') | |