Tab study에서 공부한 회귀 알고리즘과 모델 규제에 대해 정리해보려한다.
회귀란?
앞서 물고기 데이터를 분류한 것은 그 물고기가 빙어인지 도미인지 분류하는 분류 문제였다면, 회귀란 예를 들어, 물고기의 데이터를 학습해 그 물고기의 무게를 예측하는 것을 회귀라고 할 수있다.
간단한 농어 무게 예측
농어 데이터와 데이터 나누기
1
2
3
4
5
6
7
8
9
10
11
12
13
14
perch_length = np.array([8.4, 13.7, 15.0, 16.2, 17.4, 18.0, 18.7, 19.0, 19.6, 20.0, 21.0,
21.0, 21.0, 21.3, 22.0, 22.0, 22.0, 22.0, 22.0, 22.5, 22.5, 22.7,
23.0, 23.5, 24.0, 24.0, 24.6, 25.0, 25.6, 26.5, 27.3, 27.5, 27.5,
27.5, 28.0, 28.7, 30.0, 32.8, 34.5, 35.0, 36.5, 36.0, 37.0, 37.0,
39.0, 39.0, 39.0, 40.0, 40.0, 40.0, 40.0, 42.0, 43.0, 43.0, 43.5,
44.0])
perch_weight = np.array([5.9, 32.0, 40.0, 51.5, 70.0, 100.0, 78.0, 80.0, 85.0, 85.0, 110.0,
115.0, 125.0, 130.0, 120.0, 120.0, 130.0, 135.0, 110.0, 130.0,
150.0, 145.0, 150.0, 170.0, 225.0, 145.0, 188.0, 180.0, 197.0,
218.0, 300.0, 260.0, 265.0, 250.0, 250.0, 300.0, 320.0, 514.0,
556.0, 840.0, 685.0, 700.0, 700.0, 690.0, 900.0, 650.0, 820.0,
850.0, 900.0, 1015.0, 820.0, 1100.0, 1000.0, 1100.0, 1000.0,
1000.0])
1
2
3
4
5
from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target = train_test_split(perch_length, perch_weight, random_state=42)
train_input = train_input.reshape(-1, 1) # 2차원 배열 만들기
test_input = test_input.reshape(-1, 1)
scikit-learn에서는 2차원 형태의 데이터들만 학습시킬 수 있기 때문에 reshape함수를 통해 데이터들을 2차원 형태로 만들어준다.
결과 확인
1
2
3
4
5
6
from sklearn.neighbors import KNeighborsRegressor
knr = KNeighborsRegressor()
knr.fit(train_input, train_target)
print(knr.score(test_input, test_target))
1
0.992809406101064
약 99퍼센트 확률로 농어의 무게를 예측하는 것을 확인할 수 있다.
과대적합과 과소적합
과대적합과 과소적합은 과대적합 과소적합 설명 이곳에 설명을 했다.
선형회귀
선형회귀란?
선형 회귀는 알려진 다른 관련 데이터 값을 사용하여 알 수 없는 데이터의 값을 예측하는 데이터 분석 기법
농어 데이터 길이로 무게 예측 (선형회귀)
1
2
3
4
5
6
7
8
9
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
# 선형 회귀 모델을 훈련
lr.fit(train_input, train_target)
# 50cm 농어에 대해 예측
print(lr.predict([[50]]))
1
[1241.83860323]
선형 회귀로 데이터를 학습시켰을때 50cm의 농어의 무게는 약 1241로 무게를 높게 예측하는 것을 확인할 수 있다.
선형회귀 그래프
선형회귀 모형의 식은 $y=ax+b$로 a는 가중치라고 부른다.
1
2
3
4
5
6
7
8
9
10
11
12
import matplotlib.pyplot as plt
# 훈련 셋의 산점도
plt.scatter(train_input, train_target)
# 15에서 50까지 1차 방정식 그래프
plt.plot([15, 50], [15*lr.coef_+lr.intercept_, 50*lr.coef_+lr.intercept_])
# 50cm 농어 데이터
plt.scatter(50, 1241.8, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
<img src = “https://velog.velcdn.com/images/acadias12/post/623b1951-59b4-41a0-903b-be5fb4dc0c46/image.png” width = 75%>
LinearRegression학습을 하면서 회귀 모델들은 최적의 직선을 찾는데 위의 사진이 모델이 찾은 최적의 직선이다.
다항회귀
위의 농어 무게 예측은 길이로만 예측을 했지만, 농어의 여러 특성을 이용해 무게를 예측하는 것 즉, 여러가지 특성을 가중치로 둔 것을 다항회귀라고 한다.
다항회귀 식으로 나타내보기
$y=w{\tiny 0}+w{\tiny 1}x+…+w{\tiny n}x^n$
다항회귀에는 여러가지 특성이 존재하므로 위의 식처럼 나타낼 수 있다.
L1규제와 L2규제
규제란?
Regularization (정형화, 규제, 일반화) 모델이 과적합되게 학습하지 않고 일반성을 가질 수 있도록 규제 하는 것.
L1규제 (Lasso 회귀)
L1규제는 모델의 가중치를 규제하는데, 손실함수를 줄여나가면서 가중치 값을 최소로 하는 것(특정 feature는 0이 되어 불필요한 feature를 사용하지 않음)을 목표로 진행한다. $Loss=Loss+Sum(|w|)$ 손실함수에 가중치의 절댓값들을 더한 것이다. L1 규제를 통해 불필요한 feature들을 사용하지 않는 간단한 모델을 구현할 수 있게 된다.
L2규제 (Ridge 회귀)
L2규제 또한 모델의 가중치를 규제하는데, 손실함수를 줄여나가면서 가중치 값을 0에 가까워지도록 하는 것을 목표로 진행한다. $Loss=Loss+Sum(w^2)$ 손실함수에 가중치의 제곱을 더한 것이다. L1규제와 차이점은 L2규제는 가중치들이 0에 가까워지도록 가중치들을 규제하여 특정 feature이 지나치게 치우치지 않게 만든다.
느낀점
단항, 다항 회귀모델들을 배워보고 규제에 대해서도 배워보았는데, 우리가 중요하게 생각해야할 것은 데이터들이 과적합(overfitting), 과소적합(underfitting)이 되지 않고, 특정 feature들이 치우치게 하지 않게 해야한다는 것을 알게 되었다.