Методы кластеризации#
K-means#
Идея проста - фиксируем количество кластеров \(k\). По итогу, каждый \(x_i\) должен быть в том кластере, к центру масс которого он ближе всего. Алгоритм:
Алгоритм k-means
Инициализируем центроиды (центры масс) кластеров \(\mu_1, \dots, \mu_k \in \mathbb{R}^n\).
Повторять до сходимости:
Для каждой точки \(x^{(i)}\) находим ближайший к ней центр масс
Для каждого \(j\) переопределяем центры масс исходя из ближайших точек на прошлом шаге
В алгоритме \(k\) - это параметр, обозначающий количество кластеров, фиксируется изначально. Центроиды \(\mu_j\) представляют собой наши текущие предположения о положении центров кластеров. Инициализация центроидов может происходить различными способами, например, мы можем случайным образом выбрать \(k\) точек из обучающего множества.
Шаги алгоритма:
Алгоритм очень прост и, как правило, работает хорошо на данных, где «истинные» кластеры далеки друг от друга. Тем не менее, он будет плохо работать, если «истинные» кластеры находятся близко к друг другу, или далеки по своей форме от шара.
Show code cell source
from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, Button, CustomJS
from bokeh.layouts import column, row
import numpy as np
from bokeh.resources import INLINE
output_notebook(INLINE, hide_banner = True)
# Создаем начальные данные
np.random.seed(0)
cluster1 = np.random.normal([2.5, 2.5], 0.5, (50, 2))
cluster2 = np.random.normal([4.5, 4.5], 0.5, (50, 2))
data = np.vstack([cluster1, cluster2])
# Инициализируем центроиды с фиксированными начальными значениями
centroids = np.array([[2. + 2*_, 3. + 2*_] for _ in range(2)])
# Создаем источники данных для точек, центроидов и линий
source_points = ColumnDataSource(data=dict(x=data[:, 0], y=data[:, 1], color=["blue"] * 50 + ["green"] * 50))
source_centroids = ColumnDataSource(data=dict(x=centroids[:, 0], y=centroids[:, 1]))
source_lines = ColumnDataSource(data=dict(x_start=[], y_start=[], x_end=[], y_end=[], color=[]))
# Создаем фигуру
p = figure(width=400, height=400, title="K-means Clustering", x_range=(1, 6), y_range=(1, 6),
tools="", toolbar_location=None) # Отключаем toolbar и инструменты
p.scatter('x', 'y', color='color', source=source_points, size=6, alpha=0.6)
p.scatter('x', 'y', color="red", source=source_centroids, size=8, alpha=0.8)
p.segment('x_start', 'y_start', 'x_end', 'y_end', color='color', source=source_lines, line_dash="dotted", alpha=0.5)
# JavaScript для запуска анимации
callback = CustomJS(args=dict(source_points=source_points, source_centroids=source_centroids, source_lines=source_lines), code="""
const data_points = source_points.data;
const data_centroids = source_centroids.data;
const lines = source_lines.data;
const x_points = data_points['x'];
const y_points = data_points['y'];
let centroids_x = data_centroids['x'];
let centroids_y = data_centroids['y'];
function update() {
// Расстояния и метки для точек относительно центроидов
const labels = Array.from({length: x_points.length}, () => 0);
for (let i = 0; i < x_points.length; i++) {
const dists = [
Math.hypot(x_points[i] - centroids_x[0], y_points[i] - centroids_y[0]),
Math.hypot(x_points[i] - centroids_x[1], y_points[i] - centroids_y[1])
];
labels[i] = dists[0] < dists[1] ? 0 : 1;
}
// Вычисляем новые центроиды
const new_centroids_x = [0, 0];
const new_centroids_y = [0, 0];
const counts = [0, 0];
for (let i = 0; i < labels.length; i++) {
const label = labels[i];
new_centroids_x[label] += x_points[i];
new_centroids_y[label] += y_points[i];
counts[label]++;
}
// Избегаем деления на ноль и вылетов на бесконечность
for (let j = 0; j < 2; j++) {
if (counts[j] > 0) {
new_centroids_x[j] /= counts[j];
new_centroids_y[j] /= counts[j];
} else {
new_centroids_x[j] = centroids_x[j];
new_centroids_y[j] = centroids_y[j];
}
}
// Периодическое обновление источников данных для линий
lines['x_start'] = [];
lines['y_start'] = [];
lines['x_end'] = [];
lines['y_end'] = [];
lines['color'] = [];
for (let i = 0; i < labels.length; i++) {
const color = labels[i] === 0 ? "blue" : "green";
lines['x_start'].push(x_points[i]);
lines['y_start'].push(y_points[i]);
lines['x_end'].push(new_centroids_x[labels[i]]);
lines['y_end'].push(new_centroids_y[labels[i]]);
lines['color'].push(color);
}
source_lines.change.emit();
// Плавное перемещение центроидов к новым позициям
const smooth_factor = 0.02; // Уменьшено для более плавного перемещения
centroids_x[0] += smooth_factor * (new_centroids_x[0] - centroids_x[0]);
centroids_x[1] += smooth_factor * (new_centroids_x[1] - centroids_x[1]);
centroids_y[0] += smooth_factor * (new_centroids_y[0] - centroids_y[0]);
centroids_y[1] += smooth_factor * (new_centroids_y[1] - centroids_y[1]);
// Обновление цвета точек в зависимости от их кластера
const colors = Array.from(labels, label => label === 0 ? "blue" : "green");
data_points['color'] = colors;
source_points.change.emit();
source_centroids.change.emit();
}
function animate() {
if (window.intervalID) {
clearInterval(window.intervalID); // Сброс текущей анимации, если она есть
}
window.intervalID = setInterval(update, 30); // Более частое обновление для плавности
}
animate(); // Старт анимации при загрузке
""")
# Кнопка для запуска анимации
button_start = Button(label="Start K-means Animation", button_type="success")
button_start.js_on_click(callback)
# Кнопка для рестарта
button_restart = Button(label="Restart Animation", button_type="warning")
button_restart.js_on_click(CustomJS(args=dict(source_points=source_points, source_centroids=source_centroids, source_lines=source_lines), code="""
// Генерация новых случайных позиций центроидов
source_centroids.data['x'] = [1 + Math.random() * 5, 1 + Math.random() * 5];
source_centroids.data['y'] = [1 + Math.random() * 5, 1 + Math.random() * 5];
source_centroids.change.emit();
// Очистка данных линий
source_lines.data['x_start'] = [];
source_lines.data['y_start'] = [];
source_lines.data['x_end'] = [];
source_lines.data['y_end'] = [];
source_lines.data['color'] = [];
source_lines.change.emit();
// Немедленное перекрашивание точек
const data_points = source_points.data;
const x_points = data_points['x'];
const y_points = data_points['y'];
const centroids_x = source_centroids.data['x'];
const centroids_y = source_centroids.data['y'];
const labels = Array.from({length: x_points.length}, () => 0);
for (let i = 0; i < x_points.length; i++) {
const dists = [
Math.hypot(x_points[i] - centroids_x[0], y_points[i] - centroids_y[0]),
Math.hypot(x_points[i] - centroids_x[1], y_points[i] - centroids_y[1])
];
labels[i] = dists[0] < dists[1] ? 0 : 1;
}
const colors = Array.from(labels, label => label === 0 ? "blue" : "green");
data_points['color'] = colors;
source_points.change.emit();
if (window.intervalID) {
clearInterval(window.intervalID); // Сброс текущей анимации
}
"""))
# Располагаем кнопки в одном ряду и над графиком
layout = column(row(button_start, button_restart), p)
show(layout)