from typing import Dict, List, Literal, Optional, Tuple, Union
from zipfile import ZipFile, ZipExtFile
from flirt.hrv import get_hrv_features
from tqdm import tqdm
from scipy.signal import resample as scipy_resample
from plotly.subplots import make_subplots
from physioview._plotting import *
import warnings
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import datetime as dt
import pyedflib
__all__ = [
'Actiwave', 'Empatica', 'compute_ibis', 'compute_hrv', 'plot_signal',
'write_beat_editor_file', 'process_beat_edits'
]
[docs]
class Actiwave:
"""
A class for convenient preprocessing of data from the Actiwave Cardio
device.
Parameters/Attributes
---------------------
file : str
The path of the Actiwave Cardio device file saved in European
Data Format (.edf).
"""
def __init__(self, file: str):
"""
Initialize the Actiwave object.
Parameters
----------
file : str
The path of the Actiwave Cardio device file saved in European
Data Format (.edf).
"""
if not file.endswith(('.edf', '.EDF')):
raise ValueError(
'Invalid file path. The `file` parameter must take a string '
'value ending in \'.EDF\' or \'.edf\'.')
else:
self.file = file
[docs]
def preprocess(
self,
time_aligned: bool = False
) -> Union[tuple[pd.DataFrame, pd.DataFrame], pd.DataFrame]:
"""
Preprocess electrocardiograph (ECG) and acceleration data from
an Actiwave Cardio file.
Parameters
----------
time_aligned : bool, optional
Whether to time-align ECG and acceleration data based on the
sampling rate of the ECG data; by default, False.
Returns
-------
tuple or pandas.DataFrame
If `time_aligned` is False, returns a tuple (`ecg`, `acc`),
where `ecg` is a DataFrame containing the preprocessed ECG data
and `acc` is a DataFrame containing the preprocessed X-, Y-, and
Z-axis acceleration data. If `time_aligned` is True, returns a
single DataFrame containing time-synced ECG and acceleration
data according to the ECG data's timestamps.
"""
f = pyedflib.EdfReader(self.file)
start = dt.datetime.timestamp(f.getStartdatetime())
end = start + f.getFileDuration()
ecg, acc = pd.DataFrame(), pd.DataFrame()
signal_labels = f.getSignalLabels()
ecg_chn = [i for i in range(len(signal_labels))
if 'ECG' in signal_labels[i]]
acc_chn = [i for i in range(len(signal_labels))
if 'X' in signal_labels[i]
or 'Y' in signal_labels[i]
or 'Z' in signal_labels[i]]
acc_sig = dict(zip(['X', 'Y', 'Z'], acc_chn))
ecg_fs = f.getSampleFrequency(ecg_chn[0])
acc_fs = f.getSampleFrequency(acc_chn[0])
# Get ECG data
ecg['Timestamp'] = np.arange(start, end, (1 / ecg_fs))
ecg['ECG'] = pd.Series(f.readSignal(ecg_chn[0]) / 1000)
ecg['Timestamp'] = ecg['Timestamp'].apply(
lambda t: dt.datetime.utcfromtimestamp(t))
# Get ACC data
acc['Timestamp'] = np.arange(start, end, (1 / acc_fs))
for k, v in acc_sig.items():
acc[k] = pd.Series(f.readSignal(v))
acc['Magnitude'] = np.sqrt(acc[['X', 'Y', 'Z']].apply(
lambda x: x ** 2).sum(axis = 1))
acc['Timestamp'] = acc['Timestamp'].apply(
lambda t: dt.datetime.utcfromtimestamp(t))
f.close()
if time_aligned:
resampled = pd.DataFrame()
for col in ['X', 'Y', 'Z']:
rs = scipy_resample(acc[col], len(ecg))
resampled = pd.concat(
[resampled, pd.Series(rs, name = col)], axis = 1)
preprocessed = pd.concat([ecg, resampled], axis = 1)
return preprocessed
else:
return ecg, acc
[docs]
def get_ecg_fs(self) -> float:
"""
Get the sampling rate of ECG data from an Actiwave Cardio device.
Returns
-------
fs : int, float
The sampling rate of the ECG recording.
"""
f = pyedflib.EdfReader(self.file)
signal_labels = f.getSignalLabels()
for chn in range(len(signal_labels)):
if 'ECG' in signal_labels[chn]:
ecg_chn = chn
try:
fs = f.getSampleFrequency(ecg_chn)
return fs
except NameError:
raise NameError('No ECG channel found.')
finally:
f.close()
[docs]
def get_acc_fs(self) -> float:
"""
Get the sampling rate of accelerometer data from an Actiwave Cardio
device.
Returns
-------
fs : int, float
The sampling rate of the accelerometer recording.
"""
f = pyedflib.EdfReader(self.file)
signal_labels = f.getSignalLabels()
for chn in range(len(signal_labels)):
if 'X' in signal_labels[chn]:
acc_chn = chn
try:
fs = f.getSampleFrequency(acc_chn)
return fs
except NameError:
raise NameError('No ACC channels found.')
finally:
f.close()
# ==================== Empatica E4 Pre-Processing and SQA ====================
[docs]
class Empatica:
"""
A class to conveniently preprocess and assess quality of PPG and EDA data
from Empatica E4 devices.
Attributes
----------
file : str
The path of the Empatica archive file with a '.zip' extension.
"""
[docs]
class Data:
"""A class to store preprocessed data variables."""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __init__(self, file: str):
"""
Initialize the Empatica object.
Parameters
----------
file : str
The path of the Empatica archive file with a '.zip' extension.
"""
if not file.endswith(('.zip', '.ZIP')):
raise ValueError(
'Invalid file path. The `file` parameter must take a string '
'value ending in \'.zip\' or \'.ZIP\'.')
else:
self.file = file
[docs]
def preprocess(self, time_aligned: bool = False) -> 'Empatica.Data':
"""
Preprocess all data from the Empatica E4.
Parameters
----------
time_aligned : bool, optional
Whether to time-align all data based on the signal with the
highest sampling rate (i.e. blood volume pulse); by default,
False.
Returns
-------
data : Empatica.Data object
An `Empatica.Data` object with the following attributes and
corresponding preprocessed data:
If `time_aligned` is False:
acc : pandas.DataFrame
A DataFrame containing the preprocessed ACC data with
corresponding timestamps.
bvp : pandas.DataFrame
A DataFrame containing the preprocessed BVP data with
corresponding timestamps.
eda : pandas.DataFrame
A DataFrame containing the preprocessed EDA data with
corresponding timestamps.
hr : pandas.DataFrame
A DataFrame containing the preprocessed HR data with
corresponding timestamps.
ibi : pandas.DataFrame
A DataFrame containing the preprocessed IBI data with
corresponding timestamps and seconds elapsed since the
start time of the IBI recording.
temp : pandas.DataFrame
A DataFrame containing the preprocessed temperature
data with corresponding timestamps.
start_time : float
The Unix-formatted start time of the E4 recording.
bvp_fs : float
The sampling rate of the BVP recording.
eda_fs : float
The sampling rate of the EDA recording.
If `time_aligned` is True:
hrv : pandas.DataFrame
A DataFrame containing time-synced BVP, HR, IBI,
and acceleration data.
eda : pandas.DataFrame
A DataFrame containing time-synced EDA, temperature,
and acceleration data.
start_time : float
The Unix-formatted start time of the E4 recording.
bvp_fs : float
The sampling rate of the BVP recording.
eda_fs : float
The sampling rate of the EDA recording.
Examples
--------
>>> from physioview import physioview
>>> e4_archive = 'Sample_E4_Data.zip'
>>> E4 = physioview.Empatica(e4_archive)
>>> ALL_E4_DATA = E4.preprocess()
"""
with ZipFile(self.file, 'r') as archive:
e4_files = archive.namelist()
for file in e4_files:
if 'ACC' in file:
with archive.open(file) as acc_file:
acc_data = self.get_acc().acc
if 'BVP' in file:
with archive.open(file) as bvp_file:
bvp_data = self.get_bvp().bvp
start_time = self.get_bvp().start
bvp_fs = self.get_bvp().fs
if 'EDA' in file:
with archive.open(file) as eda_file:
eda_data = self.get_eda().eda
start_time = self.get_eda().start
eda_fs = self.get_eda().fs
if 'HR' in file:
with archive.open(file) as hr_file:
hr_data = self.get_hr().hr
if 'IBI' in file:
with archive.open(file) as ibi_file:
ibi_data = self.get_ibi().ibi
if 'TEMP' in file:
with archive.open(file) as temp_file:
temp_data = self.get_temp().temp
if time_aligned:
# Merge IBI and HR values into BVP data frame
full_hrv = pd.merge_asof(
bvp_data, ibi_data.drop(['Seconds'], axis = 1),
on = 'Timestamp', direction = 'nearest')
full_hrv = pd.merge_asof(
full_hrv, hr_data,
on = 'Timestamp', direction = 'nearest')
bvp_ts = bvp_data['Timestamp'].values
ibi_ts = ibi_data['Timestamp'].values
hr_ts = hr_data['Timestamp'].values
ibi_insertion_points = np.searchsorted(bvp_ts, ibi_ts) - 1
hr_insertion_points = np.searchsorted(bvp_ts, hr_ts)
full_hrv.loc[~np.isin(np.arange(len(full_hrv)),
ibi_insertion_points), 'IBI'] = np.nan
full_hrv.loc[~np.isin(np.arange(len(full_hrv)),
hr_insertion_points), 'HR'] = np.nan
# Resample acceleration data to match BVP and EDA sampling rates
acc_rs = pd.DataFrame()
acc_cols = ['X', 'Y', 'Z', 'Magnitude']
for ref_data in [bvp_data, eda_data]:
acc_rs[acc_cols] = acc_data[acc_cols].apply(
lambda a: scipy_resample(a, len(ref_data)))
if ref_data is bvp_data:
full_hrv = pd.merge(full_hrv, acc_rs,
left_index = True, right_index = True)
else:
full_eda = pd.merge(eda_data, temp_data,
on = 'Timestamp', how = 'inner')
full_eda = pd.merge(full_eda, acc_rs,
left_index = True, right_index = True)
data = self.Data(**{'hrv': full_hrv,
'eda': full_eda,
'start': start_time,
'bvp_fs': bvp_fs,
'eda_fs': eda_fs})
else:
data = self.Data(**{'acc': acc_data,
'bvp': bvp_data,
'eda': eda_data,
'hr': hr_data,
'ibi': ibi_data,
'temp': temp_data,
'start': start_time,
'bvp_fs': bvp_fs,
'eda_fs': eda_fs})
return data
[docs]
def get_acc(self) -> 'Empatica.Data':
"""
Get the preprocessed acceleration data and its start time and
sampling rate from the Empatica E4.
Returns
-------
acc_data : Empatica.Data object
An `Empatica.Data` object with the following attributes and
corresponding accelerometer data variables:
acc : pandas.DataFrame
A DataFrame containing the preprocessed BVP data with
corresponding timestamps.
start : float
The Unix-formatted start time of the BVP recording.
fs : int
The sampling rate of the BVP data.
"""
from physioview.pipeline.ACC import compute_magnitude
with ZipFile(self.file, 'r') as archive:
e4_files = archive.namelist()
acc_file = None
for file in e4_files:
if 'ACC' in file:
acc_file = file
break
if acc_file is None:
raise ValueError('No "ACC.csv" file found.')
with archive.open(file) as acc_file:
acc, acc_start, acc_fs = self._get_e4_data(
acc_file, name = ['X', 'Y', 'Z'])
acc = acc.apply(lambda x: (x / 64) * 9.81
if x.name != 'Timestamp' else x)
acc['Magnitude'] = compute_magnitude(
acc['X'], acc['Y'], acc['Z'])
acc_data = self.Data(**{'acc': acc,
'start': acc_start,
'fs': acc_fs})
return acc_data
[docs]
def get_bvp(self) -> 'Empatica.Data':
"""
Get the raw blood volume pulse (BVP) data and its start time and
sampling rate from the Empatica E4.
Returns
-------
bvp_data : Empatica.Data object
An `Empatica.Data` object with the following attributes and
corresponding BVP data variables:
bvp : pandas.DataFrame
A DataFrame containing the preprocessed BVP data with
corresponding timestamps.
start : float
The Unix-formatted start time of the BVP recording.
fs : int
The sampling rate of the BVP data.
"""
with ZipFile(self.file, 'r') as archive:
e4_files = archive.namelist()
bvp_file = None
for file in e4_files:
if 'BVP' in file:
bvp_file = file
break
if bvp_file is None:
raise ValueError('No "BVP.csv" file found.')
with archive.open(bvp_file) as bvp_file:
bvp, bvp_start, bvp_fs = self._get_e4_data(
bvp_file, name = 'BVP')
bvp_data = self.Data(**{'bvp': bvp,
'start': bvp_start,
'fs': bvp_fs})
return bvp_data
[docs]
def get_eda(self) -> 'Empatica.Data':
"""
Get the raw electrodermal activity (EDA) data and its recording
start time and sampling rate from the Empatica E4.
Returns
-------
eda_data : Empatica.Data object
An `Empatica.Data` object with the following attributes and
corresponding EDA data variables:
eda : pandas.DataFrame
A DataFrame containing the preprocessed EDA data with
corresponding timestamps.
start : float
The Unix-formatted start time of the EDA recording.
fs : int
The sampling rate of the EDA data.
"""
with ZipFile(self.file, 'r') as archive:
e4_files = archive.namelist()
eda_file = None
for file in e4_files:
if 'EDA' in file:
eda_file = file
break
if eda_file is None:
raise ValueError('No "EDA.csv" file found.')
with archive.open(eda_file) as eda_file:
eda, eda_start, eda_fs = self._get_e4_data(
eda_file, name = 'EDA')
eda_data = self.Data(**{'eda': eda,
'start': eda_start,
'fs': eda_fs})
return eda_data
[docs]
def get_hr(self) -> 'Empatica.Data':
"""
Get the preprocessed heart rate (HR) data, start time of the
first HR measurement, and sampling rate from the Empatica E4.
Returns
-------
hr_data : Empatica.Data object
An `Empatica.Data` object with the following attributes and
corresponding HR data variables:
hr : pandas.DataFrame
A DataFrame containing the preprocessed HR data with
corresponding timestamps.
start : float
The Unix-formatted start time of the HR measurements.
fs : int
The sampling rate of the BVP data.
"""
with ZipFile(self.file, 'r') as archive:
e4_files = archive.namelist()
hr_file = None
for file in e4_files:
if 'HR' in file:
hr_file = file
break
if hr_file is None:
raise ValueError('No "HR.csv" file found.')
with archive.open(file) as hr_file:
hr, hr_start, hr_fs = self._get_e4_data(
hr_file, name = 'HR')
hr_data = self.Data(**{'hr': hr,
'start': hr_start,
'fs': hr_fs})
return hr_data
[docs]
def get_ibi(self) -> 'Empatica.Data':
"""
Get the preprocessed interbeat interval (IBI) data and the start
time of the first interval from the Empatica E4.
Returns
-------
ibi_data : Empatica.Data object
An `Empatica.Data` object with the following attributes and
corresponding IBI data variables:
ibi : pandas.DataFrame
A DataFrame containing the preprocessed IBI data with
corresponding timestamps.
start : int
The Unix-formatted start time of the IBI data.
"""
with ZipFile(self.file, 'r') as archive:
e4_files = archive.namelist()
ibi_file = None
for file in e4_files:
if 'IBI' in file:
ibi_file = file
break
if ibi_file is None:
raise ValueError('No "IBI.csv" file found.')
with archive.open(file) as ibi_file:
ibi = pd.read_csv(ibi_file, header = 0,
names = ['Seconds', 'IBI'])
ibi_file.seek(0)
ibi_start = self._get_e4_start_time(ibi_file)
ibi['IBI'] *= 1000
ibi.insert(
0, 'Timestamp', (ibi['Seconds'] + ibi_start).apply(
lambda t: dt.datetime.utcfromtimestamp(t)))
ibi_data = self.Data(**{'ibi': ibi, 'start': ibi_start})
return ibi_data
[docs]
def get_temp(self) -> 'Empatica.Data':
"""
Get the raw skin temperature data and its recording start time and
sampling rate from the Empatica E4.
Returns
-------
temp_data : Empatica.Data object
An `Empatica.Data` object with the following attributes and
corresponding temperature data variables:
temp : pandas.DataFrame
A DataFrame containing the preprocessed temperature data with
corresponding timestamps.
start : float
The Unix-formatted start time of the temperature recording.
fs : int
The sampling rate of the temperature data.
"""
with ZipFile(self.file, 'r') as archive:
e4_files = archive.namelist()
temp_file = None
for file in e4_files:
if 'TEMP' in file:
temp_file = file
break
if temp_file is None:
raise ValueError('No "TEMP.csv" file found.')
with archive.open(temp_file) as temp_file:
temp, temp_start, temp_fs = self._get_e4_data(
temp_file, name = 'TEMP')
temp_data = self.Data(**{'temp': temp,
'start': temp_start,
'fs': temp_fs})
return temp_data
[docs]
def get_e4_beats(
self,
bvp_data: pd.DataFrame,
ibi_data: pd.DataFrame,
start_time: int,
show_progress: bool = True
) -> list[int]:
"""
Get locations of beats from Empatica E4 interbeat interval (IBI)
data relative to its blood volume pulse (BVP) data.
Parameters
----------
bvp_data : pandas.DataFrame
A DataFrame containing the Empatica E4 BVP data, outputted from
`Empatica.preprocess()`.
ibi_data : pandas.DataFrame
A DataFrame containing the Empatica E4 IBI data, outputted from
`Empatica.preprocess()`.
start_time : int
The Unix timestamp of the recording start time.
show_progress : bool, optional
Whether to display a progress bar while the function runs; by
default, True.
Returns
-------
e4_beats : list
A list containing the indices of beats extracted from IBI data of
the Empatica E4.
"""
ibi = ibi_data.copy()
bvp = bvp_data.copy()
ibi['Unix'] = ibi['Seconds'] + start_time
ibi['Timestamp'] = ibi['Unix'].apply(
lambda t: dt.datetime.utcfromtimestamp(t))
bvp['Timestamp'] = pd.to_datetime(bvp['Timestamp'])
e4_beats = []
for t in tqdm(ibi['Timestamp'], disable = not show_progress):
time_diff = np.abs(bvp['Timestamp'] - t)
closest_ix = time_diff.idxmin()
e4_beats.append(closest_ix)
return e4_beats
[docs]
def compute_sqa(
self,
dtype: str,
seg_size: int = 60,
initial_hr: Union[int, float, Literal['auto']] = 'auto',
min_hr: int = 40,
min_eda: float = 0.2,
max_eda: float = 40.0,
rolling_window: int = None,
rolling_step: int = 15,
show_progress: bool = True
) -> Union[pd.DataFrame, tuple[pd.DataFrame, pd.DataFrame]]:
"""
Compute signal quality assessment metrics (SQA) PPG and/or EDA data
from Empatica E4 devices.
Parameters
----------
dtype : str
The type of data whose SQA to compute. This value must be a string
variation of 'all', 'eda', or 'ppg'.
seg_size : int
The segment size in seconds; by default, 60.
initial_hr : int, float, or 'auto', optional
The heart rate value for the first interbeat interval (IBI) to be
validated against; by default, 'auto' for automatic calculation
using the mean heart rate value obtained from six consecutive
IBIs with the smallest average successive difference.
min_hr : int, float
The minimum acceptable heart rate against which the number of
beats in the last partial segment will be compared; by default, 40.
min_eda : float, optional
The minimum acceptable value for EDA data in microsiemens; by
default, 0.2 uS.
max_eda : float, optional
The maximum acceptable value for EDA data in microsiemens; by
default, 40 uS.
rolling_window : int, optional
The size, in seconds, of the sliding window across which to
compute the SQA metrics; by default, None.
rolling_step : int, optional
The step size, in seconds, of the sliding windows; by default, 15.
show_progress : bool, optional
Whether to display a progress bar while the function runs; by
default, True.
Returns
-------
metrics : pandas.DataFrame
A DataFrame with all computed SQA metrics per segment.
Notes
-----
If a value is given in the `rolling_window` parameter, the rolling
window approach will override the segmented approach, ignoring any
`seg_size` value.
"""
from physioview.pipeline.SQA import Cardio, EDA
from physioview.pipeline.PPG import BeatDetectors
from physioview.pipeline.EDA import Filters as eda_filters
if dtype.lower() not in ('all', 'eda', 'ppg'):
raise ValueError('The `kind` parameter must take a string value '
'\'all\', \'eda\', or \'ppg\'.')
else:
if dtype == 'all':
dtype = ('eda', 'ppg')
ppg_metrics, eda_metrics = None, None
if 'ppg' in dtype:
bvp = self.get_bvp().bvp
fs = self.get_bvp().fs
ppg_beats = BeatDetectors(fs, False).adaptive_threshold(
bvp['BVP'])
sqa = Cardio(fs)
artifact_beats = sqa.identify_artifacts(
ppg_beats, 'both', initial_hr, 6, 5, 1)
ppg_metrics = sqa.compute_metrics(
bvp, ppg_beats, artifact_beats, 'Timestamp', seg_size,
min_hr, rolling_window, rolling_step, show_progress)
if 'eda' in dtype:
eda = self.get_eda().eda
fs = self.get_eda().fs
start_time = pd.to_datetime(self.get_eda().start, unit = 's')
eda['EDA'] = eda_filters(fs).filter_signal(eda['EDA'])
temp = self.get_temp().temp
sqa = EDA(fs, eda_min = min_eda, eda_max = max_eda)
eda_metrics = sqa.compute_metrics(
eda['EDA'], temp['TEMP'], preprocessed = True,
seg_size = 60, rolling_window = rolling_window,
rolling_step = rolling_step)
ts = pd.date_range(
start = start_time,
periods = len(eda_metrics),
freq = pd.Timedelta(seconds = seg_size)
)
eda_metrics.insert(1, 'Timestamp', ts)
if ppg_metrics is not None and eda_metrics is not None:
return ppg_metrics, eda_metrics
else:
if ppg_metrics is not None:
return ppg_metrics
if eda_metrics is not None:
return eda_metrics
[docs]
def plot_signals(
self,
segment: int = 1,
seg_size: int = 60,
interactive: bool = True,
**kwargs
) -> go.Figure:
"""
Display a plot of a segment of signals recorded with the Empatica E4
device.
Parameters
----------
segment : int, optional
The number of the position of the segment to plot; by default, 1.
seg_size : int, optional
The segment size in seconds; by default, 60.
interactive : bool, optional
Whether to plot an interactive visualization; by default, True.
**kwargs : dict, optional
Additional keyword arguments passed to the Plotly figure's
`update_layout()` method. This allows customizing attributes such
as `height`, `width`, `title`, `template`, etc.
Returns
-------
fig : plotly.graph_objects.Figure or None
If `interactive` is True, displays and returns an interactive
Plotly figure containing the plotted signals. If `interactive`
is False, displays a static figure and returns None.
Examples
--------
>>> from physioview import physioview
>>> e4 = physioview.Empatica('empatica_file.zip')
>>> fig = e4.plot_signals(
>>> interactive = True, template = 'simple_white')
"""
data = self.preprocess(time_aligned = True)
# Set the subplot order
dtypes = ('acc', 'bvp', 'eda', 'temp')
if interactive:
fig = make_subplots(
rows = 4, cols = 1,
shared_xaxes = True,
vertical_spacing = 0.02,
row_heights = [0.2, 0.3, 0.3, 0.2])
for n in range(len(dtypes)):
if dtypes[n] in ('acc', 'bvp'):
df = data.hrv
fs = data.bvp_fs
seg_start = int((segment - 1) * fs * seg_size)
seg_end = seg_start + int(fs * seg_size)
signal_name = 'ACC' if dtypes[n] == 'acc' else 'BVP'
color = 'forestgreen' if dtypes[n] == 'acc' else '#3562bd'
if dtypes[n] == 'acc':
ylabel = 'm/s²'
x = df['Timestamp'].iloc[seg_start:seg_end]
y = df['Magnitude'].iloc[seg_start:seg_end]
else:
ylabel = 'bvp'
x = df['Timestamp'].iloc[seg_start:seg_end]
y = df['BVP'].iloc[seg_start:seg_end]
else:
df = data.eda
fs = data.eda_fs
seg_start = int((segment - 1) * fs * seg_size)
seg_end = seg_start + int(fs * seg_size)
signal_name = 'EDA' if dtypes[n] == 'eda' else 'TEMP'
color = '#249ab5' if dtypes[n] == 'eda' else '#8659c2'
if dtypes[n] == 'eda':
ylabel = 'uS'
x = df['Timestamp'].iloc[seg_start:seg_end]
y = df['EDA'].iloc[seg_start:seg_end]
else:
ylabel = '°C'
x = df['Timestamp'].iloc[seg_start:seg_end]
y = df['TEMP'].iloc[seg_start:seg_end]
fig.add_trace(
go.Scatter(
x = x, y = y,
name = signal_name,
line = dict(color = color, width = 1.5),
hovertemplate = f'<b>{signal_name}</b>: %{{y:.2f}} '
f'{ylabel}<extra></extra>'),
row = n+1, col = 1)
fig.update_yaxes(
title_text = ylabel,
row = n+1, col = 1,
showgrid = True,
gridwidth = 0.5,
gridcolor = 'lightgrey',
griddash = 'dot',
tickcolor = 'grey',
linecolor = 'grey')
# Apply user-supplied layout modifications
if kwargs:
fig.update_layout(
template = 'simple_white',
legend = dict(
font = dict(size = 15), orientation = 'h',
yanchor = 'bottom', y = 1.05,
xanchor = 'right', x = 1.0),
**kwargs)
fig.show()
return fig
else:
fig, axs = plt.subplots(4, 1, figsize = (10, 8))
for n in range(len(dtypes)):
fs = data.eda_fs
seg_start = int((segment - 1) * fs * seg_size)
seg_end = seg_start + int(fs * seg_size)
if dtypes[n] in ('acc', 'bvp'):
signal_name = 'ACC' if dtypes[n] == 'acc' else 'BVP'
color = 'forestgreen' if dtypes[n] == 'acc' else '#3562bd'
ylabel = 'm/s²' if dtypes[n] == 'acc' else 'BVP'
if dtypes[n] == 'acc':
x = data.acc['Timestamp'].iloc[seg_start:seg_end]
y = data.acc['Magnitude'].iloc[seg_start:seg_end]
else:
x = data.bvp['Timestamp'].iloc[seg_start:seg_end]
y = data.bvp['BVP'].iloc[seg_start:seg_end]
else:
signal_name = 'EDA' if dtypes[n] == 'eda' else 'Temperature'
color = '#43c9de' if dtypes[n] == 'eda' else '#8b3ac9'
ylabel = 'uS' if dtypes[n] == 'eda' else '°C'
if dtypes[n] == 'eda':
x = data.eda['Timestamp'].iloc[seg_start:seg_end]
y = data.eda['EDA'].iloc[seg_start:seg_end]
else:
x = data.temp['Timestamp'].iloc[seg_start:seg_end]
y = data.temp['TEMP'].iloc[seg_start:seg_end]
for ax in axs:
ax.plot(x, y, label = signal_name, color = color, lw = 1.2)
ax.set_xlabel('Timestamp')
ax.set_ylabel(ylabel)
ax.legend(frameon = False)
plt.tight_layout()
plt.show()
return fig, axs
def _get_e4_data(self, file, name):
"""Extract data from an Empatica E4 file."""
if not isinstance(name, list) and not isinstance(name, str):
raise ValueError('The `name` parameter must take either a string '
'or a list of strings.')
else:
if isinstance(name, list):
col_name = name
else:
col_name = [name]
data = pd.read_csv(file, header = 1, names = col_name)
if isinstance(file, str):
fs = self._get_e4_fs(file)
start_time = self._get_e4_start_time(file)
else:
if hasattr(file, 'seek'):
file.seek(0)
fs = self._get_e4_fs(file)
file.seek(0)
start_time = self._get_e4_start_time(file)
timestamps = pd.date_range(
start = pd.to_datetime(start_time, unit = 's'),
periods = len(data), freq = f'{1 / fs}s')
timestamps = pd.Series(timestamps, name = 'Timestamp')
data = pd.merge(timestamps, data,
left_index = True, right_index = True)
return data, start_time, fs
def _get_e4_fs(self, file):
"""Get the sampling rate from an Empatica E4 file."""
contents = pd.read_csv(file, header = None, nrows = 2, usecols = [0])
fs = contents.iloc[1].item()
return fs
def _get_e4_start_time(self, file):
"""Get the Unix-formatted start time of an Empatica E4 recording."""
contents = pd.read_csv(file, header = None, nrows = 2, usecols = [0])
if type(file) is ZipExtFile:
if 'IBI' in file.name:
start = contents.loc[0, 0]
else:
start = contents.iloc[0].item()
else:
if file.endswith('IBI.csv'):
start = contents.loc[0, 0]
else:
start = contents.iloc[0].item()
return start
# ======================== Other Data Pre-Processing =========================
def get_duration(
data: Union[pd.DataFrame, pd.Series, np.ndarray],
fs: int,
unit: str = 'sec'
) -> float:
"""
Get the duration of a signal.
Parameters
----------
data : array-like
An array or DataFrame containing the signal.
fs : int
The sampling rate of the data.
unit : str
The unit in which the duration should be calculated; by default,
in seconds (`sec`).
Returns
-------
dur : float
The duration of the signal in the requested unit.
"""
dur = len(data) / fs
if unit not in ['sec', 's', 'min', 'm', 'hour', 'h']:
raise ValueError('The `unit` parameter must take \'sec\', \'min\', '
'or \'hour\'.')
else:
if unit in ('min', 'm'):
return round((dur / 60), 2)
if unit == ('hour', 'h'):
return round(((dur / 60) / 60), 2)
return round(dur, 2)
def segment_data(
data: pd.DataFrame,
fs: int,
seg_size: int
) -> pd.DataFrame:
"""
Segment data into specific window sizes.
Parameters
----------
data : pandas.DataFrame
The DataFrame containing the data to be segmented.
fs : int
The sampling rate of the data.
seg_size : int
The window size, in seconds, into which the data should be
segmented.
Returns
-------
df : pandas.DataFrame
The original DataFrame with data segmented with labels in a
'Segment' column.
"""
df = data.copy()
df.insert(0, 'Segment', 0)
segment = 1
for n in range(0, len(df), int(seg_size * fs)):
df.loc[n:(n + int(seg_size * fs)), 'Segment'] = segment
segment += 1
return df
[docs]
def compute_ibis(
data: pd.DataFrame,
fs: int,
beats_ix: np.ndarray,
ts_col: Optional[str] = None
) -> pd.DataFrame:
"""
Compute interbeat intervals from beat locations in ECG or PPG data.
Parameters
----------
data : pandas.DataFrame
The DataFrame containing the preprocessed ECG/PPG data.
fs : int
The sampling rate of the ECG/PPG data.
beats_ix : array_like
An array of indices corresponding to beat occurrences.
ts_col : str
The name of the column in `data` containing timestamp values; by
default, None.
Returns
-------
ibi : pandas.DataFrame
A DataFrame containing timestamps and IBI values.
Examples
--------
>>> import physioview
>>> fs = 1024 # sampling rate
>>> # Here, `ecg` is a DataFrame with a "Timestamp" column
>>> beats_ix = physioview.ECGBeatDetectors(fs).manikandan(ecg['ECG'])
>>> ibi = physioview.compute_ibis(ecg, fs, beats_ix, 'Timestamp')
"""
df = data.copy()
ibis = (np.diff(beats_ix) / fs) * 1000
if ts_col is not None:
ibi = df[[ts_col]].copy()
else:
ibi = pd.DataFrame({'Sample': np.arange(len(df)) + 1})
for n, ix in enumerate(beats_ix[1:]):
ibi.loc[ix, 'IBI'] = ibis[n]
return ibi
[docs]
def compute_hrv(
data: pd.DataFrame,
fs: int,
beats_ix: np.ndarray,
window_size: int = 60,
step_size: int = 1,
ts_col: Optional[str] = None
) -> pd.DataFrame:
"""
Compute heart rate variability (HRV) metrics from beat locations in ECG or
PPG data.
Parameters
----------
data : pandas.DataFrame
The DataFrame containing the preprocessed ECG/PPG data.
fs : int
The sampling rate of the ECG/PPG data.
beats_ix : array_like
An array of indices corresponding to beat occurrences.
window_size : int, optional
The size of the windows over which HRV metrics are calculated; by
default, 60.
step_size : int, optional
The step size of the windows over which HRV metrics are calculated;
by default, 1.
ts_col : str
The name of the column in `data` containing timestamp values; by
default, None.
Returns
-------
hrv : pandas.DataFrame
A DataFrame containing HRV metrics.
Examples
--------
>>> import physioview
>>> fs = 1024 # sampling rate
>>> # Here, `ecg` is a DataFrame with a "Timestamp" column
>>> beats_ix = physioview.ECGBeatDetectors(fs).manikandan(ecg['ECG'])
>>> # Compute HRV across 60-sec windows at 15-sec intervals
>>> hrv = physioview.compute_hrv(ecg, fs, beats_ix, window_size = 15,
>>> step_size = 15, ts_col = 'Timestamp')
"""
ibi = compute_ibis(data, fs, beats_ix, ts_col)
ibi = ibi.dropna().reset_index(drop = True)
if ts_col is not None:
ibi.set_index(ts_col, inplace = True)
ibi_series = ibi['IBI']
else:
start_time = pd.Timestamp.now()
ibi_vals = ibi['IBI'].values
times = start_time + pd.to_timedelta(pd.Series(ibi_vals).cumsum(),
unit = 'ms')
ibi_series = pd.Series(ibi_vals, index = times, name = 'IBI')
hrv = get_hrv_features(
ibi_series,
window_length = window_size,
window_step_size = step_size,
domains = ['td', 'fd', 'nl', 'stat'],
threshold = 0.5,
clean_data = False)
if ts_col is not None:
hrv.index.name = ts_col
else:
hrv.reset_index(drop = True, inplace = True)
return hrv
SignalType = Literal['ECG', 'PPG', 'BVP', 'EDA', 'HR', 'RESP', 'TEMP']
[docs]
def plot_signal(
*,
signal: pd.DataFrame,
signal_type: Union[SignalType, List[SignalType]],
axes: Tuple[str, Union[str, List[str]]],
fs: int,
peaks_map: Optional[Dict[str, Union[str, str]]] = None,
peaks_label: Optional[str] = None,
peaks_color: Optional[str] = None,
artifacts_map: Optional[Dict[str, Union[str, str]]] = None,
correction_map: Optional[Dict[str, Union[str, str]]] = None,
edits_map: Optional[Dict[str, Union[str, str]]] = None,
acc: Optional[pd.DataFrame] = None,
ibi: Optional[pd.DataFrame] = None,
ibi_corrected: Optional[pd.DataFrame] = None,
hline: Optional[float] = None,
hline_name: Optional[str] = None,
seg_number: Optional[int] = 1,
seg_size: Optional[int] = 60,
n_segments: Optional[int] = 1,
fig_title: Optional[str] = None,
fig_height: Optional[int] = 450
) -> go.Figure:
"""
Create a Plotly figure with primary and optional secondary physiological
signals, including optional peaks and artifact markers.
Parameters
----------
signal : pandas.DataFrame
A DataFrame containing the primary signal data to plot.
signal_type : str or list of str
The name(s) of the primary signal type(s) to plot. Possible values
include 'ECG', 'PPG', 'BVP', 'EDA', 'HR', 'RESP', and 'TEMP'.
axes : tuple of (str, str or list of str or dict)
A tuple specifying the x-axis column and the y-axis signal(s) to plot.
The first element must be the name of the column in `signal` to use
for the x-axis (e.g., `'Timestamp'`).
The second element can take one of the following forms:
- str : a single y-axis column to plot for one signal type. Example:
``('Timestamp', 'EDA')``.
- list of str : multiple y-axis columns to plot for the same signal
type. Example: ``('Timestamp', ['EDA', 'Phasic'])``.
- dict : mapping of signal types to one or more y-axis columns,
allowing multiple signal types to be plotted in separate subplots.
Example:
``('Timestamp', {'EDA': 'EDA', 'ECG': 'ECG'})`` or
``('Timestamp', {'EDA': ['EDA', 'Phasic'], 'ECG': ['ECG']})``
fs : int
The sampling rate (Hz) of the signal data.
peaks_map : dict of {str: str}, optional
A dictionary mapping a signal type to the name of a column in `signal`
containing binary (0/1) peak annotations; by default, None (i.e.,
no peaks are plotted). Example: ``{'ECG': 'Beat'}`` will plot beat
markers from the 'Beat' column on the ECG subplot.
peaks_label : str, optional
A label for the peak annotations on the signal subplot.
peaks_color : str, optional
A color for the peak annotations on the signal subplot.
artifacts_map : dict of {str: str}, optional
A dictionary mapping a signal type to the name of a column in `signal`
containing binary (0/1) artifact annotations; by default, None (i.e.,
no artifacts are plotted). Example: `{'ECG': 'Artifact'}` will plot
artifact markers from the 'Artifact' column on the ECG subplot.
correction_map : dict of {str: str}, optional
A dictionary mapping a signal type to the name of a column in `signal`
containing binary (0/1) corrected beat annotations; by default, None (i.e.,
no corrected peaks are plotted). Example: ``{`'ECG': 'Corrected'}``
will plot corrected beat markers from the 'Corrected' column on the
ECG subplot.
edits_map : dict of {str: str}, optional
A dictionary mapping a signal type to one or more edit types and
their corresponding binary (0/1) annotation columns in `signal`; by
default, `None` (i.e., no edits are plotted). Format:
``{signal_type: {edit_label: column_name, ...}}``.
For example:
``{'ECG': {'Added': 'Added Beat', 'Deleted': 'Deleted Beat',
'Unusable': 'Unusable'}}``
hline : float, optional
If provided, plots a horizontal dotted line for a given reference
amplitude value in the primary signal plot(s); by default, `None`.
hline_name : str, optional
A label for the horizontal line; by default, `None`.
acc : pandas.DataFrame, optional
DataFrame containing accelerometer data. If present, plotted
as a secondary signal in the first subplot. Must contain
'Magnitude' or another numeric column.
ibi : pandas.DataFrame, optional
DataFrame containing inter-beat interval (IBI) data. If present,
plotted as a secondary signal in the last subplot. Must contain
'IBI' or another numeric column.
ibi_corrected : pandas.DataFrame, optional
DataFrame containing auto-corrected interbeat interval (IBI) data.
If present, plotted as a secondary signal in the last subplot. Must
contain 'IBI' or another numeric column.
seg_number : int, optional
The positional number of the segment to plot; by default, 1 (first
segment).
seg_size : int, optional
The length of each segment in seconds; by default, 60.
n_segments : int, optional
The number of consecutive segments to plot starting from
`seg_number`; by default, 1.
fig_title : str, optional
The title of the figure.
fig_height : int, optional
The height of the Plotly figure in pixels; by default, 450.
Returns
-------
fig : go.Figure
The Plotly figure containing the plotted signals.
Examples
--------
>>> from physioview import physioview
>>> # data has columns: 'Timestamp', 'II', 'GSR', 'Beat', 'SCR', 'Artifact'
>>> fig = physioview.plot_signal(
>>> signal = data,
>>> signal_type = ['ECG', 'EDA'],
>>> axes = ('Timestamp', {'ECG': 'II', 'EDA': 'GSR'}),
>>> fs = 256,
>>> peaks_map = {'ECG': 'Beat', 'EDA', 'SCR'},
>>> artifacts_map = {'ECG': 'Artifact'},
>>> acc = acc_data,
>>> ibi = ibi_data)
>>> fig.show()
"""
# Validate axes
ax_x, ax_y = axes[0], axes[1]
if isinstance(ax_y, str):
if ax_y not in signal.columns:
raise KeyError(f'{ax_y} not found in `signal` columns.')
elif isinstance(ax_y, list):
for y in ax_y:
if y not in signal.columns:
raise KeyError(f'{y} not found in `signal` columns.')
elif isinstance(ax_y, dict):
for stype, ycols in ax_y.items():
ycols = [ycols] if isinstance(ycols, str) else ycols
for y in ycols:
if y not in signal.columns:
raise KeyError(f'{y} not found in `signal` columns.')
if stype not in signal_type:
raise KeyError(f'{stype} not given in `signal_type`.')
else:
raise TypeError(
'`axes[1]` must be a str, list of str, or dict '
'mapping signal types to column(s).'
)
# Validate peaks map keys
if peaks_map is not None:
for sig, col in peaks_map.items():
if sig not in signal_type:
raise KeyError(
f"'{sig}' in `peaks_map` not found in `signal_type`.")
def __row_params(
n_primary: int = 1,
has_acc: bool = False,
has_ibi: bool = True,
primary_total: float = 0.6,
secondary_total_each: float = 0.2
) -> list[float]:
"""Set row_ids and row_heights for subplots with at least one primary
signal and optional secondary signals."""
row_heights: list[float] = []
# Add subplot row for acceleration signal, if any
if has_acc:
row_heights.append(secondary_total_each)
# Add middle subplot row(s) for primary signal(s)
if n_primary > 0:
primary_each = primary_total / n_primary
row_heights.extend([primary_each] * n_primary)
# Add last subplot row for IBI signal, if any
if has_ibi:
row_heights.append(secondary_total_each)
row_ids = list(range(1, len(row_heights) + 1))
return row_heights, row_ids
# Normalize axes
if isinstance(ax_y, str):
if isinstance(signal_type, list):
if len(signal_type) != 1:
raise ValueError(
'If `axes` is a tuple with a single y-column, '
'`signal_type` must also be a single string.')
axes_dict = {signal_type[0]: [ax_y]}
else:
axes_dict = {signal_type: [ax_y]}
elif isinstance(ax_y, list):
# Multiple traces but still one signal type
if isinstance(signal_type, list):
if len(signal_type) != 1:
raise ValueError(
'If `axes` is a tuple with a list of y-columns, '
'`signal_type` must be a single string.')
axes_dict = {signal_type[0]: ax_y}
else:
axes_dict = {signal_type: ax_y}
elif isinstance(ax_y, dict):
axes_dict = {
stype: ([val] if isinstance(val, str) else val)
for stype, val in ax_y.items()
}
else:
raise TypeError(
'`axes[1]` must be a str, list of str, or dict '
'mapping signal types to column(s).')
# Count and extract secondary signals
has_acc, has_ibi, has_ibi_corrected = False, False, False
n_secondary = 0
if acc is not None:
has_acc = True
n_secondary += 1
if 'Magnitude' not in acc.columns:
num_cols = acc.select_dtypes(
include = 'number').columns.tolist()
acc_col = num_cols[0]
warnings.warn(f"'Magnitude' not found in `acc` columns. Using "
f"{acc_col} instead.")
else:
acc_col = 'Magnitude'
if ibi is not None:
has_ibi = True
n_secondary += 1
ibi_col = 'IBI'
if ibi_col not in ibi.columns:
num_cols = ibi.select_dtypes(
include = 'number').columns.tolist()
ibi_col = num_cols[0] if num_cols else None
warnings.warn(f"'IBI' not found in `ibi` columns. Using "
f"{ibi_col} instead.")
if ibi_corrected is not None:
has_ibi_corrected = True
ibi_corrected_col = 'IBI'
if ibi_corrected_col not in ibi_corrected.columns:
num_cols = ibi_corrected.select_dtypes(
include = 'number').columns.tolist()
corrected_ibi_col = num_cols[0] if num_cols else None
warnings.warn(f"'IBI' not found in `ibi_corrected` columns. Using "
f"{corrected_ibi_col} instead.")
# Segment data
seg_len = int(seg_size * fs)
seg_start = int((seg_number - 1) * seg_len)
seg_end = seg_start + int(n_segments * seg_len)
sig_seg = signal.iloc[seg_start:seg_end].copy()
# Create a subplot for each signal type(s)
signal_type = signal_type if isinstance(signal_type, list) else [signal_type]
n_primary = len(signal_type)
row_heights, row_ids = __row_params(n_primary, has_acc, has_ibi)
fig = make_subplots(
rows = n_primary + n_secondary, cols = 1,
shared_xaxes = True, vertical_spacing = 0.02,
row_heights = row_heights,
)
# Plot ACC signal in the first subplot if provided
start_row = 1
if acc is not None:
if len(acc) != len(signal):
warnings.warn('`acc` and `signal` have unmatched lengths. '
'Resampling `acc`.')
acc_y = scipy_resample(acc[acc_col], len(signal))
else:
acc_y = acc[acc_col]
acc_y_seg = acc_y[seg_start:seg_end].copy()
fig = _acc_subplot(sig_seg[ax_x], acc_y_seg, fig)
start_row = 2
# Plot horizontal line if requested
has_hline = (hline is not None)
if has_hline:
unit = _DEFAULT_SIGNAL_PARAMS.get(signal_type[0], None)['unit']
fig.add_trace(
go.Scatter(
x = sig_seg[ax_x],
y = [hline] * len(sig_seg[ax_x]),
mode = 'lines',
line = dict(color = 'red', dash = 'dot', width = 1),
showlegend = False,
hovertemplate = (f'{hline_name}: {hline} {unit}<extra></extra>'
if hline_name is not None else '<extra></extra>'),
),
row = start_row, col = 1
)
# Plot primary signals
for i, stype in enumerate(signal_type):
row_id = start_row + i
unit = _DEFAULT_SIGNAL_PARAMS.get(stype, None)['unit']
color = _DEFAULT_SIGNAL_PARAMS.get(stype, None)['color']
has_peaks = (peaks_map is not None)
has_artifacts = (artifacts_map is not None)
has_edits = (edits_map is not None)
has_corrected_peaks = (correction_map is not None)
for j, ycol in enumerate(axes_dict[stype]):
if (has_edits and 'Unusable' in list(edits_map.values())[0].keys()
and 'Unusable' in sig_seg.columns):
sig_traces = dict(
x = sig_seg[ax_x],
y = sig_seg[ycol].where(sig_seg.Unusable != 1, np.nan)
)
else:
sig_traces = dict(x = sig_seg[ax_x], y = sig_seg[ycol])
trace_name = ycol if len(axes_dict[stype]) > 1 else stype
fig.add_trace(
go.Scatter(
**sig_traces,
mode = 'lines',
connectgaps = False,
line = dict(color = color),
hovertemplate = f'%{{x}}<br><b>{trace_name}:</b> %{{y:.2f}} '
f'{unit}<extra></extra>',
name = trace_name
),
row = row_id, col = 1)
# Plot peaks if provided
if has_peaks:
label = (
peaks_label
or _DEFAULT_SIGNAL_PARAMS.get(stype, {}).get('peak')
)
hover = f'<b>{label}</b><extra></extra>'
peaks_col = peaks_map.get(stype, None)
if not peaks_col:
continue
_peak_color = '#f9c669' if peaks_color is None else peaks_color
first_y = axes_dict[stype][0]
fig.add_trace(
go.Scatter(
x = sig_seg.loc[sig_seg[peaks_col] == 1, ax_x],
y = sig_seg.loc[sig_seg[peaks_col] == 1, first_y],
name = label,
mode = 'markers',
showlegend = True,
marker = dict(color = _peak_color, size = 8),
hovertemplate = hover
),
row = row_id, col = 1
)
# Plot corrected peaks if provided
if has_corrected_peaks:
corrected_peaks_col = correction_map.get(stype, None)
first_y = axes_dict[stype][0]
fig.add_trace(
go.Scatter(
x = sig_seg.loc[sig_seg[corrected_peaks_col] == 1, ax_x],
y = sig_seg.loc[sig_seg[corrected_peaks_col] == 1, first_y],
name = 'Auto-Corrected Beat',
mode = 'markers',
showlegend = True,
marker = dict(color='rgba(250, 250, 250, 0.0)', line = dict(color = 'forestgreen', width = 1.5), size = 7),
hovertemplate = f'<b>Auto-Corrected Beat</b><extra></extra>'
),
row = row_id, col = 1
)
# Plot edits if provided
if has_edits:
edits_cfg = edits_map.get(stype, None)
for edit_type, col in edits_cfg.items():
if col not in sig_seg.columns:
continue
style = _EDIT_STYLES.get(edit_type, None)
hover_label = style['name']
hover = f'<b>{hover_label}</b><extra></extra>'
edit_mask = pd.to_numeric(
sig_seg[col], errors = 'coerce').eq(1)
unus_mask = pd.to_numeric(
sig_seg['Unusable'], errors = 'coerce').eq(1) \
if 'Unusable' in sig_seg.columns \
else pd.Series(False, index = sig_seg.index)
x_vals = sig_seg[ax_x]
if col == 'Unusable':
# Draw a line only where Unusable == 1
y_vals = sig_seg[ycol].where(unus_mask, np.nan)
fig.add_trace(
go.Scatter(
x = x_vals, y = y_vals,
connectgaps = False,
showlegend = True,
hovertemplate = hover,
**style
),
row = row_id, col = 1
)
else:
mask = edit_mask & ~unus_mask
y_vals = sig_seg[ycol].where(mask, np.nan)
fig.add_trace(
go.Scatter(
x = x_vals, y = y_vals,
connectgaps = False,
showlegend = True,
hovertemplate = hover,
**style
),
row = row_id, col = 1
)
# Plot artifactual beats if provided
if has_artifacts:
artifacts_col = artifacts_map.get(stype, None)
first_y = axes_dict[stype][0]
fig.add_trace(
go.Scatter(
x = sig_seg.loc[sig_seg[artifacts_col] == 1, ax_x],
y = sig_seg.loc[sig_seg[artifacts_col] == 1, first_y],
name = 'Potential Artifact',
mode = 'markers',
showlegend = True,
marker = dict(color = 'red', size = 8),
hovertemplate = f'<b>Potential Artifact</b><extra></extra>'
),
row = row_id, col = 1
)
# Plot IBI signal in the last subplot if provided
if ibi is not None:
if len(ibi) != len(signal):
warnings.warn('`ibi` and `signal` have unmatched lengths. '
'Resampling `ibi`.')
ibi_y = scipy_resample(ibi[ibi_col], len(signal))
else:
ibi_y = ibi[ibi_col]
ibi_y_seg = ibi_y[seg_start:seg_end].copy()
fig = _ibi_subplot(sig_seg[ax_x], ibi_y_seg, fig)
if has_ibi_corrected:
ibi_y = ibi_corrected[ibi_corrected_col]
ibi_y_seg = ibi_y[seg_start:seg_end].copy()
fig = _ibi_subplot(
sig_seg[ax_x], ibi_y_seg, fig,
line_dict = dict(color = 'rgba(34, 139, 33, 0.5)',
width = 2.0),
name = 'Auto-Corrected IBI')
# General figure formatting
x_min, x_max = sig_seg[ax_x].min(), sig_seg[ax_x].max()
fig.update_xaxes(
tickfont = dict(size = 14),
tickcolor = 'grey',
linecolor = 'grey',
range = [x_min, x_max]
)
fig.update_layout(
height = fig_height,
title_text = fig_title,
template = 'simple_white',
font = dict(family = 'Poppins', color = 'black'),
legend = dict(font = dict(size = 16), orientation = 'h',
yanchor = 'bottom', y = 1.05,
xanchor = 'right', x = 1.0),
annotations = [dict(
text = ax_x, x = 0.5, y = -0.22, showarrow = False,
xref = 'paper', yref = 'paper', font = dict(size = 16)
)],
margin = dict(l = 20, r = 20, t = 60, b = 70)
)
# Add y-axis labels
for i, stype in enumerate(signal_type):
row_id = start_row + i
unit = _DEFAULT_SIGNAL_PARAMS.get(stype, {}).get('unit', stype)
fig.update_yaxes(
title_text = unit,
title_standoff = 5,
row = row_id, col = 1
)
# Enforce grid lines on all subplots
for yaxis_name in [k for k in fig.layout if k.startswith('yaxis')]:
fig.layout[yaxis_name].update(
showgrid = True,
gridcolor = 'lightgrey',
griddash = 'dot',
gridwidth = 0.5,
tickcolor = 'grey',
linecolor = 'grey'
)
return fig
[docs]
def write_beat_editor_file(
data: pd.DataFrame,
fs: int,
signal_col: str,
beats_col: str,
ts_col: Optional[str] = None,
filename: Optional[str] = None,
batch: bool = False,
verbose: bool = True
) -> None:
"""
Create a JSON file for input to the Beat Editor.
Parameters
----------
data : pandas.DataFrame
A DataFrame containing the cardiac data. Must contain at least
two columns with the cardiac signal and beat occurrences labeled as
1. Optionally, `data` can include a timestamp column (specified by
`ts_col`) and an "Artifact" column, where artifact occurrences are
labeled as 1. This allows the Beat Editor to visualize artifactual
beat locations. If a timestamp columm is not provided, sample indices
are used.
fs : int
The sampling frequency of the signal.
signal_col : str
The name of the column in `data` containing the cardiac signal.
beats_col : str
The name of the column in `data` containing beat occurrences.
ts_col : str, optional
The name of the column in `data` containing the timestamps. If not
provided, timestamps are assumed to correspond to the DataFrame index.
filename : str, optional
The name of the JSON file to write. If no filename is provided,
the default filename 'physioview_edit.json' is used.
batch : bool, optional
Whether input data is from a batch; by default, `False`. If `True`,
the JSON file is written to a 'beat-editor/data/batch' subdirectory.
verbose : bool, optional
If `True`, print a confirmation message after writing the JSON file.
Returns
-------
None
"""
from pathlib import Path
df = data.copy()
# Set the output JSON filename
if filename is None:
json_filename = 'physioview_edit.json'
else:
json_filename = filename + '_edit.json'
# Check required columns
required_cols = [('signal_col', signal_col), ('beats_col', beats_col)]
if ts_col:
required_cols.append(('ts_col', ts_col))
for name, col in required_cols:
if col not in df.columns:
raise ValueError(f'`{name}` not found in input data.')
# Check if there are any beats
if df[beats_col].sum() == 0:
warnings.warn('No beat occurrences found in input data.', UserWarning)
# Convert timestamps to `datetime` if provided
if ts_col is not None:
df[ts_col] = pd.to_datetime(df[ts_col])
df.rename(columns = {ts_col: 'Timestamp'}, inplace = True)
else:
if 'Sample' not in df.columns:
df.insert(0, 'Sample', df.index + 1)
# Add 'Segment' column if missing
if 'Segment' not in df.columns:
df.insert(0, 'Segment', (df.index // (fs * 60)) + 1)
# Rename columns for Beat Editor formatting
df.rename(columns = {signal_col: 'Signal', beats_col: 'Beat'},
inplace = True)
# Save to JSON
root = Path(__file__).resolve().parents[1]
if batch:
data_dir = root / 'beat-editor' / 'data' / 'batch'
else:
data_dir = root / 'beat-editor' / 'data'
data_dir.mkdir(parents = True, exist_ok = True)
json_path = data_dir / json_filename
df.to_json(json_path, orient = 'records', date_format = 'epoch',
lines = False)
if verbose:
print(f'Beat Editor JSON file written to {json_path}')
[docs]
def process_beat_edits(
orig_data: pd.DataFrame,
edits: pd.DataFrame
) -> pd.DataFrame:
"""
Apply manual corrections from the Beat Editor output to original data.
Edits are aligned either by sample index or timestamp, depending on the
structure of `orig_data`.
Parameters
----------
orig_data : pandas.DataFrame
A DataFrame containing the original cardiac data inputted to the Beat
Editor. Must contain a 'Beat' column and either:
- 'Timestamp' column (datetime), or
- 'Sample' column (integer sample indices)
edits : pandas.DataFrame
A DataFrame of edit instructions parsed from a Beat Editor
`_edited.json` file. Must contain:
- 'editType': one of 'ADD', 'DELETE', or 'UNUSABLE'
- either 'x' (edit location) or 'from' (start of unusable segment,
with 'to' as the end), in the same time or sample units as
`orig_data`
Returns
-------
processed : pandas.DataFrame
A copy of `orig_data` with the following additional columns:
- 'Edited': 1 where all final beats are, otherwise `NaN`
- 'Deleted Beat': 1 where beats were deleted, otherwise `NaN`
- 'Added Beat': 1 where beats were added, otherwise `NaN`
- 'Unusable': 1 where segments are marked unusable, otherwise `NaN`
"""
# Validate edits input
if not {'editType'}.issubset(edits.columns):
raise ValueError("`edits` must include columns 'editType'.")
processed = orig_data.copy()
processed['Edited'] = processed['Beat'].values
beat_edits = pd.DataFrame()
unusable_edits = pd.DataFrame()
if 'x' in edits.columns:
beat_edits = edits[['x', 'editType']].dropna(subset = ['x']).copy()
if {'from', 'to'}.issubset(edits.columns):
unusable_edits = edits.dropna(subset = ['from', 'to']).copy()
# Map by timestamp
has_ts = 'Timestamp' in processed.columns
if has_ts:
# Convert all timestamps to datetime format
if not np.issubdtype(processed['Timestamp'].dtype, np.datetime64):
processed['Timestamp'] = pd.to_datetime(
processed['Timestamp'], errors = 'coerce')
# Map timestamps of edited beats to nearest timestamp
if not beat_edits.empty:
beat_edits['Timestamp'] = pd.to_datetime(
beat_edits['x'], unit = 'ms', errors = 'coerce')
# Map edited timestamps to their nearest timestamps
left = processed.sort_values('Timestamp')
right = beat_edits.sort_values(['Timestamp']).drop_duplicates(
subset = ['Timestamp'], keep = 'last').reset_index(drop = True)
processed = pd.merge_asof(
left, right, on = 'Timestamp', direction = 'nearest',
tolerance = pd.Timedelta(milliseconds = 2))
# Record 'Unusable' portions
if not unusable_edits.empty:
ft = unusable_edits.copy()[['from', 'to']].copy()
ft['from_ts'] = pd.to_datetime(
ft['from'], unit = 'ms', errors = 'coerce')
ft['to_ts'] = pd.to_datetime(
ft['to'], unit = 'ms', errors = 'coerce')
if not ft.empty:
ts = processed['Timestamp'].to_numpy()
sorter = np.argsort(ts)
ts_sorted = ts[sorter]
def __nearest_idx(arr_sorted, arr_sorter, query_ts):
pos = np.searchsorted(arr_sorted, query_ts, side = 'left')
left = np.clip(pos - 1, 0, len(arr_sorted) - 1)
right = np.clip(pos, 0, len(arr_sorted) - 1)
ld = np.abs(arr_sorted[left] - query_ts)
rd = np.abs(arr_sorted[right] - query_ts)
pick = np.where(rd < ld, right, left)
return arr_sorter[pick]
start_ix = __nearest_idx(
ts_sorted, sorter, ft['from_ts'].to_numpy())
end_ix = __nearest_idx(
ts_sorted, sorter, ft['to_ts'].to_numpy())
for s, e in zip(start_ix, end_ix):
if s > e:
s, e = e, s
processed.loc[s:e, 'Unusable'] = 1
# Map by sample number
else:
# Convert all sample values to numeric type
processed['Sample'] = pd.to_numeric(
processed['Sample'], errors = 'coerce')
processed['Sample'] = processed['Sample'].astype('int64')
if not beat_edits.empty:
beat_edits = beat_edits[['x', 'editType']].dropna(subset = ['x']).copy()
beat_edits['Sample'] = pd.to_numeric(
beat_edits['x'], errors = 'coerce').round().astype('int64')
# Map edited sample indices to their nearest samples within 1 sample
left = processed.sort_values('Sample')
right = beat_edits.sort_values('Sample').drop_duplicates(
subset = 'Sample', keep = 'last')
processed = pd.merge_asof(
left, right, on = 'Sample', direction = 'nearest',
tolerance = 1)
# Record 'Unusable' portions
if not unusable_edits.empty:
for (f, t) in unusable_edits[['from', 'to']].values:
if pd.isna(f) or pd.isna(t):
continue
s = int(round(f)); e = int(round(t))
if s > e:
s, e = e, s
processed.loc[processed.Sample.between(s, e), 'Unusable'] = 1
# Record final edited beat occurrences
deletions_ix = processed[processed['editType'] == 'DELETE'].index.values
additions_ix = processed[processed['editType'] == 'ADD'].index.values
processed.loc[deletions_ix, 'Deleted Beat'] = 1
processed.loc[additions_ix, 'Added Beat'] = 1
processed.loc[deletions_ix, 'Edited'] = np.nan
if 'Unusable' in processed.columns:
processed.loc[processed['Unusable'].eq(1), 'Edited'] = np.nan
processed.loc[additions_ix, 'Edited'] = 1
processed = processed.drop(columns = ['x', 'editType'], errors = 'ignore')
return processed