Source code for pysepal.solara.notifications.bus
"""Kernel-scoped notification bus: state management and registry."""
import logging
import threading
from dataclasses import replace
from typing import Dict, Optional
import solara
import solara.server.kernel_context
from .state import Toast, ToastType, TrackedTask
logger = logging.getLogger(__name__)
MAX_TOAST_QUEUE = 20
MAX_TASK_HISTORY = 50
DEDUP_WINDOW_SECONDS = 2.0
[docs]
class NotificationBus:
"""Owns notification state for a single kernel/session.
All mutations produce new list copies (never mutate in place).
Thread-safe via internal lock.
"""
[docs]
def __init__(self):
"""Initialize reactive state containers and thread lock."""
self.toasts: solara.Reactive[list[Toast]] = solara.reactive([])
self.tasks: solara.Reactive[list[TrackedTask]] = solara.reactive([])
self._lock = threading.Lock()
[docs]
def add_toast(self, toast: Toast) -> None:
"""Add a toast, applying dedup and queue limit rules.
Error toasts replace previous errors (only the latest error is kept).
"""
with self._lock:
current = list(self.toasts.value)
# Error replacement: new errors remove all previous errors
if toast.type == ToastType.ERROR:
current = [t for t in current if t.type != ToastType.ERROR]
current.append(toast)
# Still enforce queue limit
if len(current) > MAX_TOAST_QUEUE:
current = current[-MAX_TOAST_QUEUE:]
self.toasts.value = current
return
# Dedup: merge if identical message+type within window
for i, existing in enumerate(current):
if (
existing.message == toast.message
and existing.type == toast.type
and (toast.created_at - existing.created_at) < DEDUP_WINDOW_SECONDS
):
# Refresh the toast identity/timestamp so the frontend
# resets its dismiss timer and progress bar on repeated
# notifications instead of expiring relative to the first
# occurrence in the burst.
current[i] = replace(
existing,
id=toast.id,
created_at=toast.created_at,
timeout=toast.timeout,
count=existing.count + 1,
)
self.toasts.value = current
return
current.append(toast)
# Enforce queue limit: drop oldest non-errors first, then oldest errors
if len(current) > MAX_TOAST_QUEUE:
errors = [t for t in current if t.type == ToastType.ERROR]
non_errors = [t for t in current if t.type != ToastType.ERROR]
# Cap errors themselves so total never exceeds MAX_TOAST_QUEUE
errors = errors[-MAX_TOAST_QUEUE:]
keep_non_errors = max(0, MAX_TOAST_QUEUE - len(errors))
non_errors = non_errors[-keep_non_errors:] if keep_non_errors else []
current = errors + non_errors
self.toasts.value = current
[docs]
def remove_toast(self, toast_id: str) -> None:
"""Remove a toast by ID."""
with self._lock:
self.toasts.value = [t for t in self.toasts.value if t.id != toast_id]
[docs]
def add_task(self, task: TrackedTask) -> None:
"""Add a tracked task. Prunes oldest finished tasks beyond MAX_TASK_HISTORY."""
with self._lock:
current = [*self.tasks.value, task]
if len(current) > MAX_TASK_HISTORY:
# Keep running/pending tasks, prune oldest finished
active = [t for t in current if t.status.value in ("running", "pending")]
finished = [t for t in current if t.status.value not in ("running", "pending")]
finished = finished[-(MAX_TASK_HISTORY - len(active)) :]
current = active + finished
self.tasks.value = current
[docs]
def update_task(self, task_id: str, **changes) -> None:
"""Update a tracked task by ID. Unknown IDs are silently ignored."""
with self._lock:
self.tasks.value = [
replace(t, **changes) if t.id == task_id else t for t in self.tasks.value
]
[docs]
def remove_task(self, task_id: str) -> None:
"""Remove a tracked task by ID."""
with self._lock:
self.tasks.value = [t for t in self.tasks.value if t.id != task_id]
# --- Kernel-scoped bus registry (matches SessionManager pattern) ---
_buses: Dict[str, NotificationBus] = {}
_bus_refcounts: Dict[str, int] = {}
_registry_lock = threading.Lock()
def _get_kernel_id() -> str:
"""Get current Solara kernel ID (same approach as SessionManager)."""
return str(id(solara.server.kernel_context.get_current_context().kernel))
[docs]
def get_current_bus() -> Optional[NotificationBus]:
"""Get the NotificationBus for the current kernel, or None."""
try:
kernel_id = _get_kernel_id()
except Exception:
return None
with _registry_lock:
return _buses.get(kernel_id)
[docs]
def create_bus() -> NotificationBus:
"""Get or create a NotificationBus for the current kernel.
If a bus already exists for this kernel, reuse it and increment
the reference count. This prevents remounts or double-mounts
from invalidating active notifiers.
"""
kernel_id = _get_kernel_id()
with _registry_lock:
existing = _buses.get(kernel_id)
if existing is not None:
_bus_refcounts[kernel_id] = _bus_refcounts.get(kernel_id, 1) + 1
logger.debug(
f"Reusing NotificationBus for kernel {kernel_id} "
f"(refcount={_bus_refcounts[kernel_id]})"
)
return existing
bus = NotificationBus()
_buses[kernel_id] = bus
_bus_refcounts[kernel_id] = 1
logger.debug(f"Created NotificationBus for kernel {kernel_id}")
return bus
[docs]
def cleanup_bus() -> None:
"""Decrement refcount for the current kernel's bus; remove when it reaches 0."""
kernel_id = _get_kernel_id()
with _registry_lock:
count = _bus_refcounts.get(kernel_id, 0)
if count <= 1:
_buses.pop(kernel_id, None)
_bus_refcounts.pop(kernel_id, None)
logger.debug(f"Cleaned up NotificationBus for kernel {kernel_id}")
else:
_bus_refcounts[kernel_id] = count - 1
logger.debug(
f"Decremented NotificationBus refcount for kernel {kernel_id} "
f"(refcount={_bus_refcounts[kernel_id]})"
)