DEV etc/머신러닝

1.3 구글 Colaboratory 와 scikit-learn을 활용한 머신러닝 모델 훈련하기

minsung521 2021. 8. 25. 04:29

구글 Colab 과 scikit-learn을 사용하여 간단한 K - 최근접 이웃 분류 모델을 만들어봅니다.

 

준비하기

 

구글 Colaboratory?

구글 Colaboratory 는 한마디로 클라우드 기반의 Jupyter Notebook 개발환경 이라고 할 수 있다. 

구글 클라우드의 가상서버 와 Compute Engine 를 사용하는 Colab은 구글 계정만 있다면 누구나 자신의 컴퓨터 사양에 상관없이 실습이 가능하다는 장점이 있다. 

 

 

https://colab.research.google.com/

 

Google Colaboratory

 

colab.research.google.com

위 링크에 접속하여 로그인하여 새 노트 를 클릭하면 아래와 같은 창을 볼 수 있다.

셀을 통해 Hello, world를 출력한 모습

Colab은 셀 이라는 최소 단위를 기반으로 구성되어있는데,

여러 셀에 코드를 나누어 사용할 수 있으며, 코드는 물론 HTML 과 마크다운을 활용할 수 있기 때문에 노트북 파일 하나에 설명과 실행코드를 같이 담을 수 있다.

 

데이터 준비

머신러닝에서 분류(classification)는 여러 개의 종류(class) 중 하나를 구별해는 문제이다. 

생선의 길이와 무게 데이터를 활용하여 도미와 빙어를 구분하는 이진 분류(binary classification) 모델을 만들어 볼 것이다!

 

 

데이터

노트를 생성한뒤 BreamAndSmelt 라는 이름을 적어주자. 이후 도미와 빙어 데이터를 준비해준다 

#35마리의 도미 길이,무게 데이터
bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]
#14마리의 빙어 길이, 무게 데이터
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

여기서 우리는 데이터의 길이와 무게. 이 두가지 특징을 찾을 수 있는데, 이런 특징을 특성(feature) 이라고 부르도록 한다.

 

 

산점도 그래프로 데이터 이해하기

두 특성을 숫자로 보기보다 그래프를 통해 확인하면 데이터를 더 잘 이해할 수 있고 앞으로 할 작업에 대한 힌트도 얻을 수 있다. 

 

길이를 x축으로, 무게를 y축으로 정한 산점도(점) 그래프를 만들어볼 것이다. 

 

과학계산용 그래프를 그리는 대표적인 라이브러리인 맷플롯립(matplotlib) 을 활용한다. 

 

import matplotlib.pyplot as plt # matplotlib 의 pylot 함수를 plt로 줄여서 사용

plt.scatter(bream_length,bream_weight) # 도미 데이터
plt.scatter(smelt_length,smelt_weight) # 빙어 데이터
plt.xlabel('length')
plt.ylabel('weight')
plt.show() # 출력하기

*Colab에는 맷플롯립과 같은 패키지들이 미리 준비되어있기 때문에 따로 설치할필요가 없다!1

 

아래 코드를 실행하면 다음과 같은 산점도 그래프를 확인할 수 있다.

 

두가지 색으로 구분된 2개의 산점도. 도미와 빙어데이터가 각각 모여있는것을 확인할 수 있다.

 

훈련 후 예측하기

 

가장 간단하며 이해하기 쉬운 k-최근접 이웃 알고리즘을 사용해 도미와 빙어 데이터를 구분하는 모델을 만들어볼 것이다.

 

먼저 두 리스트를 하나로 합쳐준다.

length = bream_length + smelt_length
weight = bream_weight + smelt_weight

 

사이킷런을 사용하려면 각 특성의 리스트를 세로 방향으로 2차원 리스트를 만들어야 하기에 zip 함수와 리스트 내포(list comprehension) 구문을 사용하여 만들어준다.

 

fish_data = [[l,w] for l,w in zip(length,weight)] # zip 함수로 각각 하나씩 원소를 반환하며 2차원 리스트를 만들어준다.

 

 

이제 훈련을 위한 [길이,무게] * 49 구성의 2차원 리스트가 완성되었다. 이제 어떤 생선이 도미이고 도미가 아닌지 알려주는 답을 만들어줄 필요가 있다.

 

 

도미인것과 도미가 아닌것(빙어) 로 구분하는 이진 분류이므로 도미 = 1 빙어 = 0 로 구성된 정답 리스트를 만들어준다.

fish_target = [1] * 35 + [0] * 14

 

훈련하기 - KNeighborsClassifier

 

모든 훈련 준비가 끝났다! KNeighborsClassifier 를 임포트해 객체를 만들어준다.

 

from sklearn.neighbors import KNeighborsClassifier
kn = KNeighborsClassifier()

이제 이 객체에 fish_data 와 fish_target 을 전달하여 도미를 찾기 위한 기준을 학습시켜준다. 

 

kn.fit(fish_data, fish_target)

 

훈련을 진행시킨뒤 score() 메서드로 훈련이 잘 되었는지 평가해준다. 

kn.score(fish_data,fish_target)

 

실행하면  결괏값으로 1.0을 받은것을 확인할 수 있다. 이 값은 정확도(accuracy) 라고 부르는 값인데, 즉 정확도가 100% 라는것을 알 수 있다!

 

 

예측하기

 

이제 학습한 모델에 생선 데이터를 넣어 어떤 생선으로 판단하는지 확인해보자. 

kn.predict([[30,600]])

-output

1 = 도미 

결괏값으로 1이 나온것으로 보아 도미라는 결과가 나왔음을 알 수 있다. 

 

 

 

 

 

 

k-최근접 이웃 알고리즘 : KNN

 

 

k-최근접 이웃 알고리즘은 새로운 데이터에 대해 가장 가까운 직선거리에 어떤 데이터가 있는지를 살펴 예측하는 방식을 사용한다. 

 

 

간단하고 이해하기 쉽지만 데이터가 아주 많아진다면 사용하기가 어렵다는 단점을 가지고 있다. 

이 외에도 데이터가 크기 때문에 메모리가 많이 필요하고 직선거리를 계산하는데도 많은 시간이 필요하다는 특징이 있다. 

 

 

 

 

 

 

 

 

 

 

*위 글은 도서 <혼자 공부하는 머신러닝 + 딥러닝> 를 기반으로 작성되었습니다.