Source code for cuperiod.methods.base

"""The method interface and registry.

Every periodogram is a :class:`PeriodogramMethod` subclass registered under an
uppercase name (``"GLS"``, ``"BLS"``, ...). The single-run, batch, and CLI layers only
ever go through this interface — ``default_grid`` to pick a trial grid, ``power`` to
compute a :class:`~cuperiod.core.result.Periodogram`, and ``make_engine`` to build a
reusable plan/kernel for batch throughput — so adding a method requires no changes to
the orchestration.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from typing import ClassVar

from pydantic_settings import BaseSettings

from cuperiod.core.backend import available_backends, cuda_available
from cuperiod.core.columns import Domain
from cuperiod.core.errors import BackendUnavailableError, UnknownMethodError
from cuperiod.core.grid import GridSpec
from cuperiod.core.lightcurve import LightCurve, MultiBandLightCurve
from cuperiod.core.result import ObjectiveSense, Periodogram


class PeriodogramMethod(ABC):
    """Abstract base for a periodogram method.

    Subclasses set the class attributes and implement :meth:`default_grid` and
    :meth:`power`. Multi-band-capable methods also set ``supports_multiband`` and
    implement :meth:`multiband_power`.
    """

    #: Registry name (canonical uppercase, e.g. ``"GLS"``).
    name: ClassVar[str]
    #: Whether peaks are maxima or minima of the statistic.
    objective_sense: ClassVar[ObjectiveSense]
    #: Whether the method has a multi-band variant.
    supports_multiband: ClassVar[bool] = False
    #: The domain inputs are converted to before computing. ``None`` means the method
    #: is domain-agnostic (magnitude or flux used as supplied); ``Domain.FLUX`` forces
    #: a flux conversion (box/transit methods, where a dip is the signal).
    natural_domain: ClassVar[Domain | None] = None
    #: The settings model class for this method.
    settings_cls: ClassVar[type[BaseSettings]]
    #: Best CPU backend name.
    cpu_backend: ClassVar[str]
    #: GPU backend name, or ``None`` if the method has no GPU path yet.
    gpu_backend: ClassVar[str | None] = None
    #: Every backend this method can run.
    all_backends: ClassVar[tuple[str, ...]]

    # -- settings ----------------------------------------------------------------
    def coerce_settings(self, settings: BaseSettings | None) -> BaseSettings:
        """Return a settings instance of this method's type (default if ``None``)."""
        if settings is None:
            return self.settings_cls()
        if not isinstance(settings, self.settings_cls):
            raise TypeError(
                f"{self.name} expects {self.settings_cls.__name__}, "
                f"got {type(settings).__name__}"
            )
        return settings

    # -- backend resolution ------------------------------------------------------
    def resolve_backend(self, requested: str) -> str:
        """Resolve ``auto``/``cpu``/``gpu``/concrete to a runnable backend name.

        Parameters
        ----------
        requested : str
            ``"auto"`` (GPU when present, else CPU), ``"cpu"``, ``"gpu"``, or a
            concrete backend name belonging to this method.

        Returns
        -------
        str
            A concrete, available backend name.

        Raises
        ------
        BackendUnavailableError
            If GPU was requested but is unavailable, or a named backend is unknown to
            this method or not importable here.
        """
        available = available_backends()
        if requested == "auto":
            if self.gpu_backend is not None and cuda_available():
                return self.gpu_backend
            return self.cpu_backend
        if requested == "cpu":
            return self.cpu_backend
        if requested == "gpu":
            if self.gpu_backend is not None and cuda_available():
                return self.gpu_backend
            raise BackendUnavailableError(
                f"{self.name}: GPU backend unavailable (need the [gpu] extra and a "
                "CUDA device)"
            )
        if requested not in self.all_backends:
            raise BackendUnavailableError(
                f"{self.name}: unknown backend {requested!r}; "
                f"choose from {self.all_backends}"
            )
        if requested == self.gpu_backend and not cuda_available():
            raise BackendUnavailableError(
                f"{self.name}: backend {requested!r} needs a CUDA device"
            )
        if requested in {"finufft", "cufinufft", "numba", "astropy"} and (
            requested not in available
        ):
            raise BackendUnavailableError(
                f"{self.name}: backend {requested!r} is not installed"
            )
        return requested

    def is_gpu_backend(self, backend: str) -> bool:
        """Whether ``backend`` is this method's GPU backend."""
        return backend == self.gpu_backend

    # -- compute -----------------------------------------------------------------
    @abstractmethod
    def default_grid(self, lc: LightCurve, settings: BaseSettings) -> GridSpec:
        """Build the default trial grid for ``lc`` under ``settings``."""

    @abstractmethod
    def power(
        self,
        grid: GridSpec,
        lc: LightCurve,
        settings: BaseSettings,
        backend: str,
        engine: object | None = None,
    ) -> Periodogram:
        """Compute the periodogram on ``grid`` using ``backend``."""

    def multiband_power(
        self,
        grid: GridSpec,
        mblc: MultiBandLightCurve,
        settings: BaseSettings,
        backend: str,
    ) -> Periodogram:
        """Compute a multi-band periodogram. Override in multi-band methods."""
        raise NotImplementedError(f"{self.name} does not support multi-band input")

    def make_engine(self, backend: str, settings: BaseSettings) -> object | None:
        """Build a reusable plan/kernel engine for batch runs (``None`` for CPU)."""
        return None

    def estimate_device_bytes(self, n_points: int) -> int:
        """Estimate one worker's GPU footprint (bytes) for an ``n_points`` grid."""
        return 64 * 1024**2 + n_points * 16 * 6


# --- registry ----------------------------------------------------------------

_REGISTRY: dict[str, PeriodogramMethod] = {}


[docs] @dataclass(frozen=True) class MethodInfo: """Introspection record for a registered method.""" name: str objective_sense: str supports_multiband: bool natural_domain: str available_backends: tuple[str, ...] all_backends: tuple[str, ...]
def _normalize_name(name: str) -> str: """Canonical registry key: uppercase with non-alphanumerics removed. Lets ``"String-Length"``, ``"StringLength"`` and ``"STRINGLENGTH"`` (and ``"gls"`` / ``"GLS"``) all resolve to the same method. """ return "".join(ch for ch in name if ch.isalnum()).upper() def register(method: PeriodogramMethod) -> PeriodogramMethod: """Register ``method`` under its normalized name. Returns it (for decoration).""" _REGISTRY[_normalize_name(method.name)] = method return method
[docs] def get_method(name: str) -> PeriodogramMethod: """Look up a registered method by name. Matching ignores case and any non-alphanumeric characters, so ``"String-Length"``, ``"StringLength"`` and ``"STRINGLENGTH"`` all resolve to the same method. Raises ------ UnknownMethodError If no method is registered under ``name``. """ try: return _REGISTRY[_normalize_name(name)] except KeyError: raise UnknownMethodError( f"unknown method {name!r}; registered: {sorted(_REGISTRY)}" ) from None
[docs] def method_names() -> list[str]: """Sorted names of all registered methods.""" return sorted(_REGISTRY)
[docs] def list_methods() -> list[MethodInfo]: """Introspection records for every registered method.""" available = available_backends() out: list[MethodInfo] = [] for name in sorted(_REGISTRY): m = _REGISTRY[name] out.append( MethodInfo( name=name, objective_sense=m.objective_sense, supports_multiband=m.supports_multiband, natural_domain=str(m.natural_domain) if m.natural_domain else "any", available_backends=tuple( b for b in m.all_backends if b in available or b == "numpy" ), all_backends=m.all_backends, ) ) return out
def registered_methods() -> Mapping[str, PeriodogramMethod]: """The live registry (read-only view).""" return dict(_REGISTRY) __all__ = [ "MethodInfo", "PeriodogramMethod", "get_method", "list_methods", "method_names", "register", "registered_methods", ]