Проблема: я загружаю простой VGG16 из сохраненной контрольной точки. Я хочу создать заметность изображения во время вывода. Когда я вычисляю градиенты (потери относительно входного изображения), необходимые для этого, я возвращаю все градиенты как ноль. Любые идеи относительно того, что мне здесь не хватает, очень ценятся!
Версия tf: tensorflow-2.0alpha-gpu
Модель:
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16 as KerasVGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Flatten, Dense
class VGG16(Model):
def __init__(self, num_classes, use_pretrained=True):
super(VGG16, self).__init__()
self.num_classes = num_classes
self.use_pretrained = use_pretrained
if use_pretrained:
self.base_model = KerasVGG16(weights='imagenet', include_top=False)
for layer in self.base_model.layers:
layer.trainable = False
else:
self.base_model = KerasVGG16(include_top=False)
self.flatten1 = Flatten(name='flatten')
self.dense1 = Dense(4096, activation='relu', name='fc1')
self.dense2 = Dense(100, activation='relu', name='fc2')
self.dense3 = Dense(self.num_classes, activation='softmax', name='predictions')
def call(self, inputs):
x = self.base_model(tf.cast(inputs, tf.float32))
x = self.flatten1(x)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
return x
Я обучаю эту модель, сохраняю ее на контрольной точке и загружаю обратно через:
model = VGG16(num_classes=2, use_pretrained=False)
checkpoint = tf.train.Checkpoint(net=model)
status = checkpoint.restore(tf.train.latest_checkpoint('./my_checkpoint'))
status.assert_consumed()
Я проверяю, правильно ли загружены грузы.
Получите тестовое изображение
# load my image and make sure its float
img = tf.convert_to_tensor(image, dtype=tf.float64)
support_class = tf.convert_to_tensor(support_class, dtype=tf.float64)
Получите градиенты:
with tf.GradientTape(persistent=True) as g_tape:
g_tape.watch(img)
#g_tape.watch(model.base_model.trainable_variables)
#g_tape.watch(model.trainable_variables)
loss = tf.losses.CategoricalCrossentropy()(support_class, model(img))
gradients_wrt_image = g_tape.gradient(loss,
img, unconnected_gradients=tf.UnconnectedGradients.NONE)
Когда я проверяю свои градиенты, они все равны нулю! Есть идеи, что мне не хватает? Заранее спасибо!
support_class
? - person Vlad   schedule 08.04.2019grads = g_tape.gradient(loss, img, unconnected_gradients=tf.UnconnectedGradients.NONE); print(tf.reduce_sum(grads, axis=None))
не выводит ноль. Может быть близко к нулю, но не к нулю - person Vlad   schedule 08.04.2019image = [np.random.normal(size=(32, 32, 3))]
. - person Vlad   schedule 08.04.2019