Source code for pysepal.solara.notifications.bus

"""Kernel-scoped notification bus: state management and registry."""

import logging
import threading
from dataclasses import replace
from typing import Dict, Optional

import solara
import solara.server.kernel_context

from .state import Toast, ToastType, TrackedTask

logger = logging.getLogger(__name__)

MAX_TOAST_QUEUE = 20
MAX_TASK_HISTORY = 50
DEDUP_WINDOW_SECONDS = 2.0


[docs] class NotificationBus: """Owns notification state for a single kernel/session. All mutations produce new list copies (never mutate in place). Thread-safe via internal lock. """
[docs] def __init__(self): """Initialize reactive state containers and thread lock.""" self.toasts: solara.Reactive[list[Toast]] = solara.reactive([]) self.tasks: solara.Reactive[list[TrackedTask]] = solara.reactive([]) self._lock = threading.Lock()
[docs] def add_toast(self, toast: Toast) -> None: """Add a toast, applying dedup and queue limit rules. Error toasts replace previous errors (only the latest error is kept). """ with self._lock: current = list(self.toasts.value) # Error replacement: new errors remove all previous errors if toast.type == ToastType.ERROR: current = [t for t in current if t.type != ToastType.ERROR] current.append(toast) # Still enforce queue limit if len(current) > MAX_TOAST_QUEUE: current = current[-MAX_TOAST_QUEUE:] self.toasts.value = current return # Dedup: merge if identical message+type within window for i, existing in enumerate(current): if ( existing.message == toast.message and existing.type == toast.type and (toast.created_at - existing.created_at) < DEDUP_WINDOW_SECONDS ): # Refresh the toast identity/timestamp so the frontend # resets its dismiss timer and progress bar on repeated # notifications instead of expiring relative to the first # occurrence in the burst. current[i] = replace( existing, id=toast.id, created_at=toast.created_at, timeout=toast.timeout, count=existing.count + 1, ) self.toasts.value = current return current.append(toast) # Enforce queue limit: drop oldest non-errors first, then oldest errors if len(current) > MAX_TOAST_QUEUE: errors = [t for t in current if t.type == ToastType.ERROR] non_errors = [t for t in current if t.type != ToastType.ERROR] # Cap errors themselves so total never exceeds MAX_TOAST_QUEUE errors = errors[-MAX_TOAST_QUEUE:] keep_non_errors = max(0, MAX_TOAST_QUEUE - len(errors)) non_errors = non_errors[-keep_non_errors:] if keep_non_errors else [] current = errors + non_errors self.toasts.value = current
[docs] def remove_toast(self, toast_id: str) -> None: """Remove a toast by ID.""" with self._lock: self.toasts.value = [t for t in self.toasts.value if t.id != toast_id]
[docs] def add_task(self, task: TrackedTask) -> None: """Add a tracked task. Prunes oldest finished tasks beyond MAX_TASK_HISTORY.""" with self._lock: current = [*self.tasks.value, task] if len(current) > MAX_TASK_HISTORY: # Keep running/pending tasks, prune oldest finished active = [t for t in current if t.status.value in ("running", "pending")] finished = [t for t in current if t.status.value not in ("running", "pending")] finished = finished[-(MAX_TASK_HISTORY - len(active)) :] current = active + finished self.tasks.value = current
[docs] def update_task(self, task_id: str, **changes) -> None: """Update a tracked task by ID. Unknown IDs are silently ignored.""" with self._lock: self.tasks.value = [ replace(t, **changes) if t.id == task_id else t for t in self.tasks.value ]
[docs] def remove_task(self, task_id: str) -> None: """Remove a tracked task by ID.""" with self._lock: self.tasks.value = [t for t in self.tasks.value if t.id != task_id]
# --- Kernel-scoped bus registry (matches SessionManager pattern) --- _buses: Dict[str, NotificationBus] = {} _bus_refcounts: Dict[str, int] = {} _registry_lock = threading.Lock() def _get_kernel_id() -> str: """Get current Solara kernel ID (same approach as SessionManager).""" return str(id(solara.server.kernel_context.get_current_context().kernel))
[docs] def get_current_bus() -> Optional[NotificationBus]: """Get the NotificationBus for the current kernel, or None.""" try: kernel_id = _get_kernel_id() except Exception: return None with _registry_lock: return _buses.get(kernel_id)
[docs] def create_bus() -> NotificationBus: """Get or create a NotificationBus for the current kernel. If a bus already exists for this kernel, reuse it and increment the reference count. This prevents remounts or double-mounts from invalidating active notifiers. """ kernel_id = _get_kernel_id() with _registry_lock: existing = _buses.get(kernel_id) if existing is not None: _bus_refcounts[kernel_id] = _bus_refcounts.get(kernel_id, 1) + 1 logger.debug( f"Reusing NotificationBus for kernel {kernel_id} " f"(refcount={_bus_refcounts[kernel_id]})" ) return existing bus = NotificationBus() _buses[kernel_id] = bus _bus_refcounts[kernel_id] = 1 logger.debug(f"Created NotificationBus for kernel {kernel_id}") return bus
[docs] def cleanup_bus() -> None: """Decrement refcount for the current kernel's bus; remove when it reaches 0.""" kernel_id = _get_kernel_id() with _registry_lock: count = _bus_refcounts.get(kernel_id, 0) if count <= 1: _buses.pop(kernel_id, None) _bus_refcounts.pop(kernel_id, None) logger.debug(f"Cleaned up NotificationBus for kernel {kernel_id}") else: _bus_refcounts[kernel_id] = count - 1 logger.debug( f"Decremented NotificationBus refcount for kernel {kernel_id} " f"(refcount={_bus_refcounts[kernel_id]})" )