Rationale:
Deep learning (DL) models offer great promise for improving speed and quality of diagnosis and treatment in medicine. However, a major flaw with these methods is that they tend to be overconfident in cases where humans would quickly realize that they were out of their depth. This is due to an underlying assumption that variability encountered by the model after being deployed is drawn from the same distribution as the variability present in its training data. In practice, it is difficult to ensure all real-world samples are drawn from the same distribution as the training data. The consequences are typically minor in consumer applications, but in medical use, this overconfidence could lead to misdiagnosis, injury or death. It is thus critical for a model used in medical applications to detect if incoming medical samples are drawn far away from the training distribution, as these are situations when it is likely to fail
Method:
We leverage a framework based on induced metrics on hierarchical vector spaces (similar to the framework outlined by Lee, NeurIPS, 2018) to identify when a model has not encountered samples from a distribution during training. We take our pre-trained EEG seizure classification model (Saab, npj Digital Medicine 2020) and extract features from its hidden layers. We approximate the probability distribution over training data with multivariate gaussians using class-wise means and covariance matrix derived from these training features. We assign a score for each test sample based on the minimum Mahalanobis distance from the class-wise means. We develop an algorithm to use these scores in a streaming fashion to identify when the data distribution shifts and minimize the number of time steps for detection.
Results:
We provide experimental results for EEG seizure detection to validate our formulation for time-series based classification tasks. The Mahalanobis scores are low when the incoming segment of the data stream is in-distribution and high when it is out of distribution (Fig 1). When applying our detection algorithm on the scores, we are able to predict when the EEG data goes out-of-distribution and hence, when the model is likely to fail. We demonstrate this across EEGs from various data distributions, drawn from different age populations (adult, child and neonates) and across different hospital centers (Stanford, LPCH, Temple). Our contribution includes a theoretical framework to detect and quantify distribution shifts based on features derived from the trained model.
Conclusion:
Our metric successfully detects when EEG data appears is “far” from the data used to train the classifier. The algorithm thus knows when to discount its own predictions and improve safety by warning human operators or ""ask for help."" Moreover, our framework naturally lends deployed clinical models an opportunity to engage in active learning and incremental class learning.
Funding:
:This work was supported in part by the Stanford Wu Tsai Neuroscience Foundation
FIGURES
Figure 1