Как использовать алгоритм K-средних для автоматического сегментирования изображения

До сих пор большинство методов, которые мы рассмотрели, требовали, чтобы мы вручную сегментировали изображение по его характеристикам. Но на самом деле мы можем использовать алгоритмы неконтролируемой кластеризации, чтобы сделать это за нас. В этой статье мы рассмотрим, как это сделать.

Давай начнем!

Как всегда, мы начинаем с импорта необходимых библиотек Python.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import colors
from skimage.color import rgb2gray, rgb2hsv, hsv2rgb
from skimage.io import imread, imshow
from sklearn.cluster import KMeans

Отлично, давайте теперь импортируем изображение, с которым мы будем работать.

dog = imread('beach_doggo.PNG')
plt.figure(num=None, figsize=(8, 6), dpi=80)
imshow(dog);

Мы знаем, что изображение, по сути, представляет собой трехмерную матрицу, в которой каждый отдельный пиксель содержит значение для красного, зеленого и синего каналов. Но на самом деле мы можем использовать любимую библиотеку Pandas для хранения каждого пикселя как отдельной точки данных. Код ниже делает именно это.

def image_to_pandas(image):
    df = pd.DataFrame([image[:,:,0].flatten(),
                       image[:,:,1].flatten(),
                       image[:,:,2].flatten()]).T
    df.columns = [‘Red_Channel’,’Green_Channel’,’Blue_Channel’]
    return df
df_doggo = image_to_pandas(dog)
df_doggo.head(5)

Это упрощает манипулирование изображением, так как его легче рассматривать как данные, которые могут быть введены в алгоритм машинного обучения. В нашем случае мы будем использовать алгоритм K Means для кластеризации изображения.

plt.figure(num=None, figsize=(8, 6), dpi=80)
kmeans = KMeans(n_clusters=  4, random_state = 42).fit(df_doggo)
result = kmeans.labels_.reshape(dog.shape[0],dog.shape[1])
imshow(result, cmap='viridis')
plt.show()

Как мы видим, изображение сгруппировано в 4 отдельных региона. Визуализируем каждый регион отдельно.

fig, axes = plt.subplots(2,2, figsize=(12, 12))
for n, ax in enumerate(axes.flatten()):
    ax.imshow(result==[n], cmap='gray');
    ax.set_axis_off()
    
fig.tight_layout()

Как мы видим, алгоритм разбивает изображение на основе значений пикселей R, G и B. Одним из прискорбных недостатков, конечно же, является то, что это полностью неконтролируемый алгоритм обучения. Его не особо заботит значение какого-либо конкретного кластера. В качестве доказательства мы можем видеть, что и второй, и четвертый кластеры имеют заметную часть собаки (заштрихованную половину и незатененную половину). Возможно, запуск 4 кластеров является чрезмерным, давайте повторим кластеризацию, но установим количество кластеров на 3.

Отлично, мы видим, что собака выходит как единое целое. Теперь давайте посмотрим, что произойдет, если мы применим каждый кластер как отдельную маску к нашему изображению.

fig, axes = plt.subplots(1,3, figsize=(15, 12))
for n, ax in enumerate(axes.flatten()):
    dog = imread('beach_doggo.png')
    dog[:, :, 0] = dog[:, :, 0]*(result==[n])
    dog[:, :, 1] = dog[:, :, 1]*(result==[n])
    dog[:, :, 2] = dog[:, :, 2]*(result==[n])
    ax.imshow(dog);
    ax.set_axis_off()
fig.tight_layout()

Мы видим, что алгоритм генерирует три отдельных кластера: песок, живые существа и небо. Конечно, сам алгоритм не заботится об этих кластерах, а только о том, что они имеют одинаковые значения RGB. Мы, люди, должны интерпретировать эти кластеры.

Прежде чем мы уйдем, я думаю, было бы полезно показать, как выглядит наше изображение, если бы мы просто изобразили его на трехмерном графике.

def pixel_plotter(df):
    x_3d = df['Red_Channel']
    y_3d = df['Green_Channel']
    z_3d = df['Blue_Channel']
    
    color_list = list(zip(df['Red_Channel'].to_list(),
                          df['Blue_Channel'].to_list(),
                          df['Green_Channel'].to_list()))
    norm = colors.Normalize(vmin=0,vmax=1.)
    norm.autoscale(color_list)
    p_color = norm(color_list).tolist()
    
    fig = plt.figure(figsize=(12,10))
    ax_3d = plt.axes(projection='3d')
    ax_3d.scatter3D(xs = x_3d, ys =  y_3d, zs = z_3d, 
                    c = p_color, alpha = 0.55);
    
    ax_3d.set_xlim3d(0, x_3d.max())
    ax_3d.set_ylim3d(0, y_3d.max())
    ax_3d.set_zlim3d(0, z_3d.max())
    ax_3d.invert_zaxis()
    
    
    ax_3d.view_init(-165, 60)
pixel_plotter(df_doggo)

Следует иметь в виду, что именно так алгоритм определяет «близость». Если мы применим к этому графику алгоритм K-средних, то способ, которым он сегментирует изображение, станет поразительно ясным.

df_doggo['cluster'] = result.flatten()
def pixel_plotter_clusters(df):
    x_3d = df['Red_Channel']
    y_3d = df['Green_Channel']
    z_3d = df['Blue_Channel']
    
    fig = plt.figure(figsize=(12,10))
    ax_3d = plt.axes(projection='3d')
    ax_3d.scatter3D(xs = x_3d, ys =  y_3d, zs = z_3d, 
                    c = df['cluster'], alpha = 0.55);
    
    ax_3d.set_xlim3d(0, x_3d.max())
    ax_3d.set_ylim3d(0, y_3d.max())
    ax_3d.set_zlim3d(0, z_3d.max())
    ax_3d.invert_zaxis()
    
    
    ax_3d.view_init(-165, 60)
pixel_plotter_clusters(df_doggo)

В заключение

Алгоритм K-средних - это популярный алгоритм обучения без учителя, который должен быть комфортно использовать любому специалисту по данным. Хотя это довольно упрощенно, он может быть особенно эффективным для изображений, которые имеют очень четкие различия в пикселях. В следующих статьях мы рассмотрим другие алгоритмы машинного обучения, которые мы можем использовать для сегментации изображений, а также для точной настройки гиперпараметров. Но пока я надеюсь, что вы теперь можете представить, как использовать этот метод в своих собственных задачах.