ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 1.3 구글 Colaboratory 와 scikit-learn을 활용한 머신러닝 모델 훈련하기
    DEV etc/머신러닝 2021. 8. 25. 04:29

    구글 Colab 과 scikit-learn을 사용하여 간단한 K - 최근접 이웃 분류 모델을 만들어봅니다.

     

    준비하기

     

    구글 Colaboratory?

    구글 Colaboratory 는 한마디로 클라우드 기반의 Jupyter Notebook 개발환경 이라고 할 수 있다. 

    구글 클라우드의 가상서버 와 Compute Engine 를 사용하는 Colab은 구글 계정만 있다면 누구나 자신의 컴퓨터 사양에 상관없이 실습이 가능하다는 장점이 있다. 

     

     

    https://colab.research.google.com/

     

    Google Colaboratory

     

    colab.research.google.com

    위 링크에 접속하여 로그인하여 새 노트 를 클릭하면 아래와 같은 창을 볼 수 있다.

    셀을 통해 Hello, world를 출력한 모습

    Colab은 셀 이라는 최소 단위를 기반으로 구성되어있는데,

    여러 셀에 코드를 나누어 사용할 수 있으며, 코드는 물론 HTML 과 마크다운을 활용할 수 있기 때문에 노트북 파일 하나에 설명과 실행코드를 같이 담을 수 있다.

     

    데이터 준비

    머신러닝에서 분류(classification)는 여러 개의 종류(class) 중 하나를 구별해는 문제이다. 

    생선의 길이와 무게 데이터를 활용하여 도미와 빙어를 구분하는 이진 분류(binary classification) 모델을 만들어 볼 것이다!

     

     

    데이터

    노트를 생성한뒤 BreamAndSmelt 라는 이름을 적어주자. 이후 도미와 빙어 데이터를 준비해준다 

    #35마리의 도미 길이,무게 데이터
    bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                    31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                    35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
    bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                    500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                    700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]
    #14마리의 빙어 길이, 무게 데이터
    smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
    smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

    여기서 우리는 데이터의 길이와 무게. 이 두가지 특징을 찾을 수 있는데, 이런 특징을 특성(feature) 이라고 부르도록 한다.

     

     

    산점도 그래프로 데이터 이해하기

    두 특성을 숫자로 보기보다 그래프를 통해 확인하면 데이터를 더 잘 이해할 수 있고 앞으로 할 작업에 대한 힌트도 얻을 수 있다. 

     

    길이를 x축으로, 무게를 y축으로 정한 산점도(점) 그래프를 만들어볼 것이다. 

     

    과학계산용 그래프를 그리는 대표적인 라이브러리인 맷플롯립(matplotlib) 을 활용한다. 

     

    import matplotlib.pyplot as plt # matplotlib 의 pylot 함수를 plt로 줄여서 사용
    
    plt.scatter(bream_length,bream_weight) # 도미 데이터
    plt.scatter(smelt_length,smelt_weight) # 빙어 데이터
    plt.xlabel('length')
    plt.ylabel('weight')
    plt.show() # 출력하기

    *Colab에는 맷플롯립과 같은 패키지들이 미리 준비되어있기 때문에 따로 설치할필요가 없다!1

     

    아래 코드를 실행하면 다음과 같은 산점도 그래프를 확인할 수 있다.

     

    두가지 색으로 구분된 2개의 산점도. 도미와 빙어데이터가 각각 모여있는것을 확인할 수 있다.

     

    훈련 후 예측하기

     

    가장 간단하며 이해하기 쉬운 k-최근접 이웃 알고리즘을 사용해 도미와 빙어 데이터를 구분하는 모델을 만들어볼 것이다.

     

    먼저 두 리스트를 하나로 합쳐준다.

    length = bream_length + smelt_length
    weight = bream_weight + smelt_weight

     

    사이킷런을 사용하려면 각 특성의 리스트를 세로 방향으로 2차원 리스트를 만들어야 하기에 zip 함수와 리스트 내포(list comprehension) 구문을 사용하여 만들어준다.

     

    fish_data = [[l,w] for l,w in zip(length,weight)] # zip 함수로 각각 하나씩 원소를 반환하며 2차원 리스트를 만들어준다.

     

     

    이제 훈련을 위한 [길이,무게] * 49 구성의 2차원 리스트가 완성되었다. 이제 어떤 생선이 도미이고 도미가 아닌지 알려주는 답을 만들어줄 필요가 있다.

     

     

    도미인것과 도미가 아닌것(빙어) 로 구분하는 이진 분류이므로 도미 = 1 빙어 = 0 로 구성된 정답 리스트를 만들어준다.

    fish_target = [1] * 35 + [0] * 14

     

    훈련하기 - KNeighborsClassifier

     

    모든 훈련 준비가 끝났다! KNeighborsClassifier 를 임포트해 객체를 만들어준다.

     

    from sklearn.neighbors import KNeighborsClassifier
    kn = KNeighborsClassifier()

    이제 이 객체에 fish_data 와 fish_target 을 전달하여 도미를 찾기 위한 기준을 학습시켜준다. 

     

    kn.fit(fish_data, fish_target)

     

    훈련을 진행시킨뒤 score() 메서드로 훈련이 잘 되었는지 평가해준다. 

    kn.score(fish_data,fish_target)

     

    실행하면  결괏값으로 1.0을 받은것을 확인할 수 있다. 이 값은 정확도(accuracy) 라고 부르는 값인데, 즉 정확도가 100% 라는것을 알 수 있다!

     

     

    예측하기

     

    이제 학습한 모델에 생선 데이터를 넣어 어떤 생선으로 판단하는지 확인해보자. 

    kn.predict([[30,600]])

    -output

    1 = 도미 

    결괏값으로 1이 나온것으로 보아 도미라는 결과가 나왔음을 알 수 있다. 

     

     

     

     

     

     

    k-최근접 이웃 알고리즘 : KNN

     

     

    k-최근접 이웃 알고리즘은 새로운 데이터에 대해 가장 가까운 직선거리에 어떤 데이터가 있는지를 살펴 예측하는 방식을 사용한다. 

     

     

    간단하고 이해하기 쉽지만 데이터가 아주 많아진다면 사용하기가 어렵다는 단점을 가지고 있다. 

    이 외에도 데이터가 크기 때문에 메모리가 많이 필요하고 직선거리를 계산하는데도 많은 시간이 필요하다는 특징이 있다. 

     

     

     

     

     

     

     

     

     

     

    *위 글은 도서 <혼자 공부하는 머신러닝 + 딥러닝> 를 기반으로 작성되었습니다.

    댓글

Designed by Tistory.