Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MLP(nn.Module): | |
| def __init__(self, dims, last_op=None): | |
| super(MLP, self).__init__() | |
| self.dims = dims | |
| self.skip_layer = [int(len(dims) / 2)] | |
| self.last_op = last_op | |
| self.layers = [] | |
| for l in range(0, len(dims) - 1): | |
| if l in self.skip_layer: | |
| self.layers.append(nn.Conv1d(dims[l] + dims[0], dims[l + 1], 1)) | |
| else: | |
| self.layers.append(nn.Conv1d(dims[l], dims[l + 1], 1)) | |
| self.add_module("conv%d" % l, self.layers[l]) | |
| def forward(self, latet_code, return_all=False): | |
| y = latet_code | |
| tmpy = latet_code | |
| y_list = [] | |
| for l, f in enumerate(self.layers): | |
| if l in self.skip_layer: | |
| y = self._modules['conv' + str(l)](torch.cat([y, tmpy], 1)) | |
| else: | |
| y = self._modules['conv' + str(l)](y) | |
| if l != len(self.layers) - 1: | |
| y = F.leaky_relu(y) | |
| if self.last_op: | |
| y = self.last_op(y) | |
| y_list.append(y) | |
| if return_all: | |
| return y_list | |
| else: | |
| return y |