Методы кластеризации#

K-means#

Идея проста - фиксируем количество кластеров \(k\). По итогу, каждый \(x_i\) должен быть в том кластере, к центру масс которого он ближе всего. Алгоритм:

Алгоритм k-means

  • Инициализируем центроиды (центры масс) кластеров \(\mu_1, \dots, \mu_k \in \mathbb{R}^n\).

  • Повторять до сходимости:

    • Для каждой точки \(x^{(i)}\) находим ближайший к ней центр масс

\[c^{(i)} = \arg \underset{j}{\min}{\Vert x^{(i)} - \mu_j \Vert^2}\]
  • Для каждого \(j\) переопределяем центры масс исходя из ближайших точек на прошлом шаге

\[\mu_j = \frac{\sum_{i=1}^{N} 1\{c^{(i)} = j\} \cdot x^{(i)}}{\sum_{i=1}^{N} 1\{c^{(i)} = j\}}\]

В алгоритме \(k\) - это параметр, обозначающий количество кластеров, фиксируется изначально. Центроиды \(\mu_j\) представляют собой наши текущие предположения о положении центров кластеров. Инициализация центроидов может происходить различными способами, например, мы можем случайным образом выбрать \(k\) точек из обучающего множества.

Шаги алгоритма:

Three different assumptions about the number of clusters

Алгоритм очень прост и, как правило, работает хорошо на данных, где «истинные» кластеры далеки друг от друга. Тем не менее, он будет плохо работать, если «истинные» кластеры находятся близко к друг другу, или далеки по своей форме от шара.

Hide code cell source
from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, Button, CustomJS
from bokeh.layouts import column, row
import numpy as np

from bokeh.resources import INLINE
output_notebook(INLINE, hide_banner = True)

# Создаем начальные данные
np.random.seed(0)
cluster1 = np.random.normal([2.5, 2.5], 0.5, (50, 2))
cluster2 = np.random.normal([4.5, 4.5], 0.5, (50, 2))
data = np.vstack([cluster1, cluster2])

# Инициализируем центроиды с фиксированными начальными значениями
centroids = np.array([[2. + 2*_, 3. + 2*_] for _ in range(2)])

# Создаем источники данных для точек, центроидов и линий
source_points = ColumnDataSource(data=dict(x=data[:, 0], y=data[:, 1], color=["blue"] * 50 + ["green"] * 50))
source_centroids = ColumnDataSource(data=dict(x=centroids[:, 0], y=centroids[:, 1]))
source_lines = ColumnDataSource(data=dict(x_start=[], y_start=[], x_end=[], y_end=[], color=[]))

# Создаем фигуру
p = figure(width=400, height=400, title="K-means Clustering", x_range=(1, 6), y_range=(1, 6),
           tools="", toolbar_location=None)  # Отключаем toolbar и инструменты
p.scatter('x', 'y', color='color', source=source_points, size=6, alpha=0.6)
p.scatter('x', 'y', color="red", source=source_centroids, size=8, alpha=0.8)
p.segment('x_start', 'y_start', 'x_end', 'y_end', color='color', source=source_lines, line_dash="dotted", alpha=0.5)

# JavaScript для запуска анимации
callback = CustomJS(args=dict(source_points=source_points, source_centroids=source_centroids, source_lines=source_lines), code="""
    const data_points = source_points.data;
    const data_centroids = source_centroids.data;
    const lines = source_lines.data;
    const x_points = data_points['x'];
    const y_points = data_points['y'];
    let centroids_x = data_centroids['x'];
    let centroids_y = data_centroids['y'];

    function update() {
        // Расстояния и метки для точек относительно центроидов
        const labels = Array.from({length: x_points.length}, () => 0);
        for (let i = 0; i < x_points.length; i++) {
            const dists = [
                Math.hypot(x_points[i] - centroids_x[0], y_points[i] - centroids_y[0]),
                Math.hypot(x_points[i] - centroids_x[1], y_points[i] - centroids_y[1])
            ];
            labels[i] = dists[0] < dists[1] ? 0 : 1;
        }

        // Вычисляем новые центроиды
        const new_centroids_x = [0, 0];
        const new_centroids_y = [0, 0];
        const counts = [0, 0];

        for (let i = 0; i < labels.length; i++) {
            const label = labels[i];
            new_centroids_x[label] += x_points[i];
            new_centroids_y[label] += y_points[i];
            counts[label]++;
        }

        // Избегаем деления на ноль и вылетов на бесконечность
        for (let j = 0; j < 2; j++) {
            if (counts[j] > 0) {
                new_centroids_x[j] /= counts[j];
                new_centroids_y[j] /= counts[j];
            } else {
                new_centroids_x[j] = centroids_x[j];
                new_centroids_y[j] = centroids_y[j];
            }
        }

        // Периодическое обновление источников данных для линий
        lines['x_start'] = [];
        lines['y_start'] = [];
        lines['x_end'] = [];
        lines['y_end'] = [];
        lines['color'] = [];
        for (let i = 0; i < labels.length; i++) {
            const color = labels[i] === 0 ? "blue" : "green";
            lines['x_start'].push(x_points[i]);
            lines['y_start'].push(y_points[i]);
            lines['x_end'].push(new_centroids_x[labels[i]]);
            lines['y_end'].push(new_centroids_y[labels[i]]);
            lines['color'].push(color);
        }
        source_lines.change.emit();

        // Плавное перемещение центроидов к новым позициям
        const smooth_factor = 0.02;  // Уменьшено для более плавного перемещения
        centroids_x[0] += smooth_factor * (new_centroids_x[0] - centroids_x[0]);
        centroids_x[1] += smooth_factor * (new_centroids_x[1] - centroids_x[1]);
        centroids_y[0] += smooth_factor * (new_centroids_y[0] - centroids_y[0]);
        centroids_y[1] += smooth_factor * (new_centroids_y[1] - centroids_y[1]);

        // Обновление цвета точек в зависимости от их кластера
        const colors = Array.from(labels, label => label === 0 ? "blue" : "green");
        data_points['color'] = colors;

        source_points.change.emit();
        source_centroids.change.emit();
    }

    function animate() {
        if (window.intervalID) {
            clearInterval(window.intervalID);  // Сброс текущей анимации, если она есть
        }
        window.intervalID = setInterval(update, 30);  // Более частое обновление для плавности
    }

    animate();  // Старт анимации при загрузке
""")

# Кнопка для запуска анимации
button_start = Button(label="Start K-means Animation", button_type="success")
button_start.js_on_click(callback)

# Кнопка для рестарта
button_restart = Button(label="Restart Animation", button_type="warning")
button_restart.js_on_click(CustomJS(args=dict(source_points=source_points, source_centroids=source_centroids, source_lines=source_lines), code="""
    // Генерация новых случайных позиций центроидов
    source_centroids.data['x'] = [1 + Math.random() * 5, 1 + Math.random() * 5];
    source_centroids.data['y'] = [1 + Math.random() * 5, 1 + Math.random() * 5];
    source_centroids.change.emit();

    // Очистка данных линий
    source_lines.data['x_start'] = [];
    source_lines.data['y_start'] = [];
    source_lines.data['x_end'] = [];
    source_lines.data['y_end'] = [];
    source_lines.data['color'] = [];
    source_lines.change.emit();

    // Немедленное перекрашивание точек
    const data_points = source_points.data;
    const x_points = data_points['x'];
    const y_points = data_points['y'];
    const centroids_x = source_centroids.data['x'];
    const centroids_y = source_centroids.data['y'];

    const labels = Array.from({length: x_points.length}, () => 0);
    for (let i = 0; i < x_points.length; i++) {
        const dists = [
            Math.hypot(x_points[i] - centroids_x[0], y_points[i] - centroids_y[0]),
            Math.hypot(x_points[i] - centroids_x[1], y_points[i] - centroids_y[1])
        ];
        labels[i] = dists[0] < dists[1] ? 0 : 1;
    }

    const colors = Array.from(labels, label => label === 0 ? "blue" : "green");
    data_points['color'] = colors;
    source_points.change.emit();

    if (window.intervalID) {
        clearInterval(window.intervalID);  // Сброс текущей анимации
    }
"""))

# Располагаем кнопки в одном ряду и над графиком
layout = column(row(button_start, button_restart), p)
show(layout)