Часть 2

Добро пожаловать во вторую часть развертывания модели Tensorflow на Android. Предварительным условием для этого руководства является часть 1. Если вы еще не проверяли его, сделайте это, прежде чем продолжить, поскольку мы будем использовать код из части 1. В этой части серии мы сосредоточимся на создании модели нейронной сети.

Мы будем использовать TensorFlow для написания нашей модели нейронной сети. Существуют варианты архитектуры нейронной сети, но я буду использовать нейронную сеть с прямой связью. Если вы не знакомы с концепцией нейронной сети, обратитесь к этому сообщению. Нейронная сеть прямого распространения - это искусственная нейронная сеть, в которой связи между узлами (нейронами) не образуют цикла. Поток данных идет в прямом направлении. Простейший вид нейронной сети - это сеть однослойного персептрона, которая состоит из одного слоя выходных узлов; входы подаются непосредственно на выходы через серию весов. Таким образом, ее можно рассматривать как простейшую сеть с прямой связью.

Не теряя времени, давайте создадим собственные перцептроны. Итак, откройте новый файл с именем ‘neuralNetworkModel.py’ и добавьте следующий импорт. Убедитесь, что весь ваш импорт работает.

Наша первая функция - получить данные обучения из «dataHelpers.py». Эта функция просто импортирует данные и регистрирует любые ошибки, если таковые имеются.

Мы уже преобразовали весь набор классов объектов из текста в целое число, но наши метки (Y) по-прежнему имеют категориальный характер. Некоторые алгоритмы машинного обучения по своей сути хорошо работают с одним горячим кодированием, поэтому мы конвертируем наши метки в один горячий вектор представления. Длина вектора равна количеству классов (в нашем случае 6). Каждый столбец вектора представляет собой уникальную метку. Предположим, что если метка класса равна 4, один горячий вектор будет иметь «1» в столбце «4», а остальные будут равны нулю.

Функция «getBatch» просто возвращает часть набора данных после преобразования меток в одни горячие векторы. Мы увидим, почему мы возвращаем пакеты данных при обучении нашей нейронной сети.

Мы закончили с вспомогательными функциями, которые нам нужны для нашей модели нейронной сети, и осталось только определить нашу модель. Наша модель нейронной сети будет простой сетью прямого распространения с двумя скрытыми слоями. Каждый скрытый слой будет иметь 800 узлов. Не стесняйтесь изменять количество узлов в скрытых слоях. Вы также можете добавить больше скрытых слоев без особых изменений, за исключением определения нового скрытого слоя. Эти изменения повлияют на точность нашей модели, поэтому поиск оптимального количества узлов и скрытых слоев может значительно улучшить производительность нашей модели. Помимо скрытых слоев, есть входной слой и выходной слой. Количество узлов во входном слое равно длине нашего вектора признаков, а количество узлов в выходном слое равно 6 (количество классов). Есть еще несколько параметров модели, таких как скорость обучения, количество циклов обучения, объем данных, которые необходимо вводить в модель на каждой итерации (размер пакета). Вы можете изменить это значение, чтобы увидеть, какой из них дает лучший результат (оптимальная скорость обучения 0,01).

Мы определяем два заполнителя: «inputTensor» и «outputTensor». Они используются для передачи данных в нейронную сеть во время обучения. InputTensor предназначен для подачи векторов признаков, а «outputTensor» - для подачи одного вектора горячей метки. Затем мы определяем наши веса и смещения (значение присваивается случайным образом) для каждого скрытого слоя. Например, если наш входной слой имел 10 узлов, а первый скрытый слой имел 5 узлов, то общее количество весов, назначенных между входным слоем и скрытым слоем, было бы 50. То есть каждый узел слоя связан со всеми узлами следующий слой.

