Гиперпараметры и кросс-валидация#
Для подбора гиперпараметров необходимо минимизировать влияние конкретного выбора тестовой и обучающей выборки. Для этого можно например использовать следующий алгоритм кросс-валидации K-Fold:
Фиксируется некоторое целое число \(k\) (обычно от 5 до 10), меньшее числа семплов в датасете.
Датасет разбивается на \(k\) одинаковых частей (в последней части может быть меньше семплов, чем в остальных). Эти части называются фолдами.
Далее происходит \(k\) итераций, во время каждой из которых один фолд выступает в роли тестового множества, а объединение остальных - в роли тренировочного. Модель учится на \(k-1\) фолде и тестируется на оставшемся.
Финальный скор модели получается либо усреднением \(k\) получившихся тестовых результатов, либо измеряется на отложенном тестовом множестве, не участвовавшем в кросс-валидации.
Подробнее тут: https://academy.yandex.ru/handbook/ml/article/kross-validaciya
Наконец, обcудим процесс подбора гиперпараметров. Тут всё просто - перебираем какое-либо количество гиперпараметров по сетке (grid-search) или случайным образом (random-search), считаем кросс-валидацией Loss на каждой точке перебора, а в конце выбираем наилучший набор гиперпараметров.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, Lasso
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error
from ipywidgets import interact, widgets
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
# Кросс-валидация гиперпараметров -
hyperparams = {'alpha': [10000, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9],
'max_iter': [1000, 10000, 100000]}
from sklearn.model_selection import GridSearchCV
gs = GridSearchCV(estimator=Lasso(fit_intercept=False, warm_start=False),
param_grid=hyperparams,
cv=5,
verbose=1,
scoring='neg_mean_squared_error', # Эта функция максимизирует, берём neg
return_train_score=True)
from sklearn.metrics import get_scorer_names
#get_scorer_names()
# Возьмём следующий пример:
xs = np.arange(0, 1, 0.01).reshape((-1, 1))
noise = np.random.normal(0., 0.2, size = xs.shape[0]).reshape((-1, 1))
ys = np.sin(20*xs) + np.sin(10*xs) + np.sin(40*xs) + noise
def make_sin_matrix(x, max_order):
X = np.ones(x.shape)
for k in range(1, max_order + 1):
X = np.concatenate([X, np.sin(k*x)], axis = 1)
return X
import warnings
warnings.filterwarnings('ignore')
gs.fit(make_sin_matrix(xs, 50), ys)
Fitting 5 folds for each of 30 candidates, totalling 150 fits
GridSearchCV(cv=5, estimator=Lasso(fit_intercept=False), param_grid={'alpha': [10000, 0.1, 0.01, 0.001, 0.0001, 1e-05, 1e-06, 1e-07, 1e-08, 1e-09], 'max_iter': [1000, 10000, 100000]}, return_train_score=True, scoring='neg_mean_squared_error', verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=5, estimator=Lasso(fit_intercept=False), param_grid={'alpha': [10000, 0.1, 0.01, 0.001, 0.0001, 1e-05, 1e-06, 1e-07, 1e-08, 1e-09], 'max_iter': [1000, 10000, 100000]}, return_train_score=True, scoring='neg_mean_squared_error', verbose=1)
Lasso(fit_intercept=False)
Lasso(fit_intercept=False)
gs.best_params_
{'alpha': 0.01, 'max_iter': 1000}
gs.best_score_
-0.09678115862064099
import pandas as pd
results = pd.DataFrame(gs.cv_results_['params'])
results['test_score'] = gs.cv_results_['mean_test_score']
results
alpha | max_iter | test_score | |
---|---|---|---|
0 | 1.000000e+04 | 1000 | -1.498311 |
1 | 1.000000e+04 | 10000 | -1.498311 |
2 | 1.000000e+04 | 100000 | -1.498311 |
3 | 1.000000e-01 | 1000 | -0.370668 |
4 | 1.000000e-01 | 10000 | -0.370668 |
5 | 1.000000e-01 | 100000 | -0.370668 |
6 | 1.000000e-02 | 1000 | -0.096781 |
7 | 1.000000e-02 | 10000 | -0.099300 |
8 | 1.000000e-02 | 100000 | -0.099300 |
9 | 1.000000e-03 | 1000 | -0.133619 |
10 | 1.000000e-03 | 10000 | -0.123795 |
11 | 1.000000e-03 | 100000 | -0.123795 |
12 | 1.000000e-04 | 1000 | -2.604917 |
13 | 1.000000e-04 | 10000 | -2.685110 |
14 | 1.000000e-04 | 100000 | -2.685110 |
15 | 1.000000e-05 | 1000 | -6.852690 |
16 | 1.000000e-05 | 10000 | -9.704656 |
17 | 1.000000e-05 | 100000 | -9.700653 |
18 | 1.000000e-06 | 1000 | -7.619127 |
19 | 1.000000e-06 | 10000 | -11.394790 |
20 | 1.000000e-06 | 100000 | -79.759712 |
21 | 1.000000e-07 | 1000 | -7.698354 |
22 | 1.000000e-07 | 10000 | -11.629701 |
23 | 1.000000e-07 | 100000 | -114.360442 |
24 | 1.000000e-08 | 1000 | -7.706308 |
25 | 1.000000e-08 | 10000 | -11.653957 |
26 | 1.000000e-08 | 100000 | -118.122378 |
27 | 1.000000e-09 | 1000 | -7.707103 |
28 | 1.000000e-09 | 10000 | -11.656379 |
29 | 1.000000e-09 | 100000 | -118.501525 |