Abstract
Accurate and reliable prediction of Alzheimer’s disease (AD) progression is crucial for effective interventions and treatment to delay its onset. Recently, deep learning models for AD progression achieve excellent predictive accuracy. However, their predictions lack reliability due to the non-calibration defects, that affects their recognition and acceptance. To address this issue, this paper proposes a temporal attention-aware evidential recurrent network for trustworthy prediction of AD progression. Specifically, evidential recurrent network explicitly models uncertainty of the output and converts it into a reliability measure for trustworthy AD progression prediction. Furthermore, considering that the actual scenario of AD progression prediction frequently relies on historical longitudinal data, we introduce temporal attention into evidential recurrent network, which improves predictive performance. We demonstrate the proposed model on the TADPOLE dataset. For predictive performance, the proposed model achieves mAUC of 0.943 and BCA of 0.881, which is comparable to the SOTA model MinimalRNN. More importantly, the proposed model provides reliability measures of the predicted results through uncertainty estimation and the ECE of the method on the TADPOLE dataset is 0.101, which is much lower than the SOTA model at 0.147, indicating that the proposed model can provide important decision-making support for risk-sensitive prediction of AD progression.
Keywords
Introduction
Alzheimer’s Disease (AD) is a progressive and irreversible neurodegenerative disease leading to cognitive impairment and behavioral dysfunction. There are currently no drugs or clinical treatments available that can effectively reverse or cure the disease [1, 2]. However, current research indicated that AD treatment strategies should intervene during the early stages of the disease to delay its onset [3]. Therefore, predicting the progression of AD is clinically significant [4].
Recently, machine learning and deep learning methods have been successfully applied to practical scenarios for AD diagnosis and progression modeling. Some methods focused on predicting one single time point of AD diagnosis using cross-sectional data, such as support vector machine (SVM) [5], Gaussian processes [6], heterogeneous ensemble learning [7], and random forest [8]. These methods exhibited excellent performance for classifying multi-modal data at a single time point. However, the prediction of AD progression often relies on historical longitudinal data, in the actual scenarios. Therefore, recent studies shifted towards deep learning methods represented by recurrent neural networks (RNN) [9, 10, 11, 12, 13] to model the longitudinal data for AD.
However, existing work typically focused on improving predictive accuracy and used softmax activation function at the end of the networks to output the probability and the confidence, ignoring the flaw that the confidence does not match actual accuracy of the model. This leads the model to be in an uncalibrated state [14], which affects user recognition and acceptance. In medical diagnosis, a risk-sensitive scenario, not only high predictive accuracy is required, but also a reliability measure of the predicted results to avoid the threat of incorrect predictions [15, 16]. Figure 1a shows the confidence histogram of the MinimalRNN proposed by Nguyen et al. [11] on the TADPOLE dataset [2]. It can be observed from its distribution that there is a considerable discrepancy between the average confidence and the actual accuracy. This is further illustrated in the reliability diagram in Fig. 1b, which shows the gap between the predicted confidence and the well-calibrated one of each confidence bin. Furthermore, although existing AD progression prediction models employed methods such as ‘gated units’ or bidirectional RNNs [13] to process historical longitudinal data, they failed to distinguish the importance of historical data at different time points for predicting disease progression. This causes information from more important historical time points being neglected and thus affecting the model’s accuracy.
Confidence histograms and reliability diagrams for MinimalRNN on TADPOLE dataset.
In order to address the problems mentioned above, this paper proposes temporal attention-aware evidential recurrent network for trustworthy prediction of AD progression. This model utilizes historical longitudinal data from AD at-risk individuals to provide reliable diagnoses of disease progress in the future.1
The code for this paper can be accessed via
We propose evidential recurrent network, which introduces evidential deep learning regularized by the AvUC loss on the basis of FastGRNN, to address the defect of non-calibration for existing AD progression models. To the best of our knowledge, the proposed model is the first work estimating uncertainty in AD progression. We integrate temporal attention into the model to adaptively weight the historical features of different time points, as AD progression prediction often relies on historical longitudinal data. This approach effectively improves the model’s predictive performance. We conduct experiments on the TADPOLE dataset, which demonstrate that the proposed model not only achieves predictive performance comparable to the state-of-the-art models, but also provides reliability measures through uncertainty, which is not present in current AD progression models.
Related works for Alzheimer’s disease progression
The prediction of the Alzheimer’s disease (AD) progression involves modeling longitudinal data to forecast future progression of the disease based on the historical diagnosis and AD biomarkers from AD at-risk individuals. To this end, Nguyen et al. [9] successively used LSTM which can address missing date with model-filled strategy and MinimalRNN [11] that can also perform model-filling with fewer parameters for longitudinal data modeling. Additionally, Ghazi et al. [10] improved the LSTM architecture and designed a corresponding loss function to address the problem of missing historical longitudinal data. Ho et al. [13] employed bidirectional LSTM and focal loss to enhance the performance of the AD progression prediction. Ghazi et al. [12] used integrated deep learning based on LSTM and continuous-time autoregressive modeling to further improve model performance (CARRNN). Xu et al. [32] proposed deep latent representation collaborated sequence learning strategy to flexibly handle the incomplete variable-length longitudinal multi-modal data of AD progression.
However, these models still fail to distinguish the relative importance of historical data at different time points for disease progression prediction, leading to information from more important historical time points being ignored. Attention [18, 19] is used to address this issue by adaptively assigning weights to data at different time point. By introducing self-Attention [19], the RNN model assigns weights to the historical features of all previous time points instead of concentrating solely on the last hidden state when updating the hidden state, which effectively improves the performance in processing longitudinal data. Recently, a lot of work brought attention into the field of disease progression prediction and achieved excellent results. Tan et al. [33] designed a novel unreliability-aware attention to handle irregular multivariate time series physiological data. Lu et al. [34] introduced attention into the LSTM to perform multi-disease prediction for intelligent clinical decision support. Luiz et al. [35] proposed a model based on Bidirectional Gated Recurrent Units along with attention to predict Parkinson’s Disease using tremor signals collected from handwriting exams. Liang et al. [36] introduced attention and smoothness regularization to handle AD progression prediction. More information about approaches for AD progression is shown in Table 1, where
Nevertheless, they overlook the fact that the output confidence of the model lacks calibration owing to the wide use of softmax activation function in the classification stage, which affects user recognition and acceptance of such models in risk-sensitive scenarios. Evidential Deep Learning (EDL) [17] recently gained widespread use in medical imaging processing due to its ability to provide an effective uncertainty estimation for uncalibrated predictive confidence and to convert this uncertainty into a reliable measure of the output result. Ghesu et al. [25] employed EDL for binary classification tasks in medical image diagnosis. Specifically, they used Beta distribution, a special case of the Dirichlet distribution to model the uncertainty of output. Li et al. [26] utilized EDL for brain tumor image segmentation, quantifying segmentation uncertainty to boost trustworthiness in the model. Zou et al. [27] applied EDL to a wider range of medical image segmentation fields, including 2D, 3D, and multimodal segmentation tasks, and achieved remarkable results.
Preliminary knowledge
Evidential deep learning (EDL)
The existing deep classification models typically use softmax activation function to obtain predictive confidence. However, due to the lack of calibration, the confidence cannot accurately reflect the reliability of model’s predictions, thus restricting the application of such models in risk-sensitive scenarios. To address this issue, Evidential Deep Learning (EDL) employs ReLU activation function to model the output of neural networks as evidence in subjective opinions, and parameterizes the evidence into parameters of Dirichlet distribution. EDL provides reliable uncertainty estimation for uncalibrated confidence based on the subjective logic theory and DS evidence theory. This enables the model to characterize the reliability measure via the uncertainty of the predicted results while outputting the probability of the category. Specifically, the Dirichlet distribution models the probability distribution of the parameter of category distribution, and its probability density function is shown in Eq. (1):
where
where
The category probability of the prediction can be obtained by calculating the mean of the Dirichlet distribution, as shown in Eq. (4):
where
Evidence of subjective opinion is learned by optimizing the MSE Bayesian risk loss for modeling the Dirichlet distribution, as shown in Eq. (5):
where
where
Block diagrams for FastRNN.
where
FastGRNN [20] is a RNN model using weight sharing and residual connection to mitigate issues such as gradient explosion, gradient vanishing, and information forgetting. With a small number of parameters, it is very suitable for modeling medical small-sample data. In fact, FastGRNN demonstrates superior performance over traditional gated RNN models such as LSTM and GRU [28] in various applications. The model architecture is depicted in Fig. 2, and the update of the hidden state is illustrated in Eqs (8)–(10):
where
Problem setups
For the prediction of AD progression, the problem considered in this paper is to use the historical diagnostic results and multimodal AD biomarkers from AD at-risk individuals to anticipate future diagnoses of the disease within a certain period of time, while providing the uncertainty of the predicted outcome as a measure of its reliability.
Proposed model
Architecture of the overall model
To address the above-mentioned problem, this paper proposes temporal attention-aware evidential recurrent network. The model consists of three modules: the model filling module, the feature extraction module and the prediction output module, as shown in Fig. 3.
Architecture of the proposed temporal attention-aware evidential recurrent network.
We propose evidential recurrent network, which introduces evidential deep learning [17] regularized by the AvUC loss [21] on the basis of FastGRNN [20]. In its classification evidence-aware subnetwork, the temporal features
Block diagrams for evidential recurrent network and classification evidence-aware subnetwork.
In order to further calibrate the uncertainty and make it more reliable, i.e., produces smaller uncertainty for correctly classified samples and larger uncertainty for misclassified ones, the AvUC loss [21] is introduced in this paper as shown in Eq. (11):
where
where
the first term of this loss function is designed to guarantee the accuracy of classification, whereas the second term aims to reduce the output variance.
where
To fully utilize historical time points that are crucial for disease progression prediction, this paper introduces the temporal attention into evidential recurrent network to extract temporal features from the complete data after model filling. Temporal attention is illustrated in Fig. 5.
Block diagrams for temporal attention.
The evidential recurrent network with the introduction of temporal attention updates its hidden state with a focus on the crucial time points for the task among all previous time points’ hidden states. This is achieved by adaptively weighting the historical features of different time points to differentiate the importance of hidden states. In the process of updating the hidden state, context information
where
In this paper, a composite loss function is designed to train the temporal attention-aware evidential recurrent network. The loss function of the model is shown in Eq. (23):
where
Dataset
In this paper, we validate the effectiveness of the proposed method using the TADPOLE dataset provided by the Alzheimer’s Disease Neuroimaging Initiative (ADNI) [29]. The dataset consists of 12,741 visit records from 1,737 AD at-risk individuals recorded between 2003 and 2007. Disease status for each individual include AD, mild cognitive impairment (MCI), and cognitively normal (CN). The age of the individuals in this dataset ranges from 54.40 to 91.40 AD biomarkers measurements and diagnostic results are recorded at approximately 6-month intervals starting from baseline time. The sampling times, total sampling durations and time intervals for each individual are different. The average number of samples per individual is 7.3
To facilitate comparison, this paper follows the approach of Nguyen et al. and selects 22 feature variables from TADPOLE dataset as the measure of AD biomarkers. These feature variables include neurocognitive test scores, magnetic resonance imaging features, positron emission tomography features and CSF markers and so on. The DX variable is used as the diagnostic result. For the disease diagnosis data, we merge the state of conversion, such as MCI-AD to the state after conversion [10]. For AD biomarkers, the brain organ volume features are normalized by the total intracranial volume (ICV) to compensate for the inter-subject variability in brain size [10]. And then, Z standardization is applied to all AD biomarkers. Additionally, due to the irreversible nature of AD, erroneous data that transitioning from AD to MCI or CN are removed. Moreover, to address variance time intervals of measurements for AD at-risk individuals, the time interval of the proposed model is set to the finest granularity measurement interval. Finally, if data missing occurs at the first time point of the longitudinal data, the mean value of all subsequent times will be used to complete, so that the model filling strategy can work properly [11]. More information about the features and their missing cases in the dataset are shown in Table 3.
Demographics of the TADPOLE dataset
Demographics of the TADPOLE dataset
Feature information and missing cases
The multiclass AUC (mAUC) is a performance evaluation metric for multiclass classification. In this experiment, the mAUC is the average of three binary AUCs, which are AD vs. non-AD, MCI vs. non-MCI, and CN vs. non-CN. The higher the mAUC, the better the classification performance. The calculation of mAUC is shown in Eq. (24):
where
where
where
Balanced Classification Accuracy (BCA) is another performance evaluation metric for multiclass classification. The higher the BCA value, the better the classification performance. The calculation of BCA can be represented by Eq. (27):
where
were
The accuracy versus uncertainty (AvU) utility function [21] is used to assess uncertainty calibration performance. This function divides the predicted outcomes into four groups based on their correctness and the corresponding uncertainty of each sample. The expression for the utility function is shown in Eq. (29):
where
The expected calibration error (ECE) [30] quantifies the gap between model output confidence and actual accuracy, reflecting the calibration performance of the model [14]. The expression for ECE is shown in Eq. (30):
where
In this experiment, the proposed method is used to predict future disease progression for AD at-risk individuals based on historical diagnosis and AD biomarkers. Firstly, the performance of the model for disease progression prediction is compared with that of MinimalRNN [11], GRU [28], LSTM [9], LSS [9] and SVM&SVR [9]. This is detailed in Section 5.5.1. Secondly, in order to verify whether the model can effectively model the predicted uncertainty, the relationship between the uncertainty and predicted accuracy of the proposed method is compared with the relationship between the confidence and predicted accuracy of traditional methods, as detailed in Section 5.5.2. Additionally, the effect of the uncertainty threshold on the model’s predictive performance is studied to determine whether the uncertainty estimation can support reliable decision-making, as detailed in Section 5.5.3. Finally, two sets of ablation experiments are conducted. First of all, the effect of the AvUC loss on the uncertainty calibration performance and model calibration performance is validated, as detailed in Sections 5.6.1 and 5.6.2 respectively. Moreover, the effect of the temporal attention on the performance of model’s disease progression prediction is validated, as detailed in Section 5.6.3.
Hyperparameter search space of the proposed method
Hyperparameter search space of the proposed method
In accordance with the method described by VAROQUAUX [31], the dataset for all of the experiments is randomly divided into 20 subsets. The ratio of subjects in the training, validation and test sets is 18:1:1. For each experiment, the training set is used to train the model, the validation set is used to tune hyperparameters and select the optimal model, and the test set is used to evaluate the performance of the model. The random split of the data into training, validation and test sets is repeated 20 times to ensure stability of results. The final result is calculated by averaging the performance evaluation metrics over the 20 experiments. To ensure comparability with Nguyen et al. [11], the data used for prediction is divided into two parts. The first part serves as the input to the model, while the second part is used to assess its performance by comparing it with the model output, as shown in Fig. 3.
The hyperparameter search space is shown in Table 4. After hyperparameter tuning, we select the optimal set of hyperparameters, with a learning rate of 0.001, batch size of 64, 256 RNN hidden units, 1 RNN hidden layer, 64 MLP hidden layer units. We set RNN input dropout to 0.1, RNN hidden layer dropout to 0.2, and attention dropout to 0.5, weight decay of 0.0001, loss function weight of 2. For optimization, we employ the Adam algorithm and train the model for 300 epochs in each experiment.
We use an Intel(R) Core (TM) i9-9900 CPU @ 3.10GHz and an NVIDIA GeForce RTX 2080 Ti GPU for our experiments. Our experiments are conducted on a 64-bit Windows 10 operating system, with CUDA version 11.6. We use the PyTorch 1.12.0 deep learning framework, with Python version 3.9.
Experimental results
Performance of disease progression prediction
Table 5 presents the mean and standard deviation of mAUC and BCA for our method and five comparison methods (without uncertainty estimation) experimented on 20 test sets. The best performance is highlighted in bold. It is evident that our method outperforms most of the comparison methods in terms of predictive performance and is comparable to the state-of-the-art models. However, the purpose of our method is not solely to pursue higher predictive performance. A vital aspect of the model is to effectively estimate the uncertainty of the prediction. Therefore, we conduct experiments to compare the performance of uncertainty modeling.
Performance of different methods for predicting disease progression
Performance of different methods for predicting disease progression
The change of predicted reliability with respect to accuracy.
In order to verify whether our method can effectively model the uncertainty of model’s prediction and better reflect the predictive reliability, we plot the mean of predictive accuracy and reliability for error samples among MinimalRNN and the proposed method on the test sets across 20 experiments, as shown in Fig. 6. The predicted reliability of MinimalRNN’s is measured by the confidence of the predicted category, while our method’s reliability is measured as 1 minus the predicted uncertainty. Figure 6a shows the comparison of the predicted reliability and accuracy for our method. As the accuracy decreases, the predicted reliability for error samples also decreases, i.e., uncertainty increases, in the CN, MCI, and AD, proving that our method can effectively model the uncertainty of model’s predictions. By contrast, Fig. 6b shows the comparison for MinimalRNN, which yields high confidence for all categories even though the prediction is incorrect! Therefore, the confidence cannot effectively identify the potential errors, and cannot serve as a measure of the reliability of model predictions. At the same time, we give the ECE of the two models to quantitatively measure the gap between the two reliability measures. It can be seen that the ECE of the proposed method is much lower than the MinimalRNN, that is, the reliability of the output can better reflect the actual accuracy of the model.
Effect of the uncertainty threshold on predictive performance
The uncertainty of the model’s output can serve not only as a measure of the model’s reliability, but also to reduce potential erroneous predictions by setting an uncertainty interception threshold. In other words, when the uncertainty of the model’s output exceeds a certain threshold, the decision-making authority should be handed over to a medical professional. It is important to note that samples with uncertainty greater than the threshold do not participate in the performance evaluation since the model indicates the low reliability of the prediction result by outputting a larger uncertainty for such samples. Existing methods, on the other hand, are prone to overconfidence and make incorrect predictions for these samples. As shown in Fig. 7, as the uncertainty threshold decreases, the mAUC and BCA of the model’s prediction increase, demonstrating our method can effectively avoid the risks of misjudgment by implementing an uncertainty threshold and supporting reliable decision-making.
Effect of AvUC loss on uncertainty calibration performance
Effect of AvUC loss on uncertainty calibration performance
The change of model performance with respect to uncertainty threshold.
AvUC loss and uncertainty calibration
To verify whether the AvUC loss can calibrate uncertainty, we compare our proposed method to a model without the AvUC loss on the test sets across 20 experiments. Table 6 shows the mean and standard deviation of the accuracy vs uncertainty (AvU) utility function values. It can be seen that our proposed method has a higher AvU value compared to the model without AvUC loss, indicating a higher proportion of correct and certain predictions as well as incorrect and uncertain predictions, demonstrating that the uncertainty is effectively calibrated.
AvUC loss and model calibration
Furthermore, we also verify the effect of introducing the AvUC loss on model calibration, as shown in Table 7. The table displays the mean and standard deviation of the expected calibration error (ECE) on the test sets across 20 experiments. Compared to the model without the AvUC loss, the model using the AvUC loss has a lower ECE, thus indicating better calibration performance. This suggests that the AvUC loss can not only improve the reliability of the model’s predictive uncertainty, but also effectively enhance the calibration performance of the model.
Temporal attention and predictive performance
In order to validate whether the temporal attention can boost the predictive performance of the model, we compare our model with a model that have the temporal attention removed (ERN). Table 8 presents the mean and standard deviation of mAUC and BCA on the test sets across 20 experiments. It is evident that the model with temporal attention outperforms ERN without temporal attention in terms of mAUC and BCA. This confirms that the temporal attention enables the model to fully utilize the historically time points that are crucial for the task, and thus improving the performance of the prediction for AD progression.
Effect of AvUC loss on model calibration performance
Effect of AvUC loss on model calibration performance
Effect of temporal attention on predictive performance
This paper presents temporal attention-aware evidential recurrent network for trustworthy prediction of Alzheimer’s disease progression. We propose evidential recurrent network, which introduces evidential deep learning regularized by the AvUC loss based on FastGRNN, to explicitly model the output uncertainty, and convert it into a reliability measure for trustworthy AD progression prediction. Meanwhile, the temporal attention allows the model to adaptively weight historical features at different time points to enhance the prediction performance. The performance measure on the TADPOLE dataset shows that the proposed method is comparable to the SOTA method in terms of disease progression prediction, and the ablation experiment proves the validity of temporal attention. For performance of uncertainty modeling, we find that the proposed method can provide a measure of uncertainty that represents the reliability of the predicted results, which is not available in other models. This provides an important decision-making basis for risky-sensitive AD progression prediction applications. Moreover, the AvUC loss can improve not only the performance of uncertainty calibration but also the performance of model calibration in a big degree. Future research will focus on utilizing the predicted uncertainty to adaptively weight training data to further improve the model’s predictive performance while representing the reliability of predicted results.
Footnotes
Acknowledgments
This work is supported by National Key Research and Development Program of China Grant Number 2022YFB3303800, Key Research and Development Program of Jiangsu Province Grant Number BE2021093, National Natural Science Foundation of China Grant Number 62202241, 62006126 and 61872190.
