| from torch import nn | |
| from transformers.modeling_utils import PreTrainedModel | |
| from .configuration_my_model import MyModelConfig | |
| class MyModelPretrainedModel(PreTrainedModel): | |
| config_class = MyModelConfig | |
| class MyModel(MyModelPretrainedModel): | |
| def __init__(self, config: MyModelConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.n_layers = config.n_layers | |
| self.hidden_dim = config.hidden_dim | |
| self.linear = nn.Linear(config.hidden_dim, config.hidden_dim) | |