Слой Lambda в Keras с keras.backend.one_hot дает TypeError

Я пытаюсь обучить CNN уровня персонажа, используя Keras. Я принимаю в качестве входных данных одно слово. Я уже преобразовал слова в списки индексов, но когда я пытаюсь передать их в one_hot, я получаю TypeError.

>>> X_train[0]
array([31, 14, 23, 29, 27, 18, 12, 30, 21, 10, 27,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0], dtype=uint8)
>>> X_train.shape
(2226641, 98)

Но когда я пытаюсь создать свою модель следующим образом:

k_model = Sequential()
k_model.add(Lambda(K.one_hot, arguments={'num_classes': 100}, input_shape=(98,), output_shape=(98,100)))
k_model.add(Conv1D(filters=16, kernel_size=5, strides=1, padding='valid'))

Я получаю TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: uint8, int32, int64.

Очевидно, что он не доходит до того момента, когда X_train даже читается, так где же он получает значение с плавающей запятой?

Я хотел бы иметь форму экземпляра (98, 100), где 100 — количество классов.

Я не могу уместить весь набор данных в памяти.


person bendl    schedule 21.05.2018    source источник


Ответы (1)


Я бы предложил более чистое решение, которое даст тот же результат, как насчет:

k_model.add(Embedding(num_classes, num_classes,
                      embeddings_initializer='identity',
                      trainable=False,
                      name='onehot'))

По сути, вы встраиваете вещи, было бы разумнее использовать один с фиксированным весом. Это также дает вам гибкость, чтобы сделать встраивание обучаемым в будущем.

person nuric    schedule 21.05.2018
comment
Теперь я получаю сообщение об ошибке, когда пытаюсь тренироваться, говоря ValueError: Error when checking target: expected onehot to have 3 dimensions, but got array with shape (742214, 98) - person bendl; 21.05.2018
comment
Я также не вижу ничего о weights в документах Embedding. У вас есть ссылка на это? - person bendl; 21.05.2018
comment
К сожалению, он является частью базового слоя класс и явно не упоминается в документах. - person nuric; 21.05.2018
comment
благодарю за разъяснение. Есть ли у вас какие-либо предложения по поводу новой ошибки? - person bendl; 21.05.2018
comment
попробуйте добавить input_length=98 Я не знал, что вы используете фиксированную длину. - person nuric; 21.05.2018
comment
Если вы согласны со мной, репозиторий находится здесь - person bendl; 21.05.2018
comment
Внедрение выводит (образцы, 98, num_classes), но ваши данные Y являются 2d (образцы, 98), поэтому цель не соответствует, вы еще не обрабатываете ввод и ожидаете, что он предскажет цель. - person nuric; 21.05.2018