Spaces:
Paused
Paused
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from torch.optim.lr_scheduler import LambdaLR | |
| def warmup_then_decay(optimizer, total_steps, warmup_steps, max_lr=1e-3, min_lr=1e-5, base_lr=1e-5): | |
| """ | |
| Create a learning rate scheduler with warmup followed by decay. | |
| """ | |
| def lr_lambda(current_step): | |
| if current_step < warmup_steps: | |
| # warmup: min_lr -> max_lr | |
| progress = float(current_step) / float(max(1, warmup_steps)) | |
| # LR(t) = min_lr + (max_lr - min_lr)*progress | |
| return (min_lr + (max_lr - min_lr)*progress) / base_lr | |
| else: | |
| # decay: warmup_steps -> total_steps | |
| progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) | |
| # LR(t) = max_lr + (min_lr - max_lr)*progress | |
| return (max_lr + (min_lr - max_lr)*progress) / base_lr | |
| scheduler = LambdaLR(optimizer, lr_lambda) | |
| return scheduler |