Source code for physioview.pipeline.SQA

from typing import Literal, Optional, Union
from tqdm import tqdm
from math import ceil
from scipy.interpolate import interp1d
from physioview import physioview
import pandas as pd
import numpy as np
import plotly.graph_objects as go

DEBUGGING = False

# ============================== CARDIOVASCULAR ==============================
[docs] class Cardio: """ A class for signal quality assessment on cardiovascular data, including electrocardiograph (ECG) or photoplethysmograph (PPG) data. Parameters/Attributes --------------------- fs : int The sampling rate of the cardiovascular data. """ def __init__(self, fs: int): """ Initialize the Cardiovascular object. Parameters ---------- fs : int The sampling rate of the ECG or PPG recording. """ self.fs = int(fs)
[docs] def compute_metrics( self, data: pd.DataFrame, beats_ix: Optional[np.ndarray] = None, artifacts_ix: Optional[np.ndarray] = None, ts_col: Optional[str] = None, seg_size: int = 60, min_hr: float = 40, rolling_window: Optional[int] = None, rolling_step: int = 15, show_progress: bool = True ) -> pd.DataFrame: """ Compute all SQA metrics for cardiovascular data by segment or moving window. Metrics per segment or moving window include numbers of detected, expected, missing, and artifactual beats and percentages of missing and artifactual beats. Parameters ---------- data : pandas.DataFrame A DataFrame containing pre-processed ECG or PPG data. beats_ix : array_like, optional An array containing the indices of detected beats. Required if `data` does not contain a "Beat" column with beat occurrences. artifacts_ix : array_like, optional An array containing the indices of artifactual beats. Required if `data` does not contain an "Artifact" column with artifactual beat occurrences. ts_col : str, optional The name of the column containing timestamps; by default, None. If a string value is given, the output will contain a timestamps column. seg_size : int, optional The segment size in seconds; by default, 60. min_hr : float, optional The minimum heart rate against which the number of detected beats is considered valid; by default, 40. rolling_window : int, optional The size, in seconds, of the sliding window across which to compute the SQA metrics; by default, None. rolling_step : int, optional The step size, in seconds, of the sliding windows; by default, 15. show_progress : bool, optional Whether to display a progress bar while the function runs; by default, True. Returns ------- metrics : pandas.DataFrame A DataFrame with all computed SQA metrics per segment. Notes ----- If a value is given in the `rolling_window` parameter, the rolling window approach will override the segmented approach, ignoring any `seg_size` value. Examples -------- >>> from physioview.pipeline import SQA >>> sqa = SQA.Cardio(fs = 1000) >>> artifacts_ix = sqa.identify_artifacts(beats_ix, method = 'cbd') >>> cardio_qa = sqa.compute_metrics(ecg, beats_ix, artifacts_ix, \ ... ts_col = 'Timestamp', \ ... seg_size = 60, min_hr = 40) """ df = data.copy() df.index = df.index.astype(int) # Ensure a "Beat" column exists if 'Beat' not in df.columns: df.loc[beats_ix, 'Beat'] = 1 # Compute IBIs if no "IBI" column if 'IBI' not in df.columns: ibi = physioview.compute_ibis(df, self.fs, beats_ix, ts_col) df['IBI'] = ibi['IBI'] # Compute SQA metrics across rolling windows if rolling_window is not None: results = [] last_valid_hr = np.nan # Compute artifacts artifacts = self.get_artifacts( df, beats_ix, artifacts_ix, seg_size = 1, ts_col = ts_col) for s, start in enumerate( tqdm(range(0, len(df), rolling_step * self.fs), disable = not show_progress), start = 1): window = df.iloc[start: start + rolling_window * self.fs] # Calculate expected HR median_hrs = self._window_medians(window) if median_hrs: exp_hr = float(np.nanmedian(median_hrs)) last_valid_hr = exp_hr elif not np.isnan(last_valid_hr): exp_hr = last_valid_hr else: exp_hr = np.nan # Calculate the expected number of beats for this window if np.isnan(exp_hr): n_expected = np.nan else: n_expected = int(round(exp_hr * (rolling_window / 60.0))) # Detected beats in this window n_detected = window['Beat'].notna().sum() # Missing beats n_missing = np.nan if np.isnan(n_expected) \ else max(0, n_expected - n_detected) perc_missing = np.nan if np.isnan(n_expected) \ else round((n_missing / n_expected) * 100, 2) # Artifactual beats in this window start_sec = start // self.fs window_artifact = artifacts.iloc[ start_sec: start_sec + rolling_window] n_artifact = window_artifact['N Artifact'].sum() perc_artifact = np.nan if n_detected == 0 \ else round((n_artifact / n_detected) * 100, 2) row = { 'Moving Window': s } if ts_col is not None and ts_col in window.columns: row['Timestamp'] = window[ts_col].iloc[0] row.update({ 'N Expected': n_expected, 'N Detected': n_detected, 'N Missing': n_missing, '% Missing': perc_missing, 'N Artifact': n_artifact, '% Artifact': perc_artifact, }) results.append(row) metrics = pd.DataFrame(results) # Compute SQA metrics across non-overlapping segments else: if ts_col is not None: missing = self.get_missing( df, beats_ix, artifacts_ix, seg_size, ts_col = ts_col) artifacts = self.get_artifacts( df, beats_ix, artifacts_ix, seg_size, ts_col) metrics = pd.merge( missing, artifacts, on = ['Segment', 'Timestamp']) else: missing = self.get_missing( df, beats_ix, artifacts_ix, seg_size) artifacts = self.get_artifacts( df, beats_ix, artifacts_ix, seg_size) metrics = pd.merge(missing, artifacts, on = ['Segment']) metrics['Invalid'] = metrics['N Detected'].apply( lambda x: 1 if x < int(min_hr * (seg_size/60)) or x > 220 else np.nan) return metrics
[docs] def get_artifacts( self, data: pd.DataFrame, beats_ix: np.ndarray, artifacts_ix: np.ndarray, seg_size: int = 60, ts_col: Optional[str] = None ) -> pd.DataFrame: """ Summarize the number and proportion of artifactual beats per segment. Parameters ---------- data : pandas.DataFrame A DataFrame containing the pre-processed ECG or PPG data. beats_ix : array_like An array containing the indices of detected beats. artifacts_ix : array_like An array containing the indices of artifactual beats. This is outputted from `SQA.Cardio.identify_artifacts()`. seg_size : int The size of the segment in seconds; by default, 60. ts_col : str, optional The name of the column containing timestamps; by default, None. If a string value is given, the output will contain a timestamps column. Returns ------- artifacts : pandas.DataFrame A DataFrame with the number and proportion of artifactual beats per segment. See Also -------- SQA.Cardio.identify_artifacts : Identify artifactual beats using both or either of the methods. """ df = data.copy() if 'Beat' not in df.columns: df.loc[beats_ix, 'Beat'] = 1 if 'Artifact' not in df.columns: df.loc[artifacts_ix, 'Artifact'] = 1 n_seg = ceil(len(df) / (self.fs * seg_size)) segments = pd.Series(np.arange(1, n_seg + 1)) n_detected = df.groupby( df.index // (self.fs * seg_size))['Beat'].sum().fillna(0).astype(int) n_artifact = df.groupby( df.index // (self.fs * seg_size))['Artifact'].sum().fillna(0).astype(int) perc_artifact = round((n_artifact / n_detected) * 100, 2) if ts_col is not None: timestamps = df.groupby( df.index // (self.fs * seg_size)).first()[ts_col] artifacts = pd.concat([ segments, timestamps, n_artifact, perc_artifact, ], axis = 1) artifacts.columns = [ 'Segment', 'Timestamp', 'N Artifact', '% Artifact', ] else: artifacts = pd.concat([ segments, n_artifact, perc_artifact, ], axis = 1) artifacts.columns = [ 'Segment', 'N Artifact', '% Artifact', ] return artifacts
[docs] def identify_artifacts( self, beats_ix: np.ndarray, method: Literal['hegarty', 'cbd', 'both'], initial_hr: Union[float, Literal['auto'], None] = None, prev_n: Optional[int] = None, neighbors: Optional[int] = None, tol: Optional[float] = None ) -> np.ndarray: """ Identify locations of artifactual beats in cardiovascular data based on the criterion beat difference approach by Berntson et al. (1990), the Hegarty-Craver et al. (2018) approach, or both. Parameters ---------- beats_ix : array_like An array containing the indices of detected beats. method : {'hegarty', 'cbd', 'both'} The artifact identification method for identifying artifacts. This must be 'hegarty', 'cbd', or 'both'. initial_hr : float, or 'auto', optional The heart rate value for the first interbeat interval (IBI) to be validated against; by default, 'auto' for automatic calculation using the mean heart rate value obtained from six consecutive IBIs with the smallest average successive difference. Required for the 'hegarty' method. prev_n : int, optional The number of preceding IBIs to validate against; by default, 6. Required for 'hegarty' method. neighbors : int, optional The number of surrounding IBIs with which to derive the criterion beat difference score; by default, 5. Required for 'cbd' method. tol : float, optional A configurable hyperparameter used to fine-tune the stringency of the criterion beat difference test; by default, 1. Required for 'cbd' method. Returns ------- artifacts_ix : array_like An array containing the indices of identified artifact beats. Notes ----- The source code for the criterion beat difference test is from work by Hoemann et al. (2020). References ---------- Berntson, G., Quigley, K., Jang, J., Boysen, S. (1990). An approach to artifact identification: Application to heart period data. Psychophysiology, 27(5), 586–598. Hegarty-Craver, M. et al. (2018). Automated respiratory sinus arrhythmia measurement: Demonstration using executive function assessment. Behavioral Research Methods, 50, 1816–1823. Hoemann, K. et al. (2020). Context-aware experience sampling reveals the scale of variation in affective experience. Scientific Reports, 10(1), 1–16. """ def identify_artifacts_hegarty( beats_ix: np.ndarray, initial_hr: Union[float, Literal['auto']] = 'auto', prev_n: int = 6 ) -> np.ndarray: """Identify locations of artifactual beats in cardiovascular data based on the approach by Hegarty-Craver et al. (2018).""" ibis = (np.diff(beats_ix) / self.fs) * 1000 beats = beats_ix[1:] # drop the first beat artifact_beats = [] valid_beats = [beats_ix[0]] # assume first beat is valid # Set the initial IBI to compare against if initial_hr == 'auto': successive_diff = np.abs(np.diff(ibis)) min_diff_ix = np.convolve( successive_diff, np.ones(6) / 6, mode = 'valid').argmin() first_ibi = ibis[min_diff_ix:min_diff_ix + 6].mean() else: first_ibi = 60000 / initial_hr for n in range(len(ibis)): current_ibi = ibis[n] current_beat = beats[n] # Check against an estimate of the first N IBIs if n < prev_n: if n == 0: ibi_estimate = first_ibi else: next_five = np.insert(ibis[:n], 0, first_ibi) ibi_estimate = np.median(next_five) # Check against an estimate of the preceding N IBIs else: ibi_estimate = np.median(ibis[n - (prev_n):n]) # Set the acceptable/valid range of IBIs low = (26 / 32) * ibi_estimate high = (44 / 32) * ibi_estimate if low <= current_ibi <= high: valid_beats.append(current_beat) else: artifact_beats.append(current_beat) return np.array(valid_beats), np.array(artifact_beats) def identify_artifacts_cbd( beats_ix: np.ndarray, neighbors: int = 5, tol: float = 1 ) -> np.ndarray: """Identify locations of abnormal interbeat intervals (IBIs) using the criterion beat difference test by Berntson et al. (1990).""" # Derive IBIs from beat indices ibis = ((np.ediff1d(beats_ix)) / self.fs) * 1000 # Compute consecutive absolute differences across IBIs ibi_diffs = np.abs(np.ediff1d(ibis)) # Initialize an array to store "bad" IBIs ibi_bad = np.zeros(shape = len(ibis)) artifact_beats = [] # Flag IBIs that are implausible = below a 40 bpm threshold min_ibi = 60000 / 40 invalid_ix = np.where(ibis > min_ibi)[0] if len(invalid_ix) > 0: artifact_beats.extend( beats_ix[invalid_ix + 1]) ibi_bad[invalid_ix] = 1 if len(ibi_diffs) < neighbors: neighbors = len(ibi_diffs) for ii in range(len(ibi_diffs)): # If there are not enough neighbors in the beginning if ii < int(neighbors / 2) + 1: select = np.concatenate( (ibi_diffs[:ii], ibi_diffs[(ii + 1):(neighbors + 1)])) select_ibi = np.concatenate( (ibis[:ii], ibis[(ii + 1):(neighbors + 1)])) # If there are not enough neighbors at the end elif (len(ibi_diffs) - ii) < (int(neighbors / 2) + 1) and ( len(ibi_diffs) - ii) > 1: select = np.concatenate( (ibi_diffs[-(neighbors - 1):ii], ibi_diffs[ii + 1:])) select_ibi = np.concatenate( (ibis[-(neighbors - 1):ii], ibis[ii + 1:])) # If there is only one neighbor left to check against elif len(ibi_diffs) - ii == 1: select = ibi_diffs[-(neighbors - 1):-1] select_ibi = ibis[-(neighbors - 1):-1] else: select = np.concatenate( (ibi_diffs[ii - int(neighbors / 2):ii], ibi_diffs[(ii + 1):(ii + 1 + int(neighbors / 2))])) select_ibi = np.concatenate( (ibis[ii - int(neighbors / 2):ii], ibis[(ii + 1):(ii + 1 + int(neighbors / 2))])) # Calculate the quartile deviation QD = self._quartile_deviation(select) # Calculate the maximum expected difference (MED) MED = 3.32 * QD # Calculate the minimal artifact difference (MAD) MAD = (np.median(select_ibi) - 2.9 * QD) / 3 # Calculate the criterion beat difference score criterion_beat_diff = (MED + MAD) / 2 # Find indices of IBIs that fail the CBD check if (ibi_diffs[ii]) > tol * criterion_beat_diff: bad_neighbors = int(neighbors * 0.25) if ii + (bad_neighbors - 1) < len(beats_ix): artifact_beats.extend( beats_ix[ii + 1:(ii + bad_neighbors + 1)]) else: artifact_beats.extend( beats_ix[ii + 1:(ii + (bad_neighbors - 1))]) ibi_bad[ii + 1] = 1 artifact_beats = np.array(artifact_beats).flatten() return artifact_beats if method == 'hegarty': initial_hr = initial_hr if initial_hr is not None else 'auto' prev_n = prev_n if prev_n is not None else 6 _, artifacts_ix = identify_artifacts_hegarty( beats_ix, initial_hr, prev_n) elif method == 'cbd': neighbors = neighbors if neighbors is not None else 5 tol = tol if tol is not None else 1 artifacts_ix = identify_artifacts_cbd( beats_ix, neighbors, tol) elif method == 'both': initial_hr = initial_hr if initial_hr is not None else 'auto' prev_n = prev_n if prev_n is not None else 6 neighbors = neighbors if neighbors is not None else 5 tol = tol if tol is not None else 1 _, artifact_hegarty = identify_artifacts_hegarty( beats_ix, initial_hr, prev_n) artifact_cbd = identify_artifacts_cbd( beats_ix, neighbors, tol) artifacts_ix = np.union1d(artifact_hegarty, artifact_cbd) else: raise ValueError( 'Invalid method. Method must be \'hegarty\', \'cbd\', ' 'or \'both\'.') return artifacts_ix
[docs] def get_missing( self, data: pd.DataFrame, beats_ix: Optional[np.ndarray] = None, artifacts_ix: Optional[np.ndarray] = None, seg_size: int = 60, ts_col: Optional[str] = None, ) -> pd.DataFrame: """ Summarize the number and proportion of missing beats per segment. Parameters ---------- data : pandas.DataFrame The DataFrame containing the pre-processed ECG or PPG data. beats_ix : array_like, optional An array containing the indices of detected beats. Required if `data` doesn't contain a "Beat" column. artifacts_ix : array_like, optional An array containing the indices of artifactual beats. Required if `data` doesn't contain an "Artifact" column . seg_size : int, optional The size of the segment in seconds; by default, 60. ts_col : str, optional The name of the column containing timestamps; by default, None. If a string value is given, the output will contain a timestamps column. Returns ------- missing : pandas.DataFrame A DataFrame with detected, expected, and missing numbers of beats per segment. """ data = data.copy() # Ensure a "Segment" column exists if 'Segment' not in data.columns: data.insert(0, 'Segment', data.index // (self.fs * seg_size) + 1) # Ensure "Beat" and "Artifact" columns exist if 'Beat' not in data.columns: data.loc[beats_ix, 'Beat'] = 1 if 'Artifact' not in data.columns: data.loc[artifacts_ix, 'Artifact'] = 1 # Compute IBIs if no "IBI" column if 'IBI' not in data.columns: ibi = physioview.compute_ibis(data, self.fs, beats_ix, ts_col) data['IBI'] = ibi['IBI'] def _expected_hr(seg: int, seg_nums: np.ndarray) -> float: """Estimate the expected HR for a segment with adjacent segment fallback.""" segment = data.loc[data.Segment == seg] median_hrs = self._window_medians(segment) # Check the last 50% of the previous segment if not median_hrs and (seg - 1) in seg_nums: prev = data.loc[data.Segment == seg - 1] last_half = prev.iloc[-int(seg_size * 0.5):] median_hrs = self._window_medians(last_half) # Check the first 50% of the next segment if not median_hrs and (seg + 1) in seg_nums: nxt = data.loc[data.Segment == seg + 1] first_half = nxt.iloc[:int(seg_size * 0.5)] median_hrs = self._window_medians(first_half) return float(np.nanmedian(median_hrs)) if median_hrs else np.nan seg_nums = data.Segment.unique() results = [] last_valid_hr = np.nan for seg in seg_nums: exp_hr = _expected_hr(seg, seg_nums) # Detected beats n_detected = data.loc[(data.Segment == seg) & data.Beat.notna()].shape[0] # If exp_hr cannot be estimated and if there are detected beats, # use the last valid exp_hr if np.isnan(exp_hr) and not np.isnan(last_valid_hr): exp_hr = last_valid_hr elif not np.isnan(exp_hr): last_valid_hr = exp_hr # Calculate expected number of beats in this segment if np.isnan(exp_hr): n_expected = np.nan else: n_expected = int(round(exp_hr * (seg_size / 60))) # Rescale n_expected for the last partial segment if seg == seg_nums[-1]: factor = (len(data[data.Segment == seg]) / self.fs) / seg_size n_expected = int(round(n_expected * factor)) # Missing beats n_missing = np.nan if np.isnan(n_expected) \ else max(0, n_expected - n_detected) perc_missing = np.nan if np.isnan(n_expected) \ else round((n_missing / n_expected) * 100, 2) row = {'Segment': seg} if ts_col is not None and ts_col in data.columns: row['Timestamp'] = data.loc[data.Segment == seg, ts_col].iloc[0] row.update({ 'N Detected': n_detected, 'N Expected': n_expected, 'N Missing': n_missing, '% Missing': perc_missing, }) results.append(row) missing = pd.DataFrame(results) # Backfill for any un-estimable leading segments first_valid = missing['N Expected'].first_valid_index() if first_valid is not None: missing.loc[:first_valid, 'N Expected'] = missing.loc[first_valid, 'N Expected'] # Recalculate missing numbers missing['N Expected'] = missing['N Expected'].astype('Int64') missing['N Missing'] = (missing['N Expected'] - missing['N Detected']).clip(lower = 0) missing['% Missing'] = ((missing['N Missing'] / missing['N Expected']) * 100).round(2) return missing
[docs] def get_seconds( self, data: pd.DataFrame, beats_ix: np.ndarray, ts_col: Optional[str] = None, show_progress: bool = True ) -> pd.DataFrame: """Get instantaneous (second-by-second) HR, IBI, and beat counts from ECG or PPG data according to the approach by Graham (1978). Parameters ---------- data : pandas.DataFrame The DataFrame containing the pre-processed ECG or PPG data. beats_ix : array-like An array containing the indices of detected beats. ts_col : str, optional The name of the column containing timestamps; by default, None. If a string value is given, the output will contain a timestamps column. show_progress : bool, optional Whether to display a progress bar while the function runs; by default, True. Returns ------- interval_data : pandas.DataFrame A DataFrame containing instantaneous HR and IBI values. Notes ----- Rows with `NaN` values in the resulting DataFrame `interval_data` denote seconds during which no beats in the data were detected. References ---------- Graham, F. K. (1978). Constraints on measuring heart rate and period sequentially through real and cardiac time. Psychophysiology, 15(5), 492–495. """ df = data.copy() temp_beat = '_temp_beat' df.index = df.index.astype(int) df.loc[beats_ix, temp_beat] = 1 interval_data = [] # Iterate over each second s = 1 for i in tqdm(range(0, len(df), self.fs), disable = not show_progress): # Get data at the current second and evaluation window current_sec = df.iloc[i:(i + self.fs)] if i == 0: # Look at current and next second window = df.iloc[:(i + self.fs)] else: # Look at previous, current, and next second window = df.iloc[(i - self.fs):(min(i + self.fs, len(df)))] # Get mean IBI and HR values from the detected beats current_beats = current_sec[current_sec[temp_beat] == 1].index.values window_beats = window[window[temp_beat] == 1].index.values ibis = np.diff(window_beats) / self.fs * 1000 if len(ibis) == 0: mean_ibi = np.nan mean_hr = np.nan else: mean_ibi = np.mean(ibis) hrs = 60000 / ibis r_hrs = 1 / hrs mean_hr = 1 / np.mean(r_hrs) # Append values for the current second if ts_col is not None: interval_data.append({ 'Second': s, 'Timestamp': current_sec.iloc[0][ts_col], 'Mean HR': mean_hr, 'Mean IBI': mean_ibi, 'N Beats': len(current_beats) }) else: interval_data.append({ 'Second': s, 'Mean HR': mean_hr, 'Mean IBI': mean_ibi, 'N Beats': len(current_beats) }) s += 1 interval_data = pd.DataFrame(interval_data) return interval_data
[docs] def correct_interval( self, beats_ix: np.ndarray, initial_hr: Union[float, Literal['auto']] = 'auto', prev_n: int = 6, min_bpm: int = 40, max_bpm: int = 200, hr_estimate_window: int = 6, print_estimated_hr: bool = True, short_threshold: float = (24 / 32), long_threshold: float = (44 / 32), extra_threshold: float = (52 / 32) ) -> tuple[np.ndarray, pd.DataFrame, pd.DataFrame, pd.DataFrame]: ''' Correct artifactual beats in cardiovascular data based on the approach by Hegarty-Craver et al. (2018). Parameters ---------- beats_ix : array_like An array containing the indices of detected beats. initial_hr : float or {'auto'}, optional The heart rate value for the first interbeat interval (IBI) to be validated against; by default, 'auto' (i.e., the value is determined automatically). prev_n : int, optional The number of preceding IBIs to validate against; by default, 6. min_bpm : int, optional The minimum possible heart rate in beats per minute (bpm); by default, 40. max_bpm : int, optional The maximum possible heart rate in beats per minute (bpm); by default, 200. hr_estimate_window : int, optional The window size for estimating the heart rate; by default, 6. print_estimated_hr : bool, optional Whether to print the estimated heart rate; by default, True. short_threshold : float, optional The threshold for short IBIs; by default, 24/32. long_threshold : float, optional The threshold for long IBIs; by default, 44/32. extra_threshold : float, optional The threshold for extra long IBIs; by default, 52/32. Returns ------- beats_ix_corrected: array_like An array containing the indices of corrected beats. corrected_ibis : array_like An array containing the indices of corrected IBIs. original: pandas.DataFrame A DataFrame containing the original IBIs (millisecond-based and index-based) and beat indices. corrected: pandas.DataFrame A DataFrame containing the corrected IBIs (millisecond-based and index-based) and beat indices. References ---------- Hegarty-Craver, M. et al. (2018). Automated respiratory sinus arrhythmia measurement: Demonstration using executive function assessment. Behavioral Research Methods, 50, 1816–1823. ''' global MIN_BPM, MAX_BPM MIN_BPM = min_bpm MAX_BPM = max_bpm ibis = np.diff(beats_ix) # drop the first beat beats = beats_ix[1:] global cnt, corrected_ibis, corrected_beats, corrected_flags global prev_ibi, prev_beat, prev_flag, current_ibi, current_beat, \ current_flag, correction_flags # increment when correcting the ibi and decrement when accepting the ibi cnt = 0 # initialize prev_ibi = 0 prev_beat = 0 prev_flag = None current_ibi = 0 current_beat = 0 current_flag = None corrected_ibis = [] corrected_beats = [] corrected_flags = [] correction_flags = [0 for i in range(len(beats))] # Set the initial IBI to compare against global prev_ibis_fifo, first_ibi, correction_failed if initial_hr == 'auto': successive_diff = np.abs(np.diff(ibis)) min_diff_ix = np.convolve( successive_diff, np.ones(hr_estimate_window) / hr_estimate_window, mode = 'valid').argmin() first_ibi = ibis[min_diff_ix:min_diff_ix + hr_estimate_window].mean() if print_estimated_hr: print('Estimated average HR (bpm): ', np.floor(60 / (first_ibi / self.fs))) else: first_ibi = self.fs * 60 / initial_hr # FIFO for the previous n+1 IBIs prev_ibis_fifo = self._MaxNFifo(prev_n, first_ibi) # Store whether the correction failed for the last n IBIs correction_failed = self._MaxNFifo(prev_n - 1) def _estimate_ibi(prev_ibis: np.ndarray) -> int: ''' Estimate IBI based on the previous IBIs. Parameters ---------- prev_ibis: array_like A list of prev_n number of previous IBIs. Returns ------- estimated_ibi : int ''' return np.median(prev_ibis) def _return_flag( current_ibi: int, prev_ibis: Optional[np.ndarray] = None ) -> str: ''' Return whether the current IBI is correct, short, long, or extra long based on the previous IBIs. Correct: 26/32 - 44/32 of the estimated IBI Short: < 26/32 of the estimated IBI Long: > 44/32 and < 54/32 of the estimated IBI Extra Long: > 54/32 of the estimated IBI Parameters ---------- current_ibi: int current IBI value in the number of indices. prev_ibis: array_like, optional A list of prev_n number of previous IBIs. Returns ------- flag : str The flag of the current IBI: 'Correct', 'Short', 'Long', or 'Extra Long'. ''' # Calculate the estimated IBI estimated_ibi = _estimate_ibi(prev_ibis) # Set the acceptable/valid range of IBIs low = short_threshold * estimated_ibi high = long_threshold * estimated_ibi extra = extra_threshold * estimated_ibi # Flag the ibi: correct, short, long, or extra long if low <= current_ibi <= high: flag = 'Correct' elif current_ibi < low: flag = 'Short' elif current_ibi > high and current_ibi < extra: flag = 'Long' else: flag = 'Extra Long' return flag def _acceptance_check( corrected_ibi: int, prev_ibis: np.ndarray ) -> bool: ''' Check if the corrected IBI is acceptable (falls within 27/32 to 42/32 of the estimated IBI). Parameters ---------- corrected_ibi: int The corrected IBI value. prev_ibis: array_like A list of prev_n number of previous IBIs. Returns ------- bool True if the corrected IBI is within the acceptable range, False otherwise. ''' # Calculate the estimated IBI estimated_ibi = _estimate_ibi(prev_ibis) # Set the acceptable/valid range of IBIs low = short_threshold * estimated_ibi high = long_threshold * estimated_ibi # If the corrected value is within the range, return True if corrected_ibi >= low and corrected_ibi <= high: #if corrected_ibi <= high: return True else: return False def _accept_ibi(n: int, correction_failed_flag: int = 0) -> None: ''' Accept the current IBI without correction. Parameters ---------- n : int The index of the current IBI. correction_failed_flag : int, optional Flag to indicate whether the correction failed for the current IBI; by default, 0. If the flag is 1, the correction failed for the current IBI. ''' global prev_ibis_fifo, cnt, correction_failed global corrected_ibis, corrected_beats, corrected_flags global prev_ibi, prev_beat, prev_flag, current_ibi, current_beat, current_flag # Check if previous IBI is within limits before accepting current IBI _check_limits(n) # Fix the previous IBI corrected_ibis.append(prev_ibi) corrected_beats.append(prev_beat) corrected_flags.append(prev_flag) # Add the previous IBI to the queue prev_ibis_fifo.push(prev_ibi) # Update the previous IBI to the current IBI prev_ibi = current_ibi prev_beat = current_beat prev_flag = current_flag # Decrement the counter cnt = max(0, cnt-1) if DEBUGGING: print('accepted:', current_ibi, ' flag:', current_flag, ' based on ', prev_ibis_fifo.get_queue()[1:]) # If the correction failed for the current IBI, push 1 to the # correction_failed FIFO, otherwise push 0 if correction_failed_flag == 0: correction_failed.push(0) else: correction_failed.push(1) def _add_prev_and_current(n: int) -> None: ''' Add the previous and current IBIs if the sum is less than 42/32 of the estimated IBI. Parameters ---------- n : int The index of the current IBI. ''' global prev_ibis_fifo, cnt global corrected_ibis, corrected_beats, corrected_flags global prev_ibi, prev_beat, prev_flag, current_ibi, current_beat, \ current_flag, correction_flags # Add the previous and current IBIs corrected_ibi = prev_ibi + current_ibi # Check if the corrected IBI is acceptable if _acceptance_check(corrected_ibi, prev_ibis_fifo.get_queue()[1:]): # Update the current IBI to the corrected IBI current_ibi = corrected_ibi current_beat = current_beat current_flag = _return_flag(current_ibi, prev_ibis_fifo.get_queue()[1:]) if n == 1: # Update the previous IBI to the current IBI prev_ibi = current_ibi prev_beat = current_beat prev_flag = current_flag else: # Pull up the second previous IBI as previous IBI prev_ibi = corrected_ibis[-1] prev_beat = corrected_beats[-1] prev_flag = corrected_flags[-1] # Check if the previous IBI is within the limits before # accepting the current IBI _check_limits(n) # Check_limits function may update the previous IBI pulled, # so update the value corrected_ibis[-1] = prev_ibi corrected_beats[-1] = prev_beat corrected_flags[-1] = prev_flag # Update the last IBI value in the queue prev_ibis_fifo.change_last(prev_ibi) # Update the previous IBI to the current IBI prev_ibi = current_ibi prev_beat = current_beat prev_flag = current_flag # Flag that previous and current IBIs are corrected correction_flags[n-1] = 1 correction_flags[n] = 1 # Increment the counter cnt += 1 if DEBUGGING: print('added:', current_ibi, ' flag:', current_flag, ' based on ', prev_ibis_fifo.get_queue()[1:]) else: if DEBUGGING: print('acceptance check failed for adding: ', corrected_ibi) # If the corrected IBI is not acceptable, accept the current IBI _accept_ibi(n, correction_failed_flag = 1) def _add_secondprev_and_prev(n: int) -> None: ''' Add the second previous and previous IBIs if the sum is less than 42/32 of the estimated IBI. Parameters ---------- n : int The index of the current IBI. ''' global prev_ibis_fifo, cnt global corrected_ibis, corrected_beats, corrected_flags global prev_ibi, prev_beat, prev_flag, current_ibi, current_beat, \ current_flag, correction_flags # Add the previous and current IBIs corrected_ibi = corrected_ibis[-1] + prev_ibi # Check if the corrected IBI is acceptable # Use IBIs before the second previous IBI if _acceptance_check(corrected_ibi, prev_ibis_fifo.get_queue()[:-2]): # Update the current IBI to the corrected IBI # Pull up the second previous IBI as previous IBI prev_ibi = corrected_ibi prev_beat = prev_beat prev_flag = _return_flag(prev_ibi, prev_ibis_fifo.get_queue()[:-2]) # Check if the previous IBI is within the limits before accepting the current IBI _check_limits(n) # Update the value corrected_ibis[-1] = prev_ibi corrected_beats[-1] = prev_beat corrected_flags[-1] = prev_flag # Update the last IBI value in the queue prev_ibis_fifo.change_last(prev_ibi) # Update the previous IBI to the current IBI prev_ibi = current_ibi prev_beat = current_beat prev_flag = current_flag # Flag that previous and current IBIs are corrected correction_flags[n-2] = 1 correction_flags[n-1] = 1 # Increment the counter cnt += 1 if DEBUGGING: print('added second prev + prev:', prev_ibi, ' flag:', prev_flag, ' based on ', prev_ibis_fifo.get_queue()[:-2]) else: if DEBUGGING: print('acceptance check failed for adding second prev + prev: ', corrected_ibi) # If the corrected IBI is not acceptable, accept the current IBI _accept_ibi(n, correction_failed_flag = 1) def _insert_interval(n: int) -> None: ''' Split the (previous IBI + current IBI) into multiple intervals. The number of splits is determined based on the initial_hr parameter. Parameters ---------- n : int The index of the current IBI. ''' global prev_ibis_fifo, cnt, first_ibi global corrected_ibis, corrected_beats, corrected_flags global prev_ibi, prev_beat, prev_flag, current_ibi, current_beat, \ current_flag, correction_flags # Calculate the number of splits n_split = round((prev_ibi + current_ibi) / _estimate_ibi( prev_ibis_fifo.get_queue()[1:]), 0).astype(int) # Calculate the new IBI ibi = np.floor((prev_ibi + current_ibi) / n_split) # Check if the corrected IBI is acceptable if _acceptance_check(ibi, prev_ibis_fifo.get_queue()[1:]): # Fix inserted IBIs other than previous/current IBIs for i in range(n_split - 2): corrected_ibis.append(ibi) corrected_flags.append(_return_flag(ibi, prev_ibis_fifo.get_queue()[1:])) if (n == 1 and i == 0) | (len(corrected_beats) == 0): corrected_beats.append(beats_ix[0] + ibi) else: corrected_beats.append(corrected_beats[-1] + ibi) # Add to the queue prev_ibis_fifo.push(ibi) # Update the previous IBI prev_ibi = ibi if len(corrected_beats) > 0: prev_beat = corrected_beats[-1] + ibi else: prev_beat = beats_ix[0] + ibi prev_flag = _return_flag(ibi, prev_ibis_fifo.get_queue()[:-1]) # Update the current IBI current_ibi = current_beat - prev_beat current_flag = _return_flag(ibi, prev_ibis_fifo.get_queue()[1:]) # Check if the previous IBI is within the limits _check_limits(n) # Fix the previous IBI corrected_ibis.append(prev_ibi) corrected_beats.append(prev_beat) corrected_flags.append(prev_flag) # Add to the queue prev_ibis_fifo.push(prev_ibi) # Update the previous IBI to the current IBI prev_ibi = current_ibi prev_beat = current_beat prev_flag = current_flag # Flag that previous and current IBIs are corrected correction_flags[n-1] = 1 correction_flags[n] = 1 # Increment the counter by n_split - 1 in this case cnt += n_split - 1 if DEBUGGING: print('inserted ', n_split - 2, ' intervals: ', ibi, ' flag:', current_flag, ' based on ', prev_ibis_fifo.get_queue()[1:]) else: if DEBUGGING: print('acceptance check failed for inserting: ', ibi) # If the corrected IBI is not acceptable, accept the current IBI _accept_ibi(n, correction_failed_flag = 1) def _average_prev_and_current(n: int) -> None: ''' Average the previous and current IBIs. Parameters ---------- n : int The index of the current IBI. ''' global prev_ibis_fifo, cnt global corrected_ibis, corrected_beats, corrected_flags global prev_ibi, prev_beat, prev_flag, current_ibi, current_beat, current_flag, correction_flags # Average the previous and current IBIs ibi = np.floor((prev_ibi + current_ibi) / 2) # Check if the corrected IBI is acceptable if _acceptance_check(ibi, prev_ibis_fifo.get_queue()[1:]): # Update the previous and current IBI prev_ibi = ibi if n == 1: prev_beat = beats_ix[0] + ibi else: prev_beat = corrected_beats[-1] + ibi prev_flag = _return_flag(ibi, prev_ibis_fifo.get_queue()[:-1]) current_ibi = current_beat - prev_beat current_flag = _return_flag(ibi, prev_ibis_fifo.get_queue()[1:]) # Check if the previous IBI is within the limits _check_limits(n) # Fix the previous IBI corrected_ibis.append(prev_ibi) corrected_beats.append(prev_beat) corrected_flags.append(prev_flag) # Add to the queue prev_ibis_fifo.push(prev_ibi) # Update the previous IBI to the current IBI prev_ibi = current_ibi prev_beat = current_beat prev_flag = current_flag # Flag that previous and current IBIs are corrected correction_flags[n-1] = 1 correction_flags[n] = 1 # Increment the counter cnt += 1 if DEBUGGING: print('averaged:', ibi, ' flag:', current_flag, ' based on ', prev_ibis_fifo.get_queue()[1:]) else: if DEBUGGING: print('acceptance check failed for averaging: ', ibi) _accept_ibi(n, correction_failed_flag=1) def _check_limits(n): ''' Check if the previous IBI (n-1) is within the limits. If it is longer the maximum IBI, shorten the previous IBI and lengthen the current IBI. If it is shorter than the minimum IBI, lengthen the previous IBI and shorten the current IBI. Parameters --------------------- n : int The index of the current IBI. ''' global prev_ibis_fifo, cnt global corrected_ibis, corrected_beats, corrected_flags global prev_ibi, prev_beat, prev_flag, current_ibi, current_beat, \ current_flag, correction_flags MIN_IBI = np.floor(self.fs * 60 / MAX_BPM) # minimum IBI in indices MAX_IBI = np.floor(self.fs * 60 / MIN_BPM) # maximum IBI in indices # If the previous IBI is shorter than the minimum IBI, lengthen the previous IBI and shorten the current IBI if prev_ibi < MIN_IBI: remainder = MIN_IBI - prev_ibi prev_beat = prev_beat + remainder prev_ibi = MIN_IBI prev_flag = _return_flag(prev_ibi, prev_ibis_fifo.get_queue()[:-1]) current_ibi = current_ibi - remainder current_flag = _return_flag(current_ibi, prev_ibis_fifo.get_queue()[1:]) # Flag that previous and current IBIs are corrected correction_flags[n-1] = 1 correction_flags[n] = 1 # Increment the counter cnt += 1 if DEBUGGING: print('Shorter than the minimum IBI and corrected: ', prev_ibi, ' ', prev_flag, ' | ', current_ibi, ' ', current_flag) # If the previous IBI is longer than the maximum IBI, shorten the previous IBI and lengthen the current IBI elif prev_ibi > MAX_IBI: remainder = prev_ibi - MAX_IBI prev_beat = prev_beat - remainder prev_ibi = MAX_IBI prev_flag = _return_flag(prev_ibi, prev_ibis_fifo.get_queue()[:-1]) current_ibi = current_ibi + remainder current_flag = _return_flag(current_ibi, prev_ibis_fifo.get_queue()[1:]) # Flag that previous and current IBIs are corrected correction_flags[n-1] = 1 correction_flags[n] = 1 # Increment the counter cnt += 1 if DEBUGGING: print('Longer than the maximum IBI and corrected: ', prev_ibi, ' ', prev_flag, ' | ', current_ibi, ' ', current_flag) return for n in range(len(ibis)): current_ibi = ibis[n] current_beat = beats[n] # Accept the first ibi if n == 0: current_flag = _return_flag(current_ibi, prev_ibis = prev_ibis_fifo.get_queue()) # Update the previous IBI to the current IBI prev_ibi = current_ibi prev_beat = current_beat prev_flag = current_flag else: current_flag = _return_flag(current_ibi, prev_ibis = prev_ibis_fifo.get_queue()[:-1]) if DEBUGGING: print('n:', n) print('prev:', prev_ibi, ' ', prev_flag, ' | current:', current_ibi, ' ', current_flag) # If current IBI is correct if current_flag == 'Correct': # If previous IBI is correct/long, accept current if prev_flag == 'Correct' or prev_flag == 'Long': _accept_ibi(n) elif prev_flag == 'Short': if n == 1: _add_prev_and_current(n) else: # If previous IBI is shorter than current IBI, add them together if corrected_ibis[-1] > current_ibi: _add_prev_and_current(n) else: _add_secondprev_and_prev(n) # If previous IBI is extra long, split previous and current elif prev_flag == 'Extra Long': _insert_interval(n) # If current IBI is short elif current_flag == 'Short': # If previous IBI is correct, accept it if prev_flag == 'Correct': _accept_ibi(n) # If previous IBI is short, add previous + current elif prev_flag == 'Short': _add_prev_and_current(n) # If previous IBI is long/extra long, average previous and current elif prev_flag == 'Long' or prev_flag == 'Extra Long': _average_prev_and_current(n) # If the current IBI is long elif current_flag == 'Long': # If previous IBI is correct or long, accept it if prev_flag == 'Correct' or prev_flag == 'Long': _accept_ibi(n) # If previous IBI is short, average previous and current elif prev_flag == 'Short': _average_prev_and_current(n) # If previous IBI is extra long, split previous and current elif prev_flag == 'Extra Long': _insert_interval(n) # If current IBI is extra long elif current_flag == 'Extra Long': # If previous IBI is correct, long, or extra long, split previous and current if prev_flag == 'Correct' or prev_flag == 'Long' or prev_flag == 'Extra Long': _insert_interval(n) # If previous IBI is short, average previous and current elif prev_flag == 'Short': _average_prev_and_current(n) # If more than 3 corrections are made in the last prev_n IBIs, reset the FIFO if sum(correction_failed.get_queue()) >= 3: prev_ibis_fifo.reset(first_ibi) # Add the last beat corrected_ibis.append(current_ibi) corrected_beats.append(current_beat) corrected_flags.append(current_flag) correction_flags = np.array(correction_flags).astype(int) # Convert the IBIs to milliseconds original_ibis_ms = np.round((np.array(ibis) / self.fs) * 1000, 2) original = pd.DataFrame({ 'Original IBI (ms)': np.insert(original_ibis_ms, 0, np.nan), 'Original IBI (index)': np.insert(ibis.astype(object), 0, np.nan), 'Original Beat': np.insert(beats, 0, beats_ix[0]), 'Correction': np.insert(correction_flags, 0, 0) }) corrected_ibis_ms = np.round((np.array(corrected_ibis) / self.fs) * 1000, 2) corrected_ibis = np.array(corrected_ibis).astype(object) corrected_flags = np.array(corrected_flags).astype(object) # Add the first beat and create a dataframe corrected = pd.DataFrame({ 'Corrected IBI (ms)': np.insert(corrected_ibis_ms, 0, np.nan), 'Corrected IBI (index)': np.insert(corrected_ibis, 0, np.nan), 'Corrected Beat': np.insert(corrected_beats, 0, beats_ix[0]), 'Flag': np.insert(corrected_flags, 0, np.nan) }) beats_ix_corrected = np.insert(corrected_beats, 0, beats_ix[0]).astype(int) return beats_ix_corrected, corrected_ibis, original, corrected
[docs] def get_corrected( self, beats_ix: np.ndarray, seg_size: int = 60, initial_hr: Union[float, Literal['auto']] = 'auto', prev_n: int = 6, min_bpm: int = 40, max_bpm: int = 200, hr_estimate_window: int = 6, print_estimated_hr: bool = True, short_threshold: float = (24 / 32), long_threshold: float = (44 / 32), extra_threshold: float = (52 / 32) ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Get the corrected interbeat intervals (IBIs) and beat indices. Parameters ---------- data : pandas.DataFrame A DataFrame containing the pre-processed ECG or PPG data. beats_ix : array_like An array containing the indices of detected beats. seg_size : int The size of the segment in seconds; by default, 60. initial_hr : float or {'auto'}, optional The heart rate value for the first interbeat interval (IBI) to be validated against; by default, 'auto' (i.e., the value is determined automatically). prev_n : int, optional The number of preceding IBIs to validate against; by default, 6. min_bpm : int, optional The minimum possible heart rate in beats per minute (bpm); by default, 40. max_bpm : int, optional The maximum possible heart rate in beats per minute (bpm); by default, 200. hr_estimate_window : int, optional The window size for estimating the heart rate; by default, 6. print_estimated_hr : bool, optional Whether to print the estimated heart rate; by default, True. short_threshold : float, optional The threshold for short IBIs; by default, 24/32. long_threshold : float, optional The threshold for long IBIs; by default, 44/32. extra_threshold : float, optional The threshold for extra long IBIs; by default, 52/32. Returns ------- original : pandas.DataFrame A data frame containing the original IBIs and beat indices. corrected : pandas.DataFrame A data frame containing the corrected IBIs and beat indices. combined : pandas.DataFrame A data frame containing the summary of flags in each segment. """ # Get the corrected IBIs and beat indices _, original, corrected = self.correct_interval( beats_ix = beats_ix, initial_hr = initial_hr, prev_n = prev_n, min_bpm = min_bpm, max_bpm = max_bpm, hr_estimate_window = hr_estimate_window, print_estimated_hr = print_estimated_hr, short_threshold = short_threshold, long_threshold = long_threshold, extra_threshold = extra_threshold) # Get the segment number for each beat for row in original.iterrows(): seg = ceil(row[1].loc['Original Beat'] / (seg_size * self.fs)) original.loc[row[0], 'Segment'] = seg for row in corrected.iterrows(): seg = ceil(row[1].loc['Corrected Beat'] / (seg_size * self.fs)) corrected.loc[row[0], 'Segment'] = seg original['Segment'] = original['Segment'].astype(pd.Int64Dtype()) corrected['Segment'] = corrected['Segment'].astype(pd.Int64Dtype()) # Get the number and percentage of corrected beats in each segment original_seg = original.groupby('Segment')['Correction'].sum().astype(pd.Int64Dtype()) original_seg = pd.DataFrame(original_seg.reset_index(name = '# Corrected')) original_seg_nbeats = original.groupby('Segment')['Correction'].count().astype(pd.Int64Dtype()) original_seg_nbeats = pd.DataFrame(original_seg_nbeats.reset_index(name = '# Beats')) original_seg = original_seg.merge(original_seg_nbeats, on = 'Segment') original_seg['% Corrected'] = round((original_seg['# Corrected'] / original_seg['# Beats']) * 100, 2) original_seg.drop('# Beats', axis = 1, inplace = True) # Get the number of each flag (Correct/Short/Long/Extra Long) in each segment corrected_seg = corrected.groupby('Segment')['Flag'].value_counts().astype(pd.Int64Dtype()) corrected_seg = pd.DataFrame(corrected_seg.reset_index(name = 'Count')) corrected_seg = corrected_seg.pivot(index = 'Segment', columns = 'Flag', values = 'Count').reset_index().fillna(0) corrected_seg.columns.name = None corrected_seg = corrected_seg.rename_axis(None, axis = 1) combined = pd.merge(corrected_seg, original_seg, on='Segment') return original, corrected, combined
[docs] def plot_missing( self, sqa_metrics: pd.DataFrame, invalid_thresh = 30, title = None ) -> go.Figure: """ Plot detected and missing beat counts. Parameters ---------- sqa_metrics : pandas.DataFrame The DataFrame containing SQA metrics per segment. invalid_thresh : int, float The minimum number of beats detected for a segment to be considered valid; by default, 30. title : str, optional The title of the plot. Returns ------- fig : plotly.graph_objects.Figure A Plotly bar chart of detected and missing beat counts. """ max_beats = ceil(sqa_metrics['N Detected'].max() / 10) * 10 nearest = ceil(max_beats / 2) * 2 dtick_value = nearest / 5 fig = go.Figure( data = [ go.Bar( x = sqa_metrics['Segment'], y = sqa_metrics['N Expected'], name = 'Missing', marker = dict(color = '#f2816d'), hovertemplate = '<b>Segment %{x}:</b> %{customdata:.0f} ' 'missing<extra></extra>'), go.Bar( x = sqa_metrics['Segment'], y = sqa_metrics['N Detected'], name = 'Detected', marker = dict(color = '#313c42'), hovertemplate = '<b>Segment %{x}:</b> %{y:.0f} ' 'detected<extra></extra>') ] ) fig.data[0].update(customdata = sqa_metrics['N Missing']) # Get invalid segment data points invalid_x = [] invalid_y = [] invalid_text = [] for segment_num, n_detected in zip( sqa_metrics['Segment'], sqa_metrics['N Detected']): if n_detected < invalid_thresh: invalid_x.append(segment_num) invalid_y.append( n_detected + 3) invalid_text.append('<b>!</b>') # Add scatter trace for invalid markers if invalid_x: fig.add_trace(go.Scatter( x = invalid_x, y = invalid_y, mode = 'text', text = invalid_text, textposition = 'top center', textfont = dict(size = 20, color = '#db0f0f'), showlegend = False, hoverinfo = 'skip' # disable tooltips )) if invalid_x: fig.add_annotation( text = '<span style="color: #db0f0f"><b>!</b></span> ' 'Invalid Number of Beats ', align = 'right', showarrow = False, xref = 'paper', yref = 'paper', x = 1, y = 1.3) fig.update_layout( xaxis_title = 'Segment Number', xaxis = dict( tickmode = 'linear', dtick = 1, range = [sqa_metrics['Segment'].min() - 0.5, sqa_metrics['Segment'].max() + 0.5]), yaxis = dict( title = 'Number of Beats', range = [0, max_beats], dtick = dtick_value), legend = dict( orientation = 'h', yanchor = 'bottom', y = 1.0, xanchor = 'right', x = 1.0), font = dict(family = 'Poppins', size = 13), height = 289, margin = dict(t = 70, r = 20, l = 40, b = 65), barmode = 'overlay', template = 'simple_white', ) if title is not None: fig.update_layout( title = title ) return fig
[docs] def plot_artifact( self, sqa_metrics: pd.DataFrame, invalid_thresh: int = 30, title: Optional[str] = None ) -> go.Figure: """ Plot detected and artifact beat counts. Parameters ---------- sqa_metrics : pandas.DataFrame The DataFrame containing SQA metrics per segment. invalid_thresh : int, float The minimum number of beats detected for a segment to be considered valid; by default, 30. title : str, optional The title of the plot. Returns ------- fig : plotly.graph_objects.Figure A Plotly bar chart of detected and missing beat counts. """ max_beats = ceil(sqa_metrics['N Detected'].max() / 10) * 10 nearest = ceil(max_beats / 2) * 2 dtick_value = nearest / 5 fig = go.Figure( data = [ go.Bar( x = sqa_metrics['Segment'], y = sqa_metrics['N Detected'], name = 'Detected', marker = dict(color = '#313c42'), hovertemplate = '<b>Segment %{x}:</b> %{y:.0f} ' 'detected<extra></extra>'), go.Bar( x = sqa_metrics['Segment'], y = sqa_metrics['N Artifact'], name = 'Artifact', marker = dict(color = '#f2b463'), hovertemplate = '<b>Segment %{x}:</b> %{y:.0f} ' 'artifact<extra></extra>') ], ) # Get invalid segment data points invalid_x = [] invalid_y = [] invalid_text = [] for segment_num, n_detected in zip( sqa_metrics['Segment'], sqa_metrics['N Detected']): if n_detected < invalid_thresh: invalid_x.append(segment_num) invalid_y.append( n_detected + 3) invalid_text.append('<b>!</b>') # Add scatter trace for invalid markers if invalid_x: fig.add_trace(go.Scatter( x = invalid_x, y = invalid_y, mode = 'text', text = invalid_text, textposition = 'top center', textfont = dict(size = 20, color = '#db0f0f'), showlegend = False, hoverinfo = 'skip' # disable tooltips )) if invalid_x: fig.add_annotation( text = '<span style="color: #db0f0f"><b>!</b></span> ' 'Invalid Number of Beats ', align = 'right', showarrow = False, xref = 'paper', yref = 'paper', x = 1, y = 1.3) fig.update_layout( xaxis_title = 'Segment Number', xaxis = dict( tickmode = 'linear', dtick = 1, range = [sqa_metrics['Segment'].min() - 0.5, sqa_metrics['Segment'].max() + 0.5]), yaxis = dict( title = 'Number of Beats', range = [0, max_beats], dtick = dtick_value), legend = dict( orientation = 'h', yanchor = 'bottom', y = 1.0, xanchor = 'right', x = 1.0, traceorder = 'reversed'), font = dict(family = 'Poppins', size = 13), height = 289, margin = dict(t = 70, r = 20, l = 40, b = 65), barmode = 'overlay', template = 'simple_white', ) if title is not None: fig.update_layout( title = title ) return fig
def _get_iqr(self, data: np.ndarray) -> float: """Compute the interquartile range of a data array.""" q75, q25 = np.percentile(data, [75, 25]) iqr = q75 - q25 return iqr def _quartile_deviation(self, data: np.ndarray) -> float: """Compute the quartile deviation in the criterion beat difference test.""" iqr = self._get_iqr(data) QD = iqr * 0.5 return QD def _window_medians(self, segment: pd.DataFrame, win_size: int = 5) -> list: """Calculate median HRs from artifact-free windows in a segment slice for get_missing().""" median_hrs = [] beats = segment.dropna(subset = ['Beat']) n = len(beats) for i in range(n - win_size + 1): window = beats.iloc[i:i + win_size] if window.Artifact.any(): continue ibi_vals = window.IBI.values med_hr = np.nanmedian(60000 / ibi_vals) median_hrs.append(med_hr) return median_hrs class _MaxNFifo: """ A class for FIFO with N elements at maximum. Parameters/Attributes --------------------- prev_n : int The maximum number of elements in the FIFO. item : int, optional The initial item to add to the FIFO; by default, None. The item is added twice if it is not None. """ def __init__(self, prev_n: int, item: Optional[int] = None): """ Initialize the FIFO object. Parameters ---------- prev_n : int The maximum number of elements in the FIFO. item : int, optional The initial item to add to the FIFO; by default, None. The item is added twice if it is not None. """ self.prev_n = prev_n if item is not None: self.queue = [item, item] else: self.queue = [] def push(self, item: int) -> None: """ Push an item to the FIFO. If the number of elements exceeds the maximum, remove the first element. Parameters ---------- item : int The item to add to the FIFO. """ self.queue.append(item) if len(self.queue) > self.prev_n + 1: self.queue.pop(0) def get_queue(self) -> list: """ Return the FIFO queue. Return ------ queue : list """ return self.queue def change_last(self, item: int) -> None: """ Change the last item in the FIFO queue. Parameters ---------- item : int The new item to replace the last item in the queue. """ self.queue[-1] = item def reset(self, item: Optional[int] = None) -> None: """ Reset the FIFO queue. If an item is given, reset the queue with the item. If not, reset the queue with an empty list. Parameters ---------- item: int, optional The item to add to the FIFO; by default, None. The item is added twice if it is not None. """ if item is None: self.queue = [] else: self.queue = [item, item]
# =================================== EDA ====================================
[docs] class EDA: """ A class for signal quality assessment on electrodermal activity (EDA) data. Parameters/Attributes --------------------- fs : int The sampling rate of the EDA data. eda_min : float, optional The minimum acceptable value for EDA data in microsiemens; by default, 0.05 uS. eda_max : float, optional The maximum acceptable value for EDA data in microsiemens; by default, 60 uS. eda_max_slope : float, optional The maximum slope of EDA data in microsiemens per second; by default, 5 uS/sec. temp_min : float, optional The minimum acceptable temperature in degrees Celsius; by default, 20. temp_max : float, optional The maximum acceptable temperature in degrees Celsius; by default, 40. invalid_spread_dur : float, optional The transition radius for artifacts in seconds; by default, 2. """ def __init__( self, fs: int, eda_min: float = 0.2, eda_max: float = 40, eda_max_slope: float = 5, temp_min: float = 20, temp_max: float = 40, invalid_spread_dur: float = 2.5 ): """ Initialize the EDA object. Parameters ---------- fs : int The sampling rate of the EDA recording. eda_min : float, optional The minimum acceptable value for EDA data in microsiemens; by default, 0.2 uS. eda_max : float, optional The maximum acceptable value for EDA data in microsiemens; by default, 40 uS. eda_max_slope : float, optional The maximum slope of EDA data in microsiemens per second; by default, 5 uS/sec. temp_min : float, optional The minimum acceptable temperature in degrees Celsius; by default, 20. temp_max : float, optional The maximum acceptable temperature in degrees Celsius; by default, 40. invalid_spread_dur : float, optional The transition radius for artifacts in seconds; by default, 2.5 seconds. """ # Check inputs if eda_min >= eda_max: raise ValueError('`eda_min` must be smaller than `eda_max`.') if temp_min >= temp_max: raise ValueError('`temp_min` must be smaller than `temp_max`.') self.fs = fs self.eda_min = eda_min self.eda_max = eda_max self.eda_max_slope = eda_max_slope self.temp_min = temp_min self.temp_max = temp_max self.invalid_spread_dur = invalid_spread_dur
[docs] def get_validity_metrics( self, signal: np.ndarray, temp: Optional[np.ndarray] = None, timestamps: Optional[np.ndarray] = None, preprocessed: bool = True, ) -> pd.DataFrame: """ Assess and flag valid and invalid EDA data points. Parameters ---------- signal : array_like An array containing the EDA signal in microsiemens. temp : array_like, optional An array containing temperature data in Celsius. timestamps : array_like, optional An array of timestamps corresponding to each data point. preprocessed : bool, optional Whether filtered EDA data is being inputted; by default, True. If False, an FIR low-pass filter is applied. Returns ------- eda_validity : pd.DataFrame A DataFrame with the columns: - 'Timestamp' (if provided) - 'EDA' - 'Temp' (if provided) - 'Valid' (1 if valid, NaN otherwise) - 'Invalid' (1 if invalid, NaN otherwise) """ valid_ix, invalid_ix, _ = self._edaqa(signal, temp, preprocessed) eda_validity = pd.DataFrame({ 'Timestamp': timestamps if timestamps is not None else np.arange( len(signal)), 'EDA': signal, }) if temp is not None: eda_validity['TEMP'] = temp eda_validity.loc[valid_ix, 'Valid'] = 1 eda_validity.loc[invalid_ix, 'Invalid'] = 1 return eda_validity
[docs] def get_quality_metrics( self, signal: np.ndarray, temp: Optional[np.ndarray] = None, timestamps: Optional[np.ndarray] = None, ) -> pd.DataFrame: """ Assess and flag rule violations of EDA quality based on the quality assessment procdure by Kleckner et al. (2017). Parameters ---------- signal : array_like The EDA signal in microsiemens. temp : array_like, optional Temperature data in Celsius. timestamps : array_like, optional Array of timestamps corresponding to each data point. Returns ------- eda_quality : pd.DataFrame A DataFrame with the columns: - 'Timestamp' - 'EDA' - 'Temp' (if provided) - 'Out of Range' - 'Excessive Slope' - 'Temp Out of Range' (if provided) References ---------- Kleckner, I.R., Jones, R. M., Wilder-Smith, O., Wormwood, J.B., Akcakaya, M., Quigley, K.S., ... & Goodwin, M.S. (2017). Simple, transparent, and flexible automated quality assessment procedures for ambulatory electrodermal activity data. IEEE Transactions on Biomedical Engineering, 65(7), 1460-1467. """ sampling_interval = 1 / self.fs # Rule-specific masks mask_out_of_range = self._check_out_of_range(signal) mask_excessive_slope = self._check_excessive_slope( signal, sampling_interval) mask_temp = self._check_temp_out_of_range( temp) if temp is not None else None # Combine all available checks combined_invalid = mask_out_of_range | mask_excessive_slope if mask_temp is not None: combined_invalid |= mask_temp eda_quality = pd.DataFrame({ 'EDA': signal, 'Out of Range': np.where(mask_out_of_range, 1, np.nan), 'Excessive Slope': np.where(mask_excessive_slope, 1, np.nan), }) if temp is not None: eda_quality['TEMP'] = temp if timestamps is not None: eda_quality.insert(0, 'Timestamp', timestamps) else: eda_quality.insert(0, 'Sample', np.arange(len(signal)) + 1) if mask_temp is not None: eda_quality['Temp Out of Range'] = np.where(mask_temp, 1, np.nan) return eda_quality
[docs] def compute_metrics( self, signal: np.ndarray, temp: Optional[np.ndarray] = None, preprocessed: bool = True, peaks_ix: Optional[np.ndarray] = None, seg_size: int = 60, rolling_window: Optional[int] = None, rolling_step: int = 15, show_progress: bool = True, ) -> pd.DataFrame: """ Assess the quality of electrodermal activity (EDA) data using the rules defined by Kleckner et al. (2017). The method identifies valid and invalid data points and computes rule-specific quality metrics (e.g., proportions of out-of-range points, excessive slopes, temperature violations, and spread-invalid counts), either by segment or across sliding windows. Parameters ---------- signal : array_like An array containing the EDA signal in microsiemens. temp : array_like, optional An optional array containing temperature data in Celsius; by default, None. preprocessed : boolean, optional Whether filtered EDA data is being inputted; by default, True. peaks_ix : array_like, optional An optional array containing locations of SCR peaks; by default, None. If provided, an 'N SCRs' metric is included in the output. seg_size : int The segment size in seconds; by default, 60. rolling_window : int, optional The size, in seconds, of the sliding window across which to compute the EDA SQA metrics; by default, None. rolling_step : int, optional The step size, in seconds, of the sliding windows; by default, 15. show_progress : bool, optional Whether to show a progress bar; by default, True. Returns ------- metrics : pd.DataFrame A DataFrame containing EDA quality assessment metrics by segment or sliding window. References ---------- Kleckner, I.R., Jones, R. M., Wilder-Smith, O., Wormwood, J.B., Akcakaya, M., Quigley, K.S., ... & Goodwin, M.S. (2017). Simple, transparent, and flexible automated quality assessment procedures for ambulatory electrodermal activity data. IEEE Transactions on Biomedical Engineering, 65(7), 1460-1467. """ fs = self.fs seg_name = 'Moving Window' if rolling_window else 'Segment' metrics = [] has_scr = peaks_ix is not None if has_scr: peaks_ix = np.asarray(peaks_ix, dtype = int) # Rolling window approach if rolling_window is not None: step = int(rolling_step * fs) win_len = int(rolling_window * fs) for i, start in enumerate( tqdm(range(0, len(signal) - win_len + 1, step), desc = 'EDA QA', disable = not show_progress)): end = start + win_len segment = signal[start:end] seg_temp = temp[start:end] if temp is not None else None # Run EDA QA by sliding window valid_ix, invalid_ix, seg_metrics = self._edaqa( segment, seg_temp, preprocessed) total_len = len(segment) row = { seg_name: i + 1, 'N Valid': len(valid_ix), '% Valid': round((len(valid_ix) / total_len) * 100, 2), 'N Invalid': len(invalid_ix), '% Invalid': round((len(invalid_ix) / total_len) * 100, 2), **seg_metrics } if has_scr: n_scr = np.count_nonzero((peaks_ix >= start) & (peaks_ix < end)) row['N SCRs'] = int(n_scr) metrics.append(row) # Segmented approach else: seg_len = int(seg_size * fs) n_segments = len(signal) // seg_len for i in range(n_segments): start, end = i * seg_len, (i + 1) * seg_len segment = signal[start:end] seg_temp = temp[start:end] if temp is not None else None # Run EDA QA by segment valid_ix, invalid_ix, seg_metrics = self._edaqa( segment, seg_temp, preprocessed) total_len = len(segment) row = { seg_name: i + 1, 'N Valid': len(valid_ix), '% Valid': round((len(valid_ix) / total_len) * 100, 2), 'N Invalid': len(invalid_ix), '% Invalid': round((len(invalid_ix) / total_len) * 100, 2), **seg_metrics } if has_scr: n_scr = np.count_nonzero((peaks_ix >= start) & (peaks_ix < end)) row['N SCRs'] = int(n_scr) metrics.append(row) metrics = pd.DataFrame(metrics) return metrics
def _edaqa( self, signal, temp: Optional[np.ndarray] = None, preprocessed: bool = True ) -> tuple[np.ndarray, np.ndarray, dict]: """Evaluate the input signal against Kleckner et al.'s (2017) quality rules.""" # Filter EDA signal with a FIR low pass filter if not preprocessed: from EDA import Filters as eda_filters try: signal = eda_filters.lowpass_fir(signal) except ValueError: pass # Filter temperature data with a moving average filter if temp is not None: window = int(2 * self.fs) b = np.ones(window) / window temp = np.convolve(temp, b, mode = 'same') sampling_interval = 1 / self.fs total_len = len(signal) # Rule 1 out_of_range_mask = self._check_out_of_range(signal) # Rule 2 excessive_slope_mask = self._check_excessive_slope( signal, sampling_interval) # Rule 3 temp_out_of_range_mask = None if temp is not None: if len(signal) != len(temp): temp = self._equalize_temp(signal, temp) temp_out_of_range_mask = self._check_temp_out_of_range(temp) # Combine rule masks if temp_out_of_range_mask is not None: invalid_mask = (out_of_range_mask | excessive_slope_mask | temp_out_of_range_mask) else: invalid_mask = out_of_range_mask | excessive_slope_mask # Rule 4 invalid_data = self._set_neighbors_invalid( invalid_mask, sampling_interval) # Get indices of valid and invalid data points valid_ix = np.where(~invalid_data)[0] invalid_ix = np.where(invalid_data)[0] # Compute metrics quality_metrics = { 'Out of Range': np.sum(out_of_range_mask), '% Out of Range': round((np.sum(out_of_range_mask) / total_len) * 100, 2), 'Excessive Slope': np.sum(excessive_slope_mask), '% Excessive Slope': round((np.sum(excessive_slope_mask) / total_len) * 100, 2), 'Temp Out of Range': (np.sum(temp_out_of_range_mask) if temp_out_of_range_mask is not None else np.nan), '% Temp Out of Range': (round((np.sum(temp_out_of_range_mask) / total_len) * 100, 2) if temp_out_of_range_mask is not None else np.nan), } return valid_ix, invalid_ix, quality_metrics def _check_out_of_range( self, signal: np.ndarray ) -> np.ndarray: """Return a boolean mask where EDA values are below eda_min or above eda_max (Rule 1).""" return (signal < self.eda_min) | (signal > self.eda_max) def _check_excessive_slope( self, signal: np.ndarray, sampling_interval: float ) -> np.ndarray: """Return a boolean mask where the slope exceeds eda_max_slope (Rule 2).""" slopes = np.concatenate([[0], np.diff(signal) / sampling_interval]) return np.abs(slopes) > self.eda_max_slope def _check_temp_out_of_range( self, temp: Optional[np.ndarray] = None ) -> Union[None, np.ndarray]: """Return a boolean mask where temperature values are below temp_min or above temp_max (Rule 3).""" if temp is None: return None return (temp < self.temp_min) | (temp > self.temp_max) def _set_neighbors_invalid( self, invalid_mask: np.ndarray, sampling_interval: float ) -> np.ndarray: """Spread invalid labels ± invalid_spread_dur seconds around detected invalid points (Rule 4).""" invalid_spread_length = int( self.invalid_spread_dur / sampling_interval) spread = np.zeros_like(invalid_mask, dtype = bool) for d, flag in enumerate(invalid_mask): if flag: start_idx = max(d - invalid_spread_length, 0) end_idx = min(d + invalid_spread_length + 1, len(invalid_mask)) spread[start_idx:end_idx] = True return spread def plot_validity( self, metrics: pd.DataFrame, title: Optional[str] = None, ) -> go.Figure: fig = go.Figure( data = [ go.Bar( x = metrics['Segment'], y = metrics['% Invalid'], name = 'Invalid', marker = dict(color = 'tomato'), hovertemplate = '<b>Segment %{x}:</b> %{y}% ' 'invalid<extra></extra>'), go.Bar( x = metrics['Segment'], y = metrics['% Valid'], name = 'Valid', marker = dict( color = 'white', pattern = dict( shape = '/', fgcolor = '#4aba74', size = 5, solidity = 0.2 ) ), hovertemplate = '<b>Segment %{x}:</b> %{y}% ' 'valid<extra></extra>') ] ) # If N SCRs exist, add markers above bars if 'N SCRs' in metrics.columns: y_top = metrics['% Invalid'] + metrics['% Valid'] mask = metrics['N SCRs'] > 0 fig.add_trace( go.Scatter( x = metrics.loc[mask, 'Segment'], y = (y_top + 3).loc[mask], mode = 'text+markers', text = ['✦'] * mask.sum(), # star only where SCRs exist textposition = 'middle center', textfont = dict(color = '#f9c669'), marker = dict(size = 1, color = '#f9c669', symbol = 'circle'), showlegend = False, hovertemplate = 'SCR(s) detected<extra></extra>', ) ) fig.update_layout( barmode = 'stack', font = dict(family = 'Poppins', color = 'black'), xaxis = dict( title = dict( text = 'Segment', font = dict(size = 16), standoff = 5), tickfont = dict(size = 14) ), yaxis = dict( title = dict( text = 'Proportion', font = dict(size = 16), standoff = 2), tickfont = dict(size = 14) ), legend = dict(font = dict(size = 14), orientation = 'h', yanchor = 'bottom', y = 1.05, xanchor = 'right', x = 1.0), template = 'simple_white', margin = dict(l = 30, r = 15, t = 60, b = 50) ) if title is not None: fig.update_layout( title = title ) return fig def plot_quality_metrics( self, metrics: pd.DataFrame, title: Optional[str] = None, ) -> go.Figure: traces = [ go.Bar( x = metrics['Segment'], y = metrics['% Out of Range'], name = 'EDA Out of Range', marker = dict(color = '#7cabcc'), hovertemplate = '<b>Segment %{x}:</b> %{y}% ' 'EDA out of range<extra></extra>'), go.Bar( x = metrics['Segment'], y = metrics['% Excessive Slope'], name = 'Excessive Slope', marker = dict(color = '#ed77aa'), hovertemplate = '<b>Segment %{x}:</b> %{y}% ' 'excessive slope<extra></extra>'), ] if '% Temp Out of Range' in metrics.columns: traces.append( go.Bar( x = metrics['Segment'], y = metrics['% Temp Out of Range'], name = 'Temp Out of Range', marker = dict(color = '#b095c2'), hovertemplate = '<b>Segment %{x}:</b> %{y}% temp out of range<extra></extra>' ) ) # Append '% Valid' trace traces.append( go.Bar( x = metrics['Segment'], y = metrics['% Valid'], name = 'Valid', marker = dict( color = 'rgba(0,0,0,0)', pattern = dict( shape = '/', fgcolor = '#4aba74', size = 5, solidity = 0.2)), hovertemplate = '<b>Segment %{x}:</b> %{y}% valid<extra></extra>' ) ) fig = go.Figure(data = traces) fig.update_layout( barmode = 'stack', font = dict(family = 'Poppins', color = 'black'), xaxis = dict( title = dict( text = 'Segment', font = dict(size = 16), standoff = 5), tickfont = dict(size = 14) ), yaxis = dict( title = dict( text = 'Proportion', font = dict(size = 16), standoff = 2), tickfont = dict(size = 14) ), legend = dict(font = dict(size = 14), orientation = 'h', yanchor = 'bottom', y = 1.05, xanchor = 'right', x = 1.0), template = 'simple_white', margin = dict(l = 30, r = 15, t = 60, b = 50) ) if title is not None: fig.update_layout( title = title ) return fig def _equalize_temp(self, eda, temp): """Interpolate or truncate data in the temperature array to match the length of the EDA data array.""" eda_ix = np.arange(len(eda)) temp_ix = np.arange(len(temp)) if len(temp) < len(eda): interp_func = interp1d(temp_ix, temp, kind = 'linear', fill_value = 'extrapolate') temp = interp_func(eda_ix) if len(temp) > len(eda): temp = temp[:len(eda)] return temp