Interpretable Modeling 관련하여 찾아던 중 Neural Additive Models: Interpretable Machine Learning with Neural Nets 이라는 논문을 찾아 간단하게 공부하게 되었다. (Spotlight (Top 3%) at NeurIPS 2021)
Deep Neural Networks(DNN)이 black-box이기 때문에 healthcare, finance, criminal justice와 같이 고위험군 task에서는 적용가능성에 문제점이 있기에 더욱 더 interpretation 가능한 방법론이 필요하다고 말한다. 기존의 방법론인 LIME과 같은 방법론들은 모델의 behavior에 대한 설명 디테일이 부족하거나, fail to provide a global view of the model하거나, explanation이 not faithful하다고 말한다.
그렇기에 저자들은 NN의 구조 자체에 제약을 가하여, Generalized Additive Models(GAMs)의 family에 속하는 Neural Additive Models(NAMs)를 제안하는 것이다. 먼저 GAMs를 다시 한 번 보자.
GAM에서는 독립변수들(x_1 ~ x_K) 각각에 univariate shape function을 씌워서 모델을 만든다. Neural Networks는 universal function approximator 이기 때문에 (참조: Universal approximation theorem) 모든 function은 Neural Network로 대체될 수 있다. GAM 앞에서 씌워진 function을 Neural Network로 바꿔준 것이 바로 Neural Additive Models(NAMs)인 것이다.
NAM Architecture
이진 분류를 위한 NAM의 architecture는 다음과 같다.
NAM은 일단 각 feature 당 할당된 Neural Network를 backpropagation을 통해서 훈련시키고(이 네트워크들이 데이터에 맞는 적합한 shape function을 배우도록) 동시에 각각의 네트워크를 어떻게 linear combination 시킬 지에 대해서도 학습하는 것으로 보인다.
학습 후에 각각의 네트워크들은 feature graph으로 표시된다.
바로 여기서 NAM이 interpretable machine learning이라고 하는 것인데, 실제로 각 feature들(x_1 ~ x_k)이 마지막 종속변수(Y)에 어떤 영향을 미치고 있는지가 그래프로 나오기 때문이다. 바로바로 해석가능하다는 것이다. 실제로 논문에서 저자들도 이 점을 강조하는데, 위 논문 캡처의 마지막 문장을 보면 "the graphs learned by NAMs are an exact description of how NAMs compute a prediction" 이라고 되어있는 걸 볼 수 있다.
DNNs tend to be too smooth to learn jumps well.
여기서 문제점이 발생하는데, standard한 DNN은 값이 확확 점핑되는 함수들을 잘 representation하지 못한다는 것이다. 저자들은 논문에서 이를 jagged shape function이라고 한다. Real-world에는 sharp jump들이 있는 데이터셋이 매우 많은데, 이를 효과적으로 모델링하는 방법이 필요하다는 것이다.
sharp jump라는게 무슨 뜻일까?
오른쪽이 저자들이 제안한 ExU 기법을 통해 jumpy한 function을 모델링한 결과이다. 왼쪽의 standard한 ReLU는 smooth하게 모델링된 반면, 오른쪽에서는 overfitting이 되지 않고도 jumpy한 function을 모델링할 수 있었다는 뜻이다. 참고로 위 figure에서 파란색 점은 실제 데이터가 아니고, log-odds이다(파란 글씨 참조). 어떤 데이터 x가 있을 때, 이 데이터가 나올 확률을 임의로 p라고 정의하고(논문에는 어떻게 p를 정했는지 나오지 않음) 이에 따른 log-odds를 놓은 것이다.
In this setup, 𝑥 values are systematically varied across a defined range, and for each 𝑥, an empirical probability 𝑝 is either observed or generated based on a specific scenario. These probabilities are then converted into log-odds (blue dots), which represent the true statistical relationship between 𝑥 and the outcome in this data. The neural network models are trained to predict these log-odds based on 𝑥, and their predictions are plotted as orange lines. The goal in such a neural network setup is to minimize the difference between the blue dots and the orange line, thereby achieving a model that can accurately predict outcomes based on new 𝑥 inputs.
흔히 우리가 overfitting된 것이라고 하면
이런 식으로 generalize에 좋지 않기에 지양해야 할 것이라고 한다. 여기서는 overfitting이 아니라, overfitting을 하지 않고도 jumpy한 function을 모델링했다는 것이다.
저자들이 제안한 기법은 standard ReLU 대신에 ExU를 쓰자는 것이다.
앞에 f는 ReLU-n 함수, 그리고 exponential term에 w (weight)가 들어가 있다. 중간에 *는 multiplication이다. exponential을 첨가함으로써 급격한 jumping도 모델링 가능하도록 조정한 느낌인 것 같다.
ReLU-n은 min(max(0, x) , n) 이다. ReLU긴 ReLU인데 상한선을 n으로 제한을 걸어놓은 것이다.
그래서 이렇게 ExU를 쓰면 어떻게 되는가?
오른쪽 figure를 먼저 보면 ExU로 모델링된 그래프 (a)들이 더 sharp한 라인을 많이 가지고 있는 것이 보인다.
그리고 왼쪽은 저자들이 ExU로 모델링된 output에 추가적 regularization을 해준 것이다. 너무 sharp하고 jumpy하게 나오는 것도 방지한 것 같다.
ExU까지 통합하여 위의 feature당 output (( f_{i}(x_{i}) ))가 나오는 방식은 다음과 같다.
결과
맨 위 표에서는 4개의 데이터셋(MIMIC-II, Credit, CA Housing, FICO)에 대한 AUC, RMSE를 볼 수 있다. 적혀있는대로 AUC는 높을수록, RMSE는 낮을수록 좋은 것이다. NAM의 성능이 기존것들과 비슷/괜찮게 나오는 것을 볼 수 있다. 특히 DNN과 비교할 때 성능은 비슷한데 결과에 대한 clear한 설명력을 가진다는 것에 의의가 있을 것이다.
Figure 5는 FICO 데이터셋에 대한 내용이다. 각각의 변수들(x축에 영어로 적힌 애들)이 최종 score에 어떤 영향을 얼마나 끼치는지가 그래프로 나와있는 것을 볼 수 있다.
Figure 6는 California Housing 데이터셋인데, House price를 예측할 때 확 비싸지는 지역인 San Francisco / LA에 sharp jump를 잘 모델링한 모습을 볼 수 있다.
Multitask NAMs
저자들은 NAM의 이점 중 하나가 쉽게 Multitask로 확장할 수 있다는 것을 들고 있다.
One advantage of NAMs is that they are easily extended to multitask learning (MTL), whereas
MTL is not available in EBMs or in any major boosted-tree package. In NAMs, the composability
of neural nets makes it easy to train multiple subnets per feature. The model can learn task-specific
weights over these subnets to allow sharing of subnets (shape functions) across tasks while also
allowing subnets to differentiate between tasks as needed.
추가적으로 task들이 similar to each other할 때 / training data is limited할 때 simgle task일 때보다 더 성능이 향상된다는 점을 기술하고 있다.
실제로 Multitask일 때 Lower MSE score를 보여주는 모습이다.
Future Work
저자들은 추후 흥미롭게 연구할 점들에 대해서도 기술해놓았는데,
NAM에게 더 higher dimension의 데이터를 집어넣는(이미지, speech), CNN-LSTM based NAM의 extension 등을 들고 있다.
'Decision Making' 카테고리의 다른 글
Generalized Additive Models(GAMs) (0) | 2024.09.04 |
---|---|
Generalized Linear Models(GLMs) (0) | 2024.09.03 |