Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from jax.tree_util import register_pytree_node_class | |
| from jax_ik.objectives import ObjectiveFunction | |
| class FastSphereCollisionPenaltyObjTraj(ObjectiveFunction): | |
| """Vectorized sphere collision penalty over bone segments. | |
| Weight stored as Python float (static aux) to avoid tracer-to-Python | |
| concretization when objective always present with varying weight. | |
| """ | |
| def __init__(self, sphere_collider: dict, weight: float = 1.0, min_clearance: float = 0.05, segment_radius: float = 0.02): | |
| self.center = jnp.asarray(sphere_collider["center"], jnp.float32) | |
| self.radius = jnp.asarray(sphere_collider["radius"], jnp.float32) | |
| self.min_clearance = jnp.asarray(min_clearance, jnp.float32) | |
| self.segment_radius = jnp.asarray(segment_radius, jnp.float32) | |
| self.weight = float(weight) | |
| # pytree impl ------------------------------------------------------------ | |
| def tree_flatten(self): | |
| # weight treated as static (aux) so changing it may retrace but avoids concretization errors | |
| return (self.center, self.radius, self.min_clearance, self.segment_radius), (self.weight,) | |
| def tree_unflatten(cls, aux, leaves): | |
| (weight,) = aux | |
| c, r, mc, sr = leaves | |
| return cls(dict(center=c, radius=r), weight, mc, sr) | |
| # API -------------------------------------------------------------------- | |
| def update_params(self, p: dict) -> None: | |
| if "sphere_collider" in p: | |
| collider = p["sphere_collider"] | |
| if "center" in collider: | |
| self.center = jnp.asarray(collider["center"], jnp.float32) | |
| if "radius" in collider: | |
| self.radius = jnp.asarray(collider["radius"], jnp.float32) | |
| if "center" in p: | |
| self.center = jnp.asarray(p["center"], jnp.float32) | |
| if "radius" in p: | |
| self.radius = jnp.asarray(p["radius"], jnp.float32) | |
| if "min_clearance" in p: | |
| self.min_clearance = jnp.asarray(p["min_clearance"], jnp.float32) | |
| if "segment_radius" in p: | |
| self.segment_radius = jnp.asarray(p["segment_radius"], jnp.float32) | |
| if "weight" in p: | |
| self.weight = float(p["weight"]) | |
| def get_params(self) -> dict: | |
| return dict( | |
| sphere_collider=dict(center=np.asarray(self.center).tolist(), radius=float(self.radius)), | |
| min_clearance=float(self.min_clearance), | |
| segment_radius=float(self.segment_radius), | |
| weight=float(self.weight), | |
| ) | |
| # core ------------------------------------------------------------------- | |
| def _penalty_single(self, cfg, fk_solver) -> jnp.ndarray: | |
| fk = fk_solver.compute_fk_from_angles(cfg) # (N,4,4) | |
| heads = fk[:, :3, 3] # (N,3) | |
| parents = jnp.asarray(fk_solver.parent_list, jnp.int32) # (N,) | |
| seg_mask = (parents >= 0).astype(jnp.float32) # (N,) | |
| safe_parent_indices = jnp.where(parents >= 0, parents, 0) | |
| p_head = heads[safe_parent_indices] | |
| c_head = heads | |
| v = c_head - p_head | |
| dot_vv = jnp.sum(v * v, axis=1) + 1e-6 | |
| eff_rad = self.radius + self.min_clearance + self.segment_radius | |
| vc = self.center - p_head | |
| t = jnp.clip(jnp.sum(vc * v, axis=1) / dot_vv, 0.0, 1.0) | |
| closest = p_head + t[:, None] * v | |
| dist = jnp.linalg.norm(self.center - closest, axis=1) | |
| penetration = jnp.maximum(0.0, eff_rad - dist) | |
| return jnp.sum((penetration ** 2) * seg_mask) | |
| def __call__(self, X: jnp.ndarray, fk_solver) -> jnp.ndarray: | |
| if X.ndim == 1: | |
| loss = self._penalty_single(X, fk_solver) | |
| else: | |
| loss = jnp.mean(jax.vmap(lambda c: self._penalty_single(c, fk_solver))(X)) | |
| return loss * jnp.float32(self.weight) | |