У вас есть изображение и вы хотите найти похожие изображения в папке с изображениями или в базе данных? Поиск изображений может быть очень полезен во многих приложениях. Например, поисковая система с функцией поиска изображений или помощь людям в поиске изображения, которое они ищут, но не совсем уверены, с чего начать.

В этой статье мы будем использовать ResNet50 и K-Nearest Neighbours для разработки классификатора K-NN. Мы используем набор данных Cifar-100, собранный Алексом Крижевским, Винодом Наиром и Джеффри Хинтоном. (Изучение нескольких слоев признаков из крошечных изображений, Алекс Крижевский, 2009 г.).

Для демонстрации мы создали меньший набор данных из исходного набора данных, случайным образом выбрав 1 изображение из каждого класса в тестовом наборе и 5 изображений из каждого класса в обучающем наборе. Таким образом, у нас есть набор данных из 500 обучающих изображений в качестве кандидатов и 100 тестовых изображений. Структура папок выглядит следующим образом:

-cifar100small
--train
---apple
---baby
---bear
...
--test
---apple
---baby
---bear
...

Эта структура папок упрощает загрузку данных с помощью API-интерфейса torchvision ImageFolder.

Теперь давайте начнем с захватывающей части!

  1. Загрузить данные
from torchvision import datasets, transforms
train_image_dir = "cifar100small/train" # path to train images
test_image_dir = "cifar100small/test" # path to test images
transform = transforms.Compose([
    transforms.Resize((224, 224,)),
    transforms.ToTensor(),
])

# Train dataset and dataloader (our candidates)
candidates_dataset = datasets.ImageFolder(train_image_dir, transform=transform)
candidates_loader = torch.utils.data.DataLoader(candidates_dataset, batch_size=64, shuffle=False, num_workers=2)
candidates_path = candidates_dataset.imgs

# Test dataset and dataloader
test_dataset = datasets.ImageFolder(test_image_dir, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

2. Загрузите предварительно обученную модель ResNet

import torch
from torchvision import models
# Use gpu if available 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use pretrained ResNet model
net = models.resnext50_32x4d(pretrained=True)
# Create a new network without the fully connected layers
net = torch.nn.Sequential(*(list(net.children())[:-1]))
net = net.to(device)

3. Рассчитайте характеристики кандидатов

import numpy as np
candidates = []
for data in candidates_loader:
    image, label = data
    image = image.to(device)
    output = net(image)
    output = output.detach().cpu().numpy().reshape(-1,
    np.prod(output.size()[1:]))
    candidates.append(output)
candidates = np.concatenate(candidates)

4. Соответствуйте кандидатам

from sklearn.neighbors import NearestNeighbors
k = 10
knn = NearestNeighbors(n_neighbors=k, metric="cosine")
knn.fit(candidates)

5. Начните прогнозировать!

def predict(test_loader):
    res = []
    for i, data in enumerate(test_loader):
        image, label = data
        image = image.to(device)
        output = net(image)
        output = output.detach().cpu().numpy().reshape(-1,
        np.prod(output.size()[1:]))
        
        _, neighbors_idx = knn.kneighbors(output)
        res.append((test_dataset.imgs[i][0], neighbors_idx[0],))
    return res
res = predict(test_loader)

Чтобы визуализировать результаты, вы можете использовать эту вспомогательную функцию:

import matplotlib.pyplot as plt
from PIL import Image
def plot_result(query_image, candidates_idx, candidates_path):
    plt.figure(figsize=(20,10))
    columns = 5
    # show query image
    image = Image.open(query_image)
    image = np.asarray(image)
    ax = plt.subplot(3, columns, 1)
    ax.set_title(f"Query: {query_image.split('/')[-1]}")
    plt.imshow(image)
    for i, idx in enumerate(candidates_idx):
        image_path = candidates_path[idx][0]
        image = Image.open(image_path)
        image = np.asarray(image)
        ax = plt.subplot(3, columns, i + 1 + 5)
        ax.set_title(image_path.split("/")[-1])
    plt.imshow(image)

А затем постройте некоторые результаты!

test_image_1, candidates_idx = res[5]
plot_result(test_image_1, candidates_idx, candidates_path)

Давайте посмотрим на некоторые результаты:

Пример 1: верблюд

Пример 2: кресло

Пример 3:

Пример 4:

Вуаля! Результаты выглядят неплохо, так как ближайшие соседи на самом деле очень похожи на изображение запроса. Конечно, в некоторых случаях это работает хуже, например, в примере с розовой кроватью. Но у него хотя бы есть диван, как у такого же соседа, который чем-то похож на кровать 😂

Если вам нравятся мои статьи, не забудьте поставить им лайк 👏 и поделиться ими с друзьями и коллегами! Ваше здоровье! 😉