File size: 3,907 Bytes
3f7d1b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax_ik.objectives import ObjectiveFunction

@register_pytree_node_class
class FastSphereCollisionPenaltyObjTraj(ObjectiveFunction):
    """Vectorized sphere collision penalty over bone segments.

    Weight stored as Python float (static aux) to avoid tracer-to-Python
    concretization when objective always present with varying weight.
    """
    def __init__(self, sphere_collider: dict, weight: float = 1.0, min_clearance: float = 0.05, segment_radius: float = 0.02):
        self.center = jnp.asarray(sphere_collider["center"], jnp.float32)
        self.radius = jnp.asarray(sphere_collider["radius"], jnp.float32)
        self.min_clearance = jnp.asarray(min_clearance, jnp.float32)
        self.segment_radius = jnp.asarray(segment_radius, jnp.float32)
        self.weight = float(weight)

    # pytree impl ------------------------------------------------------------
    def tree_flatten(self):
        # weight treated as static (aux) so changing it may retrace but avoids concretization errors
        return (self.center, self.radius, self.min_clearance, self.segment_radius), (self.weight,)

    @classmethod
    def tree_unflatten(cls, aux, leaves):
        (weight,) = aux
        c, r, mc, sr = leaves
        return cls(dict(center=c, radius=r), weight, mc, sr)

    # API --------------------------------------------------------------------
    def update_params(self, p: dict) -> None:
        if "sphere_collider" in p:
            collider = p["sphere_collider"]
            if "center" in collider:
                self.center = jnp.asarray(collider["center"], jnp.float32)
            if "radius" in collider:
                self.radius = jnp.asarray(collider["radius"], jnp.float32)
        if "center" in p:
            self.center = jnp.asarray(p["center"], jnp.float32)
        if "radius" in p:
            self.radius = jnp.asarray(p["radius"], jnp.float32)
        if "min_clearance" in p:
            self.min_clearance = jnp.asarray(p["min_clearance"], jnp.float32)
        if "segment_radius" in p:
            self.segment_radius = jnp.asarray(p["segment_radius"], jnp.float32)
        if "weight" in p:
            self.weight = float(p["weight"])

    def get_params(self) -> dict:
        return dict(
            sphere_collider=dict(center=np.asarray(self.center).tolist(), radius=float(self.radius)),
            min_clearance=float(self.min_clearance),
            segment_radius=float(self.segment_radius),
            weight=float(self.weight),
        )

    # core -------------------------------------------------------------------
    def _penalty_single(self, cfg, fk_solver) -> jnp.ndarray:
        fk = fk_solver.compute_fk_from_angles(cfg)  # (N,4,4)
        heads = fk[:, :3, 3]                        # (N,3)
        parents = jnp.asarray(fk_solver.parent_list, jnp.int32)  # (N,)
        seg_mask = (parents >= 0).astype(jnp.float32)            # (N,)
        safe_parent_indices = jnp.where(parents >= 0, parents, 0)
        p_head = heads[safe_parent_indices]
        c_head = heads
        v = c_head - p_head
        dot_vv = jnp.sum(v * v, axis=1) + 1e-6
        eff_rad = self.radius + self.min_clearance + self.segment_radius
        vc = self.center - p_head
        t = jnp.clip(jnp.sum(vc * v, axis=1) / dot_vv, 0.0, 1.0)
        closest = p_head + t[:, None] * v
        dist = jnp.linalg.norm(self.center - closest, axis=1)
        penetration = jnp.maximum(0.0, eff_rad - dist)
        return jnp.sum((penetration ** 2) * seg_mask)

    def __call__(self, X: jnp.ndarray, fk_solver) -> jnp.ndarray:
        if X.ndim == 1:
            loss = self._penalty_single(X, fk_solver)
        else:
            loss = jnp.mean(jax.vmap(lambda c: self._penalty_single(c, fk_solver))(X))
        return loss * jnp.float32(self.weight)