In this section, we introduce the proposed TIA-Net. The framework of TIA-Net is displayed in Fig. 1, and its processes, including two main operations: (1) transferring general features from similar ophthalmic dataset and (2) extracting specific features based on transfer induced attention, are further highlighted as blue and green blocks, respectively, in the figure. To learn general features, we first pre-train base CNN network on labeled cataract dataset and explore the best settings of transferred layers on glaucoma detection. We then transfer general features into TIA-Net to help learn specific features, and optimize the weights according to the loss of Eq. (5) using both source and target data. Note that we rely on mini-batches for training, since large batch sizes will increase the computational cost.
Data
In the medical field, the digital fundus screening is a popular diagnostic examination, since it is safe and efficient to analyze the changes of hypertensive retinopathy and arteriosclerosis in patients with various eye diseases. The retinal fundus images used in this paper contain two categories: the glaucoma images for target dataset and the cataract images for source dataset, respectively, which are all manually labeled by professional ophthalmologists from Beijing Tongren Eye Center. Subjects in the dataset are mainly from northern China. Among them, the proportion of males is around 48%; the remaining 52% are females. The age range of the subjects in the dataset is from 10 to 90.
The glaucoma dataset contains 1882 retinal fundus images, including non-glaucoma (1005) and glaucoma (877), where the uniform size of each image is 2196 \(*\) 1740 pixels. There are some common pathological features of fundus images for glaucoma diagnosis, such as increased cup–disc ratio, retinal nerve fiber layer defect (RNFLD), peripapillary atrophy (PPA). In the retinal image, the optic disc is a vertical shallow ellipse, and the center of the optic disc is a white cup area, as shown in Fig. 9. The measurement of cup-to-disc ratio is the ratio of the area diameter of the optical cup-disc to the diameter of the optic disc [40]. Patients with glaucoma usually have a large cup-to-disc ratio; for example, when the ratio is greater than 0.5, glaucoma probably occur [4]. RNFLD is the lesion area in the fundus images (a roughly wedge-shaped region starting from the optic disc), which is one of the features to identify glaucoma [41]. Besides, PPA, a green area around the optic disc, is another major feature of glaucoma images [5]. We can find that these special features clearly appear in Fig. 9 (where Fig. 9b is a glaucoma image, while (a) is a normal condition).
The cataract dataset used in our experiment comprises of 10463 retinal fundus images (3023 \(*\) 2431 pixels), including non-cataract (3314), mild (2331), moderate (2245), and severe (2573) cataract images. Note that all diagnosis results are based on the unified grading standard [42,43,44]. Figure 10 shows four samples of cataract patients of varying degrees. Figure 10a is a cataract-free image, where the optic disc, large and small blood vessels are visible. Figure 10 (b) has fewer vascular details in moderate-to-mild cataract images, while in Fig. 10c, only large vessels and optic discs can be seen in moderate cataract images. In addition, in Fig. 10d, the severe cataract image, there is hardly anything to see. Based on these retinal fundus images, we can conclude that blood vessels and optic discs are the main references for cataract detection and classification.
Transferring general features from similar ophthalmic dataset
As a kind of deep learning network, CNN is used in the field of image recognition to learn features automatically. Having a weight sharing network structure that is more similar to the biological neural network, CNN reduces the complexity of the network model. This advantage is more obvious when input of the network is multidimensional image. The kind of image can be used as the input of the network directly, thus avoiding the complex feature extraction and data reconstruction process of the traditional recognition algorithm. Therefore, we adopt an extension of a classic CNN network in [45], as the base model for transfer learning in our experiment. The base CNN network possesses a structure of seven layers: five convolutional layers and two fully connected (FC) layers. In the convolution layer, feature maps computed in the previous layer are convolved with a set of weights, the so-called filters. The generated feature maps are then passed through a nonlinearity unit, the rectified linear unit (RELU). Next, in the pooling layer, each feature map is subsampled with pooling over a contiguous region to produce the pooled maps. After performing convolution and pooling in the fifth layer, the output is then fed into fully connected layers to perform the classification. Besides, data augmentation and dropout methods are adopted to reduce overfitting.
In a trained CNN, features of the shallow layer are general, while those of the higher layer are task-specific; meanwhile, the middle layers transit gradually from general to specific, forming a hierarchical multilayer architecture [33]. The general layers are typically used to extract local edge features similar to Gabor filters. As shown in Fig. 11, we visualize feature maps and the corresponding deconvolution results of the first convolution layer. We can find that general features, such as edges and line segments of the fundus image, are extracted in different directions. Figure 11a–c tends to extract the edge contour features in − 45, 45, and 90 degree directions, respectively. When a pre-trained CNN structure is fine-tuned, the layers have to be frozen consecutively, so that any updated weight in the unfrozen shallower layers can be propagated to deeper layers. However, when transferring features from a less related source dataset, it may inversely hurt the transferability of general features.
Hence, rather than extracting general features from non-medical dataset, we transfer the weights of shallow layers, which are optimized to recognize the generalized structures in cataract dataset (shown in blue blocks in Fig. 1), and then retrain the weights of the deep layers with glaucoma dataset propagation. This strategy helps to identify the distinguishing features of glaucoma fundus images more accurately under limited supervision.
Extracting specific features based on transfer induced attention
Specialization of deep layer neurons for the target task is based on general features. However, there still exists redundant regions in the fundus image when capturing specific features from general features of similar ophthalmic datasets. For example, the edge regions of the eyeball or other unrelated pathological areas are redundant for the glaucoma detection. To effectively refine specific features and remove irrelevant redundancy, we use a soft attention design across channels to replace the original CNN architecture.
As it is known, attention mechanism has been successfully applied in deep learning architecture, since it can locate the most salient parts of the features [46,47,48,49]. This meritorious property conforms to human visual perception: instead of trying to deal with the whole scene at the same time, human beings use a series of local glimpses to selectively focus on the prominent parts to better capture the visual structure [50]. As shown in the green block of Fig. 1, a transfer induced attention module is produced by utilizing the inter-channel relationship of general transferred features. In our transfer processing, each learned filter operates with a local receiving field; therefore, each unit of the transferred general features \({\mathbf {G}}\) is unable to exploit contextual information outside of this region. To tackle this issue, we use global average pooling (GAP) to compress the global spatial information, which helps to accelerate specific features extraction on glaucoma critical areas. Specifically, the element of \({\mathbf {o}}\) is generated by shrinking \({\mathbf {g}}\) through spatial dimensions \(W \times H\):
$$o = {\text{GAP}}\left( \mathbf{g} \right) = \frac{1}{{W \times H}}\sum\limits_{{i = 1}}^{W} {\sum\limits_{{j = 1}}^{H} \mathbf{g} } (i,j).$$
(3)
GAP descriptor is then forwarded to FC layers which aims to recalibrate channel information adaptively:
$$\left. {{\mathbf{m}} = FC({\mathbf{o}}) = \sigma \left( {{\mathbf{W}}_{1} \left( {{\mathbf{W}}_{0} {\mathbf{o}}} \right)} \right)} \right),$$
(4)
where \(\sigma\) refers to the sigmoid activation function, \({\mathbf {W}}_{1}\) and \({\mathbf {W}}_{0}\) are the FC layer weights, and \({\mathbf {m}}\) is our channel-wise attention map.
To get final specific feature \({\mathbf {P}}\), we reweight the original transferred general feature \({\mathbf {G}}\) with the channel attention map \({\mathbf {m}}\):
$$\begin{aligned} {\mathbf {P}}={\mathbf {G}} \otimes {\mathbf {m}}, \end{aligned}$$
(5)
where \(\otimes\) denotes element-wise multiplication. During multiplication, the attention values are broadcasted accordingly. Besides, the attention-based specific feature \({\mathbf {P}}\) can help us highlight the discriminative regions by masking the original fundus image, which contributes to improve interpretability of our proposed model. When pre-training base CNN model on the source dataset, the cross-entropy \(L_{ce}\) between the predicted label and its corresponding true label is defined as the loss function. When transferring general features to learn specific features, a new loss function is redefined by integrating three parts:
$${\text{Loss}} = L_{{{\text{ce}}}} \left( {{\mathbf{X}}_{s} ,{\mathbf{Y}}_{s} } \right) + L_{{ce}} \left( {{\mathbf{X}}_{t} ,{\mathbf{Y}}_{t} } \right) + \lambda L_{{{\text{Disc}}}} ,$$
(6)
where \({\mathbf {X}}_{s}\) and \({\mathbf {X}}_{t}\) refer to the sets of training images from the source and target datasets, respectively, and is \(\lambda\) is non-negative regularization parameter. And the first and second parts represent the classification loss of corresponding dataset. The third term, discrepancy loss, aims to measure the distance of the feature vectors computed from the source and target datasets. Following the popular trend in transfer learning [51, 52], we rely on on the Maximum Mean Discrepancy (MMD) [53] to encode this distance. Supposed that \(N_{s}\) and \(N_{t}\) are the number of source and target samples respectively, then the \(L_{Disc}\) is calculated through Eq. (5):
$$\begin{aligned} {\text {MMD}}^{2}\left( {\mathbf {m}}_{s}, {\mathbf {m}}_{t}\right) =\left\| \sum _{i=1}^{N_{s}} \frac{\phi \left( {\mathbf {m}}_{s}\right) }{N_{s}}-\sum _{j=1}^{N_{t}} \frac{\phi \left( {\mathbf {m}}_{t}\right) }{N_{t}}\right\| ^{2}, \end{aligned}$$
(7)
where \(\phi (\cdot )\) denotes the mapping to RKHS. For network optimization, the mini-batch stochastic gradient descent (SGD) and back-propagation algorithm are used in this paper.