Введение

K-ближайшие соседи (KNN) — это контролируемый алгоритм машинного обучения, используемый для выполнения задач классификации и регрессии.

KNN предсказывает правильный класс для тестовых данных, определяя расстояние между тестовыми данными и точками обучения. Алгоритм выбирает K количество точек, наиболее близких к тестовым данным. Затем он вычисляет вероятность того, что тестовые данные попадут в классы K. Выбирается класс с наибольшей вероятностью.

Алгоритм KNN называется алгоритмом ленивого обучения, поскольку он не изучает дискриминационную функцию из обучающего набора данных, а вместо этого запоминает обучающий набор данных.

Набор данных

Мы будем использовать набор данных клиентов, чтобы сделать прогноз для неизвестных данных. Используемый набор данных показывает возраст, доход и продукт, купленный покупателем, как показано ниже:

Найдем рекомендации для клиента возрастом 58 лет и доходом 51000.

Сначала мы рассчитаем расстояние, а затем, в зависимости от значения k, сможем получить ближайших k соседей.

По умолчанию значение k равно 1, но мы можем передать значение k во время создания экземпляра. Если k равно 1, он покажет 1 продукт, то есть 1 ближайшего соседа, а если k равен 2, он покажет 2 продукта, то есть 2 ближайших соседа.

Реализация алгоритма K-ближайших соседей в Java

Теперь мы хотим реализовать алгоритм KNN в java, используя приведенный выше набор данных. Набор данных был сохранен в файле CSV с именем customers.csv.

Мы прочитаем данные из CSV-файла и загрузим их в GridDB. Затем данные будут извлечены из GridDB для анализа с помощью алгоритма.

Импорт пакетов

Давайте сначала импортируем пакеты, которые нам нужно будет использовать:

import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

Запись данных в GridDB

Мы хотим переместить данные из файла CSV в контейнер GridDB. Во-первых, давайте создадим схему контейнера как статический класс:

public static class Customers{
    
         @RowKey int customer;
         int age;
         Double income;
         String purchased_product;
        }

Приведенный выше класс подобен контейнеру или таблице SQL с четырьмя столбцами.

Давайте установим соединение с GridDB. Мы создадим экземпляр свойств, используя особенности нашей установки GridDB. Используйте следующий код:

Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

Измените приведенные выше данные, чтобы отразить специфику вашей установки GridDB.

Давайте выберем контейнер Customers, так как мы будем его использовать:

Collection<String, Customers> coll = store.putCollection("col01", Customers.class);

Создан экземпляр контейнера Customers, которому присвоено имя coll. Мы будем использовать этот экземпляр для ссылки на контейнер.

Храните данные в GridDB

Мы можем использовать следующий код Java для чтения данных из файла customers.csv и сохранения их в GridDB:

File file1 = new File("customers.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String customer = dataList[0];
                        String age = dataList[1];
                        String income = dataList[2];
                        String purchased_product = dataList[3];
                        
                        
                        Customers customers = new Customers();
    
                        customers.customer = Integer.parseInt(customer);
                        customers.age = Integer.parseInt(age);
                        customers.income = Double.parseDouble(income);
                        customers.purchased_product = purchased_product;
                        coll.append(customers);
                 }

Мы создали объект customers с данными о покупателях. Затем объект был добавлен в контейнер GridDB.

Получить данные из GridDB

Пришло время извлечь данные из контейнера GridDB. Используйте следующий код:

Query<customers> query = coll.query("select *");
                RowSet</customers><customers> rs = query.fetch(false);
            RowSet res = query.fetch();</customers>

Оператор select * помогает нам запрашивать все данные из контейнера базы данных.

Построить классификатор

Пришло время создать классификатор с использованием алгоритма KNN и загруженных данных. Давайте импортируем библиотеки, которые будут использоваться для этого:

import java.io.IOException;
import java.util.Enumeration;
import java.text.DecimalFormat;
import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.classifiers.lazy.IBk;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.converters.ArffLoader;

Давайте теперь построим модель и распечатаем ее статистику:

res.setClassIndex(res.numAttributes() - 1);
        Classifier cls = new IBk(1);        
        cls.buildClassifier(res);
    
        System.out.println(cls);
       
        Evaluation evaluation = new Evaluation(res);
        evaluation.evaluateModel(cls, res);
        
        System.out.println(evaluation.toSummaryString());
        System.out.println(evaluation.toClassDetailsString());
        System.out.println(evaluation.toMatrixString());

Мы указали значение k при создании экземпляра IBk. Экземпляр IBk принимает целочисленный аргумент. Если вы передадите ему значение 1, он найдет 1 ближайшего соседа. Если вы передадите 2, он вычислит 2 ближайших соседей. Если вы не передадите никакого аргумента и вызовете его с помощью конструктора по умолчанию, он вычислит 1 ближайшего соседа. В нашем случае мы передали значение 1, поэтому мы предскажем 1 ближайшего соседа для клиента.

Скомпилируйте и запустите код

Во-первых, войдите в систему как пользователь gsadm. Переместите файл .java в папку bin вашей GridDB, расположенную по следующему пути:

/griddb_4.6.0–1_amd64/usr/griddb-4.6.0/bin

Затем выполните следующую команду на своем терминале Linux, чтобы указать путь к файлу gridstore.jar:

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

Затем выполните следующую команду, чтобы скомпилировать файл .java:

javac KNNeighbor.java

Запустите файл .class, созданный с помощью следующей команды:

java KNNeighbor

Модель KNN вернет 1 ближайшего соседа для клиента.

Первоначально опубликовано на https://griddb.net 15 декабря 2021 г.