diff --git a/dlclivegui/cameras/backends/aravis_backend.py b/dlclivegui/cameras/backends/aravis_backend.py index 60059c4..d1dd24c 100644 --- a/dlclivegui/cameras/backends/aravis_backend.py +++ b/dlclivegui/cameras/backends/aravis_backend.py @@ -11,7 +11,7 @@ import numpy as np from ...config import CameraSettings -from ..base import CameraBackend, SupportLevel, register_backend +from ..base import CameraBackend, CapturedFrame, SupportLevel, register_backend from ..factory import DetectedCamera LOG = logging.getLogger(__name__) @@ -372,7 +372,7 @@ def open(self) -> None: self._camera.start_acquisition() - def read(self) -> tuple[np.ndarray, float]: + def read(self) -> CapturedFrame: """Read a frame from the camera.""" if self._camera is None or self._stream is None: raise RuntimeError("Aravis camera not initialized") @@ -430,7 +430,7 @@ def read(self) -> tuple[np.ndarray, float]: # Always push buffer back to stream self._stream.push_buffer(buffer) - return frame, timestamp + return CapturedFrame(frame=frame, software_timestamp=timestamp, timestamp_metadata=None) def stop(self) -> None: """Stop camera acquisition.""" diff --git a/dlclivegui/cameras/backends/basler_backend.py b/dlclivegui/cameras/backends/basler_backend.py index b1aff59..2739820 100644 --- a/dlclivegui/cameras/backends/basler_backend.py +++ b/dlclivegui/cameras/backends/basler_backend.py @@ -7,11 +7,10 @@ import time from typing import ClassVar -import numpy as np - from ...config import BASLER_DO_LOG_TIMING, CameraTriggerSettings from ...utils.stats import WorkerTimingStats -from ..base import CameraBackend, SupportLevel, register_backend +from ...utils.timestamps import FrameTimestampMetadata +from ..base import CameraBackend, CapturedFrame, SupportLevel, register_backend LOG = logging.getLogger(__name__) @@ -57,6 +56,8 @@ def __init__(self, settings): # (may skip StartGrabbing and converter setup for faster capability probing; not suitable for normal capture) self._fast_start: bool = bool(self.ns.get("fast_start", False)) self._retrieve_timeout_ms: int = 100 # default; may be overridden by trigger settings + self._timestamp_tick_frequency_hz: float | None = None + self._timestamp_tick_frequency_source: str | None = None # ---- Trigger settings ---- raw_trigger = self.ns.get("trigger", self._props.get("trigger")) @@ -179,6 +180,7 @@ def static_capabilities(cls) -> dict[str, SupportLevel]: "stable_identity": SupportLevel.SUPPORTED, "hardware_trigger": SupportLevel.BEST_EFFORT, "preserve_mono": SupportLevel.SUPPORTED, + "hardware_frame_timestamps": SupportLevel.BEST_EFFORT, } ) return caps @@ -472,6 +474,7 @@ def _configure_frame_rate(self) -> None: "BslResultingAcquisitionFrameRate", "ExposureAuto", "ExposureTime", + "ExposureTimeAbs", "Width", "Height", "PixelFormat", @@ -541,7 +544,10 @@ def open(self) -> None: try: if hasattr(self._camera, "ExposureAuto"): self._camera.ExposureAuto.SetValue("Off") - self._camera.ExposureTime.SetValue(float(self.settings.exposure)) + if hasattr(self._camera, "ExposureTime"): + self._camera.ExposureTime.SetValue(float(self.settings.exposure)) + if hasattr(self._camera, "ExposureTimeAbs"): + self._camera.ExposureTimeAbs.SetValue(float(self.settings.exposure)) LOG.info("[Basler] Exposure set to %s us (auto off)", self.settings.exposure) except Exception as exc: LOG.warning("[Basler] Failed to set exposure: %s", exc) @@ -652,9 +658,28 @@ def open(self) -> None: getattr(self.settings, "gain", None), ) - # ---------------------------- + # Get hardware tick frequency for timestamp conversion + try: + node = getattr(self._camera, "GevTimestampTickFrequency", None) + if node is not None and node.IsReadable(): + self._timestamp_tick_frequency_hz = float(node.GetValue()) + self._timestamp_tick_frequency_source = "GevTimestampTickFrequency" + LOG.info( + "[Basler] timestamp tick frequency: %.3f Hz from GevTimestampTickFrequency", + self._timestamp_tick_frequency_hz, + ) + except Exception: + LOG.debug("[Basler] Could not read GevTimestampTickFrequency", exc_info=True) + + if not self._timestamp_tick_frequency_hz or self._timestamp_tick_frequency_hz <= 0: + self._timestamp_tick_frequency_hz = 1_000_000_000.0 + self._timestamp_tick_frequency_source = "assumed_default_1ghz" + LOG.info( + "[Basler] timestamp tick frequency unavailable; assuming %.3f Hz", + self._timestamp_tick_frequency_hz, + ) + # Persist stable identity into namespace - # ---------------------------- try: serial = device.GetSerialNumber() if serial: @@ -667,7 +692,36 @@ def open(self) -> None: except Exception: pass - def read(self) -> tuple[np.ndarray, float]: + def _make_timestamp_metadata(self, grab_result) -> FrameTimestampMetadata | None: + try: + ticks = int(grab_result.GetTimeStamp()) + except Exception: + return None + + if ticks == 0: + # Basler returns 0 if the timestamp is not available (e.g. for some GigE cameras) + return None + + freq = getattr(self, "_timestamp_tick_frequency_hz", None) + seconds = ticks / freq if freq and freq > 0 else None + + return FrameTimestampMetadata( + source="grab_result.GetTimeStamp", + backend="basler", + default_reported="seconds" if seconds is not None else "raw_value", + seconds=seconds, + wall_clock_time=None, + raw_value=ticks, + raw_unit="ticks", + tick_frequency_hz=freq, + timebase="Basler camera timestamp counter", + kind="camera_clock", + extra={ + "tick_frequency_source": self._timestamp_tick_frequency_source, + }, + ) + + def read(self) -> CapturedFrame: if self._camera is None: raise RuntimeError("Basler camera not opened") if self._converter is None: @@ -696,6 +750,10 @@ def read(self) -> tuple[np.ndarray, float]: with self._timing.measure("Basler.get_array"): frame = image.GetArray() + with self._timing.measure("Basler.timestamp"): + software_timestamp = time.time() + timestamp_metadata = self._make_timestamp_metadata(grab_result) + if not self._logged_first_frame: self._logged_first_frame = True LOG.info( @@ -722,7 +780,11 @@ def read(self) -> tuple[np.ndarray, float]: self._timing.note_frame() self._timing.maybe_log() - return frame, time.time() + return CapturedFrame( + frame=frame, + software_timestamp=software_timestamp, + timestamp_metadata=timestamp_metadata, + ) except Exception as exc: if grab_result is not None: diff --git a/dlclivegui/cameras/backends/gentl_backend.py b/dlclivegui/cameras/backends/gentl_backend.py index a433fb1..e462a1d 100644 --- a/dlclivegui/cameras/backends/gentl_backend.py +++ b/dlclivegui/cameras/backends/gentl_backend.py @@ -13,7 +13,7 @@ import numpy as np from ...config import CameraTriggerSettings -from ..base import CameraBackend, SupportLevel, register_backend +from ..base import CameraBackend, CapturedFrame, SupportLevel, register_backend from ..factory import DetectedCamera from .utils import gentl_discovery as cti_finder @@ -615,7 +615,7 @@ def _output_format_for_frame(frame: np.ndarray) -> str: return f"{channels}ch-{frame.dtype}" return str(frame.dtype) - def read(self) -> tuple[np.ndarray, float]: + def read(self) -> CapturedFrame: if self._acquirer is None: raise RuntimeError("GenTL image acquirer not initialised") @@ -655,7 +655,11 @@ def read(self) -> tuple[np.ndarray, float]: pass self._actual_output_format = self._output_format_for_frame(frame) - return frame, timestamp + return CapturedFrame( + frame=frame, + software_timestamp=timestamp, + timestamp_metadata=None, + ) def stop(self) -> None: if self._acquirer is not None: diff --git a/dlclivegui/cameras/backends/opencv_backend.py b/dlclivegui/cameras/backends/opencv_backend.py index 869dde4..1201749 100644 --- a/dlclivegui/cameras/backends/opencv_backend.py +++ b/dlclivegui/cameras/backends/opencv_backend.py @@ -10,10 +10,9 @@ from typing import TYPE_CHECKING, Literal import cv2 -import numpy as np from pydantic import BaseModel, Field, model_validator -from ..base import CameraBackend, SupportLevel, register_backend +from ..base import CameraBackend, CapturedFrame, SupportLevel, register_backend from ..factory import DetectedCamera from .utils.opencv_discovery import ( ModeRequest, @@ -199,21 +198,45 @@ def open(self) -> None: self._configure_capture() - def read(self) -> tuple[np.ndarray | None, float]: - """Robust frame read: return (None, ts) on transient failures; never raises.""" + def read(self) -> CapturedFrame: + """Robust frame read: return CapturedFrame(frame=None, ...) on transient failures; never raises.""" if self._capture is None: logger.warning("OpenCVCameraBackend.read() called before open()") - return None, time.time() + return CapturedFrame( + frame=None, + software_timestamp=time.time(), + timestamp_metadata=None, + ) + try: if not self._capture.grab(): - return None, time.time() + return CapturedFrame( + frame=None, + software_timestamp=time.time(), + timestamp_metadata=None, + ) + success, frame = self._capture.retrieve() if not success or frame is None or frame.size == 0: - return None, time.time() - return frame, time.time() + return CapturedFrame( + frame=None, + software_timestamp=time.time(), + timestamp_metadata=None, + ) + + return CapturedFrame( + frame=frame, + software_timestamp=time.time(), + timestamp_metadata=None, + ) + except Exception as exc: - logger.debug(f"OpenCV read transient error: {exc}") - return None, time.time() + logger.debug("OpenCV read transient error: %s", exc) + return CapturedFrame( + frame=None, + software_timestamp=time.time(), + timestamp_metadata=None, + ) def close(self) -> None: self._release_capture() diff --git a/dlclivegui/cameras/base.py b/dlclivegui/cameras/base.py index f86f3d1..9217ad8 100644 --- a/dlclivegui/cameras/base.py +++ b/dlclivegui/cameras/base.py @@ -3,6 +3,7 @@ import logging from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any, ClassVar @@ -11,6 +12,7 @@ from ..config import CameraSettings if TYPE_CHECKING: + from ..utils.timestamps import FrameTimestampMetadata from .factory import DetectedCamera _BACKEND_REGISTRY: dict[str, type[CameraBackend]] = {} @@ -72,9 +74,24 @@ class SupportLevel(str, Enum): "device_discovery": SupportLevel.UNSUPPORTED, "stable_identity": SupportLevel.UNSUPPORTED, "hardware_trigger": SupportLevel.UNSUPPORTED, + "hardware_frame_timestamps": SupportLevel.UNSUPPORTED, } +@dataclass(frozen=True) +class CapturedFrame: + """Frame plus software timestamp and optional backend timestamp metadata.""" + + frame: np.ndarray | None + software_timestamp: float + timestamp_metadata: FrameTimestampMetadata | None = None + + def __iter__(self): + """Backwards-compatible unpacking: frame, software_timestamp = backend.read()""" + yield self.frame + yield self.software_timestamp + + class CameraBackend(ABC): """Abstract base class for camera backends.""" @@ -107,6 +124,11 @@ def actual_pixel_format(self) -> str | None: def recommended_preserve_mono(self) -> bool | None: return None + @property + def last_frame_timestamp_metadata(self) -> FrameTimestampMetadata | None: + """Return backend-provided timestamp metadata for the last read frame.""" + return None + @classmethod def options_key(cls) -> str: """Return the key used to store this backend's options in CameraSettings.""" @@ -171,7 +193,7 @@ def open(self) -> None: raise NotImplementedError @abstractmethod - def read(self) -> tuple[np.ndarray, float]: + def read(self) -> CapturedFrame: """Read a frame and return the image with a timestamp.""" raise NotImplementedError diff --git a/dlclivegui/gui/main_window.py b/dlclivegui/gui/main_window.py index 894fc4d..2677b8f 100644 --- a/dlclivegui/gui/main_window.py +++ b/dlclivegui/gui/main_window.py @@ -445,7 +445,7 @@ def _build_dlc_group(self) -> QGroupBox: # Processor selection processor_path_layout = QHBoxLayout() self.processor_folder_edit = QLineEdit() - self.processor_folder_edit.setText(default_processors_dir()) + self.processor_folder_edit.setText(self._settings_store.get_processor_folder(default=default_processors_dir())) processor_path_layout.addWidget(self.processor_folder_edit) self.browse_processor_folder_button = QPushButton("Browse...") @@ -1081,10 +1081,11 @@ def _action_browse_directory(self) -> None: def _action_browse_processor_folder(self) -> None: """Browse for processor folder.""" - current_path = self.processor_folder_edit.text() or default_processors_dir() + current_path = self.processor_folder_edit.text().strip() or default_processors_dir() directory = QFileDialog.getExistingDirectory(self, "Select processor folder", current_path) if directory: self.processor_folder_edit.setText(directory) + self._settings_store.set_processor_folder(directory) self._refresh_processors() def _action_open_recording_folder(self) -> None: @@ -1138,10 +1139,17 @@ def _refresh_processors(self) -> None: self.processor_combo.addItem("No Processor", None) selected_folder = self.processor_folder_edit.text().strip() - if Path(selected_folder).exists(): - self._scanned_processors = scan_processor_folder(selected_folder) + selected_path = Path(selected_folder).expanduser() if selected_folder else None + + if selected_path is not None and selected_path.is_dir(): + resolved_folder = str(selected_path.resolve()) + self._settings_store.set_processor_folder(resolved_folder) + self._scanned_processors = scan_processor_folder(resolved_folder) + source_text = resolved_folder else: self._scanned_processors = scan_processor_package("dlclivegui.processors") + source_text = "package dlclivegui.processors" + self._processor_keys = list(self._scanned_processors.keys()) for key in self._processor_keys: @@ -1150,9 +1158,7 @@ def _refresh_processors(self) -> None: self.processor_combo.addItem(display_name, key) self.processor_combo.update_shrink_width() - self.statusBar().showMessage( - f"Found {len(self._processor_keys)} processor(s) in package dlclivegui.processors", 3000 - ) + self.statusBar().showMessage(f"Found {len(self._processor_keys)} processor(s) in {source_text}", 3000) # ------------------------------------------------------------------ # Recording path preview and session name persistence @@ -1399,7 +1405,9 @@ def _render_overlays_for_recording(self, cam_id, frame): ) return output - def _on_recording_frame_ready(self, camera_id: str, frame: np.ndarray, timestamp: float) -> None: + def _on_recording_frame_ready( + self, camera_id: str, frame: np.ndarray, timestamp: float, timestamp_metadata: object | None = None + ) -> None: """Handle full-rate per-camera frames for recording only. Intentionally lean: @@ -1415,7 +1423,7 @@ def _on_recording_frame_ready(self, camera_id: str, frame: np.ndarray, timestamp if self.record_with_overlays_checkbox.isChecked(): frame = self._render_overlays_for_recording(camera_id, frame) - self._rec_manager.write_frame(camera_id, frame, timestamp) + self._rec_manager.write_frame(camera_id, frame, timestamp, timestamp_metadata=timestamp_metadata) def _on_multi_frame_processing_ready(self, frame_data: MultiFrameData) -> None: """Handle frames from multiple cameras. @@ -1728,24 +1736,28 @@ def _update_inference_buttons(self) -> None: def _update_dlc_controls_enabled(self) -> None: """Enable/disable DLC settings based on inference state.""" allow_changes = not self._dlc_active - processor_controls = allow_changes and self._processor_control_enabled() widgets = [ self.model_path_edit, self.browse_model_button, self.dlc_camera_combo, - # self.additional_options_edit, ] + processor_widgets = [ self.processor_folder_edit, self.browse_processor_folder_button, self.refresh_processors_button, self.processor_combo, ] + for widget in widgets: widget.setEnabled(allow_changes) + for widget in processor_widgets: - widget.setEnabled(processor_controls) + widget.setEnabled(allow_changes) + + if hasattr(self, "allow_processor_ctrl_checkbox"): + self.allow_processor_ctrl_checkbox.setEnabled(allow_changes) def _update_camera_controls_enabled(self) -> None: multi_cam_recording = self._rec_manager.is_active @@ -2151,6 +2163,9 @@ def closeEvent(self, event: QCloseEvent) -> None: # pragma: no cover - GUI beha # Remember model path on exit self._model_path_store.save_if_valid(self.model_path_edit.text().strip()) + # Remember processor folder on exit + if hasattr(self, "processor_folder_edit"): + self._settings_store.set_processor_folder(self.processor_folder_edit.text().strip()) # Close the window super().closeEvent(event) diff --git a/dlclivegui/gui/recording_manager.py b/dlclivegui/gui/recording_manager.py index f3509ac..ddcef47 100644 --- a/dlclivegui/gui/recording_manager.py +++ b/dlclivegui/gui/recording_manager.py @@ -202,14 +202,27 @@ def stop_all(self) -> None: self._session_dir = None self._run_dir = None - def write_frame(self, cam_id: str, frame: np.ndarray, timestamp: float | None = None) -> None: + def write_frame( + self, cam_id: str, frame: np.ndarray, timestamp: float | None = None, timestamp_metadata: object | None = None + ) -> None: rec = self._recorders.get(cam_id) if not rec or not rec.is_running: return try: - rec.write(frame, timestamp=timestamp if timestamp is not None else time.time()) + rec.write( + frame, + timestamp=timestamp if timestamp is not None else time.time(), + timestamp_metadata=timestamp_metadata, + ) except Exception as exc: - log.warning("Failed to write frame for %s: %s", cam_id, exc) + log.warning( + "Failed to write frame for %s: %s: %s frame_shape=%s dtype=%s", + cam_id, + type(exc).__name__, + str(exc) or repr(exc), + getattr(frame, "shape", None), + getattr(frame, "dtype", None), + ) try: rec.stop() except Exception: diff --git a/dlclivegui/processors/PLUGIN_SYSTEM.md b/dlclivegui/processors/PLUGIN_SYSTEM.md index 9e975e0..e6a1436 100644 --- a/dlclivegui/processors/PLUGIN_SYSTEM.md +++ b/dlclivegui/processors/PLUGIN_SYSTEM.md @@ -16,7 +16,8 @@ Processors are Python classes (typically subclasses of `dlclive.Processor`) that ### Useful files -- `dlclivegui/processors/dlc_processor_socket.py` — Example socket-based processor base class + examples +- `dlclivegui/processors/dlc_processor_socket.py` — Example socket-based processor base class +- `dlclivegui/processors/examples.py` — Example processor implementations (e.g., One-Euro filter) - `dlclivegui/processors/processor_utils.py` — Scanning + instantiation helpers used by the GUI --- @@ -204,12 +205,7 @@ The built-in `BaseProcessorSocket` (in `dlc_processor_socket.py`) demonstrates a ```python from dlclive import Processor - -PROCESSOR_REGISTRY = {} - -def register_processor(cls): - PROCESSOR_REGISTRY[getattr(cls, "PROCESSOR_ID", cls.__name__)] = cls - return cls +from dlclivegui.processors import register_processor, PROCESSOR_REGISTRY @register_processor class MyNewProcessor(Processor): diff --git a/dlclivegui/processors/__init__.py b/dlclivegui/processors/__init__.py new file mode 100644 index 0000000..8e77171 --- /dev/null +++ b/dlclivegui/processors/__init__.py @@ -0,0 +1,3 @@ +from .registry import PROCESSOR_REGISTRY, register_processor + +__all__ = ["register_processor", "PROCESSOR_REGISTRY"] diff --git a/dlclivegui/processors/dlc_processor_socket.py b/dlclivegui/processors/dlc_processor_socket.py index 8ded010..594512c 100644 --- a/dlclivegui/processors/dlc_processor_socket.py +++ b/dlclivegui/processors/dlc_processor_socket.py @@ -7,14 +7,17 @@ import sys import time from collections import deque -from math import acos, atan2, copysign, degrees, pi, sqrt from multiprocessing.connection import Client, Listener from pathlib import Path from threading import Event, Thread import numpy as np import pandas as pd -from dlclive import Processor # type: ignore + +try: + from dlclive.processor import Processor # type: ignore +except ImportError: + Processor = object # Fallback for type checking if dlclive is not installed logger = logging.getLogger("dlc_processor_socket") @@ -24,59 +27,6 @@ _handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")) logger.addHandler(_handler) -# Registry for GUI discovery -PROCESSOR_REGISTRY = {} - - -def register_processor(cls): - registry_key = getattr(cls, "PROCESSOR_ID", cls.__name__) - if registry_key in PROCESSOR_REGISTRY: - raise ValueError( - f"Duplicate processor registration key '{registry_key}': " - f"{PROCESSOR_REGISTRY[registry_key].__name__} vs {cls.__name__}" - ) - PROCESSOR_REGISTRY[registry_key] = cls - return cls - - -class OneEuroFilter: # pragma: no cover - def __init__(self, t0, x0, dx0=None, min_cutoff=1.0, beta=0.0, d_cutoff=1.0): - self.min_cutoff = min_cutoff - self.beta = beta - self.d_cutoff = d_cutoff - self.x_prev = x0 - if dx0 is None: - dx0 = np.zeros_like(x0) - self.dx_prev = dx0 - self.t_prev = t0 - - @staticmethod - def smoothing_factor(t_e, cutoff): - r = 2 * pi * cutoff * t_e - return r / (r + 1) - - @staticmethod - def exponential_smoothing(alpha, x, x_prev): - return alpha * x + (1 - alpha) * x_prev - - def __call__(self, t, x): - t_e = t - self.t_prev - if t_e <= 0: - return x - a_d = self.smoothing_factor(t_e, self.d_cutoff) - dx = (x - self.x_prev) / t_e - dx_hat = self.exponential_smoothing(a_d, dx, self.dx_prev) - - cutoff = self.min_cutoff + self.beta * abs(dx_hat) - a = self.smoothing_factor(t_e, cutoff) - x_hat = self.exponential_smoothing(a, x, self.x_prev) - - self.x_prev = x_hat - self.dx_prev = dx_hat - self.t_prev = t - - return x_hat - # pragma: cover class BaseProcessorSocket(Processor): @@ -474,375 +424,3 @@ def get_data(self): if self.dlc_cfg is not None: save_dict["dlc_cfg"] = self.dlc_cfg return save_dict - - -@register_processor -class ExampleProcessorSocketCalculateMousePose(BaseProcessorSocket): # pragma: no cover - """ - DLC Processor with pose calculations (center, heading, head angle) and optional filtering. - - Calculates: - - center: Weighted average of head keypoints - - heading: Body orientation (degrees) - - head_angle: Head rotation relative to body (radians) - - Broadcasts: [timestamp, center_x, center_y, heading, head_angle] - """ - - PROCESSOR_NAME = "Example Experiment Pose Processor" - PROCESSOR_DESCRIPTION = "Calculates mouse center, heading, and head angle with optional One-Euro filtering" - PROCESSOR_PARAMS = { - "bind": { - "type": "tuple", - "default": ("127.0.0.1", 6000), - "description": "Server address (host, port)", - }, - "authkey": { - "type": "bytes", - "default": b"secret password", - "description": "Authentication key for clients", - }, - "use_perf_counter": { - "type": "bool", - "default": False, - "description": "Use time.perf_counter() instead of time.time()", - }, - "use_filter": { - "type": "bool", - "default": False, - "description": "Apply One-Euro filter to calculated values", - }, - "filter_kwargs": { - "type": "dict", - "default": {"min_cutoff": 1.0, "beta": 0.02, "d_cutoff": 1.0}, - "description": "One-Euro filter parameters (min_cutoff, beta, d_cutoff)", - }, - "save_original": { - "type": "bool", - "default": False, - "description": "Save raw pose arrays for analysis", - }, - } - - def __init__( - self, - bind=("127.0.0.1", 6000), - authkey=b"secret password", - use_perf_counter=False, - use_filter=False, - filter_kwargs: dict | None = None, - save_original=False, - ): - super().__init__( - bind=bind, - authkey=authkey, - use_perf_counter=use_perf_counter, - save_original=save_original, - ) - - self.center_x = deque() - self.center_y = deque() - self.heading_direction = deque() - self.head_angle = deque() - - self.use_filter = use_filter - self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {} - self.filters = None - - def _clear_data_queues(self): - super()._clear_data_queues() - self.center_x.clear() - self.center_y.clear() - self.heading_direction.clear() - self.head_angle.clear() - - def _initialize_filters(self, vals): - t0 = self.timing_func() - self.filters = { - "center_x": OneEuroFilter(t0, vals[0], **self.filter_kwargs), - "center_y": OneEuroFilter(t0, vals[1], **self.filter_kwargs), - "heading": OneEuroFilter(t0, vals[2], **self.filter_kwargs), - "head_angle": OneEuroFilter(t0, vals[3], **self.filter_kwargs), - } - logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") - - def process(self, pose, **kwargs): - # Extract keypoints and confidence - xy = pose[:, :2] - conf = pose[:, 2] - - # Calculate weighted center from head keypoints - head_xy = xy[[0, 1, 2, 3, 4, 5, 6, 26], :] - head_conf = conf[[0, 1, 2, 3, 4, 5, 6, 26]] - center = np.average(head_xy, axis=0, weights=head_conf) - - # Calculate body axis (tail_base -> neck) - body_axis = xy[7] - xy[13] - body_axis /= sqrt(np.sum(body_axis**2)) - - # Calculate head axis (neck -> nose) - head_axis = xy[0] - xy[7] - head_axis /= sqrt(np.sum(head_axis**2)) - - # Calculate head angle relative to body - cross = body_axis[0] * head_axis[1] - head_axis[0] * body_axis[1] - sign = copysign(1, cross) # Positive when looking left - sign = copysign(1, cross) - try: - head_angle = acos(body_axis @ head_axis) * sign - except ValueError: - head_angle = 0 - - # Calculate heading (body orientation) - heading = degrees(atan2(body_axis[1], body_axis[0])) - - # Raw values (heading unwrapped for filtering) - vals = [center[0], center[1], heading, head_angle] - - # Apply filtering if enabled - curr_time = self.timing_func() - if self.use_filter: - if self.filters is None: - self._initialize_filters(vals) - - vals = [ - self.filters["center_x"](curr_time, vals[0]), - self.filters["center_y"](curr_time, vals[1]), - self.filters["heading"](curr_time, vals[2]), - self.filters["head_angle"](curr_time, vals[3]), - ] - - # Wrap heading to [0, 360) after filtering - vals[2] = vals[2] % 360 - # Update step counter - self.curr_step = self.curr_step + 1 - - # Store processed data (only if recording) - if self.recording: - if self.save_original and self.original_pose is not None: - self.original_pose.append(pose.copy()) - self.center_x.append(vals[0]) - self.center_y.append(vals[1]) - self.heading_direction.append(vals[2]) - self.head_angle.append(vals[3]) - self.time_stamp.append(curr_time) - self.step.append(self.curr_step) - self.frame_time.append(kwargs.get("frame_time", -1)) - if "pose_time" in kwargs: - self.pose_time.append(kwargs["pose_time"]) - - payload = [curr_time, vals[0], vals[1], vals[2], vals[3]] - self.broadcast(payload) - return pose - - def get_data(self): - save_dict = super().get_data() - save_dict["x_pos"] = np.array(self.center_x) - save_dict["y_pos"] = np.array(self.center_y) - save_dict["heading_direction"] = np.array(self.heading_direction) - save_dict["head_angle"] = np.array(self.head_angle) - save_dict["use_filter"] = self.use_filter - save_dict["filter_kwargs"] = self.filter_kwargs - return save_dict - - -@register_processor -class ExampleProcessorSocketFilterKeypoints(BaseProcessorSocket): # pragma: no cover - PROCESSOR_NAME = "Mouse Pose with less keypoints" - PROCESSOR_DESCRIPTION = "Calculates mouse center, heading, and head angle with optional One-Euro filtering" - PROCESSOR_PARAMS = { - "bind": { - "type": "tuple", - "default": ("127.0.0.1", 6000), - "description": "Server address (host, port)", - }, - "authkey": { - "type": "bytes", - "default": b"secret password", - "description": "Authentication key for clients", - }, - "use_perf_counter": { - "type": "bool", - "default": False, - "description": "Use time.perf_counter() instead of time.time()", - }, - "use_filter": { - "type": "bool", - "default": False, - "description": "Apply One-Euro filter to calculated values", - }, - "filter_kwargs": { - "type": "dict", - "default": {"min_cutoff": 1.0, "beta": 0.02, "d_cutoff": 1.0}, - "description": "One-Euro filter parameters (min_cutoff, beta, d_cutoff)", - }, - "save_original": { - "type": "bool", - "default": True, - "description": "Save raw pose arrays for analysis", - }, - } - - def __init__( - self, - bind=("127.0.0.1", 6000), - authkey=b"secret password", - use_perf_counter=False, - use_filter=False, - filter_kwargs: dict | None = None, - save_original=True, - p_cutoff=0.4, - ): - super().__init__( - bind=bind, - authkey=authkey, - use_perf_counter=use_perf_counter, - save_original=save_original, - ) - - self.center_x = deque() - self.center_y = deque() - self.heading_direction = deque() - self.head_angle = deque() - - self.p_cutoff = p_cutoff - - self.use_filter = use_filter - self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {} - self.filters = None - - def _clear_data_queues(self): - super()._clear_data_queues() - self.center_x.clear() - self.center_y.clear() - self.heading_direction.clear() - self.head_angle.clear() - - def _initialize_filters(self, vals): - t0 = self.timing_func() - self.filters = { - "center_x": OneEuroFilter(t0, vals[0], **self.filter_kwargs), - "center_y": OneEuroFilter(t0, vals[1], **self.filter_kwargs), - "heading": OneEuroFilter(t0, vals[2], **self.filter_kwargs), - "head_angle": OneEuroFilter(t0, vals[3], **self.filter_kwargs), - } - logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") - - def process(self, pose, **kwargs): - # Extract keypoints and confidence - xy = pose[:, :2] - conf = pose[:, 2] - - # Calculate weighted center from head keypoints - head_xy = xy[[0, 1, 2, 3, 5, 6, 7], :] - head_conf = conf[[0, 1, 2, 3, 5, 6, 7]] - # set low confidence keypoints to zero weight - head_conf = np.where(head_conf < self.p_cutoff, 0, head_conf) - try: - center = np.average(head_xy, axis=0, weights=head_conf) - except ZeroDivisionError: - # If all keypoints have zero weight, return without processing - return pose - - neck = np.average(xy[[2, 3, 6, 7], :], axis=0, weights=conf[[2, 3, 6, 7]]) - - # Calculate body axis (tail_base -> neck) - body_axis = neck - xy[9] - body_axis /= sqrt(np.sum(body_axis**2)) - - # Calculate head axis (neck -> nose) - head_axis = xy[0] - neck - head_axis /= sqrt(np.sum(head_axis**2)) - - # Calculate head angle relative to body - cross = body_axis[0] * head_axis[1] - head_axis[0] * body_axis[1] - sign = copysign(1, cross) # Positive when looking left - sign = copysign(1, cross) - try: - head_angle = acos(body_axis @ head_axis) * sign - except ValueError: - head_angle = 0 - - # Calculate heading (body orientation) - heading = degrees(atan2(body_axis[1], body_axis[0])) - vals = [center[0], center[1], heading, head_angle] - - curr_time = self.timing_func() - if self.use_filter: - if self.filters is None: - self._initialize_filters(vals) - - vals = [ - self.filters["center_x"](curr_time, vals[0]), - self.filters["center_y"](curr_time, vals[1]), - self.filters["heading"](curr_time, vals[2]), - self.filters["head_angle"](curr_time, vals[3]), - ] - - # Wrap heading to [0, 360) after filtering - vals[2] = vals[2] % 360 - # Update step counter - self.curr_step = self.curr_step + 1 - - # Store processed data (only if recording) - if self.recording: - if self.save_original and self.original_pose is not None: - self.original_pose.append(pose.copy()) - self.center_x.append(vals[0]) - self.center_y.append(vals[1]) - self.heading_direction.append(vals[2]) - self.head_angle.append(vals[3]) - self.time_stamp.append(curr_time) - self.step.append(self.curr_step) - self.frame_time.append(kwargs.get("frame_time", -1)) - if "pose_time" in kwargs: - self.pose_time.append(kwargs["pose_time"]) - - payload = [curr_time, vals[0], vals[1], vals[2], vals[3]] - self.broadcast(payload) - return pose - - def get_data(self): - save_dict = super().get_data() - save_dict["x_pos"] = np.array(self.center_x) - save_dict["y_pos"] = np.array(self.center_y) - save_dict["heading_direction"] = np.array(self.heading_direction) - save_dict["head_angle"] = np.array(self.head_angle) - save_dict["use_filter"] = self.use_filter - save_dict["filter_kwargs"] = self.filter_kwargs - return save_dict - - -def get_available_processors(): - """ - Get list of available processor classes. - - Returns: - dict: Dictionary mapping registry keys to processor info. - """ - return { - name: { - "class": cls, - "name": getattr(cls, "PROCESSOR_NAME", name), - "description": getattr(cls, "PROCESSOR_DESCRIPTION", ""), - "params": getattr(cls, "PROCESSOR_PARAMS", {}), - } - for name, cls in PROCESSOR_REGISTRY.items() - } - - -def instantiate_processor(class_name, **kwargs): - """ - Instantiate a processor by class name with given parameters. - - Args: - class_name: Registry key (e.g., "MyProcessorSocket") - **kwargs: Constructor kwargs - - Raises: - ValueError: If class_name is not in registry - """ - if class_name not in PROCESSOR_REGISTRY: - available = ", ".join(PROCESSOR_REGISTRY.keys()) - raise ValueError(f"Unknown processor '{class_name}'. Available: {available}") - return PROCESSOR_REGISTRY[class_name](**kwargs) diff --git a/dlclivegui/processors/examples.py b/dlclivegui/processors/examples.py new file mode 100644 index 0000000..7ed7691 --- /dev/null +++ b/dlclivegui/processors/examples.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import logging +from collections import deque +from math import acos, atan2, copysign, degrees, pi, sqrt + +import numpy as np + +from dlclivegui.processors import register_processor +from dlclivegui.processors.dlc_processor_socket import BaseProcessorSocket + +logger = logging.getLogger(__name__) + + +class OneEuroFilter: # pragma: no cover + def __init__(self, t0, x0, dx0=None, min_cutoff=1.0, beta=0.0, d_cutoff=1.0): + self.min_cutoff = min_cutoff + self.beta = beta + self.d_cutoff = d_cutoff + self.x_prev = x0 + if dx0 is None: + dx0 = np.zeros_like(x0) + self.dx_prev = dx0 + self.t_prev = t0 + + @staticmethod + def smoothing_factor(t_e, cutoff): + r = 2 * pi * cutoff * t_e + return r / (r + 1) + + @staticmethod + def exponential_smoothing(alpha, x, x_prev): + return alpha * x + (1 - alpha) * x_prev + + def __call__(self, t, x): + t_e = t - self.t_prev + if t_e <= 0: + return x + a_d = self.smoothing_factor(t_e, self.d_cutoff) + dx = (x - self.x_prev) / t_e + dx_hat = self.exponential_smoothing(a_d, dx, self.dx_prev) + + cutoff = self.min_cutoff + self.beta * abs(dx_hat) + a = self.smoothing_factor(t_e, cutoff) + x_hat = self.exponential_smoothing(a, x, self.x_prev) + + self.x_prev = x_hat + self.dx_prev = dx_hat + self.t_prev = t + + return x_hat + + +@register_processor +class ExampleProcessorSocketCalculateMousePose(BaseProcessorSocket): # pragma: no cover + """ + DLC Processor with pose calculations (center, heading, head angle) and optional filtering. + + Calculates: + - center: Weighted average of head keypoints + - heading: Body orientation (degrees) + - head_angle: Head rotation relative to body (radians) + + Broadcasts: [timestamp, center_x, center_y, heading, head_angle] + """ + + PROCESSOR_NAME = "Example Experiment Pose Processor" + PROCESSOR_DESCRIPTION = "Calculates mouse center, heading, and head angle with optional One-Euro filtering" + PROCESSOR_PARAMS = { + "bind": { + "type": "tuple", + "default": ("127.0.0.1", 6000), + "description": "Server address (host, port)", + }, + "authkey": { + "type": "bytes", + "default": b"secret password", + "description": "Authentication key for clients", + }, + "use_perf_counter": { + "type": "bool", + "default": False, + "description": "Use time.perf_counter() instead of time.time()", + }, + "use_filter": { + "type": "bool", + "default": False, + "description": "Apply One-Euro filter to calculated values", + }, + "filter_kwargs": { + "type": "dict", + "default": {"min_cutoff": 1.0, "beta": 0.02, "d_cutoff": 1.0}, + "description": "One-Euro filter parameters (min_cutoff, beta, d_cutoff)", + }, + "save_original": { + "type": "bool", + "default": False, + "description": "Save raw pose arrays for analysis", + }, + } + + def __init__( + self, + bind=("127.0.0.1", 6000), + authkey=b"secret password", + use_perf_counter=False, + use_filter=False, + filter_kwargs: dict | None = None, + save_original=False, + ): + super().__init__( + bind=bind, + authkey=authkey, + use_perf_counter=use_perf_counter, + save_original=save_original, + ) + + self.center_x = deque() + self.center_y = deque() + self.heading_direction = deque() + self.head_angle = deque() + + self.use_filter = use_filter + self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {} + self.filters = None + + def _clear_data_queues(self): + super()._clear_data_queues() + self.center_x.clear() + self.center_y.clear() + self.heading_direction.clear() + self.head_angle.clear() + + def _initialize_filters(self, vals): + t0 = self.timing_func() + self.filters = { + "center_x": OneEuroFilter(t0, vals[0], **self.filter_kwargs), + "center_y": OneEuroFilter(t0, vals[1], **self.filter_kwargs), + "heading": OneEuroFilter(t0, vals[2], **self.filter_kwargs), + "head_angle": OneEuroFilter(t0, vals[3], **self.filter_kwargs), + } + logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") + + def process(self, pose, **kwargs): + # Extract keypoints and confidence + xy = pose[:, :2] + conf = pose[:, 2] + + # Calculate weighted center from head keypoints + head_xy = xy[[0, 1, 2, 3, 4, 5, 6, 26], :] + head_conf = conf[[0, 1, 2, 3, 4, 5, 6, 26]] + try: + center = np.average(head_xy, axis=0, weights=head_conf) + except ZeroDivisionError: + center = np.zeros(2) + + # Calculate body axis (tail_base -> neck) + body_axis = xy[7] - xy[13] + body_axis /= sqrt(np.sum(body_axis**2)) + + # Calculate head axis (neck -> nose) + head_axis = xy[0] - xy[7] + head_axis /= sqrt(np.sum(head_axis**2)) + + # Calculate head angle relative to body + cross = body_axis[0] * head_axis[1] - head_axis[0] * body_axis[1] + sign = copysign(1, cross) # Positive when looking left + + try: + head_angle = acos(body_axis @ head_axis) * sign + except ValueError: + head_angle = 0 + + # Calculate heading (body orientation) + heading = degrees(atan2(body_axis[1], body_axis[0])) + + # Raw values (heading unwrapped for filtering) + vals = [center[0], center[1], heading, head_angle] + + # Apply filtering if enabled + curr_time = self.timing_func() + if self.use_filter: + if self.filters is None: + self._initialize_filters(vals) + + vals = [ + self.filters["center_x"](curr_time, vals[0]), + self.filters["center_y"](curr_time, vals[1]), + self.filters["heading"](curr_time, vals[2]), + self.filters["head_angle"](curr_time, vals[3]), + ] + + # Wrap heading to [0, 360) after filtering + vals[2] = vals[2] % 360 + # Update step counter + self.curr_step = self.curr_step + 1 + + # Store processed data (only if recording) + if self.recording: + if self.save_original and self.original_pose is not None: + self.original_pose.append(pose.copy()) + self.center_x.append(vals[0]) + self.center_y.append(vals[1]) + self.heading_direction.append(vals[2]) + self.head_angle.append(vals[3]) + self.time_stamp.append(curr_time) + self.step.append(self.curr_step) + self.frame_time.append(kwargs.get("frame_time", -1)) + if "pose_time" in kwargs: + self.pose_time.append(kwargs["pose_time"]) + + payload = [curr_time, vals[0], vals[1], vals[2], vals[3]] + self.broadcast(payload) + return pose + + def get_data(self): + save_dict = super().get_data() + save_dict["x_pos"] = np.array(self.center_x) + save_dict["y_pos"] = np.array(self.center_y) + save_dict["heading_direction"] = np.array(self.heading_direction) + save_dict["head_angle"] = np.array(self.head_angle) + save_dict["use_filter"] = self.use_filter + save_dict["filter_kwargs"] = self.filter_kwargs + return save_dict + + +@register_processor +class ExampleProcessorSocketFilterKeypoints(BaseProcessorSocket): # pragma: no cover + PROCESSOR_NAME = "Mouse Pose with less keypoints" + PROCESSOR_DESCRIPTION = "Calculates mouse center, heading, and head angle with optional One-Euro filtering" + PROCESSOR_PARAMS = { + "bind": { + "type": "tuple", + "default": ("127.0.0.1", 6000), + "description": "Server address (host, port)", + }, + "authkey": { + "type": "bytes", + "default": b"secret password", + "description": "Authentication key for clients", + }, + "use_perf_counter": { + "type": "bool", + "default": False, + "description": "Use time.perf_counter() instead of time.time()", + }, + "use_filter": { + "type": "bool", + "default": False, + "description": "Apply One-Euro filter to calculated values", + }, + "filter_kwargs": { + "type": "dict", + "default": {"min_cutoff": 1.0, "beta": 0.02, "d_cutoff": 1.0}, + "description": "One-Euro filter parameters (min_cutoff, beta, d_cutoff)", + }, + "save_original": { + "type": "bool", + "default": True, + "description": "Save raw pose arrays for analysis", + }, + } + + def __init__( + self, + bind=("127.0.0.1", 6000), + authkey=b"secret password", + use_perf_counter=False, + use_filter=False, + filter_kwargs: dict | None = None, + save_original=True, + p_cutoff=0.4, + ): + super().__init__( + bind=bind, + authkey=authkey, + use_perf_counter=use_perf_counter, + save_original=save_original, + ) + + self.center_x = deque() + self.center_y = deque() + self.heading_direction = deque() + self.head_angle = deque() + + self.p_cutoff = p_cutoff + + self.use_filter = use_filter + self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {} + self.filters = None + + def _clear_data_queues(self): + super()._clear_data_queues() + self.center_x.clear() + self.center_y.clear() + self.heading_direction.clear() + self.head_angle.clear() + + def _initialize_filters(self, vals): + t0 = self.timing_func() + self.filters = { + "center_x": OneEuroFilter(t0, vals[0], **self.filter_kwargs), + "center_y": OneEuroFilter(t0, vals[1], **self.filter_kwargs), + "heading": OneEuroFilter(t0, vals[2], **self.filter_kwargs), + "head_angle": OneEuroFilter(t0, vals[3], **self.filter_kwargs), + } + logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") + + def process(self, pose, **kwargs): + # Extract keypoints and confidence + xy = pose[:, :2] + conf = pose[:, 2] + + # Calculate weighted center from head keypoints + head_xy = xy[[0, 1, 2, 3, 5, 6, 7], :] + head_conf = conf[[0, 1, 2, 3, 5, 6, 7]] + # set low confidence keypoints to zero weight + head_conf = np.where(head_conf < self.p_cutoff, 0, head_conf) + try: + center = np.average(head_xy, axis=0, weights=head_conf) + except ZeroDivisionError: + # If all keypoints have zero weight, return without processing + return pose + + neck = np.average(xy[[2, 3, 6, 7], :], axis=0, weights=conf[[2, 3, 6, 7]]) + + # Calculate body axis (tail_base -> neck) + body_axis = neck - xy[9] + body_axis /= sqrt(np.sum(body_axis**2)) + + # Calculate head axis (neck -> nose) + head_axis = xy[0] - neck + head_axis /= sqrt(np.sum(head_axis**2)) + + # Calculate head angle relative to body + cross = body_axis[0] * head_axis[1] - head_axis[0] * body_axis[1] + sign = copysign(1, cross) # Positive when looking left + + try: + head_angle = acos(body_axis @ head_axis) * sign + except ValueError: + head_angle = 0 + + # Calculate heading (body orientation) + heading = degrees(atan2(body_axis[1], body_axis[0])) + vals = [center[0], center[1], heading, head_angle] + + curr_time = self.timing_func() + if self.use_filter: + if self.filters is None: + self._initialize_filters(vals) + + vals = [ + self.filters["center_x"](curr_time, vals[0]), + self.filters["center_y"](curr_time, vals[1]), + self.filters["heading"](curr_time, vals[2]), + self.filters["head_angle"](curr_time, vals[3]), + ] + + # Wrap heading to [0, 360) after filtering + vals[2] = vals[2] % 360 + # Update step counter + self.curr_step = self.curr_step + 1 + + # Store processed data (only if recording) + if self.recording: + if self.save_original and self.original_pose is not None: + self.original_pose.append(pose.copy()) + self.center_x.append(vals[0]) + self.center_y.append(vals[1]) + self.heading_direction.append(vals[2]) + self.head_angle.append(vals[3]) + self.time_stamp.append(curr_time) + self.step.append(self.curr_step) + self.frame_time.append(kwargs.get("frame_time", -1)) + if "pose_time" in kwargs: + self.pose_time.append(kwargs["pose_time"]) + + payload = [curr_time, vals[0], vals[1], vals[2], vals[3]] + self.broadcast(payload) + return pose + + def get_data(self): + save_dict = super().get_data() + save_dict["x_pos"] = np.array(self.center_x) + save_dict["y_pos"] = np.array(self.center_y) + save_dict["heading_direction"] = np.array(self.heading_direction) + save_dict["head_angle"] = np.array(self.head_angle) + save_dict["use_filter"] = self.use_filter + save_dict["filter_kwargs"] = self.filter_kwargs + return save_dict diff --git a/dlclivegui/processors/processor_utils.py b/dlclivegui/processors/processor_utils.py index b32445c..467792b 100644 --- a/dlclivegui/processors/processor_utils.py +++ b/dlclivegui/processors/processor_utils.py @@ -17,6 +17,65 @@ def default_processors_dir() -> str: return str(path) +def _processor_base_class(): + from dlclive.processor import Processor + + return Processor + + +def _is_processor_subclass(obj, *, include_base: bool = False) -> bool: + """Return True for dlclive.Processor subclasses, including indirect subclasses.""" + if not inspect.isclass(obj): + return False + + try: + processor_base = _processor_base_class() + except Exception: + logger.exception("Could not import dlclive.Processor") + return False + + try: + if obj is processor_base: + return bool(include_base) + return issubclass(obj, processor_base) + except Exception: + logger.exception(f"Error checking if {obj} is a subclass of dlclive.Processor") + return False + + +def _processor_info_from_class(cls, fallback_name: str) -> dict: + return { + "class": cls, + "name": getattr(cls, "PROCESSOR_NAME", fallback_name), + "description": getattr(cls, "PROCESSOR_DESCRIPTION", ""), + "params": getattr(cls, "PROCESSOR_PARAMS", {}), + } + + +def discover_processor_classes(module, *, only_defined_in_module: bool = True) -> dict[str, dict]: + """Discover dlclive.Processor subclasses in a module. + + Includes indirect subclasses of Processor. + + Args: + module: Imported Python module. + only_defined_in_module: If True, ignore Processor subclasses imported + from other modules to avoid duplicate registry entries. + """ + processors: dict[str, dict] = {} + + for name, obj in inspect.getmembers(module, inspect.isclass): + if only_defined_in_module and getattr(obj, "__module__", None) != module.__name__: + continue + + if not _is_processor_subclass(obj): + continue + + processors[name] = _processor_info_from_class(obj, name) + + return processors + + def scan_processor_folder(folder_path): all_processors = {} folder = Path(folder_path) @@ -39,11 +98,9 @@ def scan_processor_folder(folder_path): return all_processors -def scan_processor_package(package_name: str = "dlclivegui.processors") -> dict[str | dict]: +def scan_processor_package(package_name: str = "dlclivegui.processors") -> dict[str, dict]: """ Discover and load processor classes from a package namespace. - Returns a dict keyed as 'module.py::ClassName' with the same - structure you use today. """ all_processors: dict[str, dict] = {} @@ -59,28 +116,16 @@ def scan_processor_package(package_name: str = "dlclivegui.processors") -> dict[ continue try: mod = import_module(mod_name) + # Skip dlc_processor_socket.py as it's the base class and registry + if mod.__name__.endswith("dlc_processor_socket"): + continue # Prefer module-level registry function if present if hasattr(mod, "get_available_processors"): processors = mod.get_available_processors() else: # Fallback: scan for dlclive.Processor subclasses - from dlclive import Processor - - processors = {} - for attr_name in dir(mod): - obj = getattr(mod, attr_name) - try: - if isinstance(obj, type) and obj is not Processor and issubclass(obj, Processor): - processors[attr_name] = { - "class": obj, - "name": getattr(obj, "PROCESSOR_NAME", attr_name), - "description": getattr(obj, "PROCESSOR_DESCRIPTION", ""), - "params": getattr(obj, "PROCESSOR_PARAMS", {}), - } - except Exception: - # Non-class or weird metaclass; ignore - pass + processors = discover_processor_classes(mod) # Normalize into your “file::class” shape module_file = mod.__name__.split(".")[-1] + ".py" @@ -131,26 +176,7 @@ def load_processors_from_file(file_path: str | Path): return processors # Fallback path: discover subclasses of dlclive.Processor - from dlclive import Processor - - processors: dict[str, dict] = {} - for name, obj in inspect.getmembers(module, inspect.isclass): - if obj is Processor: - continue - # Guard: module might define other classes; only include Processor subclasses - try: - if issubclass(obj, Processor): - processors[name] = { - "class": obj, - "name": getattr(obj, "PROCESSOR_NAME", name), - "description": getattr(obj, "PROCESSOR_DESCRIPTION", ""), - "params": getattr(obj, "PROCESSOR_PARAMS", {}), - } - except Exception: - # Some "classes" can fail issubclass checks; ignore safely - continue - - return processors + return discover_processor_classes(module) except Exception: # Full traceback helps a ton when a plugin fails to import diff --git a/dlclivegui/processors/registry.py b/dlclivegui/processors/registry.py new file mode 100644 index 0000000..2889297 --- /dev/null +++ b/dlclivegui/processors/registry.py @@ -0,0 +1,53 @@ +import logging + +logger = logging.getLogger(__name__) + +# Registry for GUI discovery +PROCESSOR_REGISTRY = {} + + +def register_processor(cls): + registry_key = getattr(cls, "PROCESSOR_ID", cls.__name__) + if registry_key in PROCESSOR_REGISTRY: + msg = ( + f"Duplicate processor registration key '{registry_key}': " + f"{PROCESSOR_REGISTRY[registry_key].__name__} vs {cls.__name__}" + ) + logger.warning(msg) + PROCESSOR_REGISTRY[registry_key] = cls + return cls + + +def get_available_processors(): + """ + Get list of available processor classes. + + Returns: + dict: Dictionary mapping registry keys to processor info. + """ + return { + name: { + "class": cls, + "name": getattr(cls, "PROCESSOR_NAME", name), + "description": getattr(cls, "PROCESSOR_DESCRIPTION", ""), + "params": getattr(cls, "PROCESSOR_PARAMS", {}), + } + for name, cls in PROCESSOR_REGISTRY.items() + } + + +def instantiate_processor(class_name, **kwargs): + """ + Instantiate a processor by class name with given parameters. + + Args: + class_name: Registry key (e.g., "MyProcessorSocket") + **kwargs: Constructor kwargs + + Raises: + ValueError: If class_name is not in registry + """ + if class_name not in PROCESSOR_REGISTRY: + available = ", ".join(PROCESSOR_REGISTRY.keys()) + raise ValueError(f"Unknown processor '{class_name}'. Available: {available}") + return PROCESSOR_REGISTRY[class_name](**kwargs) diff --git a/dlclivegui/services/multi_camera_controller.py b/dlclivegui/services/multi_camera_controller.py index fe5b669..97c53dc 100644 --- a/dlclivegui/services/multi_camera_controller.py +++ b/dlclivegui/services/multi_camera_controller.py @@ -47,7 +47,7 @@ class MultiFrameData: class SingleCameraWorker(QObject): """Worker for a single camera in multi-camera mode.""" - frame_captured = Signal(str, object, float) # camera_id, frame, timestamp + frame_captured = Signal(str, object, float, object) # camera_id, frame, timestamp, timestamp_metadata error_occurred = Signal(str, str) # camera_id, error_message runtime_info = Signal(str, object) # camera_id, dict of runtime info started = Signal(str) # camera_id @@ -117,7 +117,10 @@ def run(self) -> None: while not self._stop_event.is_set(): try: with self._timing.measure("Single.read"): - frame, timestamp = self._backend.read() + captured = self._backend.read() + frame = captured.frame + timestamp = captured.software_timestamp + timestamp_metadata = captured.timestamp_metadata if frame is None or frame.size == 0: consecutive_errors += 1 if consecutive_errors >= self._max_consecutive_errors: @@ -131,7 +134,7 @@ def run(self) -> None: consecutive_errors = 0 with self._timing.measure("Single.emit.frame_captured"): - self.frame_captured.emit(self._camera_id, frame, timestamp) + self.frame_captured.emit(self._camera_id, frame, timestamp, timestamp_metadata) self._timing.note_frame() self._timing.maybe_log() @@ -298,7 +301,9 @@ class MultiCameraController(QObject): # Signals frame_ready = Signal(object) # MultiFrameData (full cam FPS; inference only) - recording_frame_ready = Signal(str, object, float) # camera_id, frame, timestamp (full cam FPS; for recording) + recording_frame_ready = Signal( + str, object, float, object + ) # camera_id, frame, timestamp, timestamp_metadata (full cam FPS; for recording) display_ready = Signal(object) # MultiFrameData for GUI display (throttled to GUI_MAX_DISPLAY_FPS) camera_started = Signal(str, object) # camera_id, settings camera_stopped = Signal(str) # camera_id @@ -568,7 +573,9 @@ def stop(self, wait: bool = True) -> None: self.all_stopped.emit() - def _on_frame_captured(self, camera_id: str, frame: np.ndarray, timestamp: float) -> None: + def _on_frame_captured( + self, camera_id: str, frame: np.ndarray, timestamp: float, timestamp_metadata: object | None = None + ) -> None: """Handle a frame from one camera.""" timing = self._timing_for_camera(camera_id) frame_data: MultiFrameData | None = None @@ -587,7 +594,7 @@ def _on_frame_captured(self, camera_id: str, frame: np.ndarray, timestamp: float if self._recording_frame_emission_enabled: with timing.measure("Multi.emit.recording_frame_ready"): - self.recording_frame_ready.emit(camera_id, frame, timestamp) + self.recording_frame_ready.emit(camera_id, frame, timestamp, timestamp_metadata) with self._frame_lock: with timing.measure("Multi.store_latest"): diff --git a/dlclivegui/services/video_recorder.py b/dlclivegui/services/video_recorder.py index 9cc75ef..44369a5 100644 --- a/dlclivegui/services/video_recorder.py +++ b/dlclivegui/services/video_recorder.py @@ -94,6 +94,7 @@ def __init__( self._writer: Any | None = None self._frame_size = frame_size self._frame_rate = frame_rate + self._hardware_timestamp_source: dict[str, Any] | None = None self._codec = codec self._crf = int(crf) self._buffer_size = max(1, int(buffer_size)) @@ -115,7 +116,7 @@ def __init__( self._written_times: deque[float] = deque(maxlen=600) self._encode_error: Exception | None = None self._last_log_time = 0.0 - self._frame_timestamps: list[float] = [] + self._frame_timestamps: list[dict[str, Any]] = [] # Timing self._process_timing = WorkerTimingStats( f"RecorderProcess[{self._output.name}]", logger=logger, log_interval=1.0, enabled=REC_DO_LOG_TIMING @@ -211,6 +212,7 @@ def start(self) -> None: self._last_latency = 0.0 self._written_times.clear() self._frame_timestamps.clear() + self._hardware_timestamp_source = None self._encode_error = None self._stop_event.clear() self._writer_thread = threading.Thread( @@ -224,7 +226,9 @@ def configure_stream(self, frame_size: tuple[int, int], frame_rate: float | None self._frame_size = frame_size self._frame_rate = frame_rate - def write(self, frame: np.ndarray, timestamp: float | None = None) -> bool: + def write( + self, frame: np.ndarray, timestamp: float | None = None, timestamp_metadata: object | None = None + ) -> bool: error = self._current_error() if error is not None: raise RuntimeError(f"Video encoding failed: {error}") from error @@ -274,22 +278,23 @@ def write(self, frame: np.ndarray, timestamp: float | None = None) -> bool: expected_h, expected_w = self._frame_size actual_h, actual_w = frame.shape[:2] if (actual_h, actual_w) != (expected_h, expected_w): - logger.warning( - f"Frame size mismatch: expected (h={expected_h}, w={expected_w}), " - f"got (h={actual_h}, w={actual_w}). " - "Stopping recorder to prevent encoding errors." + message = ( + f"Frame size mismatch for recorder {self._output.name}: " + f"expected_hw=({expected_h}, {expected_w}) " + f"actual_hw=({actual_h}, {actual_w}) " + f"{self._describe_frame(frame)}. " + "Stopping recorder to prevent FFmpeg pipe errors." ) - with self._stats_lock: - self._encode_error = ValueError( - f"Frame size changed from (h={expected_h}, w={expected_w}) to (h={actual_h}, w={actual_w})" - ) + + logger.warning(message) + self._set_encode_error(message) self._process_timing.note_error() self._process_timing.maybe_log() return False try: with self._process_timing.measure("Recorder.queue_put"): - q.put((frame, timestamp), block=False) + q.put((frame, timestamp, timestamp_metadata), block=False) except queue.Full: with self._stats_lock: self._dropped_frames += 1 @@ -422,9 +427,12 @@ def _writer_loop(self) -> None: break continue except Exception as exc: - with self._stats_lock: - self._encode_error = exc - logger.exception("Could not retrieve item from queue", exc_info=exc) + message = ( + f"Could not retrieve frame from recorder queue for {self._output.name}: " + f"{type(exc).__name__}: {exc!s}" + ) + self._set_encode_error(message, exc) + logger.exception(message) self._stop_event.set() break @@ -432,7 +440,7 @@ def _writer_loop(self) -> None: if item is _SENTINEL: break else: - frame, timestamp = item + frame, timestamp, timestamp_metadata = item start = time.perf_counter() try: @@ -443,10 +451,54 @@ def _writer_loop(self) -> None: with self._writer_timing.measure("Recorder.writer_write"): writer.write(frame) + record: dict[str, Any] = { + "frame_index": self._frames_written, + "software_timestamp": float(timestamp), + } + + if timestamp_metadata is not None: + if ( + hasattr(timestamp_metadata, "to_source_dict") + and self._hardware_timestamp_source is None + ): + self._hardware_timestamp_source = timestamp_metadata.to_source_dict() + + if hasattr(timestamp_metadata, "to_frame_dict"): + record["hardware_timestamp"] = timestamp_metadata.to_frame_dict() + if hasattr(timestamp_metadata, "get_default_reported"): + default_value = timestamp_metadata.get_default_reported() + if default_value is not None: + record["hardware_timestamp_default"] = default_value + elif isinstance(timestamp_metadata, dict): + record["hardware_timestamp"] = dict(timestamp_metadata) + else: + record["hardware_timestamp"] = repr(timestamp_metadata) + + self._frame_timestamps.append(record) + except Exception as exc: + queue_size = q.qsize() if q is not None else -1 + with self._stats_lock: - self._encode_error = exc - logger.exception("Video encoding failed while writing frame", exc_info=exc) + frames_enqueued = self._frames_enqueued + frames_written = self._frames_written + dropped_frames = self._dropped_frames + + message = ( + f"Video encoding failed for recorder {self._output.name}: " + f"{type(exc).__name__}: {exc!s}. " + f"{self._describe_frame(frame)} " + f"expected_frame_size={self._frame_size} " + f"frames_written={frames_written} " + f"frames_enqueued={frames_enqueued} " + f"dropped={dropped_frames} " + f"queue_size={queue_size}. " + "The FFmpeg/WriteGear pipe is no longer usable; stopping this recorder." + ) + + self._set_encode_error(message, exc) + + logger.exception(message) self._stop_event.set() self._writer_timing.note_error() self._writer_timing.maybe_log() @@ -459,7 +511,6 @@ def _writer_loop(self) -> None: self._total_latency += elapsed self._last_latency = elapsed self._written_times.append(now) - self._frame_timestamps.append(timestamp) if now - self._last_log_time >= 1.0: self._compute_write_fps_locked() self._last_log_time = now @@ -496,37 +547,80 @@ def _compute_write_fps_locked(self) -> float: return 0.0 return (len(self._written_times) - 1) / duration + def _describe_frame(self, frame: np.ndarray | None) -> str: + if frame is None: + return "frame=None" + + try: + return ( + f"shape={frame.shape} " + f"dtype={frame.dtype} " + f"contiguous={frame.flags.c_contiguous} " + f"nbytes={frame.nbytes / (1024 * 1024):.2f}MB" + ) + except Exception: + return f"frame=" + def _current_error(self) -> Exception | None: with self._stats_lock: return self._encode_error + def _set_encode_error(self, message: str, exc: Exception | None = None) -> Exception: + error = RuntimeError(message) + if exc is not None: + error.__cause__ = exc + + with self._stats_lock: + self._encode_error = error + + return error + def _save_timestamps(self) -> None: """Save frame timestamps to a JSON file alongside the video.""" if not self._frame_timestamps: logger.info("No timestamps to save") return - # Create timestamps file path timestamp_file = self._output.with_suffix("").with_suffix(self._output.suffix + "_timestamps.json") try: with self._stats_lock: - timestamps = self._frame_timestamps.copy() + frame_timestamps = self._frame_timestamps.copy() + hardware_timestamp_source = ( + dict(self._hardware_timestamp_source) if self._hardware_timestamp_source is not None else None + ) + + software_timestamps = [ + float(rec["software_timestamp"]) for rec in frame_timestamps if "software_timestamp" in rec + ] - # Prepare metadata data = { + "schema_version": 2, "video_file": str(self._output.name), - "num_frames": len(timestamps), - "timestamps": timestamps, - "start_time": timestamps[0] if timestamps else None, - "end_time": timestamps[-1] if timestamps else None, - "duration_seconds": timestamps[-1] - timestamps[0] if len(timestamps) > 1 else 0.0, + "num_frames": len(frame_timestamps), + # "timestamps": software_timestamps, + "timestamp_sources": { + "software_timestamp": { + "source": "host_time.time", + "backend": "host", + "kind": "software_wall_clock", + "timebase": "Unix epoch", + "unit": "seconds", + "description": "Host-side software timestamp captured during acquisition.", + }, + "hardware_timestamp": hardware_timestamp_source, + }, + "frame_timestamps": frame_timestamps, + "start_time": software_timestamps[0] if software_timestamps else None, + "end_time": software_timestamps[-1] if software_timestamps else None, + "duration_seconds": ( + software_timestamps[-1] - software_timestamps[0] if len(software_timestamps) > 1 else 0.0 + ), } - # Write to JSON with open(timestamp_file, "w") as f: json.dump(data, f, indent=2) - logger.info(f"Saved {len(timestamps)} frame timestamps to {timestamp_file}") + logger.info("Saved %d frame timestamps to %s", len(frame_timestamps), timestamp_file) except Exception as exc: - logger.exception(f"Failed to save timestamps to {timestamp_file}: {exc}") + logger.exception("Failed to save timestamps to %s: %s", timestamp_file, exc) diff --git a/dlclivegui/temp/engine.py b/dlclivegui/temp/engine.py index a6bb225..85c4755 100644 --- a/dlclivegui/temp/engine.py +++ b/dlclivegui/temp/engine.py @@ -6,7 +6,7 @@ # or if we update dlclive.Engine to have these methods and use that instead of a separate enum here. # The latter would be more cohesive but also creates a dependency from utils to dlclive, # pending release of dlclive -class Engine(Enum): +class Engine(str, Enum): TENSORFLOW = "tensorflow" PYTORCH = "pytorch" @@ -26,6 +26,12 @@ def is_tensorflow_model_dir_path(model_path: str | Path) -> bool: @classmethod def from_model_type(cls, model_type: str) -> "Engine": + if not isinstance(model_type, str): + try: + model_type = getattr(model_type, "value", str(model_type)) + except Exception as e: + raise ValueError(f"Could not convert model_type to string: {model_type}") from e + if model_type.lower() == "pytorch": return cls.PYTORCH elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"): diff --git a/dlclivegui/utils/settings_store.py b/dlclivegui/utils/settings_store.py index a0c5677..0107afb 100644 --- a/dlclivegui/utils/settings_store.py +++ b/dlclivegui/utils/settings_store.py @@ -57,6 +57,42 @@ def get_fast_encoding(self, default: bool = False) -> bool: return value return str(value).strip().lower() in {"1", "true", "yes", "on"} + def get_processor_folder(self, default: str = "") -> str: + """ + Return the persisted processor folder if it still exists and is a directory. + Otherwise return default. + """ + value = self._s.value("dlc/processor_folder", default) + value = str(value).strip() if value is not None else "" + + if not value: + return default + + try: + path = Path(value).expanduser() + if path.is_dir(): + return str(path.resolve()) + except Exception: + logger.debug("Persisted processor folder is invalid: %s", value, exc_info=True) + + return default + + def set_processor_folder(self, folder: str) -> None: + """ + Persist processor folder only if it exists and is a directory. + Invalid folders are ignored. + """ + folder = str(folder).strip() if folder is not None else "" + if not folder: + return + + try: + path = Path(folder).expanduser() + if path.is_dir(): + self._s.setValue("dlc/processor_folder", str(path.resolve())) + except Exception: + logger.debug("Failed to persist processor folder: %s", folder, exc_info=True) + def set_fast_encoding(self, enabled: bool) -> None: self._s.setValue("recording/fast_encoding", bool(enabled)) diff --git a/dlclivegui/utils/timestamps.py b/dlclivegui/utils/timestamps.py new file mode 100644 index 0000000..dea14ed --- /dev/null +++ b/dlclivegui/utils/timestamps.py @@ -0,0 +1,82 @@ +# dlclivegui/utils/timestamps.py +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class FrameTimestampMetadata: + """Optional backend-provided timestamp metadata for a captured frame. + + This supplements, but does not replace, the software timestamp. + + Notes: + - `seconds` is in the backend/hardware timebase. + - `wall_clock_time` should only be set if the backend can confidently + provide Unix/wall-clock seconds. + - `raw_value` preserves the original device-specific timestamp. + """ + + source: str + backend: str + + # Which value should downstream consumers use by default, if any. + # Expected values: "seconds", "wall_clock_time", or "raw_value". + default_reported: str | None = None + + # Device/hardware timebase value, if convertible to seconds + seconds: float | None = None + + # True Unix/wall-clock timestamp, if available + wall_clock_time: float | None = None + + # Raw backend value, e.g. device clock ticks + raw_value: int | float | str | None = None + raw_unit: str | None = None + + # Conversion metadata. + tick_frequency_hz: float | None = None + timebase: str | None = None + + # e.g. "camera_clock", "ptp_camera_clock", "hardware_wall_clock", + # "frame_counter", "unknown" + kind: str = "unknown" + + # Backend-specific extras. + extra: dict[str, Any] | None = None + + def to_source_dict(self) -> dict[str, Any]: + """Return metadata that should be written once per recording stream.""" + return { + "source": self.source, + "backend": self.backend, + "default_reported": self.default_reported, + "raw_unit": self.raw_unit, + "tick_frequency_hz": self.tick_frequency_hz, + "timebase": self.timebase, + "kind": self.kind, + "extra": self.extra or {}, + } + + def to_frame_dict(self) -> dict[str, Any]: + """Return defined per-frame timestamp values only.""" + ts = {} + for k in ["seconds", "wall_clock_time", "raw_value"]: + v = getattr(self, k) + if v is not None: + ts[k] = v + return ts + + def to_dict(self) -> dict[str, Any]: + """Return full representation, useful for logging/debugging.""" + return { + **self.to_source_dict(), + **self.to_frame_dict(), + } + + def get_default_reported(self) -> int | float | str | None: + """Return the value selected by `default_reported`, if configured.""" + if not self.default_reported: + return None + return self.to_frame_dict().get(self.default_reported) diff --git a/tests/cameras/backends/conftest.py b/tests/cameras/backends/conftest.py index 5bbcac3..a74a0ee 100644 --- a/tests/cameras/backends/conftest.py +++ b/tests/cameras/backends/conftest.py @@ -525,6 +525,9 @@ def GrabSucceeded(self): def Release(self): self.released = True + def GetTimeStamp(self): + return 123456789 + class InstantCamera: def __init__(self, device): self._device = device @@ -549,6 +552,7 @@ def __init__(self, device): self.AcquisitionFrameRateEnable = FakePylon._Feature(False) self.AcquisitionFrameRate = FakePylon._Feature(30.0) + self.GevTimestampTickFrequency = FakePylon._Feature(1_000_000_000.0) self.MaxNumBuffer = FakePylon._Feature(10) diff --git a/tests/cameras/backends/test_aravis_backend.py b/tests/cameras/backends/test_aravis_backend.py index 797fd11..4f7ac55 100644 --- a/tests/cameras/backends/test_aravis_backend.py +++ b/tests/cameras/backends/test_aravis_backend.py @@ -243,7 +243,7 @@ def make_backend(settings, buffers): @pytest.mark.unit def test_device_name(): - be, cam, s = make_backend(Settings(), []) + be, _cam, s = make_backend(Settings(), []) assert be.device_name() == "FakeVendor FakeModel (12345)" @@ -253,9 +253,9 @@ def test_read_mono8(): data = (np.arange(w * h) % 256).astype(np.uint8).tobytes() buf = FakeAravis.Buffer(data, w, h, FakeAravis.PIXEL_FORMAT_MONO_8) - be, cam, s = make_backend(Settings(), [buf]) + be, _cam, s = make_backend(Settings(), [buf]) - frame, ts = be.read() + frame = be.read().frame assert frame.shape == (h, w, 3) assert frame.dtype == np.uint8 # Ensure grayscale expanded to 3 channels @@ -272,9 +272,9 @@ def test_read_rgb8_converts_to_bgr(): data = np.array([255, 0, 0, 0, 255, 0], dtype=np.uint8).tobytes() buf = FakeAravis.Buffer(data, w, h, FakeAravis.PIXEL_FORMAT_RGB_8_PACKED) - be, cam, s = make_backend(Settings(), [buf]) + be, _cam, s = make_backend(Settings(), [buf]) - frame, _ = be.read() + frame = be.read().frame assert frame.shape == (1, 2, 3) # BGR conversion: red → [0,0,255], green → [0,255,0] assert (frame[0, 0] == np.array([0, 0, 255])).all() @@ -288,9 +288,9 @@ def test_read_bgr8_passthrough(): data = np.array([10, 20, 30, 40, 50, 60], dtype=np.uint8).tobytes() buf = FakeAravis.Buffer(data, w, h, FakeAravis.PIXEL_FORMAT_BGR_8_PACKED) - be, cam, s = make_backend(Settings(), [buf]) + be, _cam, s = make_backend(Settings(), [buf]) - frame, _ = be.read() + frame = be.read().frame assert frame.shape == (1, 2, 3) assert (frame.flatten() == np.array([10, 20, 30, 40, 50, 60])).all() assert s.pushed >= 1 @@ -302,9 +302,9 @@ def test_read_mono16_scaling(): raw = np.array([0, 32768, 65535], dtype=np.uint16) buf = FakeAravis.Buffer(raw.tobytes(), w, h, FakeAravis.PIXEL_FORMAT_MONO_16) - be, cam, s = make_backend(Settings(), [buf]) + be, _cam, s = make_backend(Settings(), [buf]) - frame, _ = be.read() + frame = be.read().frame assert frame.shape == (1, 3, 3) # scaling: 0 → 0, max → 255, mid → ~128 @@ -320,9 +320,9 @@ def test_read_unknown_format_fallback_to_mono8(): data = (np.arange(w * h) % 256).astype(np.uint8).tobytes() # Unknown token buf = FakeAravis.Buffer(data, w, h, "SOME_UNKNOWN_FMT") - be, cam, s = make_backend(Settings(), [buf]) + be, _cam, s = make_backend(Settings(), [buf]) - frame, _ = be.read() + frame = be.read().frame assert frame.shape == (h, w, 3) assert np.all(frame[..., 0] == frame[..., 1]) assert np.all(frame[..., 1] == frame[..., 2]) @@ -331,7 +331,7 @@ def test_read_unknown_format_fallback_to_mono8(): @pytest.mark.unit def test_read_timeout_raises(): - be, cam, s = make_backend(Settings(), []) + be, _cam, s = make_backend(Settings(), []) with pytest.raises(TimeoutError): be.read() @@ -341,7 +341,7 @@ def test_read_status_error_raises_and_pushes_back(): w, h = 1, 1 data = b"\x00" buf = FakeAravis.Buffer(data, w, h, FakeAravis.PIXEL_FORMAT_MONO_8, status="ERROR") - be, cam, s = make_backend(Settings(), [buf]) + be, _cam, s = make_backend(Settings(), [buf]) with pytest.raises(TimeoutError): be.read() @@ -350,7 +350,7 @@ def test_read_status_error_raises_and_pushes_back(): @pytest.mark.unit def test_close_is_idempotent(): - be, cam, s = make_backend(Settings(), []) + be, _cam, s = make_backend(Settings(), []) be.close() be.close() # should not raise diff --git a/tests/cameras/backends/test_basler_backend.py b/tests/cameras/backends/test_basler_backend.py index 18f49a1..48b8a29 100644 --- a/tests/cameras/backends/test_basler_backend.py +++ b/tests/cameras/backends/test_basler_backend.py @@ -3,6 +3,9 @@ import numpy as np import pytest +from dlclivegui.cameras.base import CapturedFrame +from dlclivegui.utils.timestamps import FrameTimestampMetadata + # --------------------------------------------------------------------- # Core lifecycle # --------------------------------------------------------------------- @@ -21,7 +24,8 @@ def test_basler_open_starts_grabbing_and_read_returns_frame(patch_basler_sdk, ba assert be._camera.IsGrabbing() assert be._converter is not None - frame, ts = be.read() + payload = be.read() + frame, ts = payload.frame, payload.software_timestamp assert isinstance(ts, float) assert isinstance(frame, np.ndarray) assert frame.shape == (10, 10, 3) @@ -257,7 +261,8 @@ def test_basler_default_trigger_is_off_and_free_runs( assert be._camera.TriggerMode.GetValue() == "Off" assert be.waits_for_hardware_trigger is False - frame, _ = be.read() + payload = be.read() + frame = payload.frame assert frame.shape == (10, 10, 3) be.close() @@ -356,7 +361,8 @@ def test_basler_follower_non_strict_invalid_source_disables_trigger( assert be._camera.TriggerMode.GetValue() == "Off" assert be.waits_for_hardware_trigger is False - frame, _ = be.read() + payload = be.read() + frame = payload.frame assert frame.shape == (10, 10, 3) be.close() @@ -430,7 +436,7 @@ def test_basler_software_trigger_requires_trigger_once_before_read( be.trigger_once() assert be._camera.software_trigger_calls == 1 - frame, _ = be.read() + frame = be.read().frame assert frame.shape == (10, 10, 3) be.close() @@ -463,3 +469,45 @@ def test_basler_close_turns_input_trigger_off( be.close() assert cam.TriggerMode.GetValue() == "Off" + + +class TestBaslerFrameTimestamps: + @pytest.mark.unit + def test_read_returns_captured_frame_with_hardware_timestamp_metadata( + self, + patch_basler_sdk, + basler_settings_factory, + ): + import dlclivegui.cameras.backends.basler_backend as bb + + settings = basler_settings_factory() + be = bb.BaslerCameraBackend(settings) + be.open() + + captured = be.read() + + assert isinstance(captured, CapturedFrame) + assert captured.frame is not None + assert isinstance(captured.software_timestamp, float) + + meta = captured.timestamp_metadata + assert isinstance(meta, FrameTimestampMetadata) + + assert meta.backend == "basler" + assert meta.source == "grab_result.GetTimeStamp" + assert meta.kind == "camera_clock" + assert meta.raw_unit == "ticks" + assert meta.raw_value == 123456789 + assert meta.tick_frequency_hz == pytest.approx(1_000_000_000.0) + assert meta.seconds == pytest.approx(0.123456789) + assert meta.default_reported == "seconds" + + source_dict = meta.to_source_dict() + assert source_dict["backend"] == "basler" + assert source_dict["source"] == "grab_result.GetTimeStamp" + + frame_dict = meta.to_frame_dict() + assert frame_dict["seconds"] == pytest.approx(0.123456789) + assert frame_dict["raw_value"] == 123456789 + + be.close() diff --git a/tests/cameras/backends/test_gentl_backend.py b/tests/cameras/backends/test_gentl_backend.py index 3ffdab2..3cb7d9e 100644 --- a/tests/cameras/backends/test_gentl_backend.py +++ b/tests/cameras/backends/test_gentl_backend.py @@ -54,12 +54,12 @@ def test_open_starts_stream_and_read_returns_frame(patch_gentl_sdk, gentl_settin assert be._acquirer is not None # Strict model validated via behavior: read must succeed after normal open() - frame, ts = be.read() - assert isinstance(ts, float) - assert isinstance(frame, np.ndarray) - assert frame.size > 0 + captured = be.read() + assert isinstance(captured.software_timestamp, float) + assert isinstance(captured.frame, np.ndarray) + assert captured.frame.size > 0 # Backend converts to BGR; ensure 3-channel output - assert frame.ndim == 3 and frame.shape[2] == 3 + assert captured.frame.ndim == 3 and captured.frame.shape[2] == 3 be.close() assert be._harvester is None @@ -422,7 +422,7 @@ def test_pixel_format_unavailable_does_not_crash_open_and_streams(patch_gentl_sd be.open() # No fake-internal checks; just verify it can read - frame, _ = be.read() + frame = be.read().frame assert frame is not None and frame.size > 0 be.close() diff --git a/tests/cameras/backends/test_gentl_trigger.py b/tests/cameras/backends/test_gentl_trigger.py index 57339a1..b445f4e 100644 --- a/tests/cameras/backends/test_gentl_trigger.py +++ b/tests/cameras/backends/test_gentl_trigger.py @@ -289,7 +289,7 @@ def test_trigger_timeout_is_capped_for_hardware_trigger_fetch_polling( assert be._timeout == pytest.approx(expected_fetch_timeout) # Fake acquisition is started, so read should pass and record the capped timeout. - frame, _ = be.read() + frame = be.read().frame assert frame is not None assert be._acquirer.fetch_calls[-1] == pytest.approx(expected_fetch_timeout) diff --git a/tests/cameras/backends/test_opencv_backend.py b/tests/cameras/backends/test_opencv_backend.py index 2f15578..5fff099 100644 --- a/tests/cameras/backends/test_opencv_backend.py +++ b/tests/cameras/backends/test_opencv_backend.py @@ -124,7 +124,8 @@ def test_read_returns_none_on_grab_failure(fake_capture_factory): cap.grab_ok = False backend._capture = cap - frame, ts = backend.read() + payload = backend.read() + frame, ts = payload.frame, payload.software_timestamp assert frame is None assert isinstance(ts, float) @@ -135,7 +136,8 @@ def test_read_returns_none_on_retrieve_failure(fake_capture_factory): cap.retrieve_ok = False backend._capture = cap - frame, ts = backend.read() + payload = backend.read() + frame, ts = payload.frame, payload.software_timestamp assert frame is None assert isinstance(ts, float) @@ -150,7 +152,8 @@ def boom(): cap.grab = boom backend._capture = cap - frame, ts = backend.read() + payload = backend.read() + frame, ts = payload.frame, payload.software_timestamp assert frame is None assert isinstance(ts, float) diff --git a/tests/conftest.py b/tests/conftest.py index 49cd1c6..f04941e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from dlclivegui.cameras import CameraFactory from dlclivegui.cameras.base import ( CameraBackend, + CapturedFrame, SupportLevel, register_backend_direct, unregister_backend, @@ -86,7 +87,7 @@ def read(self): raise RuntimeError("not opened") self._counter += 1 frame = np.zeros(frame_shape, dtype=np.uint8) - return frame, float(timestamp_fn()) + return CapturedFrame(frame=frame, software_timestamp=float(timestamp_fn()), timestamp_metadata=None) _TestBackend.__name__ = f"TestBackend_{name}" return _TestBackend @@ -391,10 +392,10 @@ def start(self): def stop(self): self.stopped = True - def write(self, frame, timestamp=None): + def write(self, frame, timestamp=None, timestamp_metadata=None): if self.raise_on_write: raise RuntimeError("write failed") - self.write_calls.append((frame, timestamp)) + self.write_calls.append((frame, timestamp, timestamp_metadata)) return True def get_stats(self): @@ -418,7 +419,7 @@ def patch_video_recorder(monkeypatch): def recording_frame_spy(monkeypatch, window): captured = {} - def _fake_write_frame(cam_id, frame, timestamp=None): + def _fake_write_frame(cam_id, frame, timestamp=None, timestamp_metadata=None): captured[cam_id] = frame.copy() monkeypatch.setattr(window._rec_manager, "write_frame", _fake_write_frame) diff --git a/tests/custom_processors/test_base_processor.py b/tests/custom_processors/test_base_processor.py index d38749b..94dabab 100644 --- a/tests/custom_processors/test_base_processor.py +++ b/tests/custom_processors/test_base_processor.py @@ -13,15 +13,21 @@ def _mock_dlclive(monkeypatch): - """Provide a dummy dlclive.Processor so the module can import in tests.""" - fake = types.ModuleType("dlclive") - class Processor: def __init__(self, *args, **kwargs): pass - fake.Processor = Processor - monkeypatch.setitem(sys.modules, "dlclive", fake) + def process(self, pose, **kwargs): + return pose + + dlclive_mod = types.ModuleType("dlclive") + processor_mod = types.ModuleType("dlclive.processor") + + dlclive_mod.Processor = Processor + processor_mod.Processor = Processor + + monkeypatch.setitem(sys.modules, "dlclive", dlclive_mod) + monkeypatch.setitem(sys.modules, "dlclive.processor", processor_mod) @pytest.fixture @@ -37,6 +43,19 @@ def socket_mod(monkeypatch): return importlib.import_module(mod_name) +@pytest.fixture +def example_processor_mod(monkeypatch): + """ + Import the example processor module with dlclive mocked. + Adjust module name if your file lives elsewhere. + """ + _mock_dlclive(monkeypatch) + mod_name = "dlclivegui.processors.examples" + if mod_name in sys.modules: + del sys.modules[mod_name] + return importlib.import_module(mod_name) + + def _module_data_dir(socket_mod) -> Path: """Compute the data/ directory where save() writes artifacts.""" return Path(socket_mod.__file__).parent.parent.parent / "data" @@ -233,12 +252,14 @@ def test_save_ignores_pre_recording_original_pose_frames(socket_mod): ("ExampleProcessorSocketFilterKeypoints", 10), ], ) -def test_subclass_save_ignores_pre_recording_original_pose_frames(socket_mod, class_name, n_keypoints): +def test_subclass_save_ignores_pre_recording_original_pose_frames( + socket_mod, example_processor_mod, class_name, n_keypoints +): """ Concrete processors must keep original_pose aligned with recorded metadata even when process() is called before recording starts. """ - processor_class = getattr(socket_mod, class_name) + processor_class = getattr(example_processor_mod, class_name) proc = processor_class(bind=("127.0.0.1", 0), save_original=True) try: diff --git a/tests/gui/test_rec_manager.py b/tests/gui/test_rec_manager.py index f97c43a..cf4bca2 100644 --- a/tests/gui/test_rec_manager.py +++ b/tests/gui/test_rec_manager.py @@ -7,6 +7,7 @@ from dlclivegui.gui.recording_manager import RecordingManager from dlclivegui.services.multi_camera_controller import get_camera_id, get_display_id from dlclivegui.utils.stats import RecorderStats +from dlclivegui.utils.timestamps import FrameTimestampMetadata @pytest.fixture @@ -422,3 +423,41 @@ def test_start_all_passes_writegear_options( assert rec.writer_options["-crf"] == "23" assert rec.writer_options["-preset"] == "ultrafast" assert rec.writer_options["-tune"] == "zerolatency" + + +class TestRecordingManagerTimestampMetadata: + @pytest.mark.unit + def test_write_frame_passes_timestamp_metadata( + self, + recording_settings, + _active_cams_two, + current_frames, + patch_video_recorder, + patch_build_run_dir, + ): + mgr = RecordingManager() + mgr.start_all(recording_settings, _active_cams_two, current_frames, session_name="Sess") + + cam0_id = get_camera_id(_active_cams_two[0]) + frame = current_frames[cam0_id] + + meta = FrameTimestampMetadata( + source="grab_result.GetTimeStamp", + backend="basler", + default_reported="seconds", + seconds=0.001, + raw_value=1_000_000, + raw_unit="ticks", + tick_frequency_hz=1_000_000_000.0, + kind="camera_clock", + ) + + mgr.write_frame(cam0_id, frame, timestamp=123.0, timestamp_metadata=meta) + + rec = mgr.recorders[cam0_id] + assert len(rec.write_calls) == 1 + + written_frame, written_timestamp, written_metadata = rec.write_calls[0] + assert written_frame is frame + assert written_timestamp == 123.0 + assert written_metadata is meta diff --git a/tests/services/test_multicam_controller.py b/tests/services/test_multicam_controller.py index 747b5da..4eafbda 100644 --- a/tests/services/test_multicam_controller.py +++ b/tests/services/test_multicam_controller.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from dlclivegui.cameras.factory import CameraFactory @@ -9,6 +10,7 @@ get_camera_id, get_display_id, ) +from dlclivegui.utils.timestamps import FrameTimestampMetadata @pytest.mark.unit @@ -517,7 +519,7 @@ def test_recording_frame_ready_only_emits_when_enabled(qtbot, patch_factory): cam_id = get_camera_id(cam) seen: list[tuple[str, tuple, float]] = [] - def on_recording_frame(camera_id, frame, timestamp): + def on_recording_frame(camera_id, frame, timestamp, timestamp_metadata=None): seen.append((camera_id, frame.shape, timestamp)) mc.recording_frame_ready.connect(on_recording_frame) @@ -548,3 +550,50 @@ def on_recording_frame(camera_id, frame, timestamp): finally: with qtbot.waitSignal(mc.all_stopped, timeout=2000): mc.stop(wait=True) + + +class TestRecordingFrameTimestamps: + @pytest.mark.unit + def test_recording_frame_ready_forwards_timestamp_metadata(self, qtbot): + mc = MultiCameraController() + mc._running = True + mc._recording_frame_emission_enabled = True + + cam_id = "basler:0815-0000" + mc._settings[cam_id] = CameraSettings( + name="C", + backend="basler", + index=0, + enabled=True, + ).apply_defaults() + mc._camera_display_order = [cam_id] + mc._display_ids[cam_id] = "C" + + frame = np.zeros((10, 10), dtype=np.uint8) + meta = FrameTimestampMetadata( + source="grab_result.GetTimeStamp", + backend="basler", + default_reported="seconds", + seconds=0.001, + raw_value=1_000_000, + raw_unit="ticks", + tick_frequency_hz=1_000_000_000.0, + kind="camera_clock", + ) + + seen = [] + + def on_recording_frame(camera_id, emitted_frame, timestamp, timestamp_metadata): + seen.append((camera_id, emitted_frame, timestamp, timestamp_metadata)) + + mc.recording_frame_ready.connect(on_recording_frame) + + mc._on_frame_captured(cam_id, frame, 123.0, meta) + + assert len(seen) == 1 + + camera_id, emitted_frame, timestamp, timestamp_metadata = seen[0] + assert camera_id == cam_id + assert emitted_frame is frame + assert timestamp == 123.0 + assert timestamp_metadata is meta diff --git a/tests/services/test_video_recorder.py b/tests/services/test_video_recorder.py index efde6e2..8389fbb 100644 --- a/tests/services/test_video_recorder.py +++ b/tests/services/test_video_recorder.py @@ -9,6 +9,7 @@ import pytest import dlclivegui.services.video_recorder as vr_mod +from dlclivegui.utils.timestamps import FrameTimestampMetadata # ---------------------------- # Helpers @@ -228,10 +229,13 @@ def test_stop_writes_timestamps_sidecar_json(patch_writegear, output_path, rgb_f data = json.loads(ts_path.read_text()) assert data["video_file"] == output_path.name assert data["num_frames"] == 2 - assert data["timestamps"] == [10.0, 12.0] assert data["start_time"] == 10.0 assert data["end_time"] == 12.0 assert data["duration_seconds"] == 2.0 + assert data["schema_version"] == 2 + assert data["timestamp_sources"]["hardware_timestamp"] is None + assert data["frame_timestamps"][0]["software_timestamp"] == 10.0 + assert data["frame_timestamps"][1]["software_timestamp"] == 12.0 def test_encoder_write_error_sets_encode_error_and_future_writes_raise(patch_writegear, output_path, rgb_frame): @@ -418,3 +422,109 @@ def close(self): rec.stop() assert written[0].shape == (10, 20, 3) + + +class TestVideoRecorderTimestampSidecar: + def test_stop_writes_software_only_timestamp_sidecar_json( + self, + patch_writegear, + output_path, + rgb_frame, + ): + rec = vr_mod.VideoRecorder(output_path, buffer_size=10) + rec.start() + + rec.write(rgb_frame, timestamp=10.0) + rec.write(rgb_frame, timestamp=12.0) + + wait_until(lambda: len(FakeWriteGear.instances[0].frames) >= 2) + rec.stop() + + ts_path = output_path.with_suffix("").with_suffix(output_path.suffix + "_timestamps.json") + assert ts_path.exists() + + data = json.loads(ts_path.read_text()) + + assert data["schema_version"] == 2 + assert data["video_file"] == output_path.name + assert data["num_frames"] == 2 + + assert data["timestamp_sources"]["software_timestamp"]["kind"] == "software_wall_clock" + assert data["timestamp_sources"]["hardware_timestamp"] is None + + assert data["frame_timestamps"] == [ + { + "frame_index": 0, + "software_timestamp": 10.0, + }, + { + "frame_index": 1, + "software_timestamp": 12.0, + }, + ] + + def test_stop_writes_hardware_timestamp_metadata_sidecar_json( + self, + patch_writegear, + output_path, + rgb_frame, + ): + rec = vr_mod.VideoRecorder(output_path, buffer_size=10) + rec.start() + + meta = FrameTimestampMetadata( + source="grab_result.GetTimeStamp", + backend="basler", + default_reported="seconds", + seconds=0.001, + raw_value=1_000_000, + raw_unit="ticks", + tick_frequency_hz=1_000_000_000.0, + timebase="Basler camera timestamp counter", + kind="camera_clock", + ) + + rec.write(rgb_frame, timestamp=10.0, timestamp_metadata=meta) + + wait_until(lambda: len(FakeWriteGear.instances[0].frames) >= 1) + rec.stop() + + ts_path = output_path.with_suffix("").with_suffix(output_path.suffix + "_timestamps.json") + assert ts_path.exists() + + data = json.loads(ts_path.read_text()) + + assert data["schema_version"] == 2 + assert data["video_file"] == output_path.name + assert data["num_frames"] == 1 + + # Backward-compatible software timestamp list. + assert data["start_time"] == 10.0 + assert data["end_time"] == 10.0 + assert data["duration_seconds"] == 0.0 + + # Static hardware source metadata is written once. + hw_source = data["timestamp_sources"]["hardware_timestamp"] + assert hw_source == { + "source": "grab_result.GetTimeStamp", + "backend": "basler", + "default_reported": "seconds", + "raw_unit": "ticks", + "tick_frequency_hz": 1_000_000_000.0, + "timebase": "Basler camera timestamp counter", + "kind": "camera_clock", + "extra": {}, + } + + # Per-frame records contain only per-frame values. + frame_ts = data["frame_timestamps"] + assert len(frame_ts) == 1 + + rec0 = frame_ts[0] + assert rec0["frame_index"] == 0 + assert rec0["software_timestamp"] == 10.0 + assert rec0["hardware_timestamp"] == { + "seconds": 0.001, + "raw_value": 1_000_000, + } + assert rec0["hardware_timestamp_default"] == 0.001 diff --git a/tests/utils/test_timestamps.py b/tests/utils/test_timestamps.py new file mode 100644 index 0000000..5608729 --- /dev/null +++ b/tests/utils/test_timestamps.py @@ -0,0 +1,63 @@ +import pytest + +from dlclivegui.utils.timestamps import FrameTimestampMetadata + + +class TestFrameTimestampMetadata: + def test_splits_source_and_frame_values(self): + meta = FrameTimestampMetadata( + source="grab_result.GetTimeStamp", + backend="basler", + default_reported="seconds", + seconds=0.123456789, + wall_clock_time=None, + raw_value=123456789, + raw_unit="ticks", + tick_frequency_hz=1_000_000_000.0, + timebase="Basler camera timestamp counter", + kind="camera_clock", + ) + + assert meta.to_source_dict() == { + "source": "grab_result.GetTimeStamp", + "backend": "basler", + "default_reported": "seconds", + "raw_unit": "ticks", + "tick_frequency_hz": 1_000_000_000.0, + "timebase": "Basler camera timestamp counter", + "kind": "camera_clock", + "extra": {}, + } + + frame_dict = meta.to_frame_dict() + assert frame_dict["seconds"] == pytest.approx(0.123456789) + assert frame_dict["raw_value"] == 123456789 + assert "wall_clock_time" not in frame_dict + + assert meta.get_default_reported() == pytest.approx(0.123456789) + + def test_default_reported_raw_value(self): + meta = FrameTimestampMetadata( + source="device_counter", + backend="some_backend", + default_reported="raw_value", + raw_value=42, + raw_unit="frames", + kind="frame_counter", + ) + + assert meta.to_frame_dict() == {"raw_value": 42} + assert meta.get_default_reported() == 42 + + def test_unknown_default_field_returns_none(self): + meta = FrameTimestampMetadata( + source="device_counter", + backend="some_backend", + default_reported="seconds", + raw_value=42, + raw_unit="frames", + kind="frame_counter", + ) + + assert meta.to_frame_dict() == {"raw_value": 42} + assert meta.get_default_reported() is None