Big Bro's Studying Archive

텐서플로2 기초 02) 다변량 선형 회귀 간단하게 구현해보기 본문

Big Data展

텐서플로2 기초 02) 다변량 선형 회귀 간단하게 구현해보기

빅브로오 2020. 9. 21. 17:37

TensorFlow2에서 다변량 선형 회귀 간단하게 구현해보기

# Multi variable linear regression
import tensorflow as tf
import numpy as np

인스턴스가 5이고 변수가 세개인 데이터 생성
이전 세 쪽지 시험 결과로 200점 만점의 기말고사 결과 예측

x_data = [[73., 80., 75.],
          [93., 88., 93.],
          [89., 91., 90.],
          [96., 98., 100.],
          [73., 66., 70.]]
y_data = [[152.],
          [185.],
          [180.],
          [196.],
          [142.]]

keras neural net 모델 설계하기

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units = 1, input_dim = 3, activation = 'linear'))

변수가 3개이기 때문에 input을 세개를 받는 모델을 설계해준다.
활성함수는 디폴트가 선형이기 때문에 선형회귀의 경우 최종 노드에서 별도로 설정해줄 필요는 사실 없다.

model compile 하기

여기서는 loss 함수와 옵티마이저를 설정해준다.

model.compile(loss = 'mse', optimizer = tf.keras.optimizers.SGD(lr = 1e-5))
model.summary() # 모델 확인
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 1)                 4         
=================================================================
Total params: 4
Trainable params: 4
Non-trainable params: 0
_________________________________________________________________

하나의 결과만을 출력하는 모델이 설계되었다.
학습시켜보자.

hist = model.fit(x_data, y_data, epochs = 100)
Train on 5 samples
Epoch 1/100
5/5 [==============================] - 0s 84ms/sample - loss: 93357.6016
Epoch 2/100
5/5 [==============================] - 0s 2ms/sample - loss: 29263.9375
Epoch 3/100
5/5 [==============================] - 0s 2ms/sample - loss: 9173.9766
Epoch 4/100
5/5 [==============================] - 0s 2ms/sample - loss: 2876.8394
Epoch 5/100
5/5 [==============================] - 0s 2ms/sample - loss: 903.0208
Epoch 6/100
5/5 [==============================] - 0s 2ms/sample - loss: 284.3332
Epoch 7/100
5/5 [==============================] - 0s 2ms/sample - loss: 90.4071
Epoch 8/100
5/5 [==============================] - 0s 2ms/sample - loss: 29.6211
Epoch 9/100
5/5 [==============================] - 0s 2ms/sample - loss: 10.5672
Epoch 10/100
5/5 [==============================] - 0s 2ms/sample - loss: 4.5944
Epoch 11/100
5/5 [==============================] - 0s 2ms/sample - loss: 2.7217
Epoch 12/100
5/5 [==============================] - 0s 2ms/sample - loss: 2.1343
Epoch 13/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.9496
Epoch 14/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8912
Epoch 15/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8724
Epoch 16/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8660
Epoch 17/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8635
Epoch 18/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8621
Epoch 19/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8612
Epoch 20/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8604
Epoch 21/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8597
Epoch 22/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8589
Epoch 23/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8582
Epoch 24/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8574
Epoch 25/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8567
Epoch 26/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8559
Epoch 27/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8552
Epoch 28/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8545
Epoch 29/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8537
Epoch 30/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8530
Epoch 31/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8522
Epoch 32/100
5/5 [==============================] - 0s 4ms/sample - loss: 1.8515
Epoch 33/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8507
Epoch 34/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8500
Epoch 35/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8493
Epoch 36/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8485
Epoch 37/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8478
Epoch 38/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8471
Epoch 39/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8463
Epoch 40/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8456
Epoch 41/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8448
Epoch 42/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8441
Epoch 43/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8434
Epoch 44/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8426
Epoch 45/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8419
Epoch 46/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8412
Epoch 47/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8404
Epoch 48/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8397
Epoch 49/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8390
Epoch 50/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8382
Epoch 51/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8375
Epoch 52/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8368
Epoch 53/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8360
Epoch 54/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8353
Epoch 55/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8346
Epoch 56/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8338
Epoch 57/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8331
Epoch 58/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8324
Epoch 59/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8316
Epoch 60/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8309
Epoch 61/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8302
Epoch 62/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8295
Epoch 63/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8287
Epoch 64/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8280
Epoch 65/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8273
Epoch 66/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8266
Epoch 67/100
5/5 [==============================] - 0s 1ms/sample - loss: 1.8258
Epoch 68/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8251
Epoch 69/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8244
Epoch 70/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8237
Epoch 71/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8229
Epoch 72/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8222
Epoch 73/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8215
Epoch 74/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8208
Epoch 75/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8200
Epoch 76/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8193
Epoch 77/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8186
Epoch 78/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8179
Epoch 79/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8172
Epoch 80/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8164
Epoch 81/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8157
Epoch 82/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8150
Epoch 83/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8143
Epoch 84/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8135
Epoch 85/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8128
Epoch 86/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8121
Epoch 87/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8114
Epoch 88/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8107
Epoch 89/100
5/5 [==============================] - 0s 1ms/sample - loss: 1.8100
Epoch 90/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8092
Epoch 91/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8085
Epoch 92/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8078
Epoch 93/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8071
Epoch 94/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8064
Epoch 95/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8057
Epoch 96/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8049
Epoch 97/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8042
Epoch 98/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8035
Epoch 99/100
5/5 [==============================] - 0s 3ms/sample - loss: 1.8028
Epoch 100/100
5/5 [==============================] - 0s 2ms/sample - loss: 1.8021
y_pred = model.predict(np.array([[72.,93.,90.]]))
print(y_pred)
[[172.32043]]

이 모델은 약 172점을 예측했다.

hist.params
{'batch_size': 32,
 'epochs': 100,
 'steps': 1,
 'samples': 5,
 'verbose': 1,
 'do_validation': False,
 'metrics': ['loss']}

선형 회귀의 기초를 재현해보는 내용이기 때문에 별도의 파라미터값 설정 필요가 없지만,
Neural Net을 활용할 경우 파라미터 값에 유의해야 한다.

위 모델은 별도의 val data 없이 100 에폭을 학습시킨 모델이다.

Comments