본문 바로가기

Programming/R

[R] KNN (K-Nearest Neighbours) 알고리즘으로 주식시장 예측하기

R에는 KNN (K-Nearest Neighbours) 알고리즘을 쉽게 구현할 수 있게, knn( ) 함수가 class 라이브러리에 내장되어있다. R에 내장되어있는 Smarket 데이터(2001~2005년도의 S&P 주가 지수를 측정)를 이용해서, 2001~2004년까지의 데이터를 기반으로 KNN 알고리즘을 피팅한 후에 2005년 주식시장이 어떻게 될지 예측해보고, 예측이 얼마나 잘 되었는지 확인해보자👌

 

* 데이터셋에 대한 자세한 내용은, R에서 아래와 같이 ? Smarket이라고 입력하면 확인할 수 있다.

? Smarket

 

먼저 knn()을 사용하기 위한 class 라이브러리와 내장되어있는 데이터셋을 사용하기 위한 라이브러리를 불러오고,

fix(Smarket)으로 이노무 데이터가 어떻게 생겨먹었는지 확인한다. 9개의 변수로 이루어져 있고, Direction이 Up이면 지수가 상승한 것, Down이면 하락한거다.

 

attach( )는 단어 뜻 그대로, 데이터셋 안의 변수를 R 스크립트에 챱 붙여준다. 즉, 스크립트를 작성할 때 변수명 그대로 쓸 수 있게 된다는 뜻 (Direction, Year, Lag1, ... 등) attach하지 않으면, Smarket$Direction 요로코롬 귀찮게 앞에 변수의 출처를 달아줘야 한다.

library(class)
library(MASS)
library(ISLR)

fix(Smarket)
attach(Smarket)

 

 

KNN 알고리즘의 원리는 이름 뜻 그대이다. K-Nearest Neighbours, K명의 가장 가까운 이웃들이란 뜻인데🤔

테스트 샘플과 가까이 있는 트레이닝 샘플(들)의 라벨을 기준으로, 나의 테스트 샘플의 라벨을 예측하는 알고리즘이다.

 

여기서 K는 몇 명의 이웃(들)의 레이블을 기준으로 삼을거냐 하는 숫자인데,

K=1이면, 테스트 샘플을 기준으로 제~일 가까운 딱 하나의 트레이닝 샘플의 라벨 == 테스트 샘플의 레이블

K=3이면, 테스트 샘플을 기준으로 가까운 세 개의 트레이닝 샘플의 레이블들을 기준으로 테스트 샘플의 레이블 결정

K=5이면, 테스트 샘플을 기준으로 가까운 다섯 개의 트레이닝 샘플의 레이블들을 기준으로 테스트 샘플의 레이블 결정

이렇게 생각하면 된다. 

 

K가 3 이상이 되면 예측하고자 하는 레이블의 종류에 따라 레이블을 정하는 방법이 달라지는데,

Continuous한 Quantitative 레이블이라면, 트레이닝 샘플의 레이블들의 평균 == 테스트 샘플의 레이블

Categorical한 Qualitative 레이블이라면, Majority voting 법칙에 의해 쪽수 많은 쪽의 레이블 == 테스트 샘플의 레이블이 된다. KNN 이론은 추후에 다시 한번 정리해둬야지!

 

train <- (Year<2005)
Smarket.2005 <- Smarket[!train,]
Direction.2005 <- Direction[!train]

train.X <- cbind(Lag1, Lag2)[train,]
test.X <- cbind(Lag1, Lag2)[!train,]
train.Direction <- Direction[train]

set.seed(1)
knn.pred <- knn(train.X, test.X, train.Direction, k=1)
table(knn.pred, Direction.2005)
mean(knn.pred == Direction.2005)

 

다시 본론으로 돌아와서👏

Smarket 데이터를 KNN 알고리즘을 훈련시킬 Training set(2001 ~ 2004년)와 훈련이 잘 되었는지 확인해볼 수 있는 Test set(2005년)로 나눈다. 먼저 K=1일 때를 테스트해서 confusion matrix를 그려보면 아래처럼.. 형편없는 결과가 나온다.

Confusion matrix를 보면 (Down, Down) (Down, Up) (Up, Down) (Up, Up) 네 좌표가 있는데,

(Down, Down), (Up, Up) : 제대로 예측한 수

(Down, Up) (Up, Down) : 틀리게 예측한 수

이다. 정확도를 계산해보면 0.5... 50% 밖에 되지 않는다. 그냥 찍어도 50%는 맞추겠다..😂

 

> table(knn.pred, Direction.2005)

        Direction.2005
knn.pred Down Up
    Down   43 58
    Up     68 83
    
    
> mean(knn.pred == Direction.2005)
[1] 0.5

 

K = 1이니 Overfitting 되었을거라 생각하고 K = 3으로 다시 테스트를 해보면,

 

knn.pred <- knn(train.X, test.X, train.Direction, k=3)
table(knn.pred, Direction.2005)
mean(knn.pred == Direction.2005)
> knn.pred <- knn(train.X, test.X, train.Direction, k=3)

> table(knn.pred, Direction.2005)
        Direction.2005
knn.pred Down Up
    Down   48 54
    Up     63 87
    
    
> mean(knn.pred == Direction.2005)
[1] 0.5357143

 

여전히 형편없는 정확도이지만 그래도 K = 1일 때보다는 정확도가 살짝 증가한 것을 확인할 수 있다.

데이터셋 자체가 워낙 작아 더 이상 K를 올리는 것은 무의미하니, 오늘은 여기까지!