import jax import jax.numpy as jnp import numpy as np from jax.tree_util import register_pytree_node_class from jax_ik.objectives import ObjectiveFunction @register_pytree_node_class class FastSphereCollisionPenaltyObjTraj(ObjectiveFunction): """Vectorized sphere collision penalty over bone segments. Weight stored as Python float (static aux) to avoid tracer-to-Python concretization when objective always present with varying weight. """ def __init__(self, sphere_collider: dict, weight: float = 1.0, min_clearance: float = 0.05, segment_radius: float = 0.02): self.center = jnp.asarray(sphere_collider["center"], jnp.float32) self.radius = jnp.asarray(sphere_collider["radius"], jnp.float32) self.min_clearance = jnp.asarray(min_clearance, jnp.float32) self.segment_radius = jnp.asarray(segment_radius, jnp.float32) self.weight = float(weight) # pytree impl ------------------------------------------------------------ def tree_flatten(self): # weight treated as static (aux) so changing it may retrace but avoids concretization errors return (self.center, self.radius, self.min_clearance, self.segment_radius), (self.weight,) @classmethod def tree_unflatten(cls, aux, leaves): (weight,) = aux c, r, mc, sr = leaves return cls(dict(center=c, radius=r), weight, mc, sr) # API -------------------------------------------------------------------- def update_params(self, p: dict) -> None: if "sphere_collider" in p: collider = p["sphere_collider"] if "center" in collider: self.center = jnp.asarray(collider["center"], jnp.float32) if "radius" in collider: self.radius = jnp.asarray(collider["radius"], jnp.float32) if "center" in p: self.center = jnp.asarray(p["center"], jnp.float32) if "radius" in p: self.radius = jnp.asarray(p["radius"], jnp.float32) if "min_clearance" in p: self.min_clearance = jnp.asarray(p["min_clearance"], jnp.float32) if "segment_radius" in p: self.segment_radius = jnp.asarray(p["segment_radius"], jnp.float32) if "weight" in p: self.weight = float(p["weight"]) def get_params(self) -> dict: return dict( sphere_collider=dict(center=np.asarray(self.center).tolist(), radius=float(self.radius)), min_clearance=float(self.min_clearance), segment_radius=float(self.segment_radius), weight=float(self.weight), ) # core ------------------------------------------------------------------- def _penalty_single(self, cfg, fk_solver) -> jnp.ndarray: fk = fk_solver.compute_fk_from_angles(cfg) # (N,4,4) heads = fk[:, :3, 3] # (N,3) parents = jnp.asarray(fk_solver.parent_list, jnp.int32) # (N,) seg_mask = (parents >= 0).astype(jnp.float32) # (N,) safe_parent_indices = jnp.where(parents >= 0, parents, 0) p_head = heads[safe_parent_indices] c_head = heads v = c_head - p_head dot_vv = jnp.sum(v * v, axis=1) + 1e-6 eff_rad = self.radius + self.min_clearance + self.segment_radius vc = self.center - p_head t = jnp.clip(jnp.sum(vc * v, axis=1) / dot_vv, 0.0, 1.0) closest = p_head + t[:, None] * v dist = jnp.linalg.norm(self.center - closest, axis=1) penetration = jnp.maximum(0.0, eff_rad - dist) return jnp.sum((penetration ** 2) * seg_mask) def __call__(self, X: jnp.ndarray, fk_solver) -> jnp.ndarray: if X.ndim == 1: loss = self._penalty_single(X, fk_solver) else: loss = jnp.mean(jax.vmap(lambda c: self._penalty_single(c, fk_solver))(X)) return loss * jnp.float32(self.weight)