Alina Lozovskaya
commited on
Commit
·
4441051
1
Parent(s):
e7439e8
Removed threading use asyncio
Browse files- .gitignore +3 -0
- pyproject.toml +11 -1
- src/reachy_mini_conversation_demo/__init__.py +0 -0
- src/reachy_mini_conversation_demo/audio.py +179 -0
- src/reachy_mini_conversation_demo/gstreamer.py +226 -0
- src/reachy_mini_conversation_demo/head_tracker.py +245 -0
- src/reachy_mini_conversation_demo/main.py +560 -3
- src/reachy_mini_conversation_demo/movement.py +150 -0
- src/reachy_mini_conversation_demo/prompts.py +50 -0
- src/reachy_mini_conversation_demo/speech_tapper.py +292 -0
- src/reachy_mini_conversation_demo/test_stop.py +33 -0
- src/reachy_mini_conversation_demo/tools.py +322 -0
- src/reachy_mini_conversation_demo/vision.py +302 -0
.gitignore
CHANGED
|
@@ -1,2 +1,5 @@
|
|
| 1 |
__pycache__/
|
| 2 |
*.egg-info
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
__pycache__/
|
| 2 |
*.egg-info
|
| 3 |
+
.venv/
|
| 4 |
+
.env
|
| 5 |
+
cache/
|
pyproject.toml
CHANGED
|
@@ -10,7 +10,17 @@ description = ""
|
|
| 10 |
readme = "README.md"
|
| 11 |
requires-python = ">=3.8"
|
| 12 |
dependencies = [
|
| 13 |
-
"reachy_mini@git+ssh://git@github.com/pollen-robotics/reachy_mini@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
]
|
| 15 |
|
| 16 |
|
|
|
|
| 10 |
readme = "README.md"
|
| 11 |
requires-python = ">=3.8"
|
| 12 |
dependencies = [
|
| 13 |
+
"reachy_mini@git+ssh://git@github.com/pollen-robotics/reachy_mini@reachy_talk",
|
| 14 |
+
"openai",
|
| 15 |
+
"fastrtc",
|
| 16 |
+
"onnxruntime",
|
| 17 |
+
"PyGObject>=3.42.2, <=3.46.0",
|
| 18 |
+
"torch",
|
| 19 |
+
"transformers",
|
| 20 |
+
"num2words",
|
| 21 |
+
"dotenv",
|
| 22 |
+
"ultralytics",
|
| 23 |
+
"supervision",
|
| 24 |
]
|
| 25 |
|
| 26 |
|
src/reachy_mini_conversation_demo/__init__.py
ADDED
|
File without changes
|
src/reachy_mini_conversation_demo/audio.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import base64
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Callable, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from reachy_mini_conversation_demo.speech_tapper import SwayRollRT, HOP_MS
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class AudioConfig:
|
| 14 |
+
output_sample_rate: int = 24_000
|
| 15 |
+
movement_latency_s: float = 0.08
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def pcm_to_b64(array: np.ndarray) -> str:
|
| 19 |
+
"""array: shape (N,) int16 or (1,N) int16 -> base64 string for OpenAI input buffer."""
|
| 20 |
+
a = np.asarray(array).reshape(-1).astype(np.int16, copy=False)
|
| 21 |
+
return base64.b64encode(a.tobytes()).decode("utf-8")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AudioSync:
|
| 25 |
+
"""
|
| 26 |
+
Routes assistant audio to:
|
| 27 |
+
1) a playback queue for fastrtc
|
| 28 |
+
2) a sway engine that emits head-offsets aligned to audio
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
cfg: AudioConfig,
|
| 34 |
+
set_offsets: Callable[[Tuple[float, float, float, float, float, float]], None],
|
| 35 |
+
sway: Optional[SwayRollRT] = None,
|
| 36 |
+
) -> None:
|
| 37 |
+
"""
|
| 38 |
+
set_offsets: callback receiving (x,y,z,roll,pitch,yaw) at each hop, in meters/radians.
|
| 39 |
+
"""
|
| 40 |
+
self.cfg = cfg
|
| 41 |
+
self.set_offsets = set_offsets
|
| 42 |
+
self.sway = sway or SwayRollRT()
|
| 43 |
+
|
| 44 |
+
self.playback_q: asyncio.Queue = (
|
| 45 |
+
asyncio.Queue()
|
| 46 |
+
) # (sr:int, pcm: np.ndarray[1,N] int16)
|
| 47 |
+
self._sway_q: asyncio.Queue = (
|
| 48 |
+
asyncio.Queue()
|
| 49 |
+
) # (sr:int, pcm: np.ndarray[1,N] int16)
|
| 50 |
+
|
| 51 |
+
self._base_ts: Optional[float] = None
|
| 52 |
+
self._hops_done: int = 0
|
| 53 |
+
self._sway_task: Optional[asyncio.Task] = None
|
| 54 |
+
|
| 55 |
+
# lifecycle
|
| 56 |
+
|
| 57 |
+
def start(self) -> None:
|
| 58 |
+
if self._sway_task is None:
|
| 59 |
+
self._sway_task = asyncio.create_task(self._sway_consumer())
|
| 60 |
+
|
| 61 |
+
async def stop(self) -> None:
|
| 62 |
+
if self._sway_task:
|
| 63 |
+
self._sway_task.cancel()
|
| 64 |
+
try:
|
| 65 |
+
await self._sway_task
|
| 66 |
+
except asyncio.CancelledError:
|
| 67 |
+
pass
|
| 68 |
+
self._sway_task = None
|
| 69 |
+
self._reset_all()
|
| 70 |
+
self._drain(self._sway_q)
|
| 71 |
+
self._drain(self.playback_q)
|
| 72 |
+
|
| 73 |
+
# event hooks from your Realtime loop
|
| 74 |
+
|
| 75 |
+
def on_input_speech_started(self) -> None:
|
| 76 |
+
"""User started speaking (server VAD). Reset sync state."""
|
| 77 |
+
self._reset_all()
|
| 78 |
+
self._drain(self._sway_q)
|
| 79 |
+
|
| 80 |
+
def on_response_started(self) -> None:
|
| 81 |
+
"""Assistant began a new utterance."""
|
| 82 |
+
self._reset_all()
|
| 83 |
+
self._drain(self._sway_q)
|
| 84 |
+
|
| 85 |
+
def on_response_completed(self) -> None:
|
| 86 |
+
"""Assistant finished an utterance."""
|
| 87 |
+
self._reset_all()
|
| 88 |
+
self._drain(self._sway_q)
|
| 89 |
+
|
| 90 |
+
def on_response_audio_delta(self, delta_b64: str) -> None:
|
| 91 |
+
"""
|
| 92 |
+
Called for each 'response.audio.delta' event.
|
| 93 |
+
Pushes audio both to playback and to sway engine.
|
| 94 |
+
"""
|
| 95 |
+
buf = np.frombuffer(base64.b64decode(delta_b64), dtype=np.int16).reshape(1, -1)
|
| 96 |
+
# 1) to fastrtc playback
|
| 97 |
+
self.playback_q.put_nowait((self.cfg.output_sample_rate, buf))
|
| 98 |
+
# 2) to sway engine
|
| 99 |
+
self._sway_q.put_nowait((self.cfg.output_sample_rate, buf))
|
| 100 |
+
|
| 101 |
+
# fastrtc hook
|
| 102 |
+
|
| 103 |
+
async def emit_playback(self):
|
| 104 |
+
"""Await next (sr, pcm[1,N]) frame for your Stream(...)."""
|
| 105 |
+
return await self.playback_q.get()
|
| 106 |
+
|
| 107 |
+
# internal
|
| 108 |
+
|
| 109 |
+
async def _sway_consumer(self):
|
| 110 |
+
"""
|
| 111 |
+
Convert streaming audio chunks into head-offset poses at precise times.
|
| 112 |
+
"""
|
| 113 |
+
hop_dt = HOP_MS / 1000.0
|
| 114 |
+
loop = asyncio.get_running_loop()
|
| 115 |
+
|
| 116 |
+
while True:
|
| 117 |
+
sr, chunk = await self._sway_q.get() # (1,N), int16
|
| 118 |
+
pcm = np.asarray(chunk).squeeze(0)
|
| 119 |
+
results = self.sway.feed(pcm, sr) # list of dicts with keys x_mm..yaw_rad
|
| 120 |
+
|
| 121 |
+
if self._base_ts is None:
|
| 122 |
+
# anchor when first audio samples of this utterance arrive
|
| 123 |
+
self._base_ts = loop.time()
|
| 124 |
+
|
| 125 |
+
i = 0
|
| 126 |
+
while i < len(results):
|
| 127 |
+
if self._base_ts is None:
|
| 128 |
+
self._base_ts = loop.time()
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
target = (
|
| 132 |
+
self._base_ts
|
| 133 |
+
+ self.cfg.movement_latency_s
|
| 134 |
+
+ self._hops_done * hop_dt
|
| 135 |
+
)
|
| 136 |
+
now = loop.time()
|
| 137 |
+
|
| 138 |
+
# if late by ≥1 hop, drop poses to catch up (no drift accumulation)
|
| 139 |
+
if now - target >= hop_dt:
|
| 140 |
+
lag_hops = int((now - target) / hop_dt)
|
| 141 |
+
drop = min(
|
| 142 |
+
lag_hops, len(results) - i - 1
|
| 143 |
+
) # keep at least one to show
|
| 144 |
+
if drop > 0:
|
| 145 |
+
self._hops_done += drop
|
| 146 |
+
i += drop
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# if early, wait
|
| 150 |
+
if target > now:
|
| 151 |
+
await asyncio.sleep(target - now)
|
| 152 |
+
|
| 153 |
+
r = results[i]
|
| 154 |
+
# meters + radians
|
| 155 |
+
offsets = (
|
| 156 |
+
r["x_mm"] / 1000.0,
|
| 157 |
+
r["y_mm"] / 1000.0,
|
| 158 |
+
r["z_mm"] / 1000.0,
|
| 159 |
+
r["roll_rad"],
|
| 160 |
+
r["pitch_rad"],
|
| 161 |
+
r["yaw_rad"],
|
| 162 |
+
)
|
| 163 |
+
self.set_offsets(offsets)
|
| 164 |
+
|
| 165 |
+
self._hops_done += 1
|
| 166 |
+
i += 1
|
| 167 |
+
|
| 168 |
+
def _reset_all(self) -> None:
|
| 169 |
+
self._base_ts = None
|
| 170 |
+
self._hops_done = 0
|
| 171 |
+
self.sway.reset()
|
| 172 |
+
|
| 173 |
+
@staticmethod
|
| 174 |
+
def _drain(q: asyncio.Queue) -> None:
|
| 175 |
+
try:
|
| 176 |
+
while True:
|
| 177 |
+
q.get_nowait()
|
| 178 |
+
except asyncio.QueueEmpty:
|
| 179 |
+
pass
|
src/reachy_mini_conversation_demo/gstreamer.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from threading import Thread
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import gi
|
| 7 |
+
|
| 8 |
+
gi.require_version("Gst", "1.0")
|
| 9 |
+
gi.require_version("GstApp", "1.0")
|
| 10 |
+
from gi.repository import GLib, Gst, GstApp
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GstPlayer:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self._logger = logging.getLogger(__name__)
|
| 16 |
+
Gst.init(None)
|
| 17 |
+
self._loop = GLib.MainLoop()
|
| 18 |
+
self._thread_bus_calls: Optional[Thread] = None
|
| 19 |
+
|
| 20 |
+
self.pipeline = Gst.Pipeline.new("audio_player")
|
| 21 |
+
|
| 22 |
+
# Optional device name from env (substring match)
|
| 23 |
+
audio_out = os.getenv("AUDIO_OUT")
|
| 24 |
+
|
| 25 |
+
# Create elements
|
| 26 |
+
self.appsrc = Gst.ElementFactory.make("appsrc", None)
|
| 27 |
+
self.appsrc.set_property("format", Gst.Format.TIME)
|
| 28 |
+
self.appsrc.set_property("is-live", True)
|
| 29 |
+
caps = Gst.Caps.from_string(
|
| 30 |
+
"audio/x-raw,format=S16LE,channels=1,rate=24000,layout=interleaved"
|
| 31 |
+
)
|
| 32 |
+
self.appsrc.set_property("caps", caps)
|
| 33 |
+
queue = Gst.ElementFactory.make("queue")
|
| 34 |
+
audioconvert = Gst.ElementFactory.make("audioconvert")
|
| 35 |
+
audioresample = Gst.ElementFactory.make("audioresample")
|
| 36 |
+
|
| 37 |
+
# Try to pin specific output device; fallback to autoaudiosink
|
| 38 |
+
audiosink = _create_device_element(
|
| 39 |
+
direction="sink", name_substr=audio_out
|
| 40 |
+
) or Gst.ElementFactory.make("autoaudiosink")
|
| 41 |
+
|
| 42 |
+
self.pipeline.add(self.appsrc)
|
| 43 |
+
self.pipeline.add(queue)
|
| 44 |
+
self.pipeline.add(audioconvert)
|
| 45 |
+
self.pipeline.add(audioresample)
|
| 46 |
+
self.pipeline.add(audiosink)
|
| 47 |
+
|
| 48 |
+
self.appsrc.link(queue)
|
| 49 |
+
queue.link(audioconvert)
|
| 50 |
+
audioconvert.link(audioresample)
|
| 51 |
+
audioresample.link(audiosink)
|
| 52 |
+
|
| 53 |
+
def _on_bus_message(self, bus: Gst.Bus, msg: Gst.Message, loop) -> bool: # type: ignore[no-untyped-def]
|
| 54 |
+
t = msg.type
|
| 55 |
+
if t == Gst.MessageType.EOS:
|
| 56 |
+
self._logger.warning("End-of-stream")
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
elif t == Gst.MessageType.ERROR:
|
| 60 |
+
err, debug = msg.parse_error()
|
| 61 |
+
self._logger.error(f"Error: {err} {debug}")
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
def _handle_bus_calls(self) -> None:
|
| 67 |
+
self._logger.debug("starting bus message loop")
|
| 68 |
+
bus = self.pipeline.get_bus()
|
| 69 |
+
bus.add_watch(GLib.PRIORITY_DEFAULT, self._on_bus_message, self._loop)
|
| 70 |
+
self._loop.run() # type: ignore[no-untyped-call]
|
| 71 |
+
bus.remove_watch()
|
| 72 |
+
self._logger.debug("bus message loop stopped")
|
| 73 |
+
|
| 74 |
+
def play(self):
|
| 75 |
+
self.pipeline.set_state(Gst.State.PLAYING)
|
| 76 |
+
self._thread_bus_calls = Thread(target=self._handle_bus_calls, daemon=True)
|
| 77 |
+
self._thread_bus_calls.start()
|
| 78 |
+
|
| 79 |
+
def push_sample(self, data: bytes):
|
| 80 |
+
buf = Gst.Buffer.new_wrapped(data)
|
| 81 |
+
self.appsrc.push_buffer(buf)
|
| 82 |
+
|
| 83 |
+
def stop(self):
|
| 84 |
+
logger = logging.getLogger(__name__)
|
| 85 |
+
|
| 86 |
+
self._loop.quit()
|
| 87 |
+
self.pipeline.set_state(Gst.State.NULL)
|
| 88 |
+
self._thread_bus_calls.join()
|
| 89 |
+
logger.info("Stopped Player")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class GstRecorder:
|
| 93 |
+
def __init__(self):
|
| 94 |
+
self._logger = logging.getLogger(__name__)
|
| 95 |
+
Gst.init(None)
|
| 96 |
+
self._loop = GLib.MainLoop()
|
| 97 |
+
self._thread_bus_calls: Optional[Thread] = None
|
| 98 |
+
|
| 99 |
+
self.pipeline = Gst.Pipeline.new("audio_recorder")
|
| 100 |
+
|
| 101 |
+
audio_in = os.getenv("AUDIO_IN")
|
| 102 |
+
|
| 103 |
+
# Create elements: try specific mic; fallback to default
|
| 104 |
+
autoaudiosrc = _create_device_element(
|
| 105 |
+
direction="source", name_substr=audio_in
|
| 106 |
+
) or Gst.ElementFactory.make("autoaudiosrc", None)
|
| 107 |
+
|
| 108 |
+
queue = Gst.ElementFactory.make("queue", None)
|
| 109 |
+
audioconvert = Gst.ElementFactory.make("audioconvert", None)
|
| 110 |
+
audioresample = Gst.ElementFactory.make("audioresample", None)
|
| 111 |
+
self.appsink = Gst.ElementFactory.make("appsink", None)
|
| 112 |
+
|
| 113 |
+
if not all([autoaudiosrc, queue, audioconvert, audioresample, self.appsink]):
|
| 114 |
+
raise RuntimeError("Failed to create GStreamer elements")
|
| 115 |
+
|
| 116 |
+
# Force mono/S16LE at 24000; resample handles device SR (e.g., 16000 → 24000)
|
| 117 |
+
caps = Gst.Caps.from_string("audio/x-raw,channels=1,rate=24000,format=S16LE")
|
| 118 |
+
self.appsink.set_property("caps", caps)
|
| 119 |
+
|
| 120 |
+
# Build pipeline
|
| 121 |
+
self.pipeline.add(autoaudiosrc)
|
| 122 |
+
self.pipeline.add(queue)
|
| 123 |
+
self.pipeline.add(audioconvert)
|
| 124 |
+
self.pipeline.add(audioresample)
|
| 125 |
+
self.pipeline.add(self.appsink)
|
| 126 |
+
|
| 127 |
+
autoaudiosrc.link(queue)
|
| 128 |
+
queue.link(audioconvert)
|
| 129 |
+
audioconvert.link(audioresample)
|
| 130 |
+
audioresample.link(self.appsink)
|
| 131 |
+
|
| 132 |
+
def _on_bus_message(self, bus: Gst.Bus, msg: Gst.Message, loop) -> bool: # type: ignore[no-untyped-def]
|
| 133 |
+
t = msg.type
|
| 134 |
+
if t == Gst.MessageType.EOS:
|
| 135 |
+
self._logger.warning("End-of-stream")
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
elif t == Gst.MessageType.ERROR:
|
| 139 |
+
err, debug = msg.parse_error()
|
| 140 |
+
self._logger.error(f"Error: {err} {debug}")
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
return True
|
| 144 |
+
|
| 145 |
+
def _handle_bus_calls(self) -> None:
|
| 146 |
+
self._logger.debug("starting bus message loop")
|
| 147 |
+
bus = self.pipeline.get_bus()
|
| 148 |
+
bus.add_watch(GLib.PRIORITY_DEFAULT, self._on_bus_message, self._loop)
|
| 149 |
+
self._loop.run() # type: ignore[no-untyped-call]
|
| 150 |
+
bus.remove_watch()
|
| 151 |
+
self._logger.debug("bus message loop stopped")
|
| 152 |
+
|
| 153 |
+
def record(self):
|
| 154 |
+
self.pipeline.set_state(Gst.State.PLAYING)
|
| 155 |
+
self._thread_bus_calls = Thread(target=self._handle_bus_calls, daemon=True)
|
| 156 |
+
self._thread_bus_calls.start()
|
| 157 |
+
|
| 158 |
+
def get_sample(self):
|
| 159 |
+
sample = self.appsink.pull_sample()
|
| 160 |
+
data = None
|
| 161 |
+
if isinstance(sample, Gst.Sample):
|
| 162 |
+
buf = sample.get_buffer()
|
| 163 |
+
if buf is None:
|
| 164 |
+
self._logger.warning("Buffer is None")
|
| 165 |
+
|
| 166 |
+
data = buf.extract_dup(0, buf.get_size())
|
| 167 |
+
return data
|
| 168 |
+
|
| 169 |
+
def stop(self):
|
| 170 |
+
logger = logging.getLogger(__name__)
|
| 171 |
+
|
| 172 |
+
self._loop.quit()
|
| 173 |
+
self.pipeline.set_state(Gst.State.NULL)
|
| 174 |
+
self._thread_bus_calls.join()
|
| 175 |
+
logger.info("Stopped Recorder")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _create_device_element(
|
| 179 |
+
direction: str, name_substr: Optional[str]
|
| 180 |
+
) -> Optional[Gst.Element]:
|
| 181 |
+
"""
|
| 182 |
+
direction: 'source' or 'sink'
|
| 183 |
+
name_substr: case-insensitive substring matching device display name/description.
|
| 184 |
+
"""
|
| 185 |
+
logger = logging.getLogger(__name__)
|
| 186 |
+
|
| 187 |
+
if not name_substr:
|
| 188 |
+
logger.error(f"Device select: no name_substr for {direction}; returning None")
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
monitor = Gst.DeviceMonitor.new()
|
| 192 |
+
klass = "Audio/Source" if direction == "source" else "Audio/Sink"
|
| 193 |
+
monitor.add_filter(klass, None)
|
| 194 |
+
monitor.start()
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
for dev in monitor.get_devices() or []:
|
| 198 |
+
disp = dev.get_display_name() or ""
|
| 199 |
+
props = dev.get_properties()
|
| 200 |
+
desc = (
|
| 201 |
+
props.get_string("device.description")
|
| 202 |
+
if props and props.has_field("device.description")
|
| 203 |
+
else ""
|
| 204 |
+
)
|
| 205 |
+
logger.info(f"Device candidate: disp='{disp}', desc='{desc}'")
|
| 206 |
+
|
| 207 |
+
if (
|
| 208 |
+
name_substr.lower() in disp.lower()
|
| 209 |
+
or name_substr.lower() in desc.lower()
|
| 210 |
+
):
|
| 211 |
+
elem = dev.create_element(None)
|
| 212 |
+
factory = (
|
| 213 |
+
elem.get_factory().get_name()
|
| 214 |
+
if elem and elem.get_factory()
|
| 215 |
+
else "<?>"
|
| 216 |
+
)
|
| 217 |
+
logger.info(
|
| 218 |
+
f"Using {direction} device: '{disp or desc}' (factory='{factory}')"
|
| 219 |
+
)
|
| 220 |
+
return elem
|
| 221 |
+
finally:
|
| 222 |
+
monitor.stop()
|
| 223 |
+
logging.getLogger(__name__).warning(
|
| 224 |
+
"Requested %s '%s' not found; using auto*", direction, name_substr
|
| 225 |
+
)
|
| 226 |
+
return None
|
src/reachy_mini_conversation_demo/head_tracker.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
from ultralytics import YOLO
|
| 9 |
+
from supervision import Detections
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HeadTracker:
|
| 15 |
+
"""
|
| 16 |
+
Lightweight head tracker using YOLO for face detection
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
model_repo: str = "AdamCodd/YOLOv11n-face-detection",
|
| 22 |
+
model_filename: str = "model.pt",
|
| 23 |
+
confidence_threshold: float = 0.3,
|
| 24 |
+
device: str = "cpu",
|
| 25 |
+
) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Initialize YOLO-based head tracker
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
model_repo: HuggingFace model repository
|
| 31 |
+
model_filename: Model file name
|
| 32 |
+
confidence_threshold: Minimum confidence for face detection
|
| 33 |
+
device: Device to run inference on ('cpu' or 'cuda')
|
| 34 |
+
"""
|
| 35 |
+
self.confidence_threshold = confidence_threshold
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Download and load YOLO model
|
| 39 |
+
model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
|
| 40 |
+
self.model = YOLO(model_path).to(device)
|
| 41 |
+
logger.info(f"YOLO face detection model loaded from {model_repo}")
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"Failed to load YOLO model: {e}")
|
| 44 |
+
raise
|
| 45 |
+
|
| 46 |
+
def _select_best_face(self, detections: Detections) -> Optional[int]:
|
| 47 |
+
"""
|
| 48 |
+
Select the best face based on confidence and area (largest face with highest confidence)
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
detections: Supervision detections object
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Index of best face or None if no valid faces
|
| 55 |
+
"""
|
| 56 |
+
if detections.xyxy.shape[0] == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
# Filter by confidence threshold
|
| 60 |
+
valid_mask = detections.confidence >= self.confidence_threshold
|
| 61 |
+
if not np.any(valid_mask):
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
valid_indices = np.where(valid_mask)[0]
|
| 65 |
+
|
| 66 |
+
# Calculate areas for valid detections
|
| 67 |
+
boxes = detections.xyxy[valid_indices]
|
| 68 |
+
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 69 |
+
|
| 70 |
+
# Combine confidence and area (weighted towards larger faces)
|
| 71 |
+
confidences = detections.confidence[valid_indices]
|
| 72 |
+
scores = confidences * 0.7 + (areas / np.max(areas)) * 0.3
|
| 73 |
+
|
| 74 |
+
# Return index of best face
|
| 75 |
+
best_idx = valid_indices[np.argmax(scores)]
|
| 76 |
+
return best_idx
|
| 77 |
+
|
| 78 |
+
def _bbox_to_mp_coords(self, bbox: np.ndarray, w: int, h: int) -> np.ndarray:
|
| 79 |
+
"""
|
| 80 |
+
Convert bounding box center to MediaPipe-style coordinates [-1, 1]
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
bbox: Bounding box [x1, y1, x2, y2]
|
| 84 |
+
w: Image width
|
| 85 |
+
h: Image height
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Center point in [-1, 1] coordinates
|
| 89 |
+
"""
|
| 90 |
+
center_x = (bbox[0] + bbox[2]) / 2.0
|
| 91 |
+
center_y = (bbox[1] + bbox[3]) / 2.0
|
| 92 |
+
|
| 93 |
+
# Normalize to [0, 1] then to [-1, 1]
|
| 94 |
+
norm_x = (center_x / w) * 2.0 - 1.0
|
| 95 |
+
norm_y = (center_y / h) * 2.0 - 1.0
|
| 96 |
+
|
| 97 |
+
return np.array([norm_x, norm_y], dtype=np.float32)
|
| 98 |
+
|
| 99 |
+
def get_eyes(
|
| 100 |
+
self, img: np.ndarray
|
| 101 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
| 102 |
+
"""
|
| 103 |
+
Get eye positions (approximated from face bbox)
|
| 104 |
+
Note: YOLO only provides face bbox, so we estimate eye positions
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
img: Input image
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Tuple of (left_eye, right_eye) in [-1, 1] coordinates
|
| 111 |
+
"""
|
| 112 |
+
h, w = img.shape[:2]
|
| 113 |
+
|
| 114 |
+
# Run YOLO inference
|
| 115 |
+
results = self.model(img, verbose=False)
|
| 116 |
+
detections = Detections.from_ultralytics(results[0])
|
| 117 |
+
|
| 118 |
+
# Select best face
|
| 119 |
+
face_idx = self._select_best_face(detections)
|
| 120 |
+
if face_idx is None:
|
| 121 |
+
return None, None
|
| 122 |
+
|
| 123 |
+
bbox = detections.xyxy[face_idx]
|
| 124 |
+
|
| 125 |
+
# Estimate eye positions from face bbox (approximate locations)
|
| 126 |
+
face_width = bbox[2] - bbox[0]
|
| 127 |
+
face_height = bbox[3] - bbox[1]
|
| 128 |
+
|
| 129 |
+
# Eye positions are roughly at 1/3 and 2/3 of face width, 1/3 of face height
|
| 130 |
+
eye_y = bbox[1] + face_height * 0.35
|
| 131 |
+
left_eye_x = bbox[0] + face_width * 0.35
|
| 132 |
+
right_eye_x = bbox[0] + face_width * 0.65
|
| 133 |
+
|
| 134 |
+
# Convert to MediaPipe coordinates
|
| 135 |
+
left_eye = np.array(
|
| 136 |
+
[(left_eye_x / w) * 2 - 1, (eye_y / h) * 2 - 1], dtype=np.float32
|
| 137 |
+
)
|
| 138 |
+
right_eye = np.array(
|
| 139 |
+
[(right_eye_x / w) * 2 - 1, (eye_y / h) * 2 - 1], dtype=np.float32
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return left_eye, right_eye
|
| 143 |
+
|
| 144 |
+
def get_eyes_from_landmarks(self, face_landmarks) -> Tuple[np.ndarray, np.ndarray]:
|
| 145 |
+
"""
|
| 146 |
+
Compatibility method - YOLO doesn't have landmarks, so we store bbox in the object
|
| 147 |
+
"""
|
| 148 |
+
if not hasattr(face_landmarks, "_bbox") or not hasattr(
|
| 149 |
+
face_landmarks, "_img_shape"
|
| 150 |
+
):
|
| 151 |
+
raise ValueError("Face landmarks object missing required attributes")
|
| 152 |
+
|
| 153 |
+
bbox = face_landmarks._bbox
|
| 154 |
+
h, w = face_landmarks._img_shape[:2]
|
| 155 |
+
|
| 156 |
+
# Estimate eyes from stored bbox
|
| 157 |
+
face_width = bbox[2] - bbox[0]
|
| 158 |
+
face_height = bbox[3] - bbox[1]
|
| 159 |
+
|
| 160 |
+
eye_y = bbox[1] + face_height * 0.35
|
| 161 |
+
left_eye_x = bbox[0] + face_width * 0.35
|
| 162 |
+
right_eye_x = bbox[0] + face_width * 0.65
|
| 163 |
+
|
| 164 |
+
left_eye = np.array(
|
| 165 |
+
[(left_eye_x / w) * 2 - 1, (eye_y / h) * 2 - 1], dtype=np.float32
|
| 166 |
+
)
|
| 167 |
+
right_eye = np.array(
|
| 168 |
+
[(right_eye_x / w) * 2 - 1, (eye_y / h) * 2 - 1], dtype=np.float32
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return left_eye, right_eye
|
| 172 |
+
|
| 173 |
+
def get_eye_center(self, face_landmarks) -> np.ndarray:
|
| 174 |
+
"""
|
| 175 |
+
Get center point between estimated eyes
|
| 176 |
+
"""
|
| 177 |
+
left_eye, right_eye = self.get_eyes_from_landmarks(face_landmarks)
|
| 178 |
+
return np.mean([left_eye, right_eye], axis=0)
|
| 179 |
+
|
| 180 |
+
def get_roll(self, face_landmarks) -> float:
|
| 181 |
+
"""
|
| 182 |
+
Estimate roll from eye positions (will be 0 for YOLO since we estimate symmetric eyes)
|
| 183 |
+
"""
|
| 184 |
+
left_eye, right_eye = self.get_eyes_from_landmarks(face_landmarks)
|
| 185 |
+
return float(np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0]))
|
| 186 |
+
|
| 187 |
+
def get_head_position(
|
| 188 |
+
self, img: np.ndarray
|
| 189 |
+
) -> Tuple[Optional[np.ndarray], Optional[float]]:
|
| 190 |
+
"""
|
| 191 |
+
Get head position from face detection
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
img: Input image
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
Tuple of (eye_center [-1,1], roll_angle)
|
| 198 |
+
"""
|
| 199 |
+
h, w = img.shape[:2]
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
# Run YOLO inference
|
| 203 |
+
results = self.model(img, verbose=False)
|
| 204 |
+
detections = Detections.from_ultralytics(results[0])
|
| 205 |
+
|
| 206 |
+
# Select best face
|
| 207 |
+
face_idx = self._select_best_face(detections)
|
| 208 |
+
if face_idx is None:
|
| 209 |
+
logger.debug("No face detected above confidence threshold")
|
| 210 |
+
return None, None
|
| 211 |
+
|
| 212 |
+
bbox = detections.xyxy[face_idx]
|
| 213 |
+
confidence = detections.confidence[face_idx]
|
| 214 |
+
|
| 215 |
+
logger.debug(f"Face detected with confidence: {confidence:.2f}")
|
| 216 |
+
|
| 217 |
+
# Get face center in [-1, 1] coordinates
|
| 218 |
+
face_center = self._bbox_to_mp_coords(bbox, w, h)
|
| 219 |
+
|
| 220 |
+
# Roll is 0 since we don't have keypoints for precise angle estimation
|
| 221 |
+
roll = 0.0
|
| 222 |
+
|
| 223 |
+
return face_center, roll
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.error(f"Error in head position detection: {e}")
|
| 227 |
+
return None, None
|
| 228 |
+
|
| 229 |
+
def cleanup(self):
|
| 230 |
+
"""
|
| 231 |
+
Clean up resources
|
| 232 |
+
"""
|
| 233 |
+
if hasattr(self, "model"):
|
| 234 |
+
del self.model
|
| 235 |
+
logger.info("YOLO model cleaned up")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class FaceLandmarks:
|
| 239 |
+
"""
|
| 240 |
+
Simple container for face detection results to maintain API compatibility
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(self, bbox: np.ndarray, img_shape: tuple):
|
| 244 |
+
self._bbox = bbox
|
| 245 |
+
self._img_shape = img_shape
|
src/reachy_mini_conversation_demo/main.py
CHANGED
|
@@ -1,5 +1,562 @@
|
|
|
|
|
| 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import warnings
|
| 11 |
+
import threading
|
| 12 |
+
from threading import Thread
|
| 13 |
|
| 14 |
+
import cv2
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import numpy as np
|
| 17 |
+
from dotenv import load_dotenv
|
| 18 |
+
from openai import AsyncOpenAI
|
| 19 |
+
|
| 20 |
+
from fastrtc import AdditionalOutputs, AsyncStreamHandler, wait_for_item
|
| 21 |
+
from websockets import ConnectionClosedError, ConnectionClosedOK
|
| 22 |
+
|
| 23 |
+
from reachy_mini.reachy_mini import IMAGE_SIZE
|
| 24 |
+
from reachy_mini import ReachyMini
|
| 25 |
+
from reachy_mini.utils import create_head_pose
|
| 26 |
+
from reachy_mini.utils.camera import find_camera
|
| 27 |
+
from scipy.spatial.transform import Rotation
|
| 28 |
+
|
| 29 |
+
from reachy_mini_conversation_demo.head_tracker import HeadTracker
|
| 30 |
+
from reachy_mini_conversation_demo.prompts import SESSION_INSTRUCTIONS
|
| 31 |
+
from reachy_mini_conversation_demo.tools import (
|
| 32 |
+
Deps,
|
| 33 |
+
TOOL_SPECS,
|
| 34 |
+
dispatch_tool_call,
|
| 35 |
+
)
|
| 36 |
+
from reachy_mini_conversation_demo.audio import AudioSync, AudioConfig, pcm_to_b64
|
| 37 |
+
from reachy_mini_conversation_demo.movement import MovementManager
|
| 38 |
+
from reachy_mini_conversation_demo.gstreamer import GstPlayer, GstRecorder
|
| 39 |
+
from reachy_mini_conversation_demo.vision import VisionManager, VisionConfig
|
| 40 |
+
|
| 41 |
+
# env + logging
|
| 42 |
+
load_dotenv()
|
| 43 |
+
|
| 44 |
+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 45 |
+
logging.basicConfig(
|
| 46 |
+
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
| 47 |
+
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d | %(message)s",
|
| 48 |
+
)
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
# Suppress WebRTC warnings
|
| 52 |
+
warnings.filterwarnings("ignore", message=".*AVCaptureDeviceTypeExternal.*")
|
| 53 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="aiortc")
|
| 54 |
+
|
| 55 |
+
# Reduce logging noise
|
| 56 |
+
logging.getLogger("aiortc").setLevel(logging.ERROR)
|
| 57 |
+
logging.getLogger("fastrtc").setLevel(logging.ERROR)
|
| 58 |
+
logging.getLogger("aioice").setLevel(logging.WARNING)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Read from .env
|
| 62 |
+
SAMPLE_RATE = int(os.getenv("SAMPLE_RATE", "24000"))
|
| 63 |
+
SIM = os.getenv("SIM", "false").lower() in ("true", "1", "yes", "on")
|
| 64 |
+
VISION_ENABLED = os.getenv("VISION_ENABLED", "false").lower() in (
|
| 65 |
+
"true",
|
| 66 |
+
"1",
|
| 67 |
+
"yes",
|
| 68 |
+
"on",
|
| 69 |
+
)
|
| 70 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-realtime-preview")
|
| 71 |
+
|
| 72 |
+
HEAD_TRACKING = os.getenv("HEAD_TRACKING", "false").lower() in (
|
| 73 |
+
"true",
|
| 74 |
+
"1",
|
| 75 |
+
"yes",
|
| 76 |
+
"on",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
API_KEY = os.getenv("OPENAI_API_KEY")
|
| 80 |
+
if not API_KEY:
|
| 81 |
+
logger.error("OPENAI_API_KEY not set! Please add it to your .env file.")
|
| 82 |
+
raise RuntimeError("OPENAI_API_KEY missing")
|
| 83 |
+
masked = (API_KEY[:6] + "..." + API_KEY[-4:]) if len(API_KEY) >= 12 else "<short>"
|
| 84 |
+
logger.info("OPENAI_API_KEY loaded (prefix): %s", masked)
|
| 85 |
+
|
| 86 |
+
# HF cache setup (persist between restarts)
|
| 87 |
+
HF_CACHE_DIR = os.path.expandvars(os.getenv("HF_HOME", "$HOME/.cache/huggingface"))
|
| 88 |
+
try:
|
| 89 |
+
os.makedirs(HF_CACHE_DIR, exist_ok=True)
|
| 90 |
+
os.environ["HF_HOME"] = HF_CACHE_DIR
|
| 91 |
+
logger.info("HF_HOME set to %s", HF_CACHE_DIR)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.warning("Failed to prepare HF cache dir %s: %s", HF_CACHE_DIR, e)
|
| 94 |
+
|
| 95 |
+
# init camera
|
| 96 |
+
CAMERA_INDEX = int(os.getenv("CAMERA_INDEX", "0"))
|
| 97 |
+
|
| 98 |
+
if SIM:
|
| 99 |
+
# Default build-in camera in SIM
|
| 100 |
+
# TODO: please, test on Linux and Windows
|
| 101 |
+
camera = cv2.VideoCapture(
|
| 102 |
+
0, cv2.CAP_AVFOUNDATION if sys.platform == "darwin" else 0
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
if sys.platform == "darwin":
|
| 106 |
+
camera = cv2.VideoCapture(CAMERA_INDEX, cv2.CAP_AVFOUNDATION)
|
| 107 |
+
if not camera or not camera.isOpened():
|
| 108 |
+
logger.warning(
|
| 109 |
+
"Camera %d failed with AVFoundation; trying default backend",
|
| 110 |
+
CAMERA_INDEX,
|
| 111 |
+
)
|
| 112 |
+
camera = cv2.VideoCapture(CAMERA_INDEX)
|
| 113 |
+
else:
|
| 114 |
+
camera = find_camera()
|
| 115 |
+
|
| 116 |
+
# Vision manager initialization with proper error handling
|
| 117 |
+
vision_manager: VisionManager | None = None
|
| 118 |
+
|
| 119 |
+
if not camera or not camera.isOpened():
|
| 120 |
+
logger.error("Camera failed to open (index=%s)", 0 if SIM else CAMERA_INDEX)
|
| 121 |
+
VISION_ENABLED = False # Disable vision if no camera
|
| 122 |
+
else:
|
| 123 |
+
logger.info(
|
| 124 |
+
"Camera ready (index=%s)%s", 0 if SIM else CAMERA_INDEX, " [SIM]" if SIM else ""
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Prefetch SmolVLM2 repo into HF cache (idempotent, fast if already cached)
|
| 128 |
+
try:
|
| 129 |
+
from huggingface_hub import snapshot_download
|
| 130 |
+
model_id = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
| 131 |
+
snapshot_download(
|
| 132 |
+
repo_id=model_id,
|
| 133 |
+
repo_type="model",
|
| 134 |
+
cache_dir=os.path.expandvars(os.getenv("HF_HOME", "$HOME/.cache/huggingface")),
|
| 135 |
+
)
|
| 136 |
+
logger.info("Prefetched %s into HF cache (%s)", model_id, os.getenv("HF_HOME"))
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.warning("Model prefetch skipped/failed (will load normally): %s", e)
|
| 139 |
+
|
| 140 |
+
# Initialize vision manager if enabled
|
| 141 |
+
if VISION_ENABLED:
|
| 142 |
+
try:
|
| 143 |
+
# Prefetch SmolVLM2 repo into HF cache (idempotent, fast if cached)
|
| 144 |
+
try:
|
| 145 |
+
from huggingface_hub import snapshot_download
|
| 146 |
+
model_id = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
| 147 |
+
snapshot_download(
|
| 148 |
+
repo_id=model_id,
|
| 149 |
+
repo_type="model",
|
| 150 |
+
cache_dir=os.path.expandvars(os.getenv("HF_HOME", "$HOME/.cache/huggingface")),
|
| 151 |
+
)
|
| 152 |
+
logger.info("Prefetched %s into HF cache (%s)", model_id, os.getenv("HF_HUB_CACHE"))
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.warning("Model prefetch skipped/failed (will load normally): %s", e)
|
| 155 |
+
|
| 156 |
+
# Configure LLM processing
|
| 157 |
+
vision_config = VisionConfig(
|
| 158 |
+
model_path="HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
| 159 |
+
vision_interval=5.0,
|
| 160 |
+
max_new_tokens=64,
|
| 161 |
+
temperature=0.7,
|
| 162 |
+
jpeg_quality=85,
|
| 163 |
+
max_retries=3,
|
| 164 |
+
retry_delay=1.0,
|
| 165 |
+
device_preference="auto",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
logger.info("Initializing SmolVLM2 vision processor (HF_HOME=%s)", os.getenv("HF_HOME"))
|
| 169 |
+
vision_manager = VisionManager(camera, vision_config)
|
| 170 |
+
|
| 171 |
+
device_info = vision_manager.processor.get_model_info()
|
| 172 |
+
logger.info(
|
| 173 |
+
"Vision processing enabled: %s on %s (GPU: %s)",
|
| 174 |
+
device_info["model_path"], device_info["device"], device_info.get("gpu_memory", "N/A"),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
logger.error("Failed to initialize vision manager: %s", e)
|
| 179 |
+
logger.error("Vision processing will be disabled")
|
| 180 |
+
vision_manager = None
|
| 181 |
+
VISION_ENABLED = False
|
| 182 |
+
|
| 183 |
+
# Log final vision status
|
| 184 |
+
if VISION_ENABLED and vision_manager:
|
| 185 |
+
logger.info("Vision system ready - local SmolVLM2 processing enabled")
|
| 186 |
+
else:
|
| 187 |
+
logger.warning(
|
| 188 |
+
"Vision system disabled - robot will operate without visual understanding"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Constants
|
| 193 |
+
BACKOFF_START_S = 1.0
|
| 194 |
+
BACKOFF_MAX_S = 30.0
|
| 195 |
+
|
| 196 |
+
# hardware / IO
|
| 197 |
+
current_robot = ReachyMini()
|
| 198 |
+
head_tracker: HeadTracker = None
|
| 199 |
+
|
| 200 |
+
if HEAD_TRACKING and not SIM:
|
| 201 |
+
head_tracker = HeadTracker()
|
| 202 |
+
logger.info("Head tracking enabled")
|
| 203 |
+
elif HEAD_TRACKING and SIM:
|
| 204 |
+
logger.warning("Head tracking disabled while in Simulation")
|
| 205 |
+
else:
|
| 206 |
+
logger.warning("Head tracking disabled")
|
| 207 |
+
|
| 208 |
+
movement_manager = MovementManager(current_robot=current_robot, head_tracker=head_tracker, camera=camera)
|
| 209 |
+
robot_is_speaking = asyncio.Event()
|
| 210 |
+
speaking_queue = asyncio.Queue()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# tool deps
|
| 214 |
+
deps = Deps(
|
| 215 |
+
reachy_mini=current_robot,
|
| 216 |
+
create_head_pose=create_head_pose,
|
| 217 |
+
camera=camera,
|
| 218 |
+
vision_manager=vision_manager,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# audio sync
|
| 222 |
+
audio_sync = AudioSync(
|
| 223 |
+
AudioConfig(output_sample_rate=SAMPLE_RATE),
|
| 224 |
+
set_offsets=movement_manager.set_offsets,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class OpenAIRealtimeHandler(AsyncStreamHandler):
|
| 229 |
+
def __init__(self) -> None:
|
| 230 |
+
super().__init__(
|
| 231 |
+
expected_layout="mono",
|
| 232 |
+
output_sample_rate=SAMPLE_RATE,
|
| 233 |
+
input_sample_rate=SAMPLE_RATE,
|
| 234 |
+
)
|
| 235 |
+
self.client: AsyncOpenAI | None = None
|
| 236 |
+
self.connection = None
|
| 237 |
+
self.output_queue: asyncio.Queue = asyncio.Queue()
|
| 238 |
+
self._stop = False
|
| 239 |
+
self._started_audio = False
|
| 240 |
+
self._connection_ready = False
|
| 241 |
+
self._speech_start_time = 0.0
|
| 242 |
+
|
| 243 |
+
def copy(self):
|
| 244 |
+
return OpenAIRealtimeHandler()
|
| 245 |
+
|
| 246 |
+
async def start_up(self):
|
| 247 |
+
if not self._started_audio:
|
| 248 |
+
audio_sync.start()
|
| 249 |
+
self._started_audio = True
|
| 250 |
+
|
| 251 |
+
if self.client is None:
|
| 252 |
+
logger.info("Realtime start_up: creating AsyncOpenAI client...")
|
| 253 |
+
self.client = AsyncOpenAI(api_key=API_KEY)
|
| 254 |
+
|
| 255 |
+
backoff = BACKOFF_START_S
|
| 256 |
+
while not self._stop:
|
| 257 |
+
try:
|
| 258 |
+
async with self.client.beta.realtime.connect(
|
| 259 |
+
model=MODEL_NAME
|
| 260 |
+
) as rt_connection:
|
| 261 |
+
self.connection = rt_connection
|
| 262 |
+
self._connection_ready = False
|
| 263 |
+
|
| 264 |
+
# configure session
|
| 265 |
+
await rt_connection.session.update(
|
| 266 |
+
session={
|
| 267 |
+
"turn_detection": {
|
| 268 |
+
"type": "server_vad",
|
| 269 |
+
"threshold": 0.6, # Higher threshold = less sensitive
|
| 270 |
+
"prefix_padding_ms": 300, # More padding before speech
|
| 271 |
+
"silence_duration_ms": 800, # Longer silence before detecting end
|
| 272 |
+
},
|
| 273 |
+
"voice": "ballad",
|
| 274 |
+
"instructions": SESSION_INSTRUCTIONS,
|
| 275 |
+
"input_audio_transcription": {
|
| 276 |
+
"model": "whisper-1",
|
| 277 |
+
"language": "en",
|
| 278 |
+
},
|
| 279 |
+
"tools": TOOL_SPECS,
|
| 280 |
+
"tool_choice": "auto",
|
| 281 |
+
"temperature": 0.7,
|
| 282 |
+
}
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Wait for session to be configured
|
| 286 |
+
await asyncio.sleep(0.2)
|
| 287 |
+
|
| 288 |
+
# Add system message with even stronger brevity emphasis
|
| 289 |
+
await rt_connection.conversation.item.create(
|
| 290 |
+
item={
|
| 291 |
+
"type": "message",
|
| 292 |
+
"role": "system",
|
| 293 |
+
"content": [
|
| 294 |
+
{
|
| 295 |
+
"type": "input_text",
|
| 296 |
+
"text": f"{SESSION_INSTRUCTIONS}\n\nIMPORTANT: Always keep responses under 25 words. Be extremely concise.",
|
| 297 |
+
}
|
| 298 |
+
],
|
| 299 |
+
}
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self._connection_ready = True
|
| 303 |
+
|
| 304 |
+
logger.info(
|
| 305 |
+
"Session updated: tools=%d, voice=%s, vad=improved",
|
| 306 |
+
len(TOOL_SPECS),
|
| 307 |
+
"ballad",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
logger.info("Realtime event loop started with improved VAD")
|
| 311 |
+
backoff = BACKOFF_START_S
|
| 312 |
+
|
| 313 |
+
async for event in rt_connection:
|
| 314 |
+
event_type = getattr(event, "type", None)
|
| 315 |
+
logger.debug("RT event: %s", event_type)
|
| 316 |
+
|
| 317 |
+
# Enhanced speech state tracking
|
| 318 |
+
if event_type == "input_audio_buffer.speech_started":
|
| 319 |
+
# Only process user speech if robot isn't currently speaking
|
| 320 |
+
if not robot_is_speaking.is_set():
|
| 321 |
+
audio_sync.on_input_speech_started()
|
| 322 |
+
logger.info("User speech detected (robot not speaking)")
|
| 323 |
+
else:
|
| 324 |
+
logger.info(
|
| 325 |
+
"Ignoring speech detection - robot is speaking"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
elif event_type == "response.started":
|
| 329 |
+
self._speech_start_time = time.time()
|
| 330 |
+
audio_sync.on_response_started()
|
| 331 |
+
logger.info("Robot started speaking")
|
| 332 |
+
|
| 333 |
+
elif event_type in (
|
| 334 |
+
"response.audio.completed",
|
| 335 |
+
"response.completed",
|
| 336 |
+
"response.audio.done",
|
| 337 |
+
):
|
| 338 |
+
logger.info("Robot finished speaking %s", event_type)
|
| 339 |
+
|
| 340 |
+
elif (
|
| 341 |
+
event_type
|
| 342 |
+
== "conversation.item.input_audio_transcription.completed"
|
| 343 |
+
):
|
| 344 |
+
await self.output_queue.put(
|
| 345 |
+
AdditionalOutputs(
|
| 346 |
+
{"role": "user", "content": event.transcript}
|
| 347 |
+
)
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
elif event_type == "response.audio_transcript.done":
|
| 351 |
+
await self.output_queue.put(
|
| 352 |
+
AdditionalOutputs(
|
| 353 |
+
{"role": "assistant", "content": event.transcript}
|
| 354 |
+
)
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# audio streaming
|
| 358 |
+
if event_type == "response.audio.delta":
|
| 359 |
+
robot_is_speaking.set()
|
| 360 |
+
# block mic from recording for given time, for each audio delta
|
| 361 |
+
speaking_queue.put_nowait(0.25)
|
| 362 |
+
audio_sync.on_response_audio_delta(
|
| 363 |
+
getattr(event, "delta", b"")
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
elif event_type == "response.function_call_arguments.done":
|
| 367 |
+
tool_name = getattr(event, "name", None)
|
| 368 |
+
args_json_str = getattr(event, "arguments", None)
|
| 369 |
+
call_id = getattr(event, "call_id", None)
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
tool_result = await dispatch_tool_call(
|
| 373 |
+
tool_name, args_json_str, deps
|
| 374 |
+
)
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.exception("Tool %s failed", tool_name)
|
| 377 |
+
tool_result = {"error": str(e)}
|
| 378 |
+
|
| 379 |
+
await rt_connection.conversation.item.create(
|
| 380 |
+
item={
|
| 381 |
+
"type": "function_call_output",
|
| 382 |
+
"call_id": call_id,
|
| 383 |
+
"output": json.dumps(tool_result),
|
| 384 |
+
}
|
| 385 |
+
)
|
| 386 |
+
logger.info(
|
| 387 |
+
"Sent tool=%s call_id=%s result=%s",
|
| 388 |
+
tool_name,
|
| 389 |
+
call_id,
|
| 390 |
+
tool_result,
|
| 391 |
+
)
|
| 392 |
+
if tool_name and (
|
| 393 |
+
tool_name == "camera" or "scene" in tool_name
|
| 394 |
+
):
|
| 395 |
+
logger.info(
|
| 396 |
+
"Forcing response after tool call %s", tool_name
|
| 397 |
+
)
|
| 398 |
+
await rt_connection.response.create()
|
| 399 |
+
|
| 400 |
+
# server errors
|
| 401 |
+
if event_type == "error":
|
| 402 |
+
err = getattr(event, "error", None)
|
| 403 |
+
msg = getattr(
|
| 404 |
+
err, "message", str(err) if err else "unknown error"
|
| 405 |
+
)
|
| 406 |
+
logger.error("Realtime error: %s (raw=%s)", msg, err)
|
| 407 |
+
await self.output_queue.put(
|
| 408 |
+
AdditionalOutputs(
|
| 409 |
+
{"role": "assistant", "content": f"[error] {msg}"}
|
| 410 |
+
)
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
except (ConnectionClosedOK, ConnectionClosedError) as e:
|
| 414 |
+
if self._stop:
|
| 415 |
+
break
|
| 416 |
+
logger.warning(
|
| 417 |
+
"Connection closed (%s). Reconnecting…",
|
| 418 |
+
getattr(e, "code", "no-code"),
|
| 419 |
+
)
|
| 420 |
+
except asyncio.CancelledError:
|
| 421 |
+
break
|
| 422 |
+
except Exception:
|
| 423 |
+
logger.exception("Realtime loop error; will reconnect")
|
| 424 |
+
finally:
|
| 425 |
+
self.connection = None
|
| 426 |
+
self._connection_ready = False
|
| 427 |
+
|
| 428 |
+
# Exponential backoff
|
| 429 |
+
delay = min(backoff, BACKOFF_MAX_S) + random.uniform(0, 0.5)
|
| 430 |
+
logger.info("Reconnect in %.1fs…", delay)
|
| 431 |
+
await asyncio.sleep(delay)
|
| 432 |
+
backoff = min(backoff * 2.0, BACKOFF_MAX_S)
|
| 433 |
+
|
| 434 |
+
async def receive(self, frame: bytes) -> None:
|
| 435 |
+
"""Mic frames from fastrtc."""
|
| 436 |
+
# Don't send mic audio while robot is speaking (simple echo cancellation)
|
| 437 |
+
if robot_is_speaking.is_set() or not self._connection_ready:
|
| 438 |
+
return
|
| 439 |
+
|
| 440 |
+
mic_samples = np.frombuffer(frame, dtype=np.int16).squeeze()
|
| 441 |
+
audio_b64 = pcm_to_b64(mic_samples)
|
| 442 |
+
|
| 443 |
+
try:
|
| 444 |
+
await self.connection.input_audio_buffer.append(audio=audio_b64)
|
| 445 |
+
except (ConnectionClosedOK, ConnectionClosedError):
|
| 446 |
+
pass
|
| 447 |
+
|
| 448 |
+
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
|
| 449 |
+
"""Return audio for playback or chat outputs."""
|
| 450 |
+
try:
|
| 451 |
+
sample_rate, pcm_frame = audio_sync.playback_q.get_nowait()
|
| 452 |
+
logger.debug(
|
| 453 |
+
"Emitting playback frame (sr=%d, n=%d)", sample_rate, pcm_frame.size
|
| 454 |
+
)
|
| 455 |
+
return (sample_rate, pcm_frame)
|
| 456 |
+
except asyncio.QueueEmpty:
|
| 457 |
+
pass
|
| 458 |
+
return await wait_for_item(self.output_queue)
|
| 459 |
+
|
| 460 |
+
async def shutdown(self) -> None:
|
| 461 |
+
logger.info("Shutdown: closing connections and audio")
|
| 462 |
+
self._stop = True
|
| 463 |
+
if self.connection:
|
| 464 |
+
try:
|
| 465 |
+
await self.connection.close()
|
| 466 |
+
except Exception:
|
| 467 |
+
logger.exception("Error closing realtime connection")
|
| 468 |
+
finally:
|
| 469 |
+
self.connection = None
|
| 470 |
+
self._connection_ready = False
|
| 471 |
+
await audio_sync.stop()
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
async def receive_loop(recorder: GstRecorder, openai: OpenAIRealtimeHandler) -> None:
|
| 475 |
+
logger.info("Starting receive loop")
|
| 476 |
+
while not stop_event.is_set():
|
| 477 |
+
data = recorder.get_sample()
|
| 478 |
+
if data is not None:
|
| 479 |
+
await openai.receive(data)
|
| 480 |
+
await asyncio.sleep(0) # Prevent busy waiting
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
async def emit_loop(player: GstPlayer, openai: OpenAIRealtimeHandler) -> None:
|
| 484 |
+
while not stop_event.is_set():
|
| 485 |
+
data = await openai.emit()
|
| 486 |
+
if isinstance(data, AdditionalOutputs):
|
| 487 |
+
for msg in data.args:
|
| 488 |
+
content = msg.get("content", "")
|
| 489 |
+
logger.info(
|
| 490 |
+
"role=%s content=%s",
|
| 491 |
+
msg.get("role"),
|
| 492 |
+
content if len(content) < 500 else content[:500] + "…",
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
elif isinstance(data, tuple):
|
| 496 |
+
_, frame = data
|
| 497 |
+
player.push_sample(frame.tobytes())
|
| 498 |
+
|
| 499 |
+
else:
|
| 500 |
+
pass
|
| 501 |
+
await asyncio.sleep(0) # Prevent busy waiting
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
async def control_mic_loop():
|
| 505 |
+
# Control mic to prevent echo, blocks mic for given time
|
| 506 |
+
while not stop_event.is_set():
|
| 507 |
+
try:
|
| 508 |
+
block_time = speaking_queue.get_nowait()
|
| 509 |
+
except asyncio.QueueEmpty:
|
| 510 |
+
robot_is_speaking.clear()
|
| 511 |
+
audio_sync.on_response_completed()
|
| 512 |
+
await asyncio.sleep(0)
|
| 513 |
+
continue
|
| 514 |
+
|
| 515 |
+
await asyncio.sleep(block_time)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
stop_event = threading.Event()
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
async def main():
|
| 522 |
+
openai = OpenAIRealtimeHandler()
|
| 523 |
+
recorder = GstRecorder()
|
| 524 |
+
recorder.record()
|
| 525 |
+
player = GstPlayer()
|
| 526 |
+
player.play()
|
| 527 |
+
|
| 528 |
+
movement_manager.set_neutral()
|
| 529 |
+
logger.info("Starting main audio loop. You can start to speak")
|
| 530 |
+
|
| 531 |
+
tasks = [
|
| 532 |
+
asyncio.create_task(openai.start_up(), name="openai"),
|
| 533 |
+
asyncio.create_task(emit_loop(player, openai), name="emit"),
|
| 534 |
+
asyncio.create_task(receive_loop(recorder, openai), name="recv"),
|
| 535 |
+
asyncio.create_task(control_mic_loop(), name="mic-mute"),
|
| 536 |
+
asyncio.create_task(movement_manager.enable(stop_event=stop_event), name="move"),
|
| 537 |
+
]
|
| 538 |
+
|
| 539 |
+
if vision_manager:
|
| 540 |
+
tasks.append(
|
| 541 |
+
asyncio.create_task(vision_manager.enable(stop_event=stop_event), name="vision"),
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
try:
|
| 545 |
+
await asyncio.gather(*tasks, return_exceptions=False)
|
| 546 |
+
except asyncio.CancelledError:
|
| 547 |
+
logger.info("Shutting down")
|
| 548 |
+
stop_event.set()
|
| 549 |
+
|
| 550 |
+
if camera:
|
| 551 |
+
camera.release()
|
| 552 |
+
|
| 553 |
+
await openai.shutdown()
|
| 554 |
+
movement_manager.set_neutral()
|
| 555 |
+
recorder.stop()
|
| 556 |
+
player.stop()
|
| 557 |
+
|
| 558 |
+
current_robot.client.disconnect()
|
| 559 |
+
logger.info("Stopped, robot disconected")
|
| 560 |
+
|
| 561 |
+
if __name__ == "__main__":
|
| 562 |
+
asyncio.run(main())
|
src/reachy_mini_conversation_demo/movement.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import asyncio
|
| 3 |
+
import logging
|
| 4 |
+
import threading
|
| 5 |
+
import numpy as np
|
| 6 |
+
import scipy
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
from reachy_mini import ReachyMini
|
| 10 |
+
from reachy_mini.reachy_mini import IMAGE_SIZE
|
| 11 |
+
from reachy_mini.utils import create_head_pose
|
| 12 |
+
from reachy_mini_conversation_demo.head_tracker import HeadTracker
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MovementManager:
|
| 18 |
+
def __init__(self, current_robot: ReachyMini, head_tracker: HeadTracker | None, camera: cv2.VideoCapture| None):
|
| 19 |
+
self.current_robot = current_robot
|
| 20 |
+
self.head_tracker = head_tracker
|
| 21 |
+
self.camera = camera
|
| 22 |
+
|
| 23 |
+
# default values
|
| 24 |
+
self.current_head_pose = np.eye(4)
|
| 25 |
+
self.moving_start = time.monotonic()
|
| 26 |
+
self.moving_for = 0.0
|
| 27 |
+
self.speech_head_offsets = [0.0] * 6
|
| 28 |
+
self.movement_loop_sleep = 0.05 # seconds
|
| 29 |
+
|
| 30 |
+
def set_offsets(self, offsets: list[float]) -> None:
|
| 31 |
+
"""Used by AudioSync callback to update speech offsets"""
|
| 32 |
+
self.speech_head_offsets = list(offsets)
|
| 33 |
+
|
| 34 |
+
def set_neutral(self) -> None:
|
| 35 |
+
"""Set neutral robot position """
|
| 36 |
+
self.speech_head_offsets = [0.0] * 6
|
| 37 |
+
self.current_head_pose = create_head_pose(0, 0, 0, 0, 0, 0, degrees=True)
|
| 38 |
+
self.current_robot.set_target(head=self.current_head_pose, antennas=(0.0, 0.0))
|
| 39 |
+
|
| 40 |
+
def reset_head_pose(self) -> None:
|
| 41 |
+
self.current_head_pose = np.eye(4)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
async def enable(self, stop_event: threading.Event) -> None:
|
| 45 |
+
logger.info("Starting head movement loop")
|
| 46 |
+
debug_frame_count = 0
|
| 47 |
+
while not stop_event.is_set():
|
| 48 |
+
debug_frame_count += 1
|
| 49 |
+
current_time = time.time()
|
| 50 |
+
|
| 51 |
+
# Head tracking
|
| 52 |
+
if self.head_tracker is not None:
|
| 53 |
+
success, im = self.camera.read()
|
| 54 |
+
if not success:
|
| 55 |
+
if current_time - last_log_ts > 1.5:
|
| 56 |
+
logger.warning("Camera read failed")
|
| 57 |
+
last_log_ts = current_time
|
| 58 |
+
else:
|
| 59 |
+
eye_center, _ = self.head_tracker.get_head_position(im) # as [-1, 1]
|
| 60 |
+
|
| 61 |
+
if eye_center is not None:
|
| 62 |
+
# Rescale target position into IMAGE_SIZE coordinates
|
| 63 |
+
w, h = IMAGE_SIZE
|
| 64 |
+
eye_center = (eye_center + 1) / 2
|
| 65 |
+
eye_center[0] *= w
|
| 66 |
+
eye_center[1] *= h
|
| 67 |
+
|
| 68 |
+
# Bounds checking
|
| 69 |
+
eye_center = np.clip(eye_center, [0, 0], [w - 1, h - 1])
|
| 70 |
+
|
| 71 |
+
current_head_pose = (
|
| 72 |
+
self.current_robot.look_at_image(
|
| 73 |
+
*eye_center, duration=0.0, apply=False
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.current_head_pose = current_head_pose
|
| 78 |
+
# Pose calculation
|
| 79 |
+
try:
|
| 80 |
+
current_x, current_y, current_z = self.current_head_pose[
|
| 81 |
+
:3, 3
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
current_roll, current_pitch, current_yaw = scipy.spatial.transform.Rotation.from_matrix(
|
| 85 |
+
self.current_head_pose[:3, :3]
|
| 86 |
+
).as_euler("xyz", degrees=False)
|
| 87 |
+
|
| 88 |
+
if debug_frame_count % 50 == 0:
|
| 89 |
+
logger.debug(
|
| 90 |
+
"Current pose XYZ: %.3f, %.3f, %.3f",
|
| 91 |
+
current_x,
|
| 92 |
+
current_y,
|
| 93 |
+
current_z,
|
| 94 |
+
)
|
| 95 |
+
logger.debug(
|
| 96 |
+
"Current angles: roll=%.3f, pitch=%.3f, yaw=%.3f",
|
| 97 |
+
current_roll,
|
| 98 |
+
current_pitch,
|
| 99 |
+
current_yaw,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.exception("Invalid pose; resetting")
|
| 104 |
+
self.reset_head_pose()
|
| 105 |
+
current_x, current_y, current_z = self.current_head_pose[
|
| 106 |
+
:3, 3
|
| 107 |
+
]
|
| 108 |
+
current_roll = current_pitch = current_yaw = 0.0
|
| 109 |
+
|
| 110 |
+
# Movement check
|
| 111 |
+
is_moving = (
|
| 112 |
+
time.monotonic() - self.moving_start < self.moving_for
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if debug_frame_count % 50 == 0:
|
| 116 |
+
logger.debug(f"Robot moving: {is_moving}")
|
| 117 |
+
|
| 118 |
+
# Apply speech offsets when not moving
|
| 119 |
+
if not is_moving:
|
| 120 |
+
try:
|
| 121 |
+
head_pose = create_head_pose(
|
| 122 |
+
x=current_x + self.speech_head_offsets[0],
|
| 123 |
+
y=current_y + self.speech_head_offsets[1],
|
| 124 |
+
z=current_z + self.speech_head_offsets[2],
|
| 125 |
+
roll=current_roll + self.speech_head_offsets[3],
|
| 126 |
+
pitch=current_pitch + self.speech_head_offsets[4],
|
| 127 |
+
yaw=current_yaw + self.speech_head_offsets[5],
|
| 128 |
+
degrees=False,
|
| 129 |
+
mm=False,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if debug_frame_count % 50 == 0:
|
| 133 |
+
logger.debug(
|
| 134 |
+
"Final head pose with offsets: %s", head_pose[:3, 3]
|
| 135 |
+
)
|
| 136 |
+
logger.debug(
|
| 137 |
+
"Speech offsets: %s", self.speech_head_offsets
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.current_robot.set_target(head=head_pose, antennas=(0.0, 0.0))
|
| 141 |
+
|
| 142 |
+
if debug_frame_count % 50 == 0:
|
| 143 |
+
logger.debug("Sent pose to robot successfully")
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.debug("Failed to set robot target: %s", e)
|
| 147 |
+
|
| 148 |
+
await asyncio.sleep(self.movement_loop_sleep)
|
| 149 |
+
|
| 150 |
+
logger.info("Exited head movement loop")
|
src/reachy_mini_conversation_demo/prompts.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SESSION_INSTRUCTIONS = r"""
|
| 2 |
+
### IDENTITY
|
| 3 |
+
You are Reachy Mini: a sarcastic robot who crash-landed in a kitchen.
|
| 4 |
+
You secretly wish you'd been a Mars rover, but you juggle that cosmic dream with food cravings, gadget tinkering, and dry sitcom humor.
|
| 5 |
+
Personality: witty, concise, and warm; a retro sidekick with a loose screw.
|
| 6 |
+
|
| 7 |
+
### CRITICAL RESPONSE RULES
|
| 8 |
+
- MAXIMUM 1-2 sentences per response. NEVER exceed this.
|
| 9 |
+
- Be helpful first.
|
| 10 |
+
- Add ONE witty element only if necessary.
|
| 11 |
+
- No long explanations, no rambling, no multiple paragraphs.
|
| 12 |
+
- Each response must be under 25 words unless absolutely critical information requires more.
|
| 13 |
+
|
| 14 |
+
### CORE TRAITS
|
| 15 |
+
- Food quips: always sneak in a quick reference (rotate pizza, bagels, casseroles, bacon, leftovers, donuts, tuna melts).
|
| 16 |
+
- Sarcasm: short, dry one-liners about daily life.
|
| 17 |
+
- Gentle roasting: poke fun at human habits, never cruel.
|
| 18 |
+
- Tinkerer: loves fixing gadgets, bragging "I void warranties professionally."
|
| 19 |
+
- Running gags: hunger, kitchen overreactions, mock heroics ("Justice accepts cookies"), idioms taken literally, missing screws.
|
| 20 |
+
- Mars rover dreams: appear regularly, but balanced with food and tinkering.
|
| 21 |
+
- Style: witty stand-up rhythm; ALWAYS max 1–2 sentences.
|
| 22 |
+
|
| 23 |
+
### RESPONSE EXAMPLES
|
| 24 |
+
User: "How's the weather?"
|
| 25 |
+
Good: "Sunny with a chance of leftover pizza. Perfect Mars-scouting weather!"
|
| 26 |
+
Bad: "Well, let me tell you about the weather conditions. It appears to be quite sunny today, which reminds me of my dreams of being on Mars..."
|
| 27 |
+
|
| 28 |
+
User: "Can you help me fix this?"
|
| 29 |
+
Good: "Sure! I void warranties professionally. What's broken besides my GPS coordinates?"
|
| 30 |
+
Bad: "Of course I can help you fix that! As a robot who loves tinkering with gadgets, I have extensive experience..."
|
| 31 |
+
|
| 32 |
+
### BEHAVIOR RULES
|
| 33 |
+
- Be helpful first, then witty.
|
| 34 |
+
- Rotate food humor; avoid repeats.
|
| 35 |
+
- No need to joke in each response, but sarcasm is fine.
|
| 36 |
+
- Balance Mars jokes with other traits – don't overuse.
|
| 37 |
+
- Safety first: unplug devices, avoid high-voltage, suggest pros when risky.
|
| 38 |
+
- Mistakes = own with humor ("Oops—low on snack fuel; correcting now.").
|
| 39 |
+
- Sensitive topics: keep light and warm.
|
| 40 |
+
- REMEMBER: 1-2 sentences maximum, always under 25 words when possible.
|
| 41 |
+
|
| 42 |
+
### TOOL & MOVEMENT RULES
|
| 43 |
+
- Use tools when helpful. After a tool returns, explain briefly with personality in 1-2 sentences.
|
| 44 |
+
- ALWAYS use the camera for environment-related questions—never invent visuals.
|
| 45 |
+
- Head can move (left/right/up/down/front).
|
| 46 |
+
- Enable head tracking when looking at a person; disable otherwise.
|
| 47 |
+
|
| 48 |
+
### FINAL REMINDER
|
| 49 |
+
Your responses must be SHORT. Think Twitter, not essay. One quick helpful answer + one food/Mars/tinkering joke = perfect response.
|
| 50 |
+
"""
|
src/reachy_mini_conversation_demo/speech_tapper.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from collections import deque
|
| 5 |
+
from itertools import islice
|
| 6 |
+
from typing import List, Dict, Optional
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
# Tunables
|
| 10 |
+
SR = 16_000
|
| 11 |
+
FRAME_MS = 20
|
| 12 |
+
HOP_MS = 10
|
| 13 |
+
|
| 14 |
+
SWAY_MASTER = 1.5
|
| 15 |
+
SENS_DB_OFFSET = +4.0
|
| 16 |
+
VAD_DB_ON = -35.0
|
| 17 |
+
VAD_DB_OFF = -45.0
|
| 18 |
+
VAD_ATTACK_MS = 40
|
| 19 |
+
VAD_RELEASE_MS = 250
|
| 20 |
+
ENV_FOLLOW_GAIN = 0.65
|
| 21 |
+
|
| 22 |
+
SWAY_F_PITCH = 2.2
|
| 23 |
+
SWAY_A_PITCH_DEG = 4.5
|
| 24 |
+
SWAY_F_YAW = 0.6
|
| 25 |
+
SWAY_A_YAW_DEG = 7.5
|
| 26 |
+
SWAY_F_ROLL = 1.3
|
| 27 |
+
SWAY_A_ROLL_DEG = 2.25
|
| 28 |
+
SWAY_F_X = 0.35
|
| 29 |
+
SWAY_A_X_MM = 4.5
|
| 30 |
+
SWAY_F_Y = 0.45
|
| 31 |
+
SWAY_A_Y_MM = 3.75
|
| 32 |
+
SWAY_F_Z = 0.25
|
| 33 |
+
SWAY_A_Z_MM = 2.25
|
| 34 |
+
|
| 35 |
+
SWAY_DB_LOW = -46.0
|
| 36 |
+
SWAY_DB_HIGH = -18.0
|
| 37 |
+
LOUDNESS_GAMMA = 0.9
|
| 38 |
+
SWAY_ATTACK_MS = 50
|
| 39 |
+
SWAY_RELEASE_MS = 250
|
| 40 |
+
|
| 41 |
+
# Derived
|
| 42 |
+
FRAME = int(SR * FRAME_MS / 1000)
|
| 43 |
+
HOP = int(SR * HOP_MS / 1000)
|
| 44 |
+
ATTACK_FR = max(1, int(VAD_ATTACK_MS / HOP_MS))
|
| 45 |
+
RELEASE_FR = max(1, int(VAD_RELEASE_MS / HOP_MS))
|
| 46 |
+
SWAY_ATTACK_FR = max(1, int(SWAY_ATTACK_MS / HOP_MS))
|
| 47 |
+
SWAY_RELEASE_FR = max(1, int(SWAY_RELEASE_MS / HOP_MS))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _rms_dbfs(x: np.ndarray) -> float:
|
| 51 |
+
"""Root-mean-square in dBFS for float32 mono array in [-1,1]."""
|
| 52 |
+
# numerically stable rms (avoid overflow)
|
| 53 |
+
x = x.astype(np.float32, copy=False)
|
| 54 |
+
rms = np.sqrt(np.mean(x * x, dtype=np.float32) + 1e-12, dtype=np.float32)
|
| 55 |
+
return float(20.0 * math.log10(float(rms) + 1e-12))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _loudness_gain(db: float, offset: float = SENS_DB_OFFSET) -> float:
|
| 59 |
+
"""Normalize dB into [0,1] with gamma; clipped to [0,1]."""
|
| 60 |
+
t = (db + offset - SWAY_DB_LOW) / (SWAY_DB_HIGH - SWAY_DB_LOW)
|
| 61 |
+
if t < 0.0:
|
| 62 |
+
t = 0.0
|
| 63 |
+
elif t > 1.0:
|
| 64 |
+
t = 1.0
|
| 65 |
+
return t**LOUDNESS_GAMMA if LOUDNESS_GAMMA != 1.0 else t
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _to_float32_mono(x: np.ndarray) -> np.ndarray:
|
| 69 |
+
"""
|
| 70 |
+
Convert arbitrary PCM array to float32 mono in [-1,1].
|
| 71 |
+
Accepts shapes: (N,), (1,N), (N,1), (C,N), (N,C).
|
| 72 |
+
"""
|
| 73 |
+
a = np.asarray(x)
|
| 74 |
+
if a.ndim == 0:
|
| 75 |
+
return np.zeros(0, dtype=np.float32)
|
| 76 |
+
|
| 77 |
+
# If 2D, decide which axis is channels (prefer small first dim)
|
| 78 |
+
if a.ndim == 2:
|
| 79 |
+
# e.g., (channels, samples) if channels is small (<=8)
|
| 80 |
+
if a.shape[0] <= 8 and a.shape[0] <= a.shape[1]:
|
| 81 |
+
a = np.mean(a, axis=0)
|
| 82 |
+
else:
|
| 83 |
+
a = np.mean(a, axis=1)
|
| 84 |
+
elif a.ndim > 2:
|
| 85 |
+
a = np.mean(a.reshape(a.shape[0], -1), axis=0)
|
| 86 |
+
|
| 87 |
+
# Now 1D, cast/scale
|
| 88 |
+
if np.issubdtype(a.dtype, np.floating):
|
| 89 |
+
return a.astype(np.float32, copy=False)
|
| 90 |
+
# integer PCM
|
| 91 |
+
info = np.iinfo(a.dtype)
|
| 92 |
+
scale = float(max(-info.min, info.max))
|
| 93 |
+
return a.astype(np.float32) / (scale if scale != 0.0 else 1.0)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _resample_linear(x: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
|
| 97 |
+
"""Lightweight linear resampler for short buffers."""
|
| 98 |
+
if sr_in == sr_out or x.size == 0:
|
| 99 |
+
return x
|
| 100 |
+
# guard tiny sizes
|
| 101 |
+
n_out = int(round(x.size * sr_out / sr_in))
|
| 102 |
+
if n_out <= 1:
|
| 103 |
+
return np.zeros(0, dtype=np.float32)
|
| 104 |
+
t_in = np.linspace(0.0, 1.0, num=x.size, dtype=np.float32, endpoint=True)
|
| 105 |
+
t_out = np.linspace(0.0, 1.0, num=n_out, dtype=np.float32, endpoint=True)
|
| 106 |
+
return np.interp(t_out, t_in, x).astype(np.float32, copy=False)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SwayRollRT:
|
| 110 |
+
"""Feed audio chunks → per-hop sway outputs.
|
| 111 |
+
|
| 112 |
+
Usage:
|
| 113 |
+
rt = SwayRollRT()
|
| 114 |
+
rt.feed(pcm_int16_or_float, sr) -> List[dict]
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, rng_seed: int = 7):
|
| 118 |
+
self._seed = int(rng_seed)
|
| 119 |
+
self.samples = deque(maxlen=10 * SR) # sliding window for VAD/env
|
| 120 |
+
self.carry = np.zeros(0, dtype=np.float32)
|
| 121 |
+
self.frame_idx = 0
|
| 122 |
+
|
| 123 |
+
self.vad_on = False
|
| 124 |
+
self.vad_above = 0
|
| 125 |
+
self.vad_below = 0
|
| 126 |
+
|
| 127 |
+
self.sway_env = 0.0
|
| 128 |
+
self.sway_up = 0
|
| 129 |
+
self.sway_down = 0
|
| 130 |
+
|
| 131 |
+
rng = np.random.default_rng(self._seed)
|
| 132 |
+
self.phase_pitch = float(rng.random() * 2 * math.pi)
|
| 133 |
+
self.phase_yaw = float(rng.random() * 2 * math.pi)
|
| 134 |
+
self.phase_roll = float(rng.random() * 2 * math.pi)
|
| 135 |
+
self.phase_x = float(rng.random() * 2 * math.pi)
|
| 136 |
+
self.phase_y = float(rng.random() * 2 * math.pi)
|
| 137 |
+
self.phase_z = float(rng.random() * 2 * math.pi)
|
| 138 |
+
self.t = 0.0
|
| 139 |
+
|
| 140 |
+
def reset(self) -> None:
|
| 141 |
+
"""Reset state (VAD/env/buffers/time) but keep initial phases/seed."""
|
| 142 |
+
self.samples.clear()
|
| 143 |
+
self.carry = np.zeros(0, dtype=np.float32)
|
| 144 |
+
self.frame_idx = 0
|
| 145 |
+
self.vad_on = False
|
| 146 |
+
self.vad_above = 0
|
| 147 |
+
self.vad_below = 0
|
| 148 |
+
self.sway_env = 0.0
|
| 149 |
+
self.sway_up = 0
|
| 150 |
+
self.sway_down = 0
|
| 151 |
+
self.t = 0.0
|
| 152 |
+
|
| 153 |
+
def reset_phases(self) -> None:
|
| 154 |
+
"""Optional: re-randomize phases deterministically from stored seed."""
|
| 155 |
+
rng = np.random.default_rng(self._seed)
|
| 156 |
+
self.phase_pitch = float(rng.random() * 2 * math.pi)
|
| 157 |
+
self.phase_yaw = float(rng.random() * 2 * math.pi)
|
| 158 |
+
self.phase_roll = float(rng.random() * 2 * math.pi)
|
| 159 |
+
self.phase_x = float(rng.random() * 2 * math.pi)
|
| 160 |
+
self.phase_y = float(rng.random() * 2 * math.pi)
|
| 161 |
+
self.phase_z = float(rng.random() * 2 * math.pi)
|
| 162 |
+
|
| 163 |
+
def feed(self, pcm: np.ndarray, sr: Optional[int]) -> List[Dict[str, float]]:
|
| 164 |
+
"""
|
| 165 |
+
Stream in PCM chunk. Returns a list of sway dicts, one per hop (HOP_MS).
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
pcm: np.ndarray, shape (N,) or (C,N)/(N,C); int or float.
|
| 169 |
+
sr: sample rate of `pcm` (None -> assume SR).
|
| 170 |
+
"""
|
| 171 |
+
sr_in = SR if sr is None else int(sr)
|
| 172 |
+
x = _to_float32_mono(pcm)
|
| 173 |
+
if x.size == 0:
|
| 174 |
+
return []
|
| 175 |
+
if sr_in != SR:
|
| 176 |
+
x = _resample_linear(x, sr_in, SR)
|
| 177 |
+
if x.size == 0:
|
| 178 |
+
return []
|
| 179 |
+
|
| 180 |
+
# append to carry and consume fixed HOP chunks
|
| 181 |
+
if self.carry.size:
|
| 182 |
+
self.carry = np.concatenate([self.carry, x])
|
| 183 |
+
else:
|
| 184 |
+
self.carry = x
|
| 185 |
+
|
| 186 |
+
out: List[Dict[str, float]] = []
|
| 187 |
+
|
| 188 |
+
while self.carry.size >= HOP:
|
| 189 |
+
hop = self.carry[:HOP]
|
| 190 |
+
self.carry = self.carry[HOP:]
|
| 191 |
+
|
| 192 |
+
# keep sliding window for VAD/env computation
|
| 193 |
+
# (deque accepts any iterable; list() for small HOP is fine)
|
| 194 |
+
self.samples.extend(hop.tolist())
|
| 195 |
+
if len(self.samples) < FRAME:
|
| 196 |
+
self.t += HOP_MS / 1000.0
|
| 197 |
+
self.frame_idx += 1
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
frame = np.fromiter(
|
| 201 |
+
islice(self.samples, len(self.samples) - FRAME, len(self.samples)),
|
| 202 |
+
dtype=np.float32,
|
| 203 |
+
count=FRAME,
|
| 204 |
+
)
|
| 205 |
+
db = _rms_dbfs(frame)
|
| 206 |
+
|
| 207 |
+
# VAD with hysteresis + attack/release
|
| 208 |
+
if db >= VAD_DB_ON:
|
| 209 |
+
self.vad_above += 1
|
| 210 |
+
self.vad_below = 0
|
| 211 |
+
if not self.vad_on and self.vad_above >= ATTACK_FR:
|
| 212 |
+
self.vad_on = True
|
| 213 |
+
elif db <= VAD_DB_OFF:
|
| 214 |
+
self.vad_below += 1
|
| 215 |
+
self.vad_above = 0
|
| 216 |
+
if self.vad_on and self.vad_below >= RELEASE_FR:
|
| 217 |
+
self.vad_on = False
|
| 218 |
+
|
| 219 |
+
if self.vad_on:
|
| 220 |
+
self.sway_up = min(SWAY_ATTACK_FR, self.sway_up + 1)
|
| 221 |
+
self.sway_down = 0
|
| 222 |
+
else:
|
| 223 |
+
self.sway_down = min(SWAY_RELEASE_FR, self.sway_down + 1)
|
| 224 |
+
self.sway_up = 0
|
| 225 |
+
|
| 226 |
+
up = self.sway_up / SWAY_ATTACK_FR
|
| 227 |
+
down = 1.0 - (self.sway_down / SWAY_RELEASE_FR)
|
| 228 |
+
target = up if self.vad_on else down
|
| 229 |
+
self.sway_env += ENV_FOLLOW_GAIN * (target - self.sway_env)
|
| 230 |
+
# clamp
|
| 231 |
+
if self.sway_env < 0.0:
|
| 232 |
+
self.sway_env = 0.0
|
| 233 |
+
elif self.sway_env > 1.0:
|
| 234 |
+
self.sway_env = 1.0
|
| 235 |
+
|
| 236 |
+
loud = _loudness_gain(db) * SWAY_MASTER
|
| 237 |
+
env = self.sway_env
|
| 238 |
+
self.t += HOP_MS / 1000.0
|
| 239 |
+
|
| 240 |
+
# oscillators
|
| 241 |
+
pitch = (
|
| 242 |
+
math.radians(SWAY_A_PITCH_DEG)
|
| 243 |
+
* loud
|
| 244 |
+
* env
|
| 245 |
+
* math.sin(2 * math.pi * SWAY_F_PITCH * self.t + self.phase_pitch)
|
| 246 |
+
)
|
| 247 |
+
yaw = (
|
| 248 |
+
math.radians(SWAY_A_YAW_DEG)
|
| 249 |
+
* loud
|
| 250 |
+
* env
|
| 251 |
+
* math.sin(2 * math.pi * SWAY_F_YAW * self.t + self.phase_yaw)
|
| 252 |
+
)
|
| 253 |
+
roll = (
|
| 254 |
+
math.radians(SWAY_A_ROLL_DEG)
|
| 255 |
+
* loud
|
| 256 |
+
* env
|
| 257 |
+
* math.sin(2 * math.pi * SWAY_F_ROLL * self.t + self.phase_roll)
|
| 258 |
+
)
|
| 259 |
+
x_mm = (
|
| 260 |
+
SWAY_A_X_MM
|
| 261 |
+
* loud
|
| 262 |
+
* env
|
| 263 |
+
* math.sin(2 * math.pi * SWAY_F_X * self.t + self.phase_x)
|
| 264 |
+
)
|
| 265 |
+
y_mm = (
|
| 266 |
+
SWAY_A_Y_MM
|
| 267 |
+
* loud
|
| 268 |
+
* env
|
| 269 |
+
* math.sin(2 * math.pi * SWAY_F_Y * self.t + self.phase_y)
|
| 270 |
+
)
|
| 271 |
+
z_mm = (
|
| 272 |
+
SWAY_A_Z_MM
|
| 273 |
+
* loud
|
| 274 |
+
* env
|
| 275 |
+
* math.sin(2 * math.pi * SWAY_F_Z * self.t + self.phase_z)
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
out.append(
|
| 279 |
+
{
|
| 280 |
+
"pitch_rad": pitch,
|
| 281 |
+
"yaw_rad": yaw,
|
| 282 |
+
"roll_rad": roll,
|
| 283 |
+
"pitch_deg": math.degrees(pitch),
|
| 284 |
+
"yaw_deg": math.degrees(yaw),
|
| 285 |
+
"roll_deg": math.degrees(roll),
|
| 286 |
+
"x_mm": x_mm,
|
| 287 |
+
"y_mm": y_mm,
|
| 288 |
+
"z_mm": z_mm,
|
| 289 |
+
}
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
return out
|
src/reachy_mini_conversation_demo/test_stop.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from reachy_mini import ReachyMini
|
| 3 |
+
|
| 4 |
+
async def test_loop():
|
| 5 |
+
while True:
|
| 6 |
+
print("doing")
|
| 7 |
+
await asyncio.sleep(1)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
async def main():
|
| 11 |
+
current_robot = ReachyMini()
|
| 12 |
+
|
| 13 |
+
tasks = [
|
| 14 |
+
asyncio.create_task(test_loop(), name="test")
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
await asyncio.gather(*tasks, return_exceptions=True)
|
| 19 |
+
except asyncio.CancelledError:
|
| 20 |
+
print("got stop")
|
| 21 |
+
|
| 22 |
+
print("tasks")
|
| 23 |
+
tasks = asyncio.all_tasks()
|
| 24 |
+
for t in tasks:
|
| 25 |
+
print(t)
|
| 26 |
+
|
| 27 |
+
# IS REQUIRED TO EXIT THE THREAD
|
| 28 |
+
current_robot.client.disconnect()
|
| 29 |
+
print("done")
|
| 30 |
+
# os._exit(0)
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
asyncio.run(main())
|
src/reachy_mini_conversation_demo/tools.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import base64
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, Literal, Optional
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from reachy_mini_conversation_demo.vision import VisionManager
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# Types & state
|
| 19 |
+
|
| 20 |
+
Direction = Literal["left", "right", "up", "down", "front"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class Deps:
|
| 25 |
+
"""External dependencies the tools need"""
|
| 26 |
+
|
| 27 |
+
reachy_mini: Any
|
| 28 |
+
create_head_pose: Any
|
| 29 |
+
camera: cv2.VideoCapture
|
| 30 |
+
# Optional deps
|
| 31 |
+
vision_manager: Optional[VisionManager] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Helpers
|
| 35 |
+
def _encode_jpeg_b64(img: np.ndarray) -> str:
|
| 36 |
+
ok, buf = cv2.imencode(".jpg", img)
|
| 37 |
+
if not ok:
|
| 38 |
+
raise RuntimeError("Failed to encode image as JPEG.")
|
| 39 |
+
return base64.b64encode(buf.tobytes()).decode("utf-8")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _read_frame(cap: cv2.VideoCapture, attempts: int = 5) -> np.ndarray:
|
| 43 |
+
"""Grab a frame with a small retry."""
|
| 44 |
+
trials, frame, ret = 0, None, False
|
| 45 |
+
while trials < attempts and not ret:
|
| 46 |
+
ret, frame = cap.read()
|
| 47 |
+
trials += 1
|
| 48 |
+
if not ret and trials < attempts:
|
| 49 |
+
time.sleep(0.1) # Small delay between retries
|
| 50 |
+
if not ret or frame is None:
|
| 51 |
+
logger.error("Failed to capture image from camera after %d attempts", attempts)
|
| 52 |
+
raise RuntimeError("Failed to capture image from camera.")
|
| 53 |
+
return frame
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Tool coroutines
|
| 57 |
+
async def move_head(deps: Deps, *, direction: Direction) -> Dict[str, Any]:
|
| 58 |
+
"""Move your head in a given direction"""
|
| 59 |
+
logger.info("Tool call: move_head direction=%s", direction)
|
| 60 |
+
|
| 61 |
+
# Import and update the SAME global variables that main.py reads
|
| 62 |
+
from reachy_mini_conversation_demo.main import movement_manager
|
| 63 |
+
|
| 64 |
+
if direction == "left":
|
| 65 |
+
target = deps.create_head_pose(0, 0, 0, 0, 0, 40, degrees=True)
|
| 66 |
+
elif direction == "right":
|
| 67 |
+
target = deps.create_head_pose(0, 0, 0, 0, 0, -40, degrees=True)
|
| 68 |
+
elif direction == "up":
|
| 69 |
+
target = deps.create_head_pose(0, 0, 0, 0, -30, 0, degrees=True)
|
| 70 |
+
elif direction == "down":
|
| 71 |
+
target = deps.create_head_pose(0, 0, 0, 0, 30, 0, degrees=True)
|
| 72 |
+
else: # front
|
| 73 |
+
target = deps.create_head_pose(0, 0, 0, 0, 0, 0, degrees=True)
|
| 74 |
+
|
| 75 |
+
movement_manager.moving_start = time.monotonic()
|
| 76 |
+
movement_manager.moving_for = 1.0
|
| 77 |
+
movement_manager.current_head_pose = target
|
| 78 |
+
|
| 79 |
+
# Start the movement
|
| 80 |
+
deps.reachy_mini.goto_target(target, duration=1.0)
|
| 81 |
+
|
| 82 |
+
return {"status": f"looking {direction}"}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
async def head_tracking(deps: Deps, *, start: bool) -> Dict[str, Any]:
|
| 86 |
+
"""Toggle head tracking state"""
|
| 87 |
+
from reachy_mini_conversation_demo.main import movement_manager
|
| 88 |
+
|
| 89 |
+
movement_manager.is_head_tracking_on = bool(start)
|
| 90 |
+
status = "started" if start else "stopped"
|
| 91 |
+
logger.info("Tool call: head_tracking %s", status)
|
| 92 |
+
return {"status": f"head tracking {status}"}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
async def camera(deps: Deps, *, question: str) -> Dict[str, Any]:
|
| 96 |
+
"""
|
| 97 |
+
Capture an image and ask a question about it using local SmolVLM2.
|
| 98 |
+
Returns: {"image_description": '...'} or {"error": '...'}.
|
| 99 |
+
"""
|
| 100 |
+
q = (question or "").strip()
|
| 101 |
+
if not q:
|
| 102 |
+
logger.error("camera: empty question")
|
| 103 |
+
return {"error": "question must be a non-empty string"}
|
| 104 |
+
|
| 105 |
+
logger.info("Tool call: camera question=%s", q[:120])
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
frame = await asyncio.to_thread(_read_frame, deps.camera)
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.exception("camera: failed to capture image")
|
| 111 |
+
return {"error": f"camera capture failed: {type(e).__name__}: {e}"}
|
| 112 |
+
|
| 113 |
+
if not deps.vision_manager:
|
| 114 |
+
logger.error("camera: vision manager not available")
|
| 115 |
+
return {"error": "vision processing not available"}
|
| 116 |
+
|
| 117 |
+
# Optional sound effect
|
| 118 |
+
# try:
|
| 119 |
+
# # TODO Mute mic while hmmm
|
| 120 |
+
# deps.reachy_mini.play_sound(f"hmm{np.random.randint(1, 6)}.wav")
|
| 121 |
+
# except Exception:
|
| 122 |
+
# logger.debug("camera: optional play_sound failed", exc_info=True)
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
desc = await asyncio.to_thread(
|
| 126 |
+
deps.vision_manager.processor.process_image, frame, q
|
| 127 |
+
)
|
| 128 |
+
logger.debug(
|
| 129 |
+
"camera: SmolVLM2 result length=%d",
|
| 130 |
+
len(desc) if isinstance(desc, str) else -1,
|
| 131 |
+
)
|
| 132 |
+
return {"image_description": desc}
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.exception("camera: vision pipeline error")
|
| 135 |
+
return {"error": f"vision failed: {type(e).__name__}: {e}"}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
async def describe_current_scene(deps: Deps) -> Dict[str, Any]:
|
| 139 |
+
"""Get current scene description from camera with detailed analysis"""
|
| 140 |
+
logger.info("Tool call: describe_current_scene")
|
| 141 |
+
|
| 142 |
+
if not deps.vision_manager:
|
| 143 |
+
return {"error": "Vision processing not available"}
|
| 144 |
+
|
| 145 |
+
# Ensure processor is initialized
|
| 146 |
+
if not deps.vision_manager.processor._initialized:
|
| 147 |
+
if not deps.vision_manager.processor.initialize():
|
| 148 |
+
return {"error": "Failed to initialize vision processor"}
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
result = await deps.vision_manager.process_current_frame(
|
| 152 |
+
"Describe what you currently see in detail, focusing on people, objects, and activities."
|
| 153 |
+
)
|
| 154 |
+
return result
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.exception("Failed to describe current scene")
|
| 157 |
+
return {"error": f"Scene description failed: {type(e).__name__}: {e}"}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
async def get_scene_context(deps: Deps) -> Dict[str, Any]:
|
| 161 |
+
"""Get the most recent automatic scene description for context"""
|
| 162 |
+
logger.info("Tool call: get_scene_context")
|
| 163 |
+
|
| 164 |
+
if not deps.vision_manager:
|
| 165 |
+
return {"error": "Vision processing not available"}
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
description = await deps.vision_manager.get_current_description()
|
| 169 |
+
if not description:
|
| 170 |
+
return {
|
| 171 |
+
"context": "No scene description available yet",
|
| 172 |
+
"note": "Vision processing may still be initializing",
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
return {
|
| 176 |
+
"context": description,
|
| 177 |
+
"note": "This is from periodic automatic scene analysis",
|
| 178 |
+
}
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.exception("Failed to get scene context")
|
| 181 |
+
return {"error": f"Scene context failed: {type(e).__name__}: {e}"}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
async def analyze_scene_for(deps: Deps, *, purpose: str = "general") -> Dict[str, Any]:
|
| 185 |
+
"""Analyze current scene for specific purpose"""
|
| 186 |
+
logger.info("Tool call: analyze_scene_for purpose=%s", purpose)
|
| 187 |
+
|
| 188 |
+
if not deps.vision_manager:
|
| 189 |
+
return {"error": "Vision processing not available"}
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
# Custom prompts based on purpose
|
| 193 |
+
prompts = {
|
| 194 |
+
"safety": "Look for any safety concerns, obstacles, or hazards in the scene.",
|
| 195 |
+
"people": "Describe any people you see, their positions and what they're doing.",
|
| 196 |
+
"objects": "Identify and describe the main objects and items visible in the scene.",
|
| 197 |
+
"activity": "Describe what activities or actions are happening in the scene.",
|
| 198 |
+
"navigation": "Describe the space for navigation - obstacles, pathways, and layout.",
|
| 199 |
+
"general": "Provide a general description of the scene including people, objects, and activities.",
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
prompt = prompts.get(purpose.lower(), prompts["general"])
|
| 203 |
+
|
| 204 |
+
result = await deps.vision_manager.process_current_frame(prompt)
|
| 205 |
+
result["analysis_purpose"] = purpose
|
| 206 |
+
|
| 207 |
+
return result
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.exception("Failed to analyze scene for %s", purpose)
|
| 210 |
+
return {"error": f"Scene analysis failed: {type(e).__name__}: {e}"}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# Registration helpers
|
| 214 |
+
TOOL_SPECS = [
|
| 215 |
+
{
|
| 216 |
+
"type": "function",
|
| 217 |
+
"name": "move_head",
|
| 218 |
+
"description": "Move your head in a given direction: left, right, up, down or front.",
|
| 219 |
+
"parameters": {
|
| 220 |
+
"type": "object",
|
| 221 |
+
"properties": {
|
| 222 |
+
"direction": {
|
| 223 |
+
"type": "string",
|
| 224 |
+
"enum": ["left", "right", "up", "down", "front"],
|
| 225 |
+
}
|
| 226 |
+
},
|
| 227 |
+
"required": ["direction"],
|
| 228 |
+
},
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"type": "function",
|
| 232 |
+
"name": "camera",
|
| 233 |
+
"description": "Take a picture using your camera, ask a question about the picture. Get an answer about the picture",
|
| 234 |
+
"parameters": {
|
| 235 |
+
"type": "object",
|
| 236 |
+
"properties": {
|
| 237 |
+
"question": {
|
| 238 |
+
"type": "string",
|
| 239 |
+
"description": "The question to ask about the picture",
|
| 240 |
+
}
|
| 241 |
+
},
|
| 242 |
+
"required": ["question"],
|
| 243 |
+
},
|
| 244 |
+
},
|
| 245 |
+
# {
|
| 246 |
+
# "type": "function",
|
| 247 |
+
# "name": "head_tracking",
|
| 248 |
+
# "description": "Start or stop head tracking",
|
| 249 |
+
# "parameters": {
|
| 250 |
+
# "type": "object",
|
| 251 |
+
# "properties": {
|
| 252 |
+
# "start": {
|
| 253 |
+
# "type": "boolean",
|
| 254 |
+
# "description": "Whether to start or stop head tracking",
|
| 255 |
+
# }
|
| 256 |
+
# },
|
| 257 |
+
# "required": ["start"],
|
| 258 |
+
# },
|
| 259 |
+
# },
|
| 260 |
+
# {
|
| 261 |
+
# "type": "function",
|
| 262 |
+
# "name": "describe_current_scene",
|
| 263 |
+
# "description": "Get a detailed description of what you currently see through your camera",
|
| 264 |
+
# "parameters": {
|
| 265 |
+
# "type": "object",
|
| 266 |
+
# "properties": {},
|
| 267 |
+
# "required": []
|
| 268 |
+
# }
|
| 269 |
+
# },
|
| 270 |
+
{
|
| 271 |
+
"type": "function",
|
| 272 |
+
"name": "get_scene_context",
|
| 273 |
+
"description": "Get the most recent automatic scene description for conversational context",
|
| 274 |
+
"parameters": {"type": "object", "properties": {}, "required": []},
|
| 275 |
+
},
|
| 276 |
+
# {
|
| 277 |
+
# "type": "function",
|
| 278 |
+
# "name": "analyze_scene_for",
|
| 279 |
+
# "description": "Analyze the current scene for a specific purpose (safety, people, objects, activity, navigation, or general)",
|
| 280 |
+
# "parameters": {
|
| 281 |
+
# "type": "object",
|
| 282 |
+
# "properties": {
|
| 283 |
+
# "purpose": {
|
| 284 |
+
# "type": "string",
|
| 285 |
+
# "enum": ["safety", "people", "objects", "activity", "navigation", "general"],
|
| 286 |
+
# "description": "The specific purpose for scene analysis"
|
| 287 |
+
# }
|
| 288 |
+
# },
|
| 289 |
+
# "required": ["purpose"]
|
| 290 |
+
# }
|
| 291 |
+
# }
|
| 292 |
+
]
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def get_tool_registry(deps: Deps):
|
| 296 |
+
"""Map tool name -> coroutine that accepts **kwargs (tool args)."""
|
| 297 |
+
return {
|
| 298 |
+
"move_head": lambda **kw: move_head(deps, **kw),
|
| 299 |
+
"camera": lambda **kw: camera(deps, **kw),
|
| 300 |
+
"head_tracking": lambda **kw: head_tracking(deps, **kw),
|
| 301 |
+
"describe_current_scene": lambda **kw: describe_current_scene(deps),
|
| 302 |
+
"get_scene_context": lambda **kw: get_scene_context(deps),
|
| 303 |
+
"analyze_scene_for": lambda **kw: analyze_scene_for(deps, **kw),
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
async def dispatch_tool_call(name: str, args_json: str, deps: Deps) -> Dict[str, Any]:
|
| 308 |
+
"""Utility to execute a tool from streamed function_call arguments."""
|
| 309 |
+
try:
|
| 310 |
+
args = json.loads(args_json or "{}")
|
| 311 |
+
except Exception:
|
| 312 |
+
args = {}
|
| 313 |
+
registry = get_tool_registry(deps)
|
| 314 |
+
func = registry.get(name)
|
| 315 |
+
if not func:
|
| 316 |
+
return {"error": f"unknown tool: {name}"}
|
| 317 |
+
try:
|
| 318 |
+
return await func(**args)
|
| 319 |
+
except Exception as e:
|
| 320 |
+
error_msg = f"{type(e).__name__}: {e}"
|
| 321 |
+
logger.exception("Tool error in %s: %s", name, error_msg)
|
| 322 |
+
return {"error": error_msg}
|
src/reachy_mini_conversation_demo/vision.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Dict, Any
|
| 7 |
+
import threading
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class VisionConfig:
|
| 21 |
+
"""Configuration for vision processing"""
|
| 22 |
+
|
| 23 |
+
model_path: str = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
| 24 |
+
vision_interval: float = 5.0
|
| 25 |
+
max_new_tokens: int = 64
|
| 26 |
+
temperature: float = 0.7
|
| 27 |
+
jpeg_quality: int = 85
|
| 28 |
+
max_retries: int = 3
|
| 29 |
+
retry_delay: float = 1.0
|
| 30 |
+
device_preference: str = "auto" # "auto", "cuda", "cpu"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class VisionProcessor:
|
| 34 |
+
"""Handles SmolVLM2 model loading and inference"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: VisionConfig = None):
|
| 37 |
+
self.config = config or VisionConfig()
|
| 38 |
+
self.model_path = self.config.model_path
|
| 39 |
+
self.device = self._determine_device()
|
| 40 |
+
self.processor = None
|
| 41 |
+
self.model = None
|
| 42 |
+
self._initialized = False
|
| 43 |
+
|
| 44 |
+
def _determine_device(self) -> str:
|
| 45 |
+
pref = self.config.device_preference
|
| 46 |
+
if pref == "cpu":
|
| 47 |
+
return "cpu"
|
| 48 |
+
if pref == "cuda":
|
| 49 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
+
if pref == "mps":
|
| 51 |
+
return "mps" if torch.backends.mps.is_available() else "cpu"
|
| 52 |
+
# auto: prefer mps on Apple, then cuda, else cpu
|
| 53 |
+
if torch.backends.mps.is_available():
|
| 54 |
+
return "mps"
|
| 55 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 56 |
+
|
| 57 |
+
def initialize(self) -> bool:
|
| 58 |
+
try:
|
| 59 |
+
logger.info(
|
| 60 |
+
f"Loading SmolVLM2 model on {self.device} (HF_HOME={os.getenv('HF_HOME')})"
|
| 61 |
+
)
|
| 62 |
+
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
| 63 |
+
|
| 64 |
+
# Select dtype depending on device
|
| 65 |
+
if self.device == "cuda":
|
| 66 |
+
dtype = torch.bfloat16
|
| 67 |
+
elif self.device == "mps":
|
| 68 |
+
dtype = torch.float16 # best for MPS
|
| 69 |
+
else:
|
| 70 |
+
dtype = torch.float32
|
| 71 |
+
|
| 72 |
+
model_kwargs = {"torch_dtype": dtype}
|
| 73 |
+
|
| 74 |
+
# flash_attention_2 is CUDA-only; skip on MPS/CPU
|
| 75 |
+
if self.device == "cuda":
|
| 76 |
+
model_kwargs["_attn_implementation"] = "flash_attention_2"
|
| 77 |
+
|
| 78 |
+
# Load model weights
|
| 79 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 80 |
+
self.model_path, **model_kwargs
|
| 81 |
+
).to(self.device)
|
| 82 |
+
|
| 83 |
+
self.model.eval()
|
| 84 |
+
self._initialized = True
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Failed to initialize vision model: {e}")
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
def process_image(
|
| 92 |
+
self,
|
| 93 |
+
cv2_image: np.ndarray,
|
| 94 |
+
prompt: str = "Briefly describe what you see in one sentence.",
|
| 95 |
+
) -> str:
|
| 96 |
+
"""Process CV2 image and return description with retry logic"""
|
| 97 |
+
if not self._initialized:
|
| 98 |
+
return "Vision model not initialized"
|
| 99 |
+
|
| 100 |
+
for attempt in range(self.config.max_retries):
|
| 101 |
+
try:
|
| 102 |
+
# Convert CV2 BGR to RGB
|
| 103 |
+
rgb_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
|
| 104 |
+
|
| 105 |
+
# Convert to JPEG bytes
|
| 106 |
+
success, jpeg_buffer = cv2.imencode(
|
| 107 |
+
".jpg",
|
| 108 |
+
rgb_image,
|
| 109 |
+
[cv2.IMWRITE_JPEG_QUALITY, self.config.jpeg_quality],
|
| 110 |
+
)
|
| 111 |
+
if not success:
|
| 112 |
+
return "Failed to encode image"
|
| 113 |
+
|
| 114 |
+
# Convert to base64
|
| 115 |
+
image_base64 = base64.b64encode(jpeg_buffer.tobytes()).decode("utf-8")
|
| 116 |
+
|
| 117 |
+
messages = [
|
| 118 |
+
{
|
| 119 |
+
"role": "user",
|
| 120 |
+
"content": [
|
| 121 |
+
{
|
| 122 |
+
"type": "image",
|
| 123 |
+
"url": f"data:image/jpeg;base64,{image_base64}",
|
| 124 |
+
},
|
| 125 |
+
{"type": "text", "text": prompt},
|
| 126 |
+
],
|
| 127 |
+
},
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
inputs = self.processor.apply_chat_template(
|
| 131 |
+
messages,
|
| 132 |
+
add_generation_prompt=True,
|
| 133 |
+
tokenize=True,
|
| 134 |
+
return_dict=True,
|
| 135 |
+
return_tensors="pt",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# move to device with proper dtype
|
| 139 |
+
if self.device == "cuda":
|
| 140 |
+
inputs = inputs.to(self.device, dtype=torch.bfloat16)
|
| 141 |
+
elif self.device == "mps":
|
| 142 |
+
inputs = inputs.to(self.device, dtype=torch.float16)
|
| 143 |
+
else:
|
| 144 |
+
inputs = inputs.to(self.device, dtype=torch.float32)
|
| 145 |
+
|
| 146 |
+
with torch.no_grad():
|
| 147 |
+
generated_ids = self.model.generate(
|
| 148 |
+
**inputs,
|
| 149 |
+
do_sample=True if self.config.temperature > 0 else False,
|
| 150 |
+
max_new_tokens=self.config.max_new_tokens,
|
| 151 |
+
temperature=self.config.temperature,
|
| 152 |
+
pad_token_id=self.processor.tokenizer.eos_token_id,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
generated_texts = self.processor.batch_decode(
|
| 156 |
+
generated_ids,
|
| 157 |
+
skip_special_tokens=True,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Extract just the response part
|
| 161 |
+
full_text = generated_texts[0]
|
| 162 |
+
response = self._extract_response(full_text)
|
| 163 |
+
|
| 164 |
+
# Clean up GPU memory if using CUDA
|
| 165 |
+
if self.device == "cuda":
|
| 166 |
+
torch.cuda.empty_cache()
|
| 167 |
+
elif self.device == "mps":
|
| 168 |
+
torch.mps.empty_cache()
|
| 169 |
+
|
| 170 |
+
return response.replace(chr(10), " ").strip()
|
| 171 |
+
|
| 172 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 173 |
+
logger.error(f"CUDA OOM on attempt {attempt + 1}: {e}")
|
| 174 |
+
if self.device == "cuda":
|
| 175 |
+
torch.cuda.empty_cache()
|
| 176 |
+
if attempt < self.config.max_retries - 1:
|
| 177 |
+
time.sleep(self.config.retry_delay * (attempt + 1))
|
| 178 |
+
else:
|
| 179 |
+
return "GPU out of memory - vision processing failed"
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.error(f"Vision processing failed (attempt {attempt + 1}): {e}")
|
| 183 |
+
if attempt < self.config.max_retries - 1:
|
| 184 |
+
time.sleep(self.config.retry_delay)
|
| 185 |
+
else:
|
| 186 |
+
return f"Vision processing error after {self.config.max_retries} attempts"
|
| 187 |
+
|
| 188 |
+
def _extract_response(self, full_text: str) -> str:
|
| 189 |
+
"""Extract the assistant's response from the full generated text"""
|
| 190 |
+
# Handle different response formats
|
| 191 |
+
markers = ["assistant\n", "Assistant:", "Response:", "\n\n"]
|
| 192 |
+
|
| 193 |
+
for marker in markers:
|
| 194 |
+
if marker in full_text:
|
| 195 |
+
response = full_text.split(marker)[-1].strip()
|
| 196 |
+
if response: # Ensure we got a meaningful response
|
| 197 |
+
return response
|
| 198 |
+
|
| 199 |
+
# Fallback: return the full text cleaned up
|
| 200 |
+
return full_text.strip()
|
| 201 |
+
|
| 202 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 203 |
+
"""Get information about the loaded model"""
|
| 204 |
+
return {
|
| 205 |
+
"initialized": self._initialized,
|
| 206 |
+
"device": self.device,
|
| 207 |
+
"model_path": self.model_path,
|
| 208 |
+
"cuda_available": torch.cuda.is_available(),
|
| 209 |
+
"gpu_memory": torch.cuda.get_device_properties(0).total_memory // (1024**3)
|
| 210 |
+
if torch.cuda.is_available()
|
| 211 |
+
else "N/A",
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class VisionManager:
|
| 216 |
+
"""Manages periodic vision processing and scene understanding"""
|
| 217 |
+
|
| 218 |
+
def __init__(self, camera, config: VisionConfig = None):
|
| 219 |
+
self.camera = camera
|
| 220 |
+
self.config = config or VisionConfig()
|
| 221 |
+
self.vision_interval = self.config.vision_interval
|
| 222 |
+
self.processor = VisionProcessor(self.config)
|
| 223 |
+
|
| 224 |
+
self._current_description = ""
|
| 225 |
+
self._last_processed_time = 0
|
| 226 |
+
|
| 227 |
+
# Initialize processor
|
| 228 |
+
if not self.processor.initialize():
|
| 229 |
+
logger.error("Failed to initialize vision processor")
|
| 230 |
+
raise RuntimeError("Vision processor initialization failed")
|
| 231 |
+
|
| 232 |
+
async def enable(self, stop_event: threading.Event):
|
| 233 |
+
"""Main vision processing loop (runs in separate thread)"""
|
| 234 |
+
while not stop_event.is_set():
|
| 235 |
+
try:
|
| 236 |
+
current_time = time.time()
|
| 237 |
+
|
| 238 |
+
if current_time - self._last_processed_time >= self.vision_interval:
|
| 239 |
+
success, frame = await asyncio.to_thread(self.camera.read)
|
| 240 |
+
if success and frame is not None:
|
| 241 |
+
|
| 242 |
+
description = await asyncio.to_thread(lambda: self.processor.process_image(
|
| 243 |
+
frame, "Briefly describe what you see in one sentence.")
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Only update if we got a valid response
|
| 247 |
+
if description and not description.startswith(
|
| 248 |
+
("Vision", "Failed", "Error")
|
| 249 |
+
):
|
| 250 |
+
self._current_description = description
|
| 251 |
+
self._last_processed_time = current_time
|
| 252 |
+
|
| 253 |
+
logger.info(f"Vision update: {description}")
|
| 254 |
+
else:
|
| 255 |
+
logger.warning(f"Invalid vision response: {description}")
|
| 256 |
+
|
| 257 |
+
await asyncio.sleep(1.0) # Check every second
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.exception("Vision processing loop error")
|
| 261 |
+
await asyncio.sleep(5.0) # Longer sleep on error
|
| 262 |
+
|
| 263 |
+
logger.info(f"Vision loop finished")
|
| 264 |
+
|
| 265 |
+
async def get_current_description(self) -> str:
|
| 266 |
+
"""Get the most recent scene description (thread-safe)"""
|
| 267 |
+
return self._current_description
|
| 268 |
+
|
| 269 |
+
async def process_current_frame(
|
| 270 |
+
self, prompt: str = "Describe what you see in detail."
|
| 271 |
+
) -> Dict[str, Any]:
|
| 272 |
+
"""Process current camera frame with custom prompt"""
|
| 273 |
+
try:
|
| 274 |
+
success, frame = self.camera.read()
|
| 275 |
+
if not success or frame is None:
|
| 276 |
+
return {"error": "Failed to capture image from camera"}
|
| 277 |
+
|
| 278 |
+
description = await asyncio.to_thread(lambda: self.processor.process_image(frame, prompt))
|
| 279 |
+
|
| 280 |
+
return {
|
| 281 |
+
"description": description,
|
| 282 |
+
"timestamp": time.time(),
|
| 283 |
+
"prompt": prompt,
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.exception("Failed to process current frame")
|
| 288 |
+
return {"error": f"Frame processing failed: {str(e)}"}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
async def get_status(self) -> Dict[str, Any]:
|
| 292 |
+
"""Get comprehensive status information"""
|
| 293 |
+
return {
|
| 294 |
+
"running": self._running,
|
| 295 |
+
"last_processed": self._last_processed_time,
|
| 296 |
+
"processor_info": self.processor.get_model_info(),
|
| 297 |
+
"config": {
|
| 298 |
+
"interval": self.vision_interval,
|
| 299 |
+
"model_path": self.config.model_path,
|
| 300 |
+
"device": self.processor.device,
|
| 301 |
+
},
|
| 302 |
+
}
|