JAX-IK / collision_objectives.py
hvoss-techfak's picture
Initial commit
3f7d1b3
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax_ik.objectives import ObjectiveFunction
@register_pytree_node_class
class FastSphereCollisionPenaltyObjTraj(ObjectiveFunction):
"""Vectorized sphere collision penalty over bone segments.
Weight stored as Python float (static aux) to avoid tracer-to-Python
concretization when objective always present with varying weight.
"""
def __init__(self, sphere_collider: dict, weight: float = 1.0, min_clearance: float = 0.05, segment_radius: float = 0.02):
self.center = jnp.asarray(sphere_collider["center"], jnp.float32)
self.radius = jnp.asarray(sphere_collider["radius"], jnp.float32)
self.min_clearance = jnp.asarray(min_clearance, jnp.float32)
self.segment_radius = jnp.asarray(segment_radius, jnp.float32)
self.weight = float(weight)
# pytree impl ------------------------------------------------------------
def tree_flatten(self):
# weight treated as static (aux) so changing it may retrace but avoids concretization errors
return (self.center, self.radius, self.min_clearance, self.segment_radius), (self.weight,)
@classmethod
def tree_unflatten(cls, aux, leaves):
(weight,) = aux
c, r, mc, sr = leaves
return cls(dict(center=c, radius=r), weight, mc, sr)
# API --------------------------------------------------------------------
def update_params(self, p: dict) -> None:
if "sphere_collider" in p:
collider = p["sphere_collider"]
if "center" in collider:
self.center = jnp.asarray(collider["center"], jnp.float32)
if "radius" in collider:
self.radius = jnp.asarray(collider["radius"], jnp.float32)
if "center" in p:
self.center = jnp.asarray(p["center"], jnp.float32)
if "radius" in p:
self.radius = jnp.asarray(p["radius"], jnp.float32)
if "min_clearance" in p:
self.min_clearance = jnp.asarray(p["min_clearance"], jnp.float32)
if "segment_radius" in p:
self.segment_radius = jnp.asarray(p["segment_radius"], jnp.float32)
if "weight" in p:
self.weight = float(p["weight"])
def get_params(self) -> dict:
return dict(
sphere_collider=dict(center=np.asarray(self.center).tolist(), radius=float(self.radius)),
min_clearance=float(self.min_clearance),
segment_radius=float(self.segment_radius),
weight=float(self.weight),
)
# core -------------------------------------------------------------------
def _penalty_single(self, cfg, fk_solver) -> jnp.ndarray:
fk = fk_solver.compute_fk_from_angles(cfg) # (N,4,4)
heads = fk[:, :3, 3] # (N,3)
parents = jnp.asarray(fk_solver.parent_list, jnp.int32) # (N,)
seg_mask = (parents >= 0).astype(jnp.float32) # (N,)
safe_parent_indices = jnp.where(parents >= 0, parents, 0)
p_head = heads[safe_parent_indices]
c_head = heads
v = c_head - p_head
dot_vv = jnp.sum(v * v, axis=1) + 1e-6
eff_rad = self.radius + self.min_clearance + self.segment_radius
vc = self.center - p_head
t = jnp.clip(jnp.sum(vc * v, axis=1) / dot_vv, 0.0, 1.0)
closest = p_head + t[:, None] * v
dist = jnp.linalg.norm(self.center - closest, axis=1)
penetration = jnp.maximum(0.0, eff_rad - dist)
return jnp.sum((penetration ** 2) * seg_mask)
def __call__(self, X: jnp.ndarray, fk_solver) -> jnp.ndarray:
if X.ndim == 1:
loss = self._penalty_single(X, fk_solver)
else:
loss = jnp.mean(jax.vmap(lambda c: self._penalty_single(c, fk_solver))(X))
return loss * jnp.float32(self.weight)