File size: 3,949 Bytes
ec9a6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import numpy as np
import pickle
from scipy.io import loadmat
from pytorch3d.transforms import so3_exponential_map


class FVMModule(nn.Module):
    def __init__(self, batch_size):
        super(FVMModule, self).__init__()

        self.id_dims = 150
        self.exp_dims = 52

        # a = loadmat('assets/BFM/BFM09_model_info.mat')
        model_dict = np.load('assets/FVM/faceverse_simple_v2.npy', allow_pickle=True).item()
        self.register_buffer('skinmask', torch.tensor(model_dict['skinmask']))
        kp_inds = torch.tensor(model_dict['keypoints']).squeeze().long()
        #kp_inds = torch.cat([kp_inds[0:48], kp_inds[49:54], kp_inds[55:68]])
        self.register_buffer('kp_inds', kp_inds)

        meanshape = torch.tensor(model_dict['meanshape'])
        meanshape[:, 1:] = -meanshape[:, 1:]
        self.register_buffer('meanshape', meanshape.view(1, -1).float())

        idBase = torch.tensor(model_dict['idBase']).view(-1, 3, self.id_dims).float()
        idBase[:, 1:, :] = -idBase[:, 1:, :]
        self.register_buffer('idBase', idBase.view(-1, self.id_dims))

        exBase = torch.tensor(model_dict['exBase']).view(-1, 3, self.exp_dims).float()
        exBase[:, 1:, :] = -exBase[:, 1:, :]
        self.register_buffer('exBase', exBase.view(-1, self.exp_dims))

        self.register_buffer('faces', torch.tensor(model_dict['tri']).long())

        self.batch_size = batch_size
        self.id_coeff = nn.Parameter(torch.zeros(1, self.id_dims).float())
        self.exp_coeff = nn.Parameter(torch.zeros(self.batch_size, self.exp_dims).float())
        self.scale = nn.Parameter(torch.ones(1).float() * 0.3)
        self.pose = nn.Parameter(torch.zeros(self.batch_size, 6).float())

    def set_id_param(self, id_coeff, scale):
        self.id_coeff.data = id_coeff
        self.scale.data = scale
        self.id_coeff.requires_grad = False
        self.scale.requires_grad = False

    def get_lms(self, vs):
        lms = vs[:, self.kp_inds, :]
        return lms

    def get_vs(self, id_coeff, exp_coeff):
        n_b = id_coeff.size(0)

        face_shape = torch.einsum('ij,aj->ai', self.idBase, id_coeff) + \
            torch.einsum('ij,aj->ai', self.exBase, exp_coeff) + self.meanshape

        face_shape = face_shape.view(n_b, -1, 3)
        face_shape = face_shape - \
            self.meanshape.view(1, -1, 3).mean(dim=1, keepdim=True)

        return face_shape

    def forward(self):
        id_coeff = self.id_coeff.repeat(self.batch_size, 1)
        vertices = self.get_vs(id_coeff, self.exp_coeff)
        R = so3_exponential_map(self.pose[:, :3])
        T = self.pose[:, 3:]
        vertices = torch.bmm(vertices * self.scale, R.permute(0,2,1)) + T[:, None, :]
        landmarks = self.get_lms(vertices)
        return vertices, landmarks

    def reg_loss(self, id_weight, exp_weight):
        id_reg_loss = (self.id_coeff ** 2).sum()
        exp_reg_loss = (self.exp_coeff ** 2).sum(-1).mean()
        return id_reg_loss * id_weight + exp_reg_loss * exp_weight

    def temporal_smooth_loss(self, smo_weight):
        return ((self.exp_coeff[1:] - self.exp_coeff[:-1]) ** 2).sum(-1).mean() * smo_weight

    def save(self, path, batch_id=-1):
        if batch_id < 0:
            id_coeff = self.id_coeff.detach().cpu().numpy()
            exp_coeff = self.exp_coeff.detach().cpu().numpy()
            scale = self.scale.detach().cpu().numpy()
            pose = self.pose.detach().cpu().numpy()
            np.savez(path, id_coeff=id_coeff, exp_coeff=exp_coeff, scale=scale, pose=pose)
        else:
            id_coeff = self.id_coeff.detach().cpu().numpy()
            exp_coeff = self.exp_coeff[batch_id:batch_id+1].detach().cpu().numpy()
            scale = self.scale.detach().cpu().numpy()
            pose = self.pose[batch_id:batch_id+1].detach().cpu().numpy()
            np.savez(path, id_coeff=id_coeff, exp_coeff=exp_coeff, scale=scale, pose=pose)