Source code for qilisdk.core.interpolator

# Copyright 2025 Qilimanjaro Quantum Tech
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import inspect
from bisect import bisect_right
from collections.abc import Callable
from copy import copy
from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping

import numpy as np
from typing_extensions import TypeAlias

from qilisdk.core.parameterizable import Parameterizable
from qilisdk.core.variables import LEQ, BaseVariable, Parameter, Term
from qilisdk.settings import get_settings
from qilisdk.yaml import yaml

if TYPE_CHECKING:
    from qilisdk.core.types import Number

_TIME_PARAMETER_NAME = "t"

# type aliases just to keep this short
[docs] PARAMETERIZED_NUMBER: TypeAlias = float | Parameter | Term
[docs] TimeDict = dict[PARAMETERIZED_NUMBER | tuple[float, float], PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER]]
[docs] class Interpolation(str, Enum):
[docs] STEP = "Step function interpolation between schedule points"
[docs] LINEAR = "linear interpolation between schedule points"
def _process_callable( function: Callable[[], PARAMETERIZED_NUMBER], current_time: Parameter, **kwargs: Any ) -> tuple[PARAMETERIZED_NUMBER, dict[str, Parameter]]: """ Evaluate a coefficient-producing callable and collect any parameters it exposes. Args: function (Callable[..., PARAMETERIZED_NUMBER]): Callable that returns a coefficient expression. current_time (Parameter): Time parameter to bind when evaluating the callable. **kwargs: Additional keyword arguments passed to the callable. Returns: tuple[PARAMETERIZED_NUMBER, dict[str, Parameter]]: Evaluated expression and parameters discovered. Raises: ValueError: If the callable uses variables other than time or ``Parameter`` instances. """ # Define variables parameters: dict[str, Parameter] = {} # get callable parameters c_params = inspect.signature(function).parameters EMPTY = inspect.Parameter.empty # process callable parameters for param_name, param_info in c_params.items(): # parameter type extraction if param_info.annotation is not EMPTY and param_info.annotation is Parameter: if param_info.default is not EMPTY: parameters[param_info.default.label] = copy(param_info.default) else: value = kwargs.get(param_name, 0) if isinstance(value, (float, int)): parameters[param_name] = Parameter(param_name, value) # needed since it could be that the kwargs don't contain param_name and we don't have a default # and in that case the below function() call would fail kwargs[param_name] = parameters[param_name] elif isinstance(value, Parameter): parameters[value.label] = value if _TIME_PARAMETER_NAME in c_params: kwargs[_TIME_PARAMETER_NAME] = current_time term = function(**kwargs) if isinstance(term, Term) and not all( (isinstance(v, Parameter) or v.label == _TIME_PARAMETER_NAME) for v in term.variables() ): raise ValueError("function contains variables that are not time. Only Parameters are allowed.") if isinstance(term, BaseVariable) and not (isinstance(term, Parameter) or term.label == _TIME_PARAMETER_NAME): raise ValueError("function contains variables that are not time. Only Parameters are allowed.") return term, parameters @yaml.register_class
[docs] class Interpolator(Parameterizable): """Mapping of time points to coefficients with optional interpolation.""" def __init__( self, time_dict: TimeDict, interpolation: Interpolation = Interpolation.LINEAR, nsamples: int = 100, ) -> None: """Initialize an interpolator over discrete points or intervals. Args: time_dict (TimeDict): Mapping from time points or intervals to coefficients or callables. interpolation (Interpolation): Interpolation rule between provided points (``LINEAR`` or ``STEP``). nsamples (int): Number of samples used to expand interval definitions. Raises: ValueError: If the time intervals contain a number of points different than 2. """ super(Interpolator, self).__init__() self._interpolation = interpolation self._time_dict: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER] = {} self._current_time = Parameter("t", 0) self._total_time: float | None = None
[docs] self.iter_time_step = 0
self._cached = False self._cached_time: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER | Number] = {} self._tlist: list[PARAMETERIZED_NUMBER] | None = None self._fixed_tlist: list[float] | None = None self._max_time: PARAMETERIZED_NUMBER | None = None self._time_scale_cache: float | None = None fixed_times: list[PARAMETERIZED_NUMBER | tuple[float, float]] = sorted( time_dict.keys(), key=lambda t: self._get_value( min(t, key=self._get_value) # ty:ignore[no-matching-overload] if isinstance(t, tuple) else self._get_value(t) ), ) for i in range(len(fixed_times) - 1): ti: PARAMETERIZED_NUMBER | tuple[float, float] = fixed_times[i] tj: PARAMETERIZED_NUMBER | tuple[float, float] = fixed_times[i + 1] t0 = ( self._get_value(ti) if not isinstance(ti, tuple) else self._get_value(ti[1]) # ty:ignore[invalid-argument-type] ) t1 = ( self._get_value(tj) if not isinstance(tj, tuple) else self._get_value(tj[0]) # ty:ignore[invalid-argument-type] ) if abs(t0 - t1) < get_settings().atol: raise ValueError(f"The time point {t0} is defined twice.") if t0 > t1: raise ValueError(f"Can't provide a point that intersects an interval (issue in: {ti} and {tj}).") for time, coefficient in time_dict.items(): if isinstance(time, tuple): if len(time) != 2: # noqa: PLR2004 raise ValueError( f"time intervals need to be defined by two points, but this interval was provided: {time}" ) self.add_time_point(time[0], coefficient) # ty:ignore[invalid-argument-type] self.add_time_point(time[1], coefficient) # ty:ignore[invalid-argument-type] else: self.add_time_point(time, coefficient) self._tlist = self._generate_tlist() time_insertion_list = sorted( [k for item in time_dict for k in (item if isinstance(item, tuple) else (item,))], key=self._get_value, ) # ty:ignore[no-matching-overload] l = len(time_insertion_list) for i in range(l): t = time_insertion_list[i] if isinstance(t, (Parameter, Term)): if i > 0: term = LEQ(time_insertion_list[i - 1], t) if term not in self._parameter_constraints: self._parameter_constraints.append(term) if i < l - 1: term = LEQ(t, time_insertion_list[i + 1]) if term not in self._parameter_constraints: self._parameter_constraints.append(term) def _generate_tlist(self) -> list[PARAMETERIZED_NUMBER]: """ Generate a sorted list of the registered time keys. Returns: list[PARAMETERIZED_NUMBER]: Sorted time indices based on their evaluated value. """ return sorted((self._time_dict.keys()), key=self._get_value) # ty:ignore[invalid-return-type] @property def _time_scale(self) -> float: """ Return the scaling factor applied to time points when ``max_time`` is set. This handles the caching, so self._time_scale_cache should never be accessed directly. Returns: float: Scaling factor. """ # If we don't have a cached value, compute it if self._time_scale_cache is None: # Make sure we have a max time set if self._max_time is not None: # Generate the tlist if we haven't already if self._tlist is None: self._tlist = self._generate_tlist() # Find the maximum time value in the tlist max_t = 0.0 for t in self._tlist: max_t = max(max_t, self._get_value(t)) max_t = max_t if abs(max_t) > get_settings().atol else 1.0 # Compute the time scale self._time_scale_cache = self._get_value(self._max_time) / max_t # Otherwise set the scale to 1.0 else: self._time_scale_cache = 1.0 return self._time_scale_cache @property
[docs] def tlist(self) -> list[PARAMETERIZED_NUMBER]: """ Return the (possibly rescaled) list of time points used for interpolation. Returns: list[PARAMETERIZED_NUMBER]: Interpolation time points, rescaled if ``max_time`` is set. """ if self._tlist is None: self._tlist = self._generate_tlist() if self._max_time is not None: return [t * self._time_scale for t in self._tlist] return self._tlist
@property
[docs] def fixed_tlist(self) -> list[float]: """ Return the list of time points as plain floats. Returns: list[float]: Evaluated time points. """ if self._fixed_tlist: return self._fixed_tlist self._fixed_tlist = [self._get_value(k) for k in self.tlist] return self._fixed_tlist
@property
[docs] def total_time(self) -> float: """ Return the maximum time among all points. Returns: float: Largest time value in ``fixed_tlist``. """ if not self._total_time: self._total_time = max(self.fixed_tlist) return self._total_time
[docs] def items(self) -> list[tuple[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]]: """ Return (time, coefficient) pairs, rescaling time if a max is set. Returns: list[tuple[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]]: Time and coefficient pairs. """ if self._max_time is not None: return [(k * self._time_scale, v) for k, v in self._time_dict.items()] return list(self._time_dict.items())
[docs] def fixed_items(self) -> list[tuple[float, float]]: """ Return (time, coefficient) pairs evaluated to floats. Returns: list[tuple[float, float]]: Evaluated time and coefficient pairs. """ return [(t, self._get_value(self[t], t)) for t in self.fixed_tlist]
@property
[docs] def coefficients(self) -> list[PARAMETERIZED_NUMBER]: """ Return coefficients in the order of ``tlist`` without evaluation. Returns: list[PARAMETERIZED_NUMBER]: Coefficients aligned with ``tlist``. """ return list(self._time_dict.values())
@property
[docs] def coefficients_dict(self) -> dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]: """ Return a shallow copy of the internal time-to-coefficient mapping. Returns: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]: Mapping from time to coefficient expressions. """ return copy(self._time_dict)
@property
[docs] def fixed_coefficients(self) -> list[float]: """ Return coefficients evaluated to floats in the order of ``fixed_tlist``. Returns: list[float]: Evaluated coefficients. """ return [self._get_value(self[t]) for t in self.fixed_tlist]
[docs] def set_max_time(self, max_time: PARAMETERIZED_NUMBER) -> None: """ Rescale all time points to a new maximum duration while keeping relative spacing. Args: max_time (PARAMETERIZED_NUMBER): Desired maximum time after rescaling. Raises: ValueError: If the max time is set to zero. """ if abs(self._get_value(max_time)) < get_settings().atol: raise ValueError("Cannot set the max time to zero.") self.delete_cache() self._max_time = max_time
[docs] def delete_cache(self) -> None: """Clear cached evaluations and derived lists.""" self._cached = False self._total_time = None self._cached_time = {} self._tlist = None self._fixed_tlist = None self._time_scale_cache = None
def _get_value(self, value: PARAMETERIZED_NUMBER | complex, t: float | None = None) -> float: """ Evaluate a numeric, parameter, or term into a concrete float. Args: value (PARAMETERIZED_NUMBER | complex): Value or expression to evaluate. t (float | None): Time value to bind when evaluating time-dependent expressions. Returns: float: Evaluated numeric value. Raises: ValueError: If evaluating a time parameter without a provided time, or an unsupported type is used. """ if isinstance(value, (int, float)): return value if isinstance(value, complex): return value.real if isinstance(value, Parameter): if value.label == _TIME_PARAMETER_NAME: if t is None: raise ValueError("Can't evaluate Parameter because time is not provided.") value.set_value(t) return float(value.evaluate()) if isinstance(value, Term): ctx: Mapping[BaseVariable, list[int] | int | float] = {self._current_time: t} if t is not None else {} aux = value.evaluate(ctx) return aux.real if isinstance(aux, complex) else float(aux) raise ValueError(f"Invalid value of type {type(value)} is being evaluated.") def _extract_parameters(self, element: PARAMETERIZED_NUMBER) -> None: """ Collect parameters from an element, ensuring only allowed variables are used. Args: element (PARAMETERIZED_NUMBER): Element to inspect for parameters. Raises: ValueError: If the element contains variables that are not parameters. """ if isinstance(element, Parameter) and element.label != _TIME_PARAMETER_NAME: self._add_parameter(element.label, element) elif isinstance(element, Term): if not element.is_parameterized_term(): raise ValueError( f"Tlist can only contain parameters and no variables, but the term {element} contains objects other than parameters." ) for p in element.variables(): if isinstance(p, Parameter) and p.label != _TIME_PARAMETER_NAME: self._add_parameter(p.label, p)
[docs] def add_time_point( self, time: PARAMETERIZED_NUMBER, coefficient: PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER], ) -> None: """ Add or update a coefficient associated with a time point, processing callables if needed. Args: time (PARAMETERIZED_NUMBER): Time point for the coefficient. coefficient (PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER]): Coefficient value or callable. Raises: ValueError: If the coefficient type is unsupported or the callable uses invalid variables. """ self._extract_parameters(time) coeff = coefficient if callable(coeff): self._current_time.set_value(self._get_value(time)) coeff, _params = _process_callable(coeff, self._current_time) # ty:ignore[invalid-argument-type] self._extract_parameters(coeff) if len(_params) > 0: self._update_parameters(_params) elif isinstance(coeff, (int, float, Parameter, Term)): self._extract_parameters(coeff) else: raise ValueError( "Coefficient must be a number, Parameter, Term, or callable that returns one of these types." ) self._time_dict[time / self._time_scale] = coeff self.delete_cache()
[docs] def set_parameter_values( self, values: list[float], where: Callable[[Parameter], bool] | None = None, ) -> None: """ Assign parameter values by position and clear caches. Args: values (list[float]): New values ordered consistently with ``get_parameter_names()``. where (Callable[[Parameter], bool] | None): Optional predicate selecting parameters to update. """ self.delete_cache() super().set_parameter_values(values=values, where=where)
[docs] def set_parameters(self, parameters: dict[str, int | float]) -> None: """ Assign parameter values by name and clear caches. Args: parameters (dict[str, int | float]): Mapping from parameter labels to numeric values. """ self.delete_cache() super().set_parameters(parameters)
[docs] def set_parameter_bounds(self, ranges: dict[str, tuple[float, float]]) -> None: """ Update parameter bounds and clear caches. Args: ranges (dict[str, tuple[float, float]]): Bounds keyed by parameter label. """ self.delete_cache() super().set_parameter_bounds(ranges)
[docs] def get_coefficient(self, time_step: float) -> float: """ Return the numeric coefficient at a given time, applying interpolation and scaling. Args: time_step (float): Time at which to evaluate the coefficient. Returns: float: Evaluated coefficient. """ time_step = time_step.item() if isinstance(time_step, np.generic) else self._get_value(time_step) # ty:ignore[invalid-assignment] val = self.get_coefficient_expression(time_step=time_step) if self._max_time is not None: if self._tlist is None: self._tlist = self._generate_tlist() time_step /= self._time_scale return self._get_value(val, time_step)
[docs] def get_coefficient_expression(self, time_step: float) -> Number | Term | Parameter: """ Return the raw expression for the coefficient at ``time_step`` without final evaluation. Args: time_step (float): Time at which to retrieve the coefficient expression. Returns: Number | Term | Parameter: Coefficient expression before numeric evaluation. Raises: ValueError: If the interpolation mode is unsupported or evaluation fails. """ time_step = time_step.item() if isinstance(time_step, np.generic) else self._get_value(time_step) # ty:ignore[invalid-assignment] # generate the tlist self._tlist = self._generate_tlist() if time_step in self.fixed_tlist: indx = self.fixed_tlist.index(time_step) return self._time_dict[self._tlist[indx]] if time_step in self._cached_time: return self._cached_time[time_step] if self._max_time is not None: time_step /= self._time_scale factor = self._time_scale_cache or 1.0 result = None if self._interpolation is Interpolation.STEP: result = self._get_coefficient_expression_step(time_step) if self._interpolation is Interpolation.LINEAR: result = self._get_coefficient_expression_linear(time_step) if result is None: raise ValueError(f"Interpolation type {self._interpolation.value} is not supported.") self._cached_time[time_step * factor] = result return result
def _get_coefficient_expression_step(self, time_step: float) -> Number | Term | Parameter: """ Return the step-interpolated coefficient expression for ``time_step``. Args: time_step (float): Time at which to retrieve the coefficient. Returns: Number | Term | Parameter: Coefficient expression for the previous time point. """ self._tlist = self._generate_tlist() prev_indx = bisect_right(self._tlist, time_step, key=self._get_value) - 1 prev_indx = -1 if prev_indx >= len(self._tlist) else prev_indx prev_time_step = self._tlist[prev_indx] return self._time_dict[prev_time_step] def _get_coefficient_expression_linear(self, time_step: float) -> Number | Term | Parameter: """ Return the linearly interpolated coefficient expression for ``time_step``. Args: time_step (float): Time at which to interpolate. Returns: Number | Term | Parameter: Coefficient expression interpolated between neighbor points. Raises: # noqa: DOC502 ValueError: If two points share the same time or an unexpected interpolation state is reached. """ self._tlist = self._generate_tlist() insert_pos = bisect_right(self._tlist, time_step, key=self._get_value) def _linear_value( t0: PARAMETERIZED_NUMBER, v0: PARAMETERIZED_NUMBER, t1: PARAMETERIZED_NUMBER, v1: PARAMETERIZED_NUMBER ) -> PARAMETERIZED_NUMBER: t0_val = self._get_value(t0) t1_val = self._get_value(t1) if t0_val == t1_val: raise ValueError( f"Ambiguous evaluation: The same time step {t0_val} has two different coefficient assignation ({v0} and {v1})." ) alpha: float = (time_step - t0_val) / (t1_val - t0_val) next_is_term = isinstance(v1, (Term, Parameter)) prev_is_term = isinstance(v0, (Term, Parameter)) if next_is_term and prev_is_term and v1 != v0: v1 = self._get_value(v1, t1_val) v0 = self._get_value(v0, t0_val) elif next_is_term and not prev_is_term: v1 = self._get_value(v1, t1_val) elif prev_is_term and not next_is_term: v0 = self._get_value(v0, t0_val) return v1 * alpha + v0 * (1 - alpha) # this is done in order to prevent setting the indices to none and causing type errors has_prev = insert_pos >= 1 has_next = insert_pos < len(self._tlist) prev_idx = self._tlist[insert_pos - 1] if has_prev else 0 prev_expr = self._time_dict[prev_idx] if has_prev else 0 next_idx = self._tlist[insert_pos] if has_next else 0 next_expr = self._time_dict[next_idx] if has_next else 0 if not has_prev and has_next: if len(self._tlist) == 1: return next_expr first_idx = self._tlist[0] second_idx = self._tlist[1] return _linear_value(first_idx, self._time_dict[first_idx], second_idx, self._time_dict[second_idx]) if not has_next and has_prev: if len(self._tlist) == 1: return prev_expr last_idx = self._tlist[-1] penultimate_idx = self._tlist[-2] return _linear_value(penultimate_idx, self._time_dict[penultimate_idx], last_idx, self._time_dict[last_idx]) if not has_next and not has_prev: return 0 # this can only be reached if both has_next and has_prev are true, meaning they will not be the default 0s return _linear_value(prev_idx, prev_expr, next_idx, next_expr) def __getitem__(self, time_step: float) -> float: """ Enable bracket access to coefficients: ``interp[t]``. Args: time_step (float): Time at which to evaluate the coefficient. Returns: float: Evaluated coefficient. """ return self.get_coefficient(time_step) def __len__(self) -> int: """ Return the number of defined time points. Returns: int: Number of time points stored. """ return len(self.tlist) def __iter__(self) -> "Interpolator": """ Return an iterator over evaluated coefficients in time order. Returns: Interpolator: Iterator over the instance. """ self.iter_time_step = 0 return self def __next__(self) -> float: """ Iterate over evaluated coefficients across the schedule. Returns: float: Next coefficient in time order. Raises: StopIteration: When all coefficients have been iterated over. """ if self.iter_time_step < self.__len__(): result = self[self.fixed_tlist[self.iter_time_step]] self.iter_time_step += 1 return result raise StopIteration