'''
Tools to plot ECG signals.
ECGDrawing takes a called reader instance (ECGXMLReader or ECGDICOMReader)
and renders the lead-specific ECG signals on a GridSpec figure with
clinical ECG paper scaling.
'''
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# from __future__ import annotations
import logging
import numpy as np
import matplotlib.pylab as plt
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MultipleLocator
from typing import Any, NamedTuple, Self, Literal
from ecgprocess.errors import InputValidationError, NotCalledError, is_type
from ecgprocess.constants import PlotNames as PDNames, CoreData
from ecgprocess.utils.general import _update_kwargs
# initiate the logger
_log = logging.getLogger(__name__)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _minor_only_ticks(
lo: float,
hi: float,
minor_step: float,
major_step: float,
) -> np.ndarray:
"""
Return minor tick positions that do not coincide with major ticks.
Parameters
----------
lo : `float`
Lower axis limit.
hi : `float`
Upper axis limit.
minor_step : `float`
Minor tick spacing.
major_step : `float`
Major tick spacing.
Returns
-------
ticks : `np.ndarray`
Minor positions with major-coincident values removed.
"""
# Small epsilon to keep np.arange inclusive at hi and to absorb
# floating-point noise when testing whether a tick falls on a major line.
tol = minor_step * 1e-6
# get minor tick locations
tick_loc = np.arange(
np.ceil(lo / minor_step) * minor_step,
hi + tol,
minor_step,
)
remainder = tick_loc % major_step
# which tick to remove
on_major = np.isclose(remainder, 0.0, atol=tol) | np.isclose(
remainder, major_step, atol=tol,
)
# return
return tick_loc[~on_major]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _resolve_lead_kwargs(
kw_dict: dict[str, Any] | None,
lead: str,
) -> dict[str, Any]:
"""
Return the resolved kwargs dict for a single lead.
Parameters
----------
kw_dict : `dict[str, Any]` | `None`
Either a flat dict (uniform across all leads), a dict of dicts
(per-lead), or None.
lead : `str`
Lead name to resolve for.
Returns
-------
resolved : `dict[str, Any]`
Kwargs applicable to this lead. Empty dict when kw_dict is None or
the lead is not listed in a per-lead mapping.
Raises
------
InputValidationError
If kw_dict contains a mix of dict and non-dict values.
Examples
--------
Uniform, same kwargs applied to every lead:
>>> _resolve_lead_kwargs({'color': 'red', 'linewidth': 1.5}, 'I')
{'color': 'red', 'linewidth': 1.5}
Per-lead, different kwargs per lead:
>>> _resolve_lead_kwargs({
'I': {'color': 'red'},
'II': {'color': 'blue'}
}, lead = 'I')
{'color': 'red'}
"""
# return empty dict if None
if not kw_dict:
return {}
vals = list(kw_dict.values())
all_dicts = all(isinstance(v, dict) for v in vals)
any_dicts = any(isinstance(v, dict) for v in vals)
if all_dicts:
# return the lead specific dict
return kw_dict.get(lead, {})
if not any_dicts:
# if there are no dicts, that means the original dict should be
# applied to each lead
return kw_dict
# if mixed dicts raise an error
raise InputValidationError(
"Mixed dict/non-dict values in kwargs, either pass a flat dict "
"(uniform) or a dict of dicts (per-lead)."
)
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
[docs]
class ECGStyle(NamedTuple):
"""
Style configuration for ECGDrawing.
Attributes
----------
figsize : `tuple` [`float`, `float`], default (20.0, 10.0)
Figure width and height in inches.
dpi : `int`, default 600
Figure resolution in dots per inch.
background_colour : `str`, default '#ffd6d6'
Axes background fill colour (ECG paper pink).
major_grid_colour : `str`, default '#ff6666'
Colour of the major grid lines.
minor_grid_colour : `str`, default '#ffaaaa'
Colour of the minor grid lines.
major_grid_linewidth : `float`, default 0.3
Line width of the major grid lines in points.
minor_grid_linewidth : `float`, default 0.2
Line width of the minor grid lines in points.
trace_colour : `str`, default 'black'
ECG trace line colour.
trace_linewidth : `float`, default 0.4
ECG trace line width in points.
paper_speed : `float`, default 25.0
Paper speed in mm/s (standard clinical: 25 mm/s).
mm_per_mv : `float`, default 10.0
Voltage scaling in mm/mV (standard clinical: 10 mm/mV).
clip_on_trace : `bool`, default False
Whether each ECG trace line is clipped to its axes boundary.
False replicates the physical ECG paper look.
hspace : `float`, default 0.0
Vertical spacing between GridSpec rows. Zero mimics continuous ECG
paper; increase to visually separate panels.
wspace : `float`, default 0.0
Horizontal spacing between GridSpec columns. Zero mimics continuous
ECG paper; increase to visually separate panels.
x_lim : `tuple` [`float`, `float`] | `None`, default None
X-axis limits in seconds. None auto-computes from n_samples / fs.
y_lim : `tuple` [`float`, `float`], default (-2.0, 2.0)
Y-axis limits in applied to every lead panel. The default would be fine
for mV, for μV simply multiply by 1,000.
label_coordinates : `tuple` [`float`, `float`], default (0.02, 0.95)
Lead label position as ``(x, y)`` in axes coordinates (0–1).
Notes
-----
Paper speed and mm_per_mv drive the minor/major grid step sizes using the
standard clinical ECG paper mapping:
* x minor = 1 / paper_speed (s) e.g. 0.04 s at 25 mm/s
* x major = 5 / paper_speed (s) e.g. 0.20 s at 25 mm/s
* y minor = 1 / mm_per_mv (mV) e.g. 0.1 mV at 10 mm/mV
* y major = 5 / mm_per_mv (mV) e.g. 0.5 mV at 10 mm/mV
Settings not covered here can be applied after rendering via
``drawing.figure`` / ``drawing.axes`` (standard matplotlib), or passed
as ``**kwargs`` to ``ECGDrawing.__call__`` which forwards them to
``plt.figure()``.
"""
# #### figure output
figsize: tuple[float, float] = (20.0, 10.0)
dpi: int = 600
# #### colours
background_colour: str = '#ffd6d6'
major_grid_colour: str = '#ff6666'
minor_grid_colour: str = '#ffaaaa'
# #### grid line widths
major_grid_linewidth: float = 0.3
minor_grid_linewidth: float = 0.2
# #### trace
trace_colour: str = 'black'
trace_linewidth: float = 0.4
# #### calibration
# x-axis grid spacing is derived from paper_speed (mm/s)
paper_speed: float = 25.0
# y-axis grid spacing is derived from mm_per_mv (mm/mV)
mm_per_mv: float = 10.0
# #### layout
clip_on_trace: bool = False
# Zero spacing mimics continuous ECG paper; increase to separate panels
hspace: float = 0.0
wspace: float = 0.0
# #### axis limits
# None triggers auto-computation from n_samples / fs at render time
x_lim: tuple[float, float] | None = None
y_lim: tuple[float, float] = (-2.0, 2.0)
# #### lead label
label_coordinates: tuple[float, float] = (0.02, 0.95)
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
[docs]
class ECGDrawing:
"""
Takes a called reader instance and plots lead-specific ECG signals.
After calling, the instance exposes the rendered figure and per-lead axes
so the caller can apply further matplotlib customisation directly.
Parameters
----------
leads : `list` [`str`] | `None`, default None
Lead names to plot by default. Uses the standard 12-lead list when
None.
Attributes
----------
figure : plt.Figure
The matplotlib figure.
axes : dict[str, plt.Axes]
Dict mapping lead name to its Axes.
layout_mapper : dict[str, tuple[int, slice]]
The resolved mapper dict used for this render.
style : ECGStyle
The ECGStyle instance used.
minor_step_x : float
Minor x-axis step in seconds (1 / paper_speed).
major_step_x : float
Major x-axis step in seconds (5 / paper_speed).
minor_step_y : float
Minor y-axis step in mV (1 / mm_per_mv).
major_step_y : float
Major y-axis step in mV (5 / mm_per_mv).
paper_speed : float
Resolved paper speed in mm/s.
mm_per_mv : float
Resolved voltage scaling in mm/mV.
default_leads : list[str]
Lead names plotted by default. Set at init; overridable per call.
Methods
-------
tile_x_axis(xlims)
Tiles each column's x-axis to an equal slice of the full signal,
mimicking continuous ECG paper. Per-lead overrides via ``xlims``.
to_numpy(crop, close)
Renders the figure to a 3-D numpy array (RGBA).
"""
# /////////////////////////////////////////////////////////////////////////
# class attributes
SOURCE_MAP: dict[str, str] = {
'waveforms': CoreData.DataTypes.WaveForms,
'median': CoreData.DataTypes.MedianBeats,
}
_VALID_SHOW_AXES: frozenset[str | None] = frozenset({'x', 'y', 'b', None})
_VALID_CAL_PULSE_STYLES: frozenset[str] = frozenset({'square', 'line'})
# /////////////////////////////////////////////////////////////////////////
# __slots__
# reducing memory overhead and making the public attribute surface explicit.
__slots__ = (
PDNames.DEFAULT_LEADS,
PDNames.FIGURE, PDNames.AXES, PDNames.LAYOUT_MAPPER, PDNames.STYLE,
PDNames.MINOR_STEP_X, PDNames.MAJOR_STEP_X,
PDNames.MINOR_STEP_Y, PDNames.MAJOR_STEP_Y,
PDNames.PAPER_SPEED, PDNames.MM_PER_MV,
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __init__(
self,
leads: list[str] | None = None,
) -> None:
"""
Initialises a new instance of ECGDrawing.
"""
lead_default = [
v for v in vars(CoreData.Leads).values()
if isinstance(v, str) and not v.startswith("_")
][2:]
# set to leads or the CoreData.Leads
setattr(self, PDNames.DEFAULT_LEADS, leads or lead_default)
setattr(self, PDNames.FIGURE, None)
setattr(self, PDNames.AXES, None)
setattr(self, PDNames.LAYOUT_MAPPER, None)
setattr(self, PDNames.STYLE, None)
setattr(self, PDNames.MINOR_STEP_X, None)
setattr(self, PDNames.MAJOR_STEP_X, None)
setattr(self, PDNames.MINOR_STEP_Y, None)
setattr(self, PDNames.MAJOR_STEP_Y, None)
setattr(self, PDNames.PAPER_SPEED, None)
setattr(self, PDNames.MM_PER_MV, None)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def __str__(self) -> str:
"""Return human-readable string representation."""
CLASS_NAME = type(self).__name__
leads = getattr(self, PDNames.DEFAULT_LEADS)
return f"{CLASS_NAME} instance with leads={leads}."
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def __repr__(self) -> str:
"""Return unambiguous string representation."""
CLASS_NAME = type(self).__name__
leads = getattr(self, PDNames.DEFAULT_LEADS)
return f"{CLASS_NAME}(leads={leads!r})"
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _build_figure(
self,
signals: dict[str, np.ndarray],
mapper: dict[str, tuple[int, slice]],
style: ECGStyle,
kwargs_fig: dict[str, Any] | None = None,
kwargs_gridspec: dict[str, Any] | None = None,
) -> tuple[plt.Figure, dict[str, plt.Axes]]:
"""
Build the GridSpec figure and per-lead axes from a mapper dict.
Parameters
----------
signals : `dict[str, np.ndarray]`
Signal data extracted from the reader.
mapper : `dict[str, tuple[int, slice]]`
Lead name to ``(row_index, col_slice)`` mapping. Dict insertion
order determines plotting order. Both n_rows and n_cols are
derived from the mapper — no assumptions about grid size are made.
style : `ECGStyle`
Style configuration.
kwargs_fig : `dict[str, Any]` | `None`, optional
Extra keyword arguments for ``plt.figure()``. None is treated
as an empty dict.
kwargs_gridspec : `dict[str, Any]` | `None`, optional
Extra keyword arguments for ``GridSpec()``. None is treated
as an empty dict.
Returns
-------
figure : `plt.Figure`
The created figure.
axes : `dict` [`str`, `plt.Axes`]
Mapping of lead name to axes.
"""
# #### derive grid dimensions from mapper
n_rows = max(row for row, _ in mapper.values()) + 1
n_cols = max(col.stop for _, col in mapper.values())
# #### style values are defaults
kwargs_fig = kwargs_fig or {}
kwargs_gridspec = kwargs_gridspec or {}
# update the kwargs to allow overwriting
fig_kwargs = _update_kwargs(
update_dict=kwargs_fig,
figsize=style.figsize, dpi=style.dpi,
)
gs_kwargs = _update_kwargs(
update_dict=kwargs_gridspec,
hspace=style.hspace, wspace=style.wspace,
)
fig = plt.figure(**fig_kwargs)
gs = GridSpec(n_rows, n_cols, figure=fig, **gs_kwargs)
# #### create one axes per mapper entry in insertion order
# using logger to print warnings
axes: dict[str, plt.Axes] = {}
for lead, (row, col) in mapper.items():
if lead not in signals:
_log.warning(
"Lead %r not found in signal data — axes left blank.", lead
)
ax = fig.add_subplot(gs[row, col])
axes[lead] = ax
# #### warn about signal data keys the mapper did not request
for lead in signals:
if lead not in mapper:
_log.warning(
"Lead %r is in signal data but not requested in "
"mapper.", lead,
)
# return
return fig, axes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _render_leads(
self,
axes: dict[str, plt.Axes],
signals: dict[str, np.ndarray],
style: ECGStyle,
sampling_rate: float,
show_axes: dict[str, str | None] | str | None,
show_lead_labels: bool,
cal_pulse: str | None,
cal_leads: list[str] | None,
kwargs_plot: dict | dict[str, dict] | None,
kwargs_grid: dict | dict[str, dict] | None,
kwargs_ticks: dict | dict[str, dict] | None,
kwargs_background: dict | dict[str, dict] | None,
kwargs_label: dict | dict[str, dict] | None,
) -> None:
"""
Render each lead: trace, limits, background, grid, and tick visibility.
Parameters
----------
axes : `dict` [`str`, `plt.Axes`]
Per-lead axes created by ``_build_figure``.
signals : `dict[str, np.ndarray]`
Signal data extracted from the reader.
style : `ECGStyle`
Style configuration.
sampling_rate : `float`
Sampling frequency in Hz used to build the time axis.
show_axes : `dict` [`str`, `str` | `None`] | `str` | `None`
Which axis ticks and labels to show per lead. A string value
applies uniformly: ``'x'`` x-axis only, ``'y'`` y-axis only,
``'b'`` both, ``None`` hides all. A dict maps individual lead
names to these same string values, leaving unlisted leads hidden.
show_lead_labels : `bool`
When True, annotate each panel with its lead name.
cal_pulse : {'square', 'line'} | `None`
Calibration pulse style to prepend to each lead, or ``None`` to
disable. The pulse height follows the clinical 10 mm paper
convention and is computed as ``10.0 / style.mm_per_mv`` in
the signal's own units. With the default ``mm_per_mv=10`` this
is the standard 1 mV pulse for mV signals; for μV signals
using e.g. ``mm_per_mv=10/1000`` it is 1000 (signal units),
i.e. a 1 mV-equivalent 10 mm pulse. ``'square'`` draws the
full rectangular pulse. ``'line'`` draws only a single
vertical spike, matching the convention used by some
thermal-paper ECG machines.
kwargs_plot : `dict` | `dict` [`str`, `dict`] | `None`
Extra kwargs for ``ax.plot()``.
kwargs_grid : `dict` | `dict` [`str`, `dict`] | `None`
Extra kwargs for ``ax.grid()``.
kwargs_ticks : `dict` | `dict` [`str`, `dict`] | `None`
Extra kwargs for ``ax.tick_params()``.
kwargs_background : `dict` | `dict` [`str`, `dict`] | `None`
Extra kwargs for ``ax.patch.set()``.
kwargs_label : `dict` | `dict` [`str`, `dict`] | `None`
Extra kwargs for ``ax.text()`` lead annotations.
"""
minor_x = getattr(self, PDNames.MINOR_STEP_X)
major_x = getattr(self, PDNames.MAJOR_STEP_X)
minor_y = getattr(self, PDNames.MINOR_STEP_Y)
major_y = getattr(self, PDNames.MAJOR_STEP_Y)
# #### get cal_leads
if cal_leads is None:
# simply print cal for each lead
cal_leads = list(signals.keys())
for lead, ax in axes.items():
if lead not in signals:
# blank axes already created and warned in _build_figure
continue
signal = signals[lead]
# #### prepend calibration pulse
if (cal_pulse is not None) and (lead in cal_leads):
n_pre = round(0.1 * sampling_rate)
n_post = round(0.1 * sampling_rate)
# amplitude follows the clinical 10 mm paper convention:
# mm_per_mv is used as "mm per signal unit" (the y-axis
# gain), so 10 / mm_per_mv is the pulse height in signal
# units. With the default mm_per_mv=10 this is 1.0 (the
# standard 1 mV pulse for mV signals); for μV signals
# using mm_per_mv=10/1000 it is 1000 (a 1 mV-equivalent
# 10 mm pulse).
amp = 10.0 / style.mm_per_mv
if cal_pulse == 'square':
body = np.full(round(0.2 * sampling_rate), amp)
else:
body = np.array([0.0, amp, 0.0])
# 0.1s zero baseline, 0.2s square body (or 3-sample line), 0.1s zero baseline
pulse = np.concatenate(
[np.zeros(n_pre), body, np.zeros(n_post)])
# replace the first len(pulse) samples so all leads stay the
# same length and tile_x_axis windows remain aligned
signal = signal.copy()
signal[:len(pulse)] = pulse
t = np.arange(len(signal)) / sampling_rate
# #### resolve per-lead kwargs
p_kw = _resolve_lead_kwargs(kwargs_plot, lead)
g_kw = _resolve_lead_kwargs(kwargs_grid, lead)
t_kw = _resolve_lead_kwargs(kwargs_ticks, lead)
bg_kw = _resolve_lead_kwargs(kwargs_background, lead)
# #### trace
plot_kw = _update_kwargs(
p_kw,
color=style.trace_colour,
lw=style.trace_linewidth,
clip_on=style.clip_on_trace,
zorder=3,
)
ax.plot(
t, signal,
**plot_kw,
)
# #### axis limits
if style.x_lim:
ax.set_xlim(style.x_lim)
else:
ax.set_xlim(0.0, len(signal) / sampling_rate)
ax.set_ylim(style.y_lim)
# #### background
ax.patch.set_facecolor(style.background_colour)
ax.patch.set_visible(True)
if bg_kw:
ax.patch.set(**bg_kw)
# #### grid lines — minor first (behind), major second (in front)
# ax.grid() does not reliably honour zorder across x/y axes;
# vlines/hlines produce LineCollections that do.
ax.xaxis.set_major_locator(MultipleLocator(major_x))
ax.yaxis.set_major_locator(MultipleLocator(major_y))
x_lo, x_hi = ax.get_xlim()
y_lo, y_hi = ax.get_ylim()
minor_xs = _minor_only_ticks(x_lo, x_hi, minor_x, major_x)
minor_ys = _minor_only_ticks(y_lo, y_hi, minor_y, major_y)
major_xs = np.arange(
np.ceil(x_lo / major_x) * major_x,
x_hi + major_x * 1e-6,
major_x,
)
major_ys = np.arange(
np.ceil(y_lo / major_y) * major_y,
y_hi + major_y * 1e-6,
major_y,
)
minor_kw = _update_kwargs(
g_kw,
color=style.minor_grid_colour,
linewidth=style.minor_grid_linewidth,
zorder=1,
)
major_kw = _update_kwargs(
g_kw,
color=style.major_grid_colour,
linewidth=style.major_grid_linewidth,
zorder=2,
)
ax.vlines(minor_xs, y_lo, y_hi, **minor_kw)
ax.hlines(minor_ys, x_lo, x_hi, **minor_kw)
ax.vlines(major_xs, y_lo, y_hi, **major_kw)
ax.hlines(major_ys, x_lo, x_hi, **major_kw)
# #### tick visibility
lead_show = (
show_axes.get(lead) if isinstance(show_axes, dict)
else show_axes
)
t_params_kwarg = _update_kwargs(
t_kw,
axis='both', which='both',
bottom=False, left=False,
labelbottom=False, labelleft=False,
)
ax.tick_params(
**t_params_kwarg,
)
if lead_show in ('x', 'b'):
ax.tick_params(axis='x', bottom=True, labelbottom=True)
if lead_show in ('y', 'b'):
ax.tick_params(axis='y', left=True, labelleft=True)
# #### lead label
if show_lead_labels:
lbl_kw = _update_kwargs(
_resolve_lead_kwargs(kwargs_label, lead),
fontsize=7, fontweight='bold',
va='top', ha='left',
color=style.trace_colour,
clip_on=False,
)
ax.text(
*style.label_coordinates,
lead,
transform=ax.transAxes,
**lbl_kw,
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __call__(
self,
reader: Any,
mapper: dict,
source: str = "waveforms",
style: ECGStyle | None = None,
show_axes: dict[str, str | None] | str | None = None,
sampling_rate: float | None = None,
show_lead_labels: bool = False,
cal_pulse: Literal["square", "line"] | None = None,
cal_leads: list[str] | None = None,
kwargs_fig: dict | None = None,
kwargs_gridspec: dict | None = None,
kwargs_plot: dict | dict[str, dict] | None = None,
kwargs_grid: dict | dict[str, dict] | None = None,
kwargs_ticks: dict | dict[str, dict] | None = None,
kwargs_background: dict | dict[str, dict] | None = None,
kwargs_label: dict | dict[str, dict] | None = None,
) -> Self:
"""
Creates an ECG figure from a called reader instance.
Parameters
----------
reader : Any
A called ECGXMLReader or ECGDICOMReader instance.
mapper : `dict` [`str`, `tuple` [`int`, `slice`]]
Layout dict mapping lead names to ``(row_index, col_slice)``.
Dict insertion order determines plotting order; grid dimensions are
derived from the mapper values. Mismatch warnings (missing or
unrequested leads) are emitted at WARNING level via the
``ecgprocess.plot_ecgs`` logger.
source : {'waveforms', 'median'}, default 'waveforms'
Whether to plot WaveForms or MedianBeats signals. Can be modified
to use a non-standard source name by updating self.SOURCE_MAP.
style : `ECGStyle` | `None`, optional
Style configuration. Uses defaults when None.
show_axes : `dict` [`str`, `str` | `None`] | `str` | `None`, optional
Which axis ticks and labels to show per lead. A string value
applies uniformly: ``'x'`` x-axis only, ``'y'`` y-axis only,
``'b'`` both, ``None`` hides all. A dict maps individual lead
names to these same string values, leaving unlisted leads hidden.
sampling_rate : `float` | `None`, optional
Sampling frequency in Hz. When None, read from
``reader.MetaData[CoreData.MetaData.SF]``. Raises
`InputValidationError` if neither source is available.
show_lead_labels : `bool`, default False
When True, each panel is annotated with its lead name in the
top-left corner using the default label style.
cal_pulse : {'square', 'line'} | `None`, default `None`
Calibration pulse style to prepend to each lead signal before
rendering. ``None`` disables the pulse. The pulse height
follows the clinical 10 mm paper convention and is computed
as ``10.0 / style.mm_per_mv`` in the signal's own units, so
it always spans 10 mm of plotted paper regardless of unit
choice (mV, μV, …). With the default ``mm_per_mv=10`` this
gives the standard 1 mV pulse for mV signals; for μV signals
using e.g. ``mm_per_mv=10/1000`` it gives 1000 signal units
(a 1 mV-equivalent 10 mm pulse). ``'square'`` draws the full
rectangular pulse: 0.1 s baseline, then 0.2 s at the computed
amplitude, then 0.1 s baseline. ``'line'`` draws only a
single vertical spike: 0.1 s baseline, then a 3-sample
rise/peak/fall, then 0.1 s baseline, matching the convention
used by some thermal-paper ECG machines. The caller's data is
not mutated.
cal_leads : `list` [`str`] | `None`, default `None`
Restricts which leads receive the calibration pulse. When
`None`, every lead in the signal data receives the pulse
(provided ``cal_pulse`` is not `None`). Lead names not
present in the signal data are ignored. Has no effect when
``cal_pulse`` is `None`.
kwargs_fig : `dict` | `None`, optional
Extra keyword arguments forwarded to ``plt.figure()``. Style
fields (``figsize``, ``dpi``) are used as defaults and are
overridden by any matching keys here.
kwargs_gridspec : `dict` | `None`, optional
Extra keyword arguments forwarded to ``GridSpec()``. Style
fields (``hspace``, ``wspace``) are used as defaults and are
overridden by any matching keys here.
kwargs_plot : `dict` | `dict` [`str`, `dict`] | `None`, optional
Extra keyword arguments forwarded to ``ax.plot()``. Pass a flat
dict to apply uniformly, or a dict of dicts for per-lead control.
kwargs_grid : `dict` | `dict` [`str`, `dict`] | `None`, optional
Extra keyword arguments forwarded to ``ax.grid()``. Same
uniform/per-lead convention as ``kwargs_plot``.
kwargs_ticks : `dict` | `dict` [`str`, `dict`] | `None`, optional
Extra keyword arguments forwarded to ``ax.tick_params()``. Same
uniform/per-lead convention as ``kwargs_plot``.
kwargs_background : `dict` | `dict` [`str`, `dict`] | `None`, optional
Extra keyword arguments forwarded to ``ax.patch.set()``. Same
uniform/per-lead convention as ``kwargs_plot``.
kwargs_label : `dict` | `dict` [`str`, `dict`] | `None`, optional
Extra keyword arguments forwarded to ``ax.text()`` for lead
labels. Only used when ``show_lead_labels=True``. Same
uniform/per-lead convention as ``kwargs_plot``. Caller-supplied
values override the defaults (fontsize=7, fontweight='bold',
va='top', ha='left').
Returns
-------
self : ECGDrawing
Returns the instance with populated figure, axes, and
calibration attributes.
"""
# #### type-check every parameter at the public boundary
is_type(mapper, dict)
is_type(source, str)
is_type(style, (ECGStyle, type(None)))
is_type(show_axes, (str, dict, type(None)))
is_type(sampling_rate, (float, int, type(None)))
is_type(show_lead_labels, bool)
is_type(cal_pulse, (str, type(None)))
is_type(cal_leads, (list, type(None)))
is_type(kwargs_fig, (dict, type(None)))
is_type(kwargs_gridspec, (dict, type(None)))
is_type(kwargs_plot, (dict, type(None)))
is_type(kwargs_grid, (dict, type(None)))
is_type(kwargs_ticks, (dict, type(None)))
is_type(kwargs_background, (dict, type(None)))
is_type(kwargs_label, (dict, type(None)))
# #### resolve style
style = style or ECGStyle()
# #### validate source and extract signals
if source not in self.SOURCE_MAP:
raise InputValidationError(
f"source must be 'waveforms' or 'median', got {source!r}."
)
signals = getattr(reader, self.SOURCE_MAP[source], None)
if signals is None or len(signals) == 0:
raise InputValidationError(
f"Reader has no data for source={source!r}."
)
# #### check for NaN in every lead
for lead, arr in signals.items():
if np.any(np.isnan(arr)):
raise InputValidationError(
f"Signal for lead {lead!r} contains NaN values."
)
# #### store style and calibration attributes
setattr(self, PDNames.STYLE, style)
setattr(self, PDNames.PAPER_SPEED, style.paper_speed)
setattr(self, PDNames.MM_PER_MV, style.mm_per_mv)
# grid steps: mm / (mm/s) = s and mm / (mm/mV) = mV
setattr(self, PDNames.MINOR_STEP_X,
PDNames.MINOR_GRID_MM / style.paper_speed)
setattr(self, PDNames.MAJOR_STEP_X,
PDNames.MAJOR_GRID_MM / style.paper_speed)
setattr(self, PDNames.MINOR_STEP_Y,
PDNames.MINOR_GRID_MM / style.mm_per_mv)
setattr(self, PDNames.MAJOR_STEP_Y,
PDNames.MAJOR_GRID_MM / style.mm_per_mv)
# #### validate show_axes content (type already checked above)
if isinstance(show_axes, str):
if show_axes not in self._VALID_SHOW_AXES:
raise InputValidationError(
f"show_axes must be 'x', 'y', 'b', or None, "
f"got {show_axes!r}."
)
elif isinstance(show_axes, dict):
invalid = {
v for v in show_axes.values()
if v not in self._VALID_SHOW_AXES
}
if invalid:
raise InputValidationError(
f"show_axes dict contains invalid values: {invalid!r}."
)
# #### validate cal_pulse
if cal_pulse is not None \
and cal_pulse not in self._VALID_CAL_PULSE_STYLES:
raise InputValidationError(
f"cal_pulse must be 'square', 'line', or None, "
f"got {cal_pulse!r}."
)
# #### build figure
fig, axes = self._build_figure(
signals, mapper, style,
kwargs_fig=kwargs_fig,
kwargs_gridspec=kwargs_gridspec,
)
setattr(self, PDNames.FIGURE, fig)
setattr(self, PDNames.AXES, axes)
setattr(self, PDNames.LAYOUT_MAPPER, mapper)
# #### resolve sampling rate
# trying to pull this from the metadata
if sampling_rate is None:
meta = getattr(reader, CoreData.DataTypes.MetaData, None) or {}
sampling_rate = meta.get(CoreData.MetaData.SF)
# if still None raise an error
if sampling_rate is None:
raise InputValidationError(
"Sampling rate unavailable: pass sampling_rate= or ensure "
f"MetaData[{CoreData.MetaData.SF!r}] is set."
)
# #### render leads
self._render_leads(
axes=axes,
signals=signals,
style=style,
sampling_rate=float(sampling_rate),
show_axes=show_axes,
show_lead_labels=show_lead_labels,
cal_pulse=cal_pulse,
cal_leads=cal_leads,
kwargs_plot=kwargs_plot,
kwargs_grid=kwargs_grid,
kwargs_ticks=kwargs_ticks,
kwargs_background=kwargs_background,
kwargs_label=kwargs_label,
)
# return
return self
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def tile_x_axis(
self,
xlims: dict[str, tuple[float, float]] | None = None,
) -> Self:
"""
Tile the x-axis so each column shows an equal slice of the total
signal duration, mimicking continuous ECG paper.
The column count is derived from the mapper used at render time.
Each column's x window runs from
``x_min + col * segment`` to ``x_min + (col + 1) * segment``,
where ``segment = (x_max - x_min) / n_cols``. Multi-column spans
(e.g. a full-width rhythm strip) receive the proportionally wider
window. Each row resets to zero independently.
Signal data outside a column's x window is removed from the line
artists rather than clipped. This prevents traces from bleeding
horizontally into adjacent panels while leaving
``clip_on_trace=False`` intact so tall deflections can still cross
row boundaries.
Parameters
----------
xlims : `dict` [`str`, `tuple` [`float`, `float`]] | `None`, \
optional
Per-lead x-axis overrides. Keys are lead names; values are
``(x_min, x_max)`` in seconds. Leads not listed fall back to
the auto-tiled column window.
Returns
-------
self : ECGDrawing
Returns the instance for method chaining.
Raises
------
NotCalledError
If ``__call__`` has not been invoked yet.
InputValidationError
If ``xlims`` is not a dict or contains non-tuple values.
"""
axes = getattr(self, PDNames.AXES)
mapper = getattr(self, PDNames.LAYOUT_MAPPER)
if axes is None or mapper is None:
raise NotCalledError(
"ECGDrawing must be called before tile_x_axis()."
)
is_type(xlims, (dict, type(None)))
xlims = xlims or {}
# #### derive total x range and column count from rendered state
# Use min xlim across all axes so that every tile window has complete
# signal coverage — leads with cal pulses have a longer xlim than
# leads without, and using min prevents an empty trailing gap.
x_min = min(ax.get_xlim()[0] for ax in axes.values())
x_max = min(ax.get_xlim()[1] for ax in axes.values())
n_cols = max(col.stop for _, col in mapper.values())
segment = (x_max - x_min) / n_cols
for lead, (_, col_slice) in mapper.items():
if lead in xlims:
col_x_min, col_x_max = xlims[lead]
else:
col_x_min = x_min + col_slice.start * segment
col_x_max = x_min + col_slice.stop * segment
ax = axes[lead]
for line in ax.lines:
xd = np.asarray(line.get_xdata())
yd = np.asarray(line.get_ydata())
mask = (xd >= col_x_min) & (xd <= col_x_max)
line.set_xdata(xd[mask])
line.set_ydata(yd[mask])
ax.set_xlim(col_x_min, col_x_max)
return self
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def to_numpy(
self,
crop: bool = False,
close: bool = True,
) -> np.ndarray:
"""
Maps the matplotlib figure to a numpy array.
Parameters
----------
crop : `bool`, default False
When True, applies tight layout before conversion, removing
excess whitespace around the plot area.
close : `bool`, default True
Whether to close the figure after extraction.
Returns
-------
array : np.ndarray
3-dimensional array of shape (height, width, 4) in RGBA format.
Raises
------
NotCalledError
If __call__ has not been invoked yet.
"""
# input
is_type(crop, bool)
is_type(close, bool)
# get figure
fig = getattr(self, PDNames.FIGURE)
if fig is None:
raise NotCalledError(
"ECGDrawing must be called before to_numpy()."
)
# cropping
if crop:
fig.tight_layout(pad=0)
fig.canvas.draw()
# get array from fig
arr = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).copy()
h, w = fig.canvas.get_width_height()[::-1]
arr = arr.reshape(h, w, 4)
# close if needed and results attributes
if close:
plt.close(fig)
setattr(self, PDNames.FIGURE, None)
setattr(self, PDNames.AXES, None)
# return
return arr