본문 바로가기
AI/회귀

[회귀 - 3] 단순 선형 회귀(2)

by Nhahan 2022. 4. 19.

3. 테스트 세트의 결과 예측하기

> y_pred = predict(regressor, newdata = test_set)
> y_pred

위처럼 직관적으로 predict 펑션을 쓰면,

 

       2         4         5         8        11        16        20        21        24        26
37766.77  44322.33  46195.35  55560.43  62115.99  71481.07  81782.66  89274.72 102385.84 109877.90

이런 결과를 얻을 수 있는데, 위 숫자가 index, 아래는 예측값이다.

test_set

실제 test_set과 비교해보면 비슷하게 값을 예측한 것을 알 수 있다.

 

 

 

4.  훈련 세트 Visualize

 install.packages(ggplot2)

Visualize를 위해 위와 같이 실행하여 ggplot2라는 라이브러리를 설치하고

 

# init ggplot
library(ggplot2)
ggplot() +
# 실제 데이터 점 찍기, 색은 red
    geom_point(aes(x = training_set$YearsExperience, y = training_set$Salary),
               colour = 'red') +
# 예측 회귀선 긋기, 색은 blue
    geom_line(aes(x = training_set$YearsExperience, y = predict(regressor, newdata = training_set)),
               colour = 'blue') +
# 순서대로 타이틀, x축명, y축명 (생략가능)
    ggtitle('Salary vs Experince (Training Set)') +
    xlab('Years of experience') +
    ylab('Salary')

 

위 코드를 실행시키면 (설명은 주석에)

와우

이런 결과를 얻을 수 있다. 빨간점이 파란색 회귀선의 y축값보다 크다면 연차대비 연봉이 높은 편이고, 작다면 낮은 편이라고 분석해볼 수 있다.

 

 

5. 테스트 세트 Visualize

library(ggplot2)
ggplot() +
  geom_point(aes(x = test_set$YearsExperience, y = test_set$Salary),
             colour = 'red') +
  geom_line(aes(x = training_set$YearsExperience, y = predict(regressor, newdata = training_set)),
            colour = 'blue') +
  ggtitle('Salary vs Experience (Test set)') +
  xlab('Years of experience') +
  ylab('Salary')

geom_point는 실제 값의 데이터이므로 test_set을 넣어준다.

하지만 geom_line은 훈련 세트로 해야한다. 이미 훈련 세트로 회귀선을 만들어주었는데 테스트 세트로 만든다면 낭비일 것이다. (머신러닝을 시킬 때, 훈련 세트로 훈련을 시키고 테스트 세트로 실제 테스트를 해본다고 생각)

 

와우 뽕맛

훈련 세트로 테스트 세트의 값을 잘 예측한 것을 볼 수 있다.

댓글