# Copyright 2025 Qilimanjaro Quantum Tech
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
from bisect import bisect_right
from collections.abc import Callable
from copy import copy
from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping
import numpy as np
from typing_extensions import TypeAlias
from qilisdk.core.parameterizable import Parameterizable
from qilisdk.core.variables import LEQ, BaseVariable, Parameter, Term
from qilisdk.settings import get_settings
from qilisdk.yaml import yaml
if TYPE_CHECKING:
from qilisdk.core.types import Number
_TIME_PARAMETER_NAME = "t"
# type aliases just to keep this short
[docs]
PARAMETERIZED_NUMBER: TypeAlias = float | Parameter | Term
[docs]
TimeDict = dict[PARAMETERIZED_NUMBER | tuple[float, float], PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER]]
[docs]
class Interpolation(str, Enum):
[docs]
STEP = "Step function interpolation between schedule points"
[docs]
LINEAR = "linear interpolation between schedule points"
def _process_callable(
function: Callable[[], PARAMETERIZED_NUMBER], current_time: Parameter, **kwargs: Any
) -> tuple[PARAMETERIZED_NUMBER, dict[str, Parameter]]:
"""
Evaluate a coefficient-producing callable and collect any parameters it exposes.
Args:
function (Callable[..., PARAMETERIZED_NUMBER]): Callable that returns a coefficient expression.
current_time (Parameter): Time parameter to bind when evaluating the callable.
**kwargs: Additional keyword arguments passed to the callable.
Returns:
tuple[PARAMETERIZED_NUMBER, dict[str, Parameter]]: Evaluated expression and parameters discovered.
Raises:
ValueError: If the callable uses variables other than time or ``Parameter`` instances.
"""
# Define variables
parameters: dict[str, Parameter] = {}
# get callable parameters
c_params = inspect.signature(function).parameters
EMPTY = inspect.Parameter.empty
# process callable parameters
for param_name, param_info in c_params.items():
# parameter type extraction
if param_info.annotation is not EMPTY and param_info.annotation is Parameter:
if param_info.default is not EMPTY:
parameters[param_info.default.label] = copy(param_info.default)
else:
value = kwargs.get(param_name, 0)
if isinstance(value, (float, int)):
parameters[param_name] = Parameter(param_name, value)
# needed since it could be that the kwargs don't contain param_name and we don't have a default
# and in that case the below function() call would fail
kwargs[param_name] = parameters[param_name]
elif isinstance(value, Parameter):
parameters[value.label] = value
if _TIME_PARAMETER_NAME in c_params:
kwargs[_TIME_PARAMETER_NAME] = current_time
term = function(**kwargs)
if isinstance(term, Term) and not all(
(isinstance(v, Parameter) or v.label == _TIME_PARAMETER_NAME) for v in term.variables()
):
raise ValueError("function contains variables that are not time. Only Parameters are allowed.")
if isinstance(term, BaseVariable) and not (isinstance(term, Parameter) or term.label == _TIME_PARAMETER_NAME):
raise ValueError("function contains variables that are not time. Only Parameters are allowed.")
return term, parameters
@yaml.register_class
[docs]
class Interpolator(Parameterizable):
"""Mapping of time points to coefficients with optional interpolation."""
def __init__(
self,
time_dict: TimeDict,
interpolation: Interpolation = Interpolation.LINEAR,
nsamples: int = 100,
) -> None:
"""Initialize an interpolator over discrete points or intervals.
Args:
time_dict (TimeDict): Mapping from time points or intervals to coefficients or callables.
interpolation (Interpolation): Interpolation rule between provided points (``LINEAR`` or ``STEP``).
nsamples (int): Number of samples used to expand interval definitions.
Raises:
ValueError: If the time intervals contain a number of points different than 2.
"""
super(Interpolator, self).__init__()
self._interpolation = interpolation
self._time_dict: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER] = {}
self._current_time = Parameter("t", 0)
self._total_time: float | None = None
[docs]
self.iter_time_step = 0
self._cached = False
self._cached_time: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER | Number] = {}
self._tlist: list[PARAMETERIZED_NUMBER] | None = None
self._fixed_tlist: list[float] | None = None
self._max_time: PARAMETERIZED_NUMBER | None = None
self._time_scale_cache: float | None = None
fixed_times: list[PARAMETERIZED_NUMBER | tuple[float, float]] = sorted(
time_dict.keys(),
key=lambda t: self._get_value(
min(t, key=self._get_value) # ty:ignore[no-matching-overload]
if isinstance(t, tuple)
else self._get_value(t)
),
)
for i in range(len(fixed_times) - 1):
ti: PARAMETERIZED_NUMBER | tuple[float, float] = fixed_times[i]
tj: PARAMETERIZED_NUMBER | tuple[float, float] = fixed_times[i + 1]
t0 = (
self._get_value(ti) if not isinstance(ti, tuple) else self._get_value(ti[1]) # ty:ignore[invalid-argument-type]
)
t1 = (
self._get_value(tj) if not isinstance(tj, tuple) else self._get_value(tj[0]) # ty:ignore[invalid-argument-type]
)
if abs(t0 - t1) < get_settings().atol:
raise ValueError(f"The time point {t0} is defined twice.")
if t0 > t1:
raise ValueError(f"Can't provide a point that intersects an interval (issue in: {ti} and {tj}).")
for time, coefficient in time_dict.items():
if isinstance(time, tuple):
if len(time) != 2: # noqa: PLR2004
raise ValueError(
f"time intervals need to be defined by two points, but this interval was provided: {time}"
)
self.add_time_point(time[0], coefficient) # ty:ignore[invalid-argument-type]
self.add_time_point(time[1], coefficient) # ty:ignore[invalid-argument-type]
else:
self.add_time_point(time, coefficient)
self._tlist = self._generate_tlist()
time_insertion_list = sorted(
[k for item in time_dict for k in (item if isinstance(item, tuple) else (item,))],
key=self._get_value,
) # ty:ignore[no-matching-overload]
l = len(time_insertion_list)
for i in range(l):
t = time_insertion_list[i]
if isinstance(t, (Parameter, Term)):
if i > 0:
term = LEQ(time_insertion_list[i - 1], t)
if term not in self._parameter_constraints:
self._parameter_constraints.append(term)
if i < l - 1:
term = LEQ(t, time_insertion_list[i + 1])
if term not in self._parameter_constraints:
self._parameter_constraints.append(term)
def _generate_tlist(self) -> list[PARAMETERIZED_NUMBER]:
"""
Generate a sorted list of the registered time keys.
Returns:
list[PARAMETERIZED_NUMBER]: Sorted time indices based on their evaluated value.
"""
return sorted((self._time_dict.keys()), key=self._get_value) # ty:ignore[invalid-return-type]
@property
def _time_scale(self) -> float:
"""
Return the scaling factor applied to time points when ``max_time`` is set.
This handles the caching, so self._time_scale_cache should never be accessed directly.
Returns:
float: Scaling factor.
"""
# If we don't have a cached value, compute it
if self._time_scale_cache is None:
# Make sure we have a max time set
if self._max_time is not None:
# Generate the tlist if we haven't already
if self._tlist is None:
self._tlist = self._generate_tlist()
# Find the maximum time value in the tlist
max_t = 0.0
for t in self._tlist:
max_t = max(max_t, self._get_value(t))
max_t = max_t if abs(max_t) > get_settings().atol else 1.0
# Compute the time scale
self._time_scale_cache = self._get_value(self._max_time) / max_t
# Otherwise set the scale to 1.0
else:
self._time_scale_cache = 1.0
return self._time_scale_cache
@property
[docs]
def tlist(self) -> list[PARAMETERIZED_NUMBER]:
"""
Return the (possibly rescaled) list of time points used for interpolation.
Returns:
list[PARAMETERIZED_NUMBER]: Interpolation time points, rescaled if ``max_time`` is set.
"""
if self._tlist is None:
self._tlist = self._generate_tlist()
if self._max_time is not None:
return [t * self._time_scale for t in self._tlist]
return self._tlist
@property
[docs]
def fixed_tlist(self) -> list[float]:
"""
Return the list of time points as plain floats.
Returns:
list[float]: Evaluated time points.
"""
if self._fixed_tlist:
return self._fixed_tlist
self._fixed_tlist = [self._get_value(k) for k in self.tlist]
return self._fixed_tlist
@property
[docs]
def total_time(self) -> float:
"""
Return the maximum time among all points.
Returns:
float: Largest time value in ``fixed_tlist``.
"""
if not self._total_time:
self._total_time = max(self.fixed_tlist)
return self._total_time
[docs]
def items(self) -> list[tuple[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]]:
"""
Return (time, coefficient) pairs, rescaling time if a max is set.
Returns:
list[tuple[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]]: Time and coefficient pairs.
"""
if self._max_time is not None:
return [(k * self._time_scale, v) for k, v in self._time_dict.items()]
return list(self._time_dict.items())
[docs]
def fixed_items(self) -> list[tuple[float, float]]:
"""
Return (time, coefficient) pairs evaluated to floats.
Returns:
list[tuple[float, float]]: Evaluated time and coefficient pairs.
"""
return [(t, self._get_value(self[t], t)) for t in self.fixed_tlist]
@property
[docs]
def coefficients(self) -> list[PARAMETERIZED_NUMBER]:
"""
Return coefficients in the order of ``tlist`` without evaluation.
Returns:
list[PARAMETERIZED_NUMBER]: Coefficients aligned with ``tlist``.
"""
return list(self._time_dict.values())
@property
[docs]
def coefficients_dict(self) -> dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]:
"""
Return a shallow copy of the internal time-to-coefficient mapping.
Returns:
dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]: Mapping from time to coefficient expressions.
"""
return copy(self._time_dict)
@property
[docs]
def fixed_coefficients(self) -> list[float]:
"""
Return coefficients evaluated to floats in the order of ``fixed_tlist``.
Returns:
list[float]: Evaluated coefficients.
"""
return [self._get_value(self[t]) for t in self.fixed_tlist]
[docs]
def set_max_time(self, max_time: PARAMETERIZED_NUMBER) -> None:
"""
Rescale all time points to a new maximum duration while keeping relative spacing.
Args:
max_time (PARAMETERIZED_NUMBER): Desired maximum time after rescaling.
Raises:
ValueError: If the max time is set to zero.
"""
if abs(self._get_value(max_time)) < get_settings().atol:
raise ValueError("Cannot set the max time to zero.")
self.delete_cache()
self._max_time = max_time
[docs]
def delete_cache(self) -> None:
"""Clear cached evaluations and derived lists."""
self._cached = False
self._total_time = None
self._cached_time = {}
self._tlist = None
self._fixed_tlist = None
self._time_scale_cache = None
def _get_value(self, value: PARAMETERIZED_NUMBER | complex, t: float | None = None) -> float:
"""
Evaluate a numeric, parameter, or term into a concrete float.
Args:
value (PARAMETERIZED_NUMBER | complex): Value or expression to evaluate.
t (float | None): Time value to bind when evaluating time-dependent expressions.
Returns:
float: Evaluated numeric value.
Raises:
ValueError: If evaluating a time parameter without a provided time, or an unsupported type is used.
"""
if isinstance(value, (int, float)):
return value
if isinstance(value, complex):
return value.real
if isinstance(value, Parameter):
if value.label == _TIME_PARAMETER_NAME:
if t is None:
raise ValueError("Can't evaluate Parameter because time is not provided.")
value.set_value(t)
return float(value.evaluate())
if isinstance(value, Term):
ctx: Mapping[BaseVariable, list[int] | int | float] = {self._current_time: t} if t is not None else {}
aux = value.evaluate(ctx)
return aux.real if isinstance(aux, complex) else float(aux)
raise ValueError(f"Invalid value of type {type(value)} is being evaluated.")
def _extract_parameters(self, element: PARAMETERIZED_NUMBER) -> None:
"""
Collect parameters from an element, ensuring only allowed variables are used.
Args:
element (PARAMETERIZED_NUMBER): Element to inspect for parameters.
Raises:
ValueError: If the element contains variables that are not parameters.
"""
if isinstance(element, Parameter) and element.label != _TIME_PARAMETER_NAME:
self._add_parameter(element.label, element)
elif isinstance(element, Term):
if not element.is_parameterized_term():
raise ValueError(
f"Tlist can only contain parameters and no variables, but the term {element} contains objects other than parameters."
)
for p in element.variables():
if isinstance(p, Parameter) and p.label != _TIME_PARAMETER_NAME:
self._add_parameter(p.label, p)
[docs]
def add_time_point(
self,
time: PARAMETERIZED_NUMBER,
coefficient: PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER],
) -> None:
"""
Add or update a coefficient associated with a time point, processing callables if needed.
Args:
time (PARAMETERIZED_NUMBER): Time point for the coefficient.
coefficient (PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER]): Coefficient value or callable.
Raises:
ValueError: If the coefficient type is unsupported or the callable uses invalid variables.
"""
self._extract_parameters(time)
coeff = coefficient
if callable(coeff):
self._current_time.set_value(self._get_value(time))
coeff, _params = _process_callable(coeff, self._current_time) # ty:ignore[invalid-argument-type]
self._extract_parameters(coeff)
if len(_params) > 0:
self._update_parameters(_params)
elif isinstance(coeff, (int, float, Parameter, Term)):
self._extract_parameters(coeff)
else:
raise ValueError(
"Coefficient must be a number, Parameter, Term, or callable that returns one of these types."
)
self._time_dict[time / self._time_scale] = coeff
self.delete_cache()
[docs]
def set_parameter_values(
self,
values: list[float],
where: Callable[[Parameter], bool] | None = None,
) -> None:
"""
Assign parameter values by position and clear caches.
Args:
values (list[float]): New values ordered consistently with ``get_parameter_names()``.
where (Callable[[Parameter], bool] | None): Optional predicate selecting parameters to update.
"""
self.delete_cache()
super().set_parameter_values(values=values, where=where)
[docs]
def set_parameters(self, parameters: dict[str, int | float]) -> None:
"""
Assign parameter values by name and clear caches.
Args:
parameters (dict[str, int | float]): Mapping from parameter labels to numeric values.
"""
self.delete_cache()
super().set_parameters(parameters)
[docs]
def set_parameter_bounds(self, ranges: dict[str, tuple[float, float]]) -> None:
"""
Update parameter bounds and clear caches.
Args:
ranges (dict[str, tuple[float, float]]): Bounds keyed by parameter label.
"""
self.delete_cache()
super().set_parameter_bounds(ranges)
[docs]
def get_coefficient(self, time_step: float) -> float:
"""
Return the numeric coefficient at a given time, applying interpolation and scaling.
Args:
time_step (float): Time at which to evaluate the coefficient.
Returns:
float: Evaluated coefficient.
"""
time_step = time_step.item() if isinstance(time_step, np.generic) else self._get_value(time_step) # ty:ignore[invalid-assignment]
val = self.get_coefficient_expression(time_step=time_step)
if self._max_time is not None:
if self._tlist is None:
self._tlist = self._generate_tlist()
time_step /= self._time_scale
return self._get_value(val, time_step)
[docs]
def get_coefficient_expression(self, time_step: float) -> Number | Term | Parameter:
"""
Return the raw expression for the coefficient at ``time_step`` without final evaluation.
Args:
time_step (float): Time at which to retrieve the coefficient expression.
Returns:
Number | Term | Parameter: Coefficient expression before numeric evaluation.
Raises:
ValueError: If the interpolation mode is unsupported or evaluation fails.
"""
time_step = time_step.item() if isinstance(time_step, np.generic) else self._get_value(time_step) # ty:ignore[invalid-assignment]
# generate the tlist
self._tlist = self._generate_tlist()
if time_step in self.fixed_tlist:
indx = self.fixed_tlist.index(time_step)
return self._time_dict[self._tlist[indx]]
if time_step in self._cached_time:
return self._cached_time[time_step]
if self._max_time is not None:
time_step /= self._time_scale
factor = self._time_scale_cache or 1.0
result = None
if self._interpolation is Interpolation.STEP:
result = self._get_coefficient_expression_step(time_step)
if self._interpolation is Interpolation.LINEAR:
result = self._get_coefficient_expression_linear(time_step)
if result is None:
raise ValueError(f"Interpolation type {self._interpolation.value} is not supported.")
self._cached_time[time_step * factor] = result
return result
def _get_coefficient_expression_step(self, time_step: float) -> Number | Term | Parameter:
"""
Return the step-interpolated coefficient expression for ``time_step``.
Args:
time_step (float): Time at which to retrieve the coefficient.
Returns:
Number | Term | Parameter: Coefficient expression for the previous time point.
"""
self._tlist = self._generate_tlist()
prev_indx = bisect_right(self._tlist, time_step, key=self._get_value) - 1
prev_indx = -1 if prev_indx >= len(self._tlist) else prev_indx
prev_time_step = self._tlist[prev_indx]
return self._time_dict[prev_time_step]
def _get_coefficient_expression_linear(self, time_step: float) -> Number | Term | Parameter:
"""
Return the linearly interpolated coefficient expression for ``time_step``.
Args:
time_step (float): Time at which to interpolate.
Returns:
Number | Term | Parameter: Coefficient expression interpolated between neighbor points.
Raises: # noqa: DOC502
ValueError: If two points share the same time or an unexpected interpolation state is reached.
"""
self._tlist = self._generate_tlist()
insert_pos = bisect_right(self._tlist, time_step, key=self._get_value)
def _linear_value(
t0: PARAMETERIZED_NUMBER, v0: PARAMETERIZED_NUMBER, t1: PARAMETERIZED_NUMBER, v1: PARAMETERIZED_NUMBER
) -> PARAMETERIZED_NUMBER:
t0_val = self._get_value(t0)
t1_val = self._get_value(t1)
if t0_val == t1_val:
raise ValueError(
f"Ambiguous evaluation: The same time step {t0_val} has two different coefficient assignation ({v0} and {v1})."
)
alpha: float = (time_step - t0_val) / (t1_val - t0_val)
next_is_term = isinstance(v1, (Term, Parameter))
prev_is_term = isinstance(v0, (Term, Parameter))
if next_is_term and prev_is_term and v1 != v0:
v1 = self._get_value(v1, t1_val)
v0 = self._get_value(v0, t0_val)
elif next_is_term and not prev_is_term:
v1 = self._get_value(v1, t1_val)
elif prev_is_term and not next_is_term:
v0 = self._get_value(v0, t0_val)
return v1 * alpha + v0 * (1 - alpha)
# this is done in order to prevent setting the indices to none and causing type errors
has_prev = insert_pos >= 1
has_next = insert_pos < len(self._tlist)
prev_idx = self._tlist[insert_pos - 1] if has_prev else 0
prev_expr = self._time_dict[prev_idx] if has_prev else 0
next_idx = self._tlist[insert_pos] if has_next else 0
next_expr = self._time_dict[next_idx] if has_next else 0
if not has_prev and has_next:
if len(self._tlist) == 1:
return next_expr
first_idx = self._tlist[0]
second_idx = self._tlist[1]
return _linear_value(first_idx, self._time_dict[first_idx], second_idx, self._time_dict[second_idx])
if not has_next and has_prev:
if len(self._tlist) == 1:
return prev_expr
last_idx = self._tlist[-1]
penultimate_idx = self._tlist[-2]
return _linear_value(penultimate_idx, self._time_dict[penultimate_idx], last_idx, self._time_dict[last_idx])
if not has_next and not has_prev:
return 0
# this can only be reached if both has_next and has_prev are true, meaning they will not be the default 0s
return _linear_value(prev_idx, prev_expr, next_idx, next_expr)
def __getitem__(self, time_step: float) -> float:
"""
Enable bracket access to coefficients: ``interp[t]``.
Args:
time_step (float): Time at which to evaluate the coefficient.
Returns:
float: Evaluated coefficient.
"""
return self.get_coefficient(time_step)
def __len__(self) -> int:
"""
Return the number of defined time points.
Returns:
int: Number of time points stored.
"""
return len(self.tlist)
def __iter__(self) -> "Interpolator":
"""
Return an iterator over evaluated coefficients in time order.
Returns:
Interpolator: Iterator over the instance.
"""
self.iter_time_step = 0
return self
def __next__(self) -> float:
"""
Iterate over evaluated coefficients across the schedule.
Returns:
float: Next coefficient in time order.
Raises:
StopIteration: When all coefficients have been iterated over.
"""
if self.iter_time_step < self.__len__():
result = self[self.fixed_tlist[self.iter_time_step]]
self.iter_time_step += 1
return result
raise StopIteration