使用DL4J(DeepLearning4J)搭建一个简单的图像识别模型,并将其集成到Spring Boot后端中。我们将使用MNIST数据集来训练一个简单的卷积神经网络(CNN),然后将其部署到Spring Boot应用中。
1. 设置Spring Boot项目
首先,创建一个新的Spring Boot项目。你可以使用Spring Initializr(https://start.spring.io/)来快速生成项目结构。选择以下依赖:
- Spring Web
- Spring Boot DevTools
- Lombok(可选,用于简化代码)
2. 添加DL4J依赖
在你的pom.xml
文件中添加DL4J和相关依赖:
xml
<dependencies><!-- Spring Boot Web --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><!-- DL4J --><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-core</artifactId><version>1.0.0-beta7</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-nn</artifactId><version>1.0.0-beta7</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-ui</artifactId><version>1.0.0-beta7</version></dependency><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-native-platform</artifactId><version>1.0.0-beta7</version></dependency><!-- File Upload --><dependency><groupId>commons-fileupload</groupId><artifactId>commons-fileupload</artifactId><version>1.4</version></dependency><dependency><groupId>commons-io</groupId><artifactId>commons-io</artifactId><version>2.11.0</version></dependency><!-- Lombok (optional) --><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId><optional>true</optional></dependency>
</dependencies>
3. 训练DL4J模型
我们将使用MNIST数据集来训练一个简单的卷积神经网络(CNN)。创建一个新的Java类MnistModelTrainer.java
来训练模型:
java
package com.example.scanapp;import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.YoloOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.learning.config.Adam;import java.io.File;
import java.io.IOException;public class MnistModelTrainer {public static void main(String[] args) throws IOException {int numEpochs = 10;int batchSize = 64;int numLabels = 10;int numRows = 28;int numColumns = 28;int numChannels = 1;// Load MNIST dataDataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);// Preprocess dataImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);mnistTrain.setPreProcessor(scaler);mnistTest.setPreProcessor(scaler);// Define the network architectureMultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(0.001)).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(20).stride(1, 1).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(5, 5).nOut(50).stride(1, 1).activation(Activation.RELU).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(numLabels).activation(Activation.SOFTMAX).build()).build();// Initialize the networkMultiLayerNetwork model = new MultiLayerNetwork(conf);model.init();model.setListeners(new ScoreIterationListener(10));// Train the networkfor (int i = 0; i < numEpochs; i++) {model.fit(mnistTrain);}// Save the modelFile locationToSave = new File("mnist-model.zip");boolean saveUpdater = true; // Save the updaterModelSerializer.writeModel(model, locationToSave, saveUpdater);}
}
运行MnistModelTrainer
类来训练模型并保存到mnist-model.zip
文件中。
4. 创建Spring Boot Controller
创建一个新的Controller来处理图片上传和图像识别:
java
package com.example.scanapp.controller;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;@RestController
public class ImageController {private static final String MODEL_PATH = "mnist-model.zip"; // 替换为你的模型路径private MultiLayerNetwork model;private ImagePreProcessingScaler scaler;public ImageController() throws IOException {this.model = ModelSerializer.restoreMultiLayerNetwork(new File(MODEL_PATH));this.scaler = new ImagePreProcessingScaler(0, 1);}@PostMapping("/recognize")public String recognize(@RequestParam("image") MultipartFile file) {try {BufferedImage image = ImageIO.read(file.getInputStream());INDArray imageArray = Nd4j.create(new int[]{1, 1, 28, 28});for (int i = 0; i < 28; i++) {for (int j = 0; j < 28; j++) {int rgb = image.getRGB(j, i);int gray = (rgb >> 16) & 0xFF; // Convert to grayscaleimageArray.putScalar(0, 0, i, j, gray / 255.0);}}scaler.transform(imageArray);INDArray output = model.output(imageArray);int predictedClass = output.argMax(1).getInt(0);return "Predicted class: " + predictedClass;} catch (IOException e) {e.printStackTrace();return "Error processing image";}}
}
5. 测试API
你可以使用Postman或其他工具来测试你的API。发送一个POST请求到/recognize
端点,并附带一个MNIST格式的图片文件(28x28像素的灰度图像)。
6. 运行Spring Boot应用
确保你的Spring Boot应用能够正常启动。你可以通过以下命令运行应用:
bash
mvn spring-boot:run
7. 前端集成(可选)
如果你有一个前端应用(例如Vue.js),你可以创建一个简单的表单来上传图片并调用后端API。以下是一个简单的Vue.js组件示例:
vue
<template><div><h1>Image Recognition</h1><input type="file" @change="onFileChange" accept="image/*" /><button @click="uploadImage">Upload</button><p v-if="result">{{ result }}</p></div>
</template><script>
export default {data() {return {file: null,result: ''};},methods: {onFileChange(e) {this.file = e.target.files[0];},async uploadImage() {const formData = new FormData();formData.append('image', this.file);try {const response = await fetch('http://localhost:8080/recognize', {method: 'POST',body: formData});const data = await response.text();this.result = data;} catch (error) {console.error('Error uploading image:', error);}}}
};
</script>
将上述Vue.js组件添加到你的Vue项目中,然后运行前端应用来测试整个流程。
通过以上步骤,你应该能够成功搭建一个使用DL4J模型的Spring Boot后端服务,并通过前端应用进行图像识别。