# Copyright 2025 Qilimanjaro Quantum Tech
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Any, Literal, Optional
import matplotlib.font_manager as fm
from pydantic import BaseModel, Field
from .themes import Theme, light
_DEFAULT_FONT_PATH = Path(__file__).parent / "PlusJakartaSans-SemiBold.ttf"
[docs]
class Style(BaseModel):
# --- FontProperties-mapped fields (mirror matplotlib.font_manager.FontProperties) ---
# If `fontfname` exists, it takes precedence and loads the exact TTF.
[docs]
theme: Theme = Field(default=light, description="Colour theme.")
[docs]
fontfamily: str | list[str] | None = Field(
default=None, description="Font family name(s), e.g. 'Outfit' or ['Outfit', 'DejaVu Sans']."
)
[docs]
fontstyle: Literal["normal", "italic", "oblique"] = Field(
default="normal", description="Font style: 'normal', 'italic', or 'oblique'."
)
[docs]
fontvariant: Literal["normal", "small-caps"] = Field(
default="normal", description="Font variant: typically 'normal' or 'small-caps'."
)
[docs]
fontweight: str | int = Field(
default="normal", description="Font weight: 'normal', 'bold', 'light', or numeric (100-900)."
)
[docs]
fontstretch: str | int = Field(
default="normal", description="Width/condensation: 'ultra-condensed'..'ultra-expanded' or numeric."
)
[docs]
fontsize: float | str = Field(
default=10, description="Font size in pt or keywords like 'small', 'medium', 'large'."
)
[docs]
fontfname: str | None = Field(
default=str(_DEFAULT_FONT_PATH), description="Absolute path to the TTF/OTF file. If present, overrides family."
)
[docs]
math_fontfamily: str | None = Field(default=None, description="Math text family, e.g. 'dejavusans', 'cm', or None.")
[docs]
dpi: int = Field(default=150, description="Figure DPI.")
[docs]
title: str | None = Field(default=None, description="Figure title.")
@property
[docs]
def font(self) -> fm.FontProperties:
"""
Construct a Matplotlib FontProperties from the configured fields.
If `fontfname` points to a real file, it is used (and overrides family).
"""
return fm.FontProperties(
family=self.fontfamily,
style=self.fontstyle,
variant=self.fontvariant,
weight=self.fontweight,
stretch=self.fontstretch,
size=self.fontsize,
fname=self.fontfname,
math_fontfamily=self.math_fontfamily,
)
[docs]
class QTensorStyle(Style):
"""All visual parameters controlling the appearance of a QTensor plot."""
[docs]
sphere_points: int = Field(
default=50, description="Number of points to use when plotting the Bloch sphere surface."
)
[docs]
sphere_color: str = Field(default="#1f77b4", description="Color for the Bloch sphere surface (hex or named color).")
[docs]
arrow_color: str = Field(default="#1f77b4", description="Color for the state vector arrow (hex or named color).")
[docs]
arrow_length_ratio: float = Field(
default=0.1,
description="Length of the arrow head as a fraction of the arrow length (e.g. 0.1 means the head is 10% of the total arrow length).",
)
[docs]
draw_center_circle: bool = Field(
default=True,
description="Whether to draw a circle around the centre of the Bloch sphere for reference.",
)
[docs]
centre_circle_color: str = Field(
default="#1f77b4",
description="Color for a circle drawn around the centre of the Bloch sphere for reference (hex or named color).",
)
[docs]
draw_reference_points: bool = Field(
default=True,
description="Whether to draw reference points (|0⟩, |1⟩, |+⟩, |-⟩, |+i⟩, |-i⟩) on the Bloch sphere for orientation.",
)
[docs]
reference_point_distance: float = Field(
default=1.2,
description="Distance from the origin to place the reference point labels (|0⟩, |1⟩, etc.) on the Bloch sphere.",
)
[docs]
rotation_style: Literal["azel", "trackball", "sphere", "arcball"] = Field(
default="azel",
description="Mouse rotation style for 3D plots.",
)
[docs]
class CircuitStyle(Style):
"""All visual parameters controlling the appearance of a circuit plot."""
[docs]
end_wire_ext: int = Field(default=2, description="Extra space after last layer.")
[docs]
padding: float = Field(default=0.3, description="Padding around drawing (inches).")
[docs]
gate_margin: float = Field(default=0.15, description="Left/right margin per gate.")
[docs]
wire_sep: float = Field(default=0.5, description="Vertical separation of wires.")
[docs]
layer_sep: float = Field(default=0.5, description="Horizontal separation of layers.")
[docs]
gate_pad: float = Field(default=0.05, description="Padding around gate text.")
[docs]
label_pad: float = Field(default=0.1, description="Padding before wire label.")
[docs]
bulge: str = Field(default="round", description="Box-style for gate rectangles.")
[docs]
align_layer: bool = Field(default=True, description="Align layers across wires.")
[docs]
wire_label: list[Any] | None = Field(default=None, description="Custom wire labels.")
[docs]
start_pad: float = Field(
default=0.1, description="Minimum spacing (inches) before the first layer so wire labels fit."
)
[docs]
min_gate_h: float = Field(default=0.2, description="Minimum gate box height (inches).")
[docs]
min_gate_w: float = Field(default=0.2, description="Minimum gate box width (inches).")
[docs]
connector_r: float = Field(
default=0.01, description="Radius (inches) of small connector dots on multi-target gates."
)
[docs]
target_r: float = Field(default=0.12, description="Radius (inches) of ⊕ target circle and SWAP half-width.")
[docs]
control_r: float = Field(default=0.05, description="Radius (inches) of a filled control dot.")
[docs]
layout: Literal["normal", "compact"] = Field(
default="normal",
description="If 'compact' minimizes the layers to highlight circuit depth, if 'normal' conserves the order of the circuit",
)
[docs]
class ScheduleStyle(Style):
"""
Customization options for matplotlib schedule plots, with theme support.
"""
# Figure and axes
[docs]
figsize: Optional[tuple] = Field(default=(8, 5), description="Figure size in inches (width, height).")
[docs]
grid: bool = Field(default=True, description="Whether to show grid lines on the plot.")
[docs]
grid_style: dict[str, Any] = Field(
default_factory=lambda: {"linestyle": "--", "color": "#e0e0e0", "alpha": 0.7},
description="Style dictionary for grid lines (linestyle, color, alpha, etc.).",
)
# Title and labels
[docs]
title_fontsize: int = Field(default=16, description="Font size for the plot title.")
[docs]
xlabel: str = Field(default="time", description="Label for the x-axis.")
[docs]
ylabel: str = Field(default="coefficient value", description="Label for the y-axis.")
[docs]
label_fontsize: int = Field(default=14, description="Font size for axis labels.")
[docs]
legend_loc: str = Field(
default="best", description="Location of the legend (matplotlib string, e.g. 'best', 'upper right')."
[docs]
legend_fontsize: int = Field(default=12, description="Font size for legend text.")
[docs]
legend_frame: bool = Field(default=True, description="Whether to draw a frame around the legend.")
# Line style
[docs]
line_styles: dict[str, dict[str, Any]] = Field(
description="Custom line style dictionary for each Hamiltonian (e.g. {label: {color, linestyle, linewidth}}).",
)
[docs]
[docs]
default_line_style: dict[str, Any] = Field(
default_factory=lambda: {"linestyle": "-", "linewidth": 2},
description="Default line style for Hamiltonians not in line_styles.",
# Marker style
[docs]
marker: Optional[str] = Field(
[docs]
default=None, description="Matplotlib marker style for data points (e.g. 'o', 's', None for no marker)."
)
[docs]
marker_size: int = Field(default=6, description="Size of markers if used.")
[docs]
xtick_fontsize: int = Field(default=12, description="Font size for x-axis tick labels.")
[docs]
ytick_fontsize: int = Field(default=12, description="Font size for y-axis tick labels.")
[docs]
tick_color: Optional[str] = Field(
default=None, description="Color for tick labels (None uses theme.on_background)."
)
# Misc
[docs]
[docs]
tight_layout: bool = Field(default=True, description="Whether to use matplotlib's tight_layout for figure spacing.")