Уведомления

Группа в Telegram: присоединиться

#1 Май 31, 2020 23:43:44

Volodya
Зарегистрирован: 2020-02-13
Сообщения: 22
Репутация: +  0  -
Профиль   Отправить e-mail  

Импорт сохраненной в Python модели нейронной сети в Java

Здравствуйте!

Пытаюсь импортировать в Java обученную и сохраненную в Python модель нейронной сети.
Вы дает следующее исключение:

Exception in thread “main” java.lang.NoClassDefFoundError: org/deeplearning4j/nn/weights/IWeightInit
at org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense.<init>(KerasDense.java:96)
at org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getKerasLayerFromConfig(KerasLayerUtils.java:220)
at org.deeplearning4j.nn.modelimport.keras.KerasModel.prepareLayers(KerasModel.java:218)
at org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel.<init>(KerasSequentialModel.java:110)
at org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel.<init>(KerasSequentialModel.java:57)
at org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder.buildSequential(KerasModelBuilder.java:322)
at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasSequentialModelAndWeights(KerasModelImport.java:223)
at NeuralNetwork.main(NeuralNetwork.java:21)
Caused by: java.lang.ClassNotFoundException: org.deeplearning4j.nn.weights.IWeightInit
at java.net.URLClassLoader.findClass(URLClassLoader.java:382)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:349)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
… 8 more

При чем если смотреть через отладчик, то он вроде как считывает файл H5 (см. рисунок).

Модель нейронной сети, построенная и сохраненная в Python:
 model_fully_connected = Sequential()
model_fully_connected.add(keras.layers.Dense(17, activation='tanh', input_shape=(x_train.shape[1],), W_regularizer=l2(l2_lambda)))
model_fully_connected.add(keras.layers.Dense(17, activation='tanh', W_regularizer=l2(l2_lambda)))
model_fully_connected.add(keras.layers.LeakyReLU (alpha=0.1))
model_fully_connected.add(keras.layers.Dense(17, activation='tanh', W_regularizer=l2(l2_lambda)))
model_fully_connected.add(keras.layers.LeakyReLU (alpha=0.1))
model_fully_connected.add(keras.layers.Dense(17, activation='tanh', W_regularizer=l2(l2_lambda)))
model_fully_connected.add(keras.layers.Dense(1))
model_fully_connected.compile(optimizer='adam', loss='mse', metrics=["mae", "mse"])
history=model_fully_connected.fit(x_train, y_train, epochs=10, batch_size=1, verbose=2, validation_data=(x_test, y_test))
# #Сохранение обученной нейронной сети
model_fully_connected.save("trained _neural_network.H5",True,True)

Код импорта в Java:
 MultiLayerNetwork modelMultiLayer=null;
        KerasModelImport kerasModelImport=new KerasModelImport();
        try {            modelMultiLayer=kerasModelImport.importKerasSequentialModelAndWeights("E:\\Java\\neuralwork\\trained _neural_network.H5");
        } catch (IOException e) {
            e.printStackTrace();
        } catch (InvalidKerasConfigurationException e) {
            e.printStackTrace();
        } catch (UnsupportedKerasConfigurationException e) {
            e.printStackTrace();
        }
        System.out.println(modelMultiLayer.conf());

Библиотеки, которые использую в Java для импорта:
 <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>1.0.0-beta2</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native-platform</artifactId>
            <version>1.0.0-beta2</version>
        </dependency>
        <dependency>
            <groupId>com.google.cloud.dataflow</groupId>
            <artifactId>google-cloud-dataflow-java-sdk-all</artifactId>
            <version>2.2.0</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-modelimport</artifactId>
            <version>1.0.0-beta7</version>
        </dependency>

В чем здесь проблема может быть?


Прикреплённый файлы:
attachment Import_model_debager_2.PNG (131,5 KБ)

Офлайн

Board footer

Модераторировать

Powered by DjangoBB

Lo-Fi Version