Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| This file defines a mixin class for sparse transformers that enables elastic memory management. | |
| It provides functionality to dynamically adjust memory usage by controlling gradient checkpointing | |
| across transformer blocks, allowing for trading computation for memory efficiency. | |
| """ | |
| from contextlib import contextmanager | |
| from typing import * | |
| import math | |
| from ..modules import sparse as sp | |
| from ..utils.elastic_utils import ElasticModuleMixin | |
| class SparseTransformerElasticMixin(ElasticModuleMixin): | |
| """ | |
| A mixin class for sparse transformers that provides elastic memory management capabilities. | |
| Extends the base ElasticModuleMixin with sparse tensor-specific functionality. | |
| """ | |
| def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): | |
| """ | |
| Determines the input size from a sparse tensor. | |
| Args: | |
| x: A SparseTensor input | |
| *args, **kwargs: Additional arguments (unused) | |
| Returns: | |
| The size of the feature dimension of the sparse tensor | |
| """ | |
| return x.feats.shape[0] | |
| def with_mem_ratio(self, mem_ratio=1.0): | |
| """ | |
| Context manager that temporarily adjusts memory usage by enabling gradient checkpointing | |
| for a portion of the transformer blocks based on the specified memory ratio. | |
| Args: | |
| mem_ratio: A value between 0 and 1 indicating the desired memory ratio. | |
| 1.0 means use all available memory (no checkpointing). | |
| Lower values enable more checkpointing to reduce memory usage. | |
| Yields: | |
| The exact memory ratio that could be achieved with the block granularity. | |
| """ | |
| if mem_ratio == 1.0: | |
| # No memory optimization needed if ratio is 1.0 | |
| yield 1.0 | |
| return | |
| # Calculate how many blocks should use checkpointing | |
| num_blocks = len(self.blocks) | |
| num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks) | |
| # Calculate the actual memory ratio based on the number of checkpointed blocks | |
| exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks | |
| # Enable checkpointing for the calculated number of blocks | |
| for i in range(num_blocks): | |
| self.blocks[i].use_checkpoint = i < num_checkpoint_blocks | |
| yield exact_mem_ratio | |
| # Restore all blocks to not use checkpointing after context exit | |
| for i in range(num_blocks): | |
| self.blocks[i].use_checkpoint = False | |