Далее мы создадим наши перцептроны. Это простые линейные уравнения вида «f (x) = Wx + b», где W - вес, b - смещение, а x - входные данные (вектор признаков). Мы будем использовать функцию активации reLu на выходе «f (x)». В простейшем виде функция активации используется для масштабирования значения «f (x)» до 0 или 1 (вывод зависит от типа используемой функции. ). Порог используется для определения вывода, то есть, если «f (x)» выше определенного значения, на выходе будет 1, иначе 0. «output» - это слой, на котором мы строим наши прогнозы. . У него есть активация softmax, которая масштабирует вывод на каждом узле вывода от 0 до 1. После определения нашей функции потерь и оптимизатора мы можем начать обучение модели. Поскольку мы хотим развернуть эту модель на Android, мы хотели бы сохранить модель. Tensorflow использует графики для представления ваших вычислений с точки зрения зависимостей между отдельными операциями. Узлы и ребра графа показывают, как отдельные операции составляются вместе. Мы сохраняем график, используя функцию тензорного потока в строке 42.

Аналогичным образом класс Saver добавляет операции для сохранения и восстановления переменных в контрольных точках и из них. Он также предоставляет удобные методы для запуска этих операций. Контрольные точки представляют собой двоичные файлы в формате, который сопоставляет имена переменных со значениями тензора. Эти файлы контрольных точек содержат значения веса. Строки 40–60 относятся к обучению моделей. Во время обучения мы не передаем в модель весь набор данных сразу, а делаем это партиями (большие партии, похоже, ухудшают качество модели). Следовательно, мы загружаем 100 точек данных за один раз с помощью функции «getBatch». Данные вводятся в заполнители, которые мы определили ранее. К концу обучения вы должны получить точность 60% (плохо, но пока работает).

Заполнитель - это просто переменная, которой мы назначим данные позже. Это позволяет нам создавать наши операции и строить наш вычислительный граф без необходимости в данных. В терминологии Tensorflow мы затем вводим данные в граф через эти заполнители.

Закончив обучение модели, мы ее замораживаем. Что мы подразумеваем под замораживанием? Из документации по тензорному потоку у нас есть веса, хранящиеся в отдельных файлах контрольных точек, и на графике есть Variable операций, которые загружают последние значения при их инициализации. Часто бывает не очень удобно иметь отдельные файлы при развертывании в производственной среде, поэтому существует сценарий «freeze_graph.py, который берет определение графика и набор контрольных точек и фиксирует их вместе в один файл.

Это загружает GraphDef, извлекает значения для всех переменных из последнего файла контрольной точки, а затем заменяет каждый Variable op на Const, который имеет числовые данные для весов, хранящихся в его атрибутах. Затем он удаляет все посторонние узлы, которые не используются для прямого вывода, и сохраняет результирующий GraphDef в выходной файл ». Выходной файл имеет собственный формат файла ProtoBufs. Для более глубокого понимания того, как работает сохранение модели, обратитесь к this отличной публикации .

Один важный момент, который следует отметить при сохранении графика, - это имя переменной для наших входных данных в модель и выходных данных модели. Обратите внимание на строку 90–92, где мы определяем вход во входной узел модели («inputTensor») и выходной узел («output»). Мы будем использовать те же имена переменных в приложении для Android, когда использовали модель для прогнозирования.

Итак, теперь мы определили нашу модель и другие вспомогательные функции. Теперь все, что нам нужно сделать, это вызвать эти функции.

В приведенном выше коде мы разбиваем данные на обучающие и тестовые наборы. Это простое разделение набора данных. Лучшим разделением было бы использование cross_validation из sklearn. Я оставляю это как упражнение. После выполнения всего кода у вас должно быть несколько файлов в вашем каталоге, которые включают вашу модель («optimizedTfModel.pb»), контрольные точки и т. Д. Нашим основным фокусом будут файлы protoBuf, которые будут использоваться в приложение для Android.

Это все для второй части серии. В следующей части мы рассмотрим, как настроить наше приложение для Android, загрузить необходимые библиотеки и закончить написание шаблона приложения для Android. Надеюсь, это было полезно, и следите за обновлениями части 3: D.