Spaces:
Sleeping
Sleeping
| # Lint as: python3 | |
| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Question answering task.""" | |
| import logging | |
| import dataclasses | |
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| from official.core import base_task | |
| from official.modeling.hyperparams import config_definitions as cfg | |
| from official.nlp.bert import input_pipeline | |
| from official.nlp.configs import encoders | |
| from official.nlp.modeling import models | |
| from official.nlp.tasks import utils | |
| class QuestionAnsweringConfig(cfg.TaskConfig): | |
| """The model config.""" | |
| # At most one of `init_checkpoint` and `hub_module_url` can be specified. | |
| init_checkpoint: str = '' | |
| hub_module_url: str = '' | |
| network: encoders.TransformerEncoderConfig = ( | |
| encoders.TransformerEncoderConfig()) | |
| train_data: cfg.DataConfig = cfg.DataConfig() | |
| validation_data: cfg.DataConfig = cfg.DataConfig() | |
| class QuestionAnsweringTask(base_task.Task): | |
| """Task object for question answering. | |
| TODO(lehou): Add post-processing. | |
| """ | |
| def __init__(self, params=cfg.TaskConfig): | |
| super(QuestionAnsweringTask, self).__init__(params) | |
| if params.hub_module_url and params.init_checkpoint: | |
| raise ValueError('At most one of `hub_module_url` and ' | |
| '`init_checkpoint` can be specified.') | |
| if params.hub_module_url: | |
| self._hub_module = hub.load(params.hub_module_url) | |
| else: | |
| self._hub_module = None | |
| def build_model(self): | |
| if self._hub_module: | |
| encoder_network = utils.get_encoder_from_hub(self._hub_module) | |
| else: | |
| encoder_network = encoders.instantiate_encoder_from_cfg( | |
| self.task_config.network) | |
| return models.BertSpanLabeler( | |
| network=encoder_network, | |
| initializer=tf.keras.initializers.TruncatedNormal( | |
| stddev=self.task_config.network.initializer_range)) | |
| def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: | |
| start_positions = labels['start_positions'] | |
| end_positions = labels['end_positions'] | |
| start_logits, end_logits = model_outputs | |
| start_loss = tf.keras.losses.sparse_categorical_crossentropy( | |
| start_positions, | |
| tf.cast(start_logits, dtype=tf.float32), | |
| from_logits=True) | |
| end_loss = tf.keras.losses.sparse_categorical_crossentropy( | |
| end_positions, | |
| tf.cast(end_logits, dtype=tf.float32), | |
| from_logits=True) | |
| loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2 | |
| return loss | |
| def build_inputs(self, params, input_context=None): | |
| """Returns tf.data.Dataset for sentence_prediction task.""" | |
| if params.input_path == 'dummy': | |
| def dummy_data(_): | |
| dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32) | |
| x = dict( | |
| input_word_ids=dummy_ids, | |
| input_mask=dummy_ids, | |
| input_type_ids=dummy_ids) | |
| y = dict( | |
| start_positions=tf.constant(0, dtype=tf.int32), | |
| end_positions=tf.constant(1, dtype=tf.int32)) | |
| return (x, y) | |
| dataset = tf.data.Dataset.range(1) | |
| dataset = dataset.repeat() | |
| dataset = dataset.map( | |
| dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
| return dataset | |
| batch_size = input_context.get_per_replica_batch_size( | |
| params.global_batch_size) if input_context else params.global_batch_size | |
| # TODO(chendouble): add and use nlp.data.question_answering_dataloader. | |
| dataset = input_pipeline.create_squad_dataset( | |
| params.input_path, | |
| params.seq_length, | |
| batch_size, | |
| is_training=params.is_training, | |
| input_pipeline_context=input_context) | |
| return dataset | |
| def build_metrics(self, training=None): | |
| del training | |
| # TODO(lehou): a list of metrics doesn't work the same as in compile/fit. | |
| metrics = [ | |
| tf.keras.metrics.SparseCategoricalAccuracy( | |
| name='start_position_accuracy'), | |
| tf.keras.metrics.SparseCategoricalAccuracy( | |
| name='end_position_accuracy'), | |
| ] | |
| return metrics | |
| def process_metrics(self, metrics, labels, model_outputs): | |
| metrics = dict([(metric.name, metric) for metric in metrics]) | |
| start_logits, end_logits = model_outputs | |
| metrics['start_position_accuracy'].update_state( | |
| labels['start_positions'], start_logits) | |
| metrics['end_position_accuracy'].update_state( | |
| labels['end_positions'], end_logits) | |
| def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): | |
| start_logits, end_logits = model_outputs | |
| compiled_metrics.update_state( | |
| y_true=labels, # labels has keys 'start_positions' and 'end_positions'. | |
| y_pred={'start_positions': start_logits, 'end_positions': end_logits}) | |
| def initialize(self, model): | |
| """Load a pretrained checkpoint (if exists) and then train from iter 0.""" | |
| ckpt_dir_or_file = self.task_config.init_checkpoint | |
| if tf.io.gfile.isdir(ckpt_dir_or_file): | |
| ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) | |
| if not ckpt_dir_or_file: | |
| return | |
| ckpt = tf.train.Checkpoint(**model.checkpoint_items) | |
| status = ckpt.restore(ckpt_dir_or_file) | |
| status.expect_partial().assert_existing_objects_matched() | |
| logging.info('finished loading pretrained checkpoint from %s', | |
| ckpt_dir_or_file) | |