Deep Learning/NLP

Deep Learning/NLP

[DL] RETAIN 모델 논문정리 / interpretable AI

김쪼욘 2020. 11. 11. 01:55

RETAIN : An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism (NIPS 2016)

 

병원에서 인턴을 하면서 알게되었고 실제로 빅데이터 콘테스트에 참여하며 적용해 보았던 RETAIN을 정리해보았다!

 

EHR 데이터는 high-demensional clinical variables, sequential한 데이터이다. (e.g., 진단기록, 처방기록)

환자가 방문했을 때에 처방 또는 진단받은 기록이 남게 되는데 기존의 통계모델, 머신러닝 기법으로는 진단을 받은 시기와 진단의 종류 사이의 관계를 모두 학습하지 못한다는 단점이 있다. sequential한 데이터에 강세를 보이는 RNN 모델을 사용해서 예측하게 되면 예측력은 높지만, 해석이 어렵다는 단점이 있다. 이전에도 RNN을 사용해 해석하려는 다양한 시도들이 있었지만 이 논문에서는 그런 방법론들이 의료분야에서는 잘 적용되지 못하고 있다고 언급하고있다.

 

실제로 서울대 병원에서 EMR데이터를 활용해 기저질환 또는 약물복용 대한 target 질병의 발생 여부를 예측하는 프로젝트를 진행했을 때, RNN 모델은 예측력은 높지만 해석이 불가능 하기 때문에 전통적인(?) 후향적 연구에서는 잘 사용되지 않는 것 같았다. 예측보다는 질병군에 대한 설명과 그룹간 차이에 대한 통계적 검정 등등,, 을 좀 더 중요하게 여기는 느낌이었다. 물론 이건 우리 연구 이야기고 "정확한 예측"을 중심 주제로 연구하시는 분들은 머신러닝, 딥러닝 기반 모델도 적극적으로 사용하셨다.

 

RETAIN은 예측력과 해석력 모두 갖추어 의료분야의 다양한 연구에서도 딥러닝 기반 모델을 활용할 수 있는 가능성을 확인했다는 의의가 있는 것으로 생각된다.

 

RETAIN 모델의 핵심 구조

  • two - level neural attention model
  • attention generation mechanism
  • reverse time order

출처 : RETAIN 논문

RETAIN test

EHR dataset (14 million visits , 263K patients, 8년)

heart failure을 진단받는 경우를 예측해 RNN과 비교

 

 

Methodology

1) EHR data의 구조와 notation

- 각 환자들에 대한 time-labeled sequence of multivariate observations

- $r$ 개의 변수( 진단, 처방 등)

- $n-th$ 환자 (총 $N$명)

- $(t_i^{(n)} , x_i^{(n)})$ , $ i = 1,...,T^{(n)}$

- $t_i^{(n)}$ : n번 환자의 i번째 방문 시점

- $\{c_1,c_2,...,c_n\}$ , $c_j$ : vocaburary $C$의 $j-th$ code, $r=|C|$  : unique한 질병코드의 총 갯수가 필요함

- $x_i ∈ \{0,1\}^{|c|}$ : i 번째 방문에 $c_{j}$진단여부에 대한 binary data

- $y ∈ \{0,1\}^s$ : s는 1개 이상일 수 있음.

 

- input vector $x_{i}$ example : [[[1,2,3], [4,5,6,7]], [[2,4], [8,3,1], [3]]]

 : 첫번째 환자는 첫번째 방문에 1,2,3 질병 진단     $x_{1}^{(1)}$

 : 첫번째 환자는 두번째 방문에 4,5,6,7 질병 진단  $x_{2}^{(1)}$

 : 두번째 환자는 첫번째 방문에 2,4 질병 진단        $x_{1}^{(2)}$

 : 두번째 환자는 두번째 방문에 8,3,1 질병 진단     $x_{2}^{(2)}$

                                              ... 

 

- encounter sequence modeling (ESM) 

환자가 방문했을 때 진단받은 진단코드가 다음과 같을 때 ${c_{1}, ..., c_{n}}$, 변수의 갯수는 $r= |C|$ 이다 입력값 $x_{i}$는 binary variable로 방문시에 질병은 진단 받은 경우 1로 기록하게 된다.  환자별로 각 time step $i$에서 다음 방문인 $x_2,...,x_{t+1}$의 code occuring을 예측한다.


- learning to diagnose (L2D) 

$x_1,...x_T$ 는 continuous clinical measures 이고 measure 가 $r$개 일때, 특정 질병 ${(s=1)}$ 또는 여러 질명 ${(s>1)}$의 발생여부 예측

 

1) Preliminaries on Neural Attention Model

이 논문에서는 EHR data를 사용해 해석가능한 모델을 만들기 위해 attention mechanism의 기본적인 구조를 따라간다.

의사가 실제로 환자에게 필요한 진단을 내릴때 과거 진단받은 질병(e.g., key risk factors)을 참고한다는 점에서 착안한 모델이다.

2) Reverse Time Attention Model RETAIN

출처 : retain 논문

retain 모델은 위 그림과 같은 구조를 가지고 있다.

입력벡터 $x_1....x_t$를 가지고 $y_i label$을 예측한다. (표기의 편의를 위해 $x_{i}^{n}$ 에서 $n$ 생략)

 

1. 입력벡터는 linear embedding layer를 거치게 된다. 이때 $embedding size = m$ 으로 설정하게 된다.

$v_i = W_{emb}x_i$

 

${v_{i}}\in\mathbb{R}^{m}$, ${x_{i}}\in\mathbb{R}^{r}$ , $W_{emb}\in\mathbb{R}^{m \times r}$

 

2. 시점 정보(visit-level attention weight)를 저장하는 $RNN_{\alpha}$와 변수정보(variable-level attention weight)를 저장하는 $RNN_{\beta}$를 분리해서 학습해 $\alpha,\beta$ 를 얻어낸다.

 

  - $\alpha_1 , ... , \alpha_i$

     visit - level attention weights  - $v_1, ..., v_i$ 임베딩 자체에 영향을 받음

  - $\beta_1, ... , \beta_i$

     variable-level attention weights - $ v_{1,1} , ... ,v_{1,m} , ... ,v_{i,m}$ 임베딩 상호관계에 영향을 받음

 

$g_i , g_{i-1}, ... , g_1 = RNN_{\alpha}(v_i, v_{i-1} , ... , v_1)$,

$e_j = w_a^Tg_j + b_{\alpha} $    $for$   $j = 1, ..., i$

$\alpha1, ... , \alpha_i = softmax(e_1,e_2, ... ,e_i)$

 

$h_i , h_{i-1}, ... , h_1 = RNN_{\beta}(v_i, v_{i-1} , ... , v_1)$

$\beta_j = tanh(w_a^Tg_j + b_{\alpha})$    $for$   $j = 1, ..., i$

 

  - time-step i 에 해당하는 hidden layer

    ${g_{i}}\in\mathbb{R}^{p}$ , ${h_{i}\in\mathbb{R}^{q}}$ 

 

  - parameters

    $w_{\alpha}\in\mathbb{R}^{p}$, $b_{\alpha}\in\mathbb{R}$,  

    $W_{\beta}\in\mathbb{R}^(m \times r)$ , $b_{\beta}\in\mathbb{R}^{m}$ 

 

3. 각 time step 마다 $\alpha, \beta$ 마다 attention vector을 구해준다.

언급했듯이 RETAIN은 attention vector를 reverse time order로 입력을 받아내 얻어낸다고 했다.

즉, 두 RNN이 visit embedding을 $(v_i , v_{i-1}, ... ,v_1)$와 같은 순서로 받는 다는 것을 말한다.

실제로 theano 로 작성된 코드 상에서 입력값을 거꾸로 넣어준다. (keras로 작성된 코드는 이렇게 직접적으로 Embedding 값을 거꾸로 넣어주지 않고 bidirection을 적용해준 것으로 보인다.)

 

이런식으로 reverse time order로 넣어주게 되면 각 time step마다 예측할 때  rnn layer의 output 값에 큰 변화를 주고 attention vector가 각 time step마다 다르게 출력되도록 한다. 

만약에 입력값을 원래 순서대로 넣어줄 경우 즉, 1~3 , 1~5 이런식으로 넣어줄때 매번 시작이 1이기때문에 $e_1, \beta_1$은 모든 time step마다 같은 값을 가지게 될것이다. 하지만 많은 경우 과거에 진단받은 기록보다 최근에 진단받은 기록이 더 중요하기 때문에 입력을 거꾸로 해줌으로써 rnn layer의 output에 매 time step 마다 영향을 주겠다는 것이다.

 

4. generated attention을 통해 context vector $c_i$ 를 뽑아낸다.

 

$c_i = \sum_{j=1}^{i}\alpha_j\beta_j$$v_j$, ⊙ : element - wise multiplication

 

5. context vector $c_i$ 는 label $y_i$를 예측하기 위해 사용된다.

$y_i^{hat} = Softmax(Wc_i+b)$

$loss = cross entropy$

분류 task가 아닌경우 여타 다른 모델과 마찬가지로 softmax 를 사용하지 않고 Loss 또한 MSE로 변경해준다.

 

3) Interpreting RETAIN

그래서 이 모델이 어떻게 해석력을 가질까. 여기서 염두해야할 점은 다음 두가지이다.

1. n 차원의 입력값을 M차원으로 임베딩 했다는 점

2. 변수의 시점과 종류를 모두 고려한 해석이 가능해야 한다는 점

논문에서는 결과 수식을 차례대로 풀어가면서 input vector의 weight을 찾아 정리해 해석했다.

 

$p(y_i | x_1, ... x_i) = p(y_i | c_i) = softmax(Wc_i + b)$

                                             $= softmax(W(\sum_{j=1}^i\alpha_j\beta_j⊙v_j) + b)$

                                             $= softmax(W(\sum_{j=1}^i\alpha_j\beta_j⊙\sum_{k=1}^rx_{j,k}W_{emb}[:,k]) + b)$

                                             $= softmax(\sum_{j=1}^i\sum_{k=1}^rx_{j,k}\alpha_jW(\beta_j⊙W_{emb}[:,k])+b)$

 

위 식에서  $w(y_i,x_{j,k}) = \alpha_jW(\beta_j⊙W_{emb}[:,k]) x_{j,k}$ 로 해석될 수 있다.

따라서 입력값 앞에 곱해지는 값이 이 input value에 대한 contribution coefficient라고 해석되는 것이다.

(입력 값이 binary인 경우에만 해당하며, non -binary라면 coefficient와 Input value를 곱해 계산해야 contribution coefficient를 적절하게 계산할 수 있다.)

coefficient와 곱해지는 $x_{i,j}$ 는 input vector $x_j$의 $k$번째 element라는 뜻으로 우리는 이러한 계산을 모든 입력 벡터들의 element들에 대해서 수행할 수 있다.

$\alpha, W(context vector에 대한 weight), W_{emb} $ 모두 뽑아낼 수 있는 값이기 때문에 직접 계산해 해석할 수 있다.

 

Experiments

RETAIN, RNN, 전통적인 머신러닝 방법을 비교하였다.

1) Experimental setting

데이터 셋은 helth records from sutter health에서 수집되었다. 50-80세의 심부전 환자를 대상으로 모델을 실험하였으며, 방문기록에서 진단, 처방, 수술 코드를 뽑아서 사용했다.

 

2) heart failure prediction

- objective

$x_{1},...,x_{T}$ 방문 sequence를 input으로 넣었을 때 환자가 마지막 time step에서 심부전을 진단받을 것인지에 대한 위험률을 예측한다. 

 

- cohort construction

3884 cases 가 선택되었으며 각 case별로 10개의 control case.( 28903 controls)를 선택했다

각 case는 심부전을 진단받은 날짜를 index로 진단 이전 18개월을 window로 설정해 그때 기록된 처방,수술,진단 기록을 사용

Training details : train / validation / test = 0.75 / 0.1 / 0.15

- $Logistic Regression$

  환자별 방문 기록에 기반해 medical codes의 숫자를 세 training data로 사용

 

- $MLP$

  Logistic Regression과 같은 데이터를 사용하였으나, 입력과 출력 사이에 hidden layer(size = 256) 추가

 

- $RNN$

  size 256 two hidden layer GRU 사용, 입력벡터 $(x_{1}, ... , x_{n})$  사용

 

- $RNN+\alpha_{M}$

  size 256 directional RNN 로 input embedding $v_{1} , ... , v_{i}$ 생성

  size 256 single hiddel layer MLP 로 visit-level attention $\alpha_{1}, ... , \alpha_{i}$ 생성

 

- $RNN+\alpha_{R}$

  $RNN+\alpha_{M}$와 비슷하지만  size 256 reverse -order RNN 로 visit-level attention  $\alpha_{1}, ... , \alpha_{i}$  생성

  

- 해석

Logistic regression, MLP 는 순환신경망 모델들과 비교했을 때 성능이 떨어진다.

RETAIN은 다른 RNN모델들과 비교했을 때 성능 크게 떨어지지 않으면서 해석이 가능하다.

RNN모델과 RETAIN모델의 학습시간은 크게 차이나지 않으나 모델이 수렴하기 까지의 epoch수는 10,15,30으로 차이가 있다.

 

 

3) Model Interpretation for Heart Failure Prediction 

RETAIN이 시점정보와 변수정보를 잘 보존하고 있는지를 확인해 보기 위해 3개 그림을 비교해 보았다.

각 그래프는 x축은 time, y축은 contribution으로 정의 되었다.

 

Figure(a)는 그래프에 순차적으로 입력된 의료코드를 표시했다.

Figure(b)는 Figure(a)의 입력을 거꾸로 넣어 학습한 뒤 계산된 결과로, Figure(a)에서 contribution이 높았던 HVD, CD, CA 등의 contribution이 양의 값으로 표시되나 그 차이가 적고 예측된 HF risk도 매우 작아졌다.

실제 logistic과 같이 변수별 시점정보를 지울 수 밖에 없는 모델에서는 figure(a)와 figure(b)에 사용된 순서가 다른 데이터의 경우에도 같은 크기의 HF risk를 예측 할 것이다. 이런 점에서 RETAIN을 시점,변수 정보를 잘 보존해 모델에 반영했다고 볼 수 있다.

Figure(c) 는 Figure(a)모델에 medication데이터를 추가한 것으로 HF risk가 0.2475에서 0.2615로 줄어들어 적절한 Medication을 받는 경우 HF risk가 줄어든다고 해석할 수 있다.

 

 

-2020 빅콘테스트에서 동별 편의점 배달지수를 예측하는 모델을 만들 때 RETAIN 알고리즘을 활용해 보았다.

 https://github.com/JngHyun/bigcontest_RETAIN

 

JngHyun/2020_RETAIN_model

빅콘테스트에서 활용한 retain model. Contribute to JngHyun/2020_RETAIN_model development by creating an account on GitHub.

github.com