1.什么是DeepLearning4j?
DeepLearning4J(DL4J)是一套基于Java語言的神經網絡工具包,可以構建、定型和部署神經網絡。DL4J與Hadoop和Spark集成,支持分布式CPU和GPU,為商業環境(而非研究工具目的)所設計。Skymind是DL4J的商業支持機構。 Deeplearning4j擁有先進的技術,以即插即用為目標,通過更多預設的使用,避免多余的配置,讓非企業也能夠進行快速的原型制作。DL4J同時可以規模化定制。DL4J遵循Apache 2.0許可協議,一切以其為基礎的衍生作品均屬于衍生作品的作
Deeplearning4j的功能
Deeplearning4j包括了分布式、多線程的深度學習框架,以及普通的單線程深度學習框架。定型過程以集群進行,也就是說,Deeplearning4j可以快速處理大量數據。神經網絡可通過[迭代化簡]平行定型,與 Java、?Scala?和?Clojure?均兼容。Deeplearning4j在開放堆棧中作為模塊組件的功能,使之成為首個為微服務架構打造的深度學習框架。
???????
Deeplearning4j的組件
深度神經網絡能夠實現前所未有的準確度。對神經網絡的簡介請參見概覽頁。簡而言之,Deeplearning4j能夠讓你從各類淺層網絡(其中每一層在英文中被稱為layer)出發,設計深層神經網絡。這一靈活性使用戶可以根據所需,在分布式、生產級、能夠在分布式CPU或GPU的基礎上與Spark和Hadoop協同工作的框架內,整合受限玻爾茲曼機、其他自動編碼器、卷積網絡或遞歸網絡。 此處為我們已經建立的各個庫及其在系統整體中的所處位置: ?
DeepLearning4J用于設計神經網絡:
- Deeplearning4j(簡稱DL4J)是為Java和Scala編寫的首個商業級開源分布式深度學習
- DL4J與Hadoop和Spark集成,為商業環境(而非研究工具目的)所設計。
- 支持GPU和CPU
- 受到 Cloudera, Hortonwork, NVIDIA, Intel, IBM 等認證,可以在Spark, Flink, Hadoop 上運行
- 支持并行迭代算法架構
- DeepLearning4J的JavaDoc可在此處獲取
- DeepLearning4J示例的Github代碼庫請見此處。相關示例的簡介匯總請見此處。
- 開源工具 ASF 2.0許可證:github.com/deeplearning4j/deeplearning4j
2.訓練模型
訓練和測試數據集下載
https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz
MNIST簡介
- MNIST是經典的計算機視覺數據集,來源是National Institute of Standards and Technology (NIST,美國國家標準與技術研究所),包含各種手寫數字圖片,其中訓練集60,000張,測試集 10,000張,
- MNIST來源于250 個不同人的手寫,其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員.,測試集(test set) 也是同樣比例的手寫數字數據
- MNIST官網:http://yann.lecun.com/exdb/mnist/
數據集簡介
從MNIST官網下載的原始數據并非圖片文件,需要按官方給出的格式說明做解析處理才能轉為一張張圖片,這些事情顯然不是本篇的主題,因此咱們可以直接使用DL4J為我們準備好的數據集(下載地址稍后給出),該數據集中是一張張獨立的圖片,這些圖片所在目錄的名字就是該圖片具體的數字
模型訓練
LeNet-5簡介
LeNet-5 結構:
- 輸入層
圖片大小為 32×32×1,其中 1 表示為黑白圖像,只有一個 channel。
- 卷積層
filter 大小 5×5,filter 深度(個數)為 6,padding 為 0, 卷積步長?s=1=1,輸出矩陣大小為 28×28×6,其中 6 表示 filter 的個數。
- 池化層
average pooling,filter 大小 2×2(即?f=2=2),步長?s=2=2,no padding,輸出矩陣大小為 14×14×6。
- 卷積層
filter 大小 5×5,filter 個數為 16,padding 為 0, 卷積步長?s=1=1,輸出矩陣大小為 10×10×16,其中 16 表示 filter 的個數。
- 池化層
average pooling,filter 大小 2×2(即?f=2=2),步長?s=2=2,no padding,輸出矩陣大小為 5×5×16。注意,在該層結束,需要將 5×5×16 的矩陣flatten 成一個 400 維的向量。
- 全連接層(Fully Connected layer,FC)
neuron 數量為 120。
- 全連接層(Fully Connected layer,FC)
neuron 數量為 84。
- 全連接層,輸出層
現在版本的 LeNet-5 輸出層一般會采用 softmax 激活函數,在 LeNet-5 提出的論文中使用的激活函數不是 softmax,但其現在不常用。該層神經元數量為 10,代表 0~9 十個數字類別。(圖 1 其實少畫了一個表示全連接層的方框,而直接用?^y^?表示輸出層。) ?
/******************************************************************************** Copyright (c) 2020 Konduit K.K.* Copyright (c) 2015-2019 Skymind, Inc.** This program and the accompanying materials are made available under the* terms of the Apache License, Version 2.0 which is available at* https://www.apache.org/licenses/LICENSE-2.0.** Unless required by applicable law or agreed to in writing, software* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the* License for the specific language governing permissions and limitations* under the License.** SPDX-License-Identifier: Apache-2.0******************************************************************************/package com.et.dl4j.model;import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;/*** Implementation of LeNet-5 for handwritten digits image classification on MNIST dataset (99% accuracy)* <a href="http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf">[LeCun et al., 1998. Gradient based learning applied to document recognition]</a>* Some minor changes are made to the architecture like using ReLU and identity activation instead of* sigmoid/tanh, max pooling instead of avg pooling and softmax output layer.* <p>* This example will download 15 Mb of data on the first run.** @author hanlon* @author agibsonccc* @author fvaleri* @author dariuszzbyrad*/
@Slf4j
public class LeNetMNISTReLu {//dataset github:https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz// 存放文件的地址,請酌情修改
// private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";private static final String BASE_PATH = "/Users/liuhaihua/Downloads";public static void main(String[] args) throws Exception {// 圖片像素高int height = 28;// 圖片像素寬int width = 28;// 因為是黑白圖像,所以顏色通道只有一個int channels = 1;// 分類結果,0-9,共十種數字int outputNum = 10;// 批大小int batchSize = 54;// 循環次數int nEpochs = 1;// 初始化偽隨機數的種子int seed = 1234;// 隨機數工具Random randNumGen = new Random(seed);log.info("檢查數據集文件夾是否存在:{}", BASE_PATH + "/mnist_png");if (!new File(BASE_PATH + "/mnist_png").exists()) {log.info("數據集文件不存在,請下載壓縮包并解壓到:{}", BASE_PATH);return;}// 標簽生成器,將指定文件的父目錄作為標簽ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();// 歸一化配置(像素值從0-255變為0-1)DataNormalization imageScaler = new ImagePreProcessingScaler();// 不論訓練集還是測試集,初始化操作都是相同套路:// 1. 讀取圖片,數據格式為NCHW// 2. 根據批大小創建的迭代器// 3. 將歸一化器作為預處理器log.info("訓練集的矢量化操作...");// 初始化訓練集File trainData = new File(BASE_PATH + "/mnist_png/training");FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);trainRR.initialize(trainSplit);DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);// 擬合數據(實現類中實際上什么也沒做)imageScaler.fit(trainIter);trainIter.setPreProcessor(imageScaler);log.info("測試集的矢量化操作...");// 初始化測試集,與前面的訓練集操作類似File testData = new File(BASE_PATH + "/mnist_png/testing");FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);testRR.initialize(testSplit);DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);testIter.setPreProcessor(imageScaler); // same normalization for better resultslog.info("配置神經網絡");// 在訓練中,將學習率配置為隨著迭代階梯性下降Map<Integer, Double> learningRateSchedule = new HashMap<>();learningRateSchedule.put(0, 0.06);learningRateSchedule.put(200, 0.05);learningRateSchedule.put(600, 0.028);learningRateSchedule.put(800, 0.0060);learningRateSchedule.put(1000, 0.001);// 超參數MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed)// L2正則化系數.l2(0.0005)// 梯度下降的學習率設置.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))// 權重初始化.weightInit(WeightInit.XAVIER)// 準備分層.list()// 卷積層.layer(new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())// 下采樣,即池化.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build())// 卷積層.layer(new ConvolutionLayer.Builder(5, 5).stride(1, 1) // nIn need not specified in later layers.nOut(50).activation(Activation.IDENTITY).build())// 下采樣,即池化.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build())// 稠密層,即全連接.layer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())// 輸出.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image.build();MultiLayerNetwork net = new MultiLayerNetwork(conf);net.init();// 每十個迭代打印一次損失函數值net.setListeners(new ScoreIterationListener(10));log.info("神經網絡共[{}]個參數", net.numParams());long startTime = System.currentTimeMillis();// 循環操作for (int i = 0; i < nEpochs; i++) {log.info("第[{}]個循環", i);net.fit(trainIter);Evaluation eval = net.evaluate(testIter);log.info(eval.stats());trainIter.reset();testIter.reset();}log.info("完成訓練和測試,耗時[{}]毫秒", System.currentTimeMillis()-startTime);// 保存模型File ministModelPath = new File(BASE_PATH + "/minist-model.zip");ModelSerializer.writeModel(net, ministModelPath, true);log.info("最新的MINIST模型保存在[{}]", ministModelPath.getPath());}
}
輸出模型文件和得分結果
3.編寫模型預測接口
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><parent><artifactId>springboot-demo</artifactId><groupId>com.et</groupId><version>1.0-SNAPSHOT</version></parent><modelVersion>4.0.0</modelVersion><artifactId>Deeplearning4j</artifactId><properties><maven.compiler.source>8</maven.compiler.source><maven.compiler.target>8</maven.compiler.target><dl4j-master.version>1.0.0-beta7</dl4j-master.version><nd4j.backend>nd4j-native</nd4j.backend></properties><dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-autoconfigure</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope></dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId><version>1.18.20</version></dependency><dependency><groupId>ch.qos.logback</groupId><artifactId>logback-classic</artifactId></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-core</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.nd4j</groupId><artifactId>${nd4j.backend}</artifactId><version>${dl4j-master.version}</version></dependency><!--用于本地GPU--><!-- <dependency>--><!-- <groupId>org.deeplearning4j</groupId>--><!-- <artifactId>deeplearning4j-cuda-9.2</artifactId>--><!-- <version>${dl4j-master.version}</version>--><!-- </dependency>--><!-- <dependency>--><!-- <groupId>org.nd4j</groupId>--><!-- <artifactId>nd4j-cuda-9.2-platform</artifactId>--><!-- <version>${dl4j-master.version}</version>--><!-- </dependency>--></dependencies>
</project>
cotroller
package com.et.dl4j.controller;import com.et.dl4j.service.PredictService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;import java.util.HashMap;
import java.util.Map;@RestController
public class HelloWorldController {@RequestMapping("/hello")public Map<String, Object> showHelloWorld(){Map<String, Object> map = new HashMap<>();map.put("msg", "HelloWorld");return map;}@AutowiredPredictService predictService;@PostMapping("/predict-with-black-background")public int predictWithBlackBackground(@RequestParam("file") MultipartFile file) throws Exception {// 訓練模型的時候,用的數字是白字黑底,// 因此如果上傳白字黑底的圖片,可以直接拿去識別,而無需反色處理return predictService.predict(file, false);}@PostMapping("/predict-with-white-background")public int predictWithWhiteBackground(@RequestParam("file") MultipartFile file) throws Exception {// 訓練模型的時候,用的數字是白字黑底,// 因此如果上傳黑字白底的圖片,就需要做反色處理,// 反色之后就是白字黑底了,可以拿去識別return predictService.predict(file, true);}
}
service
package com.et.dl4j.service;import org.springframework.web.multipart.MultipartFile;public interface PredictService {/*** 取得上傳的圖片,做轉換后識別成數字* @param file 上傳的文件* @param isNeedRevert 是否要做反色處理* @return*/int predict(MultipartFile file, boolean isNeedRevert) throws Exception ;
}
package com.et.dl4j.service.impl;
import com.et.dl4j.service.PredictService;
import com.et.dl4j.util.ImageFileUtil;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;import javax.annotation.PostConstruct;
import java.io.File;@Service
@Slf4j
public class PredictServiceImpl implements PredictService {/*** -1表示識別失敗*/private static final int RLT_INVALID = -1;/*** 模型文件的位置*/@Value("${predict.modelpath}")private String modelPath;/*** 處理圖片文件的目錄*/@Value("${predict.imagefilepath}")private String imageFilePath;/*** 神經網絡*/private MultiLayerNetwork net;/*** bean實例化成功就加載模型*/@PostConstructprivate void loadModel() {log.info("load model from [{}]", modelPath);// 加載模型try {net = ModelSerializer.restoreMultiLayerNetwork(new File(modelPath));log.info("module summary\n{}", net.summary());} catch (Exception exception) {log.error("loadModel error", exception);}}@Overridepublic int predict(MultipartFile file, boolean isNeedRevert) throws Exception {log.info("start predict, file [{}], isNeedRevert [{}]", file.getOriginalFilename(), isNeedRevert);// 先存文件String rawFileName = ImageFileUtil.save(imageFilePath, file);if (null==rawFileName) {return RLT_INVALID;}// 反色處理后的文件名String revertFileName = null;// 調整大小后的文件名String resizeFileName;// 是否需要反色處理if (isNeedRevert) {// 把原始文件做反色處理,返回結果是反色處理后的新文件revertFileName = ImageFileUtil.colorRevert(imageFilePath, rawFileName);// 把反色處理后調整為28*28大小的文件resizeFileName = ImageFileUtil.resize(imageFilePath, revertFileName);} else {// 直接把原始文件調整為28*28大小的文件resizeFileName = ImageFileUtil.resize(imageFilePath, rawFileName);}// 現在已經得到了結果反色和調整大小處理過后的文件,// 那么原始文件和反色處理過的文件就可以刪除了ImageFileUtil.clear(imageFilePath, rawFileName, revertFileName);// 取出該黑白圖片的特征INDArray features = ImageFileUtil.getGrayImageFeatures(imageFilePath, resizeFileName);// 將特征傳給模型去識別return net.predict(features)[0];}
}
application.properties
# 上傳文件總的最大值
spring.servlet.multipart.max-request-size=1024MB# 單個文件的最大值
spring.servlet.multipart.max-file-size=10MB# 處理圖片文件的目錄
predict.imagefilepath=/Users/liuhaihua/Downloads/images/# 模型所在位置
predict.modelpath=/Users/liuhaihua/Downloads/minist-model.zip
工具類
package com.et.dl4j.util;import lombok.extern.slf4j.Slf4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.springframework.web.multipart.MultipartFile;import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.UUID;@Slf4j
public class ImageFileUtil {/*** 調整后的文件寬度*/public static final int RESIZE_WIDTH = 28;/*** 調整后的文件高度*/public static final int RESIZE_HEIGHT = 28;/*** 將上傳的文件存在服務器上* @param base 要處理的文件所在的目錄* @param file 要處理的文件* @return*/public static String save(String base, MultipartFile file) {// 檢查是否為空if (file.isEmpty()) {log.error("invalid file");return null;}// 文件名來自原始文件String fileName = file.getOriginalFilename();// 要保存的位置File dest = new File(base + fileName);// 開始保存try {file.transferTo(dest);} catch (IOException e) {log.error("upload fail", e);return null;}return fileName;}/*** 將圖片轉為28*28像素* @param base 處理文件的目錄* @param fileName 待調整的文件名* @return*/public static String resize(String base, String fileName) {// 新文件名是原文件名在加個隨機數后綴,而且擴展名固定為pngString resizeFileName = fileName.substring(0, fileName.lastIndexOf(".")) + "-" + UUID.randomUUID() + ".png";log.info("start resize, from [{}] to [{}]", fileName, resizeFileName);try {// 讀原始文件BufferedImage bufferedImage = ImageIO.read(new File(base + fileName));// 縮放后的實例Image image = bufferedImage.getScaledInstance(RESIZE_WIDTH, RESIZE_HEIGHT, Image.SCALE_SMOOTH);BufferedImage resizeBufferedImage = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);Graphics graphics = resizeBufferedImage.getGraphics();// 繪圖graphics.drawImage(image, 0, 0, null);graphics.dispose();// 轉換后的圖片寫文件ImageIO.write(resizeBufferedImage, "png", new File(base + resizeFileName));} catch (Exception exception) {log.info("resize error from [{}] to [{}], {}", fileName, resizeFileName, exception);resizeFileName = null;}log.info("finish resize, from [{}] to [{}]", fileName, resizeFileName);return resizeFileName;}/*** 將RGB轉為int數字* @param alpha* @param red* @param green* @param blue* @return*/private static int colorToRGB(int alpha, int red, int green, int blue) {int pixel = 0;pixel += alpha;pixel = pixel << 8;pixel += red;pixel = pixel << 8;pixel += green;pixel = pixel << 8;pixel += blue;return pixel;}/*** 反色處理* @param base 處理文件的目錄* @param src 用于處理的源文件* @return 反色處理后的新文件* @throws IOException*/public static String colorRevert(String base, String src) throws IOException {int color, r, g, b, pixel;// 讀原始文件BufferedImage srcImage = ImageIO.read(new File(base + src));// 修改后的文件BufferedImage destImage = new BufferedImage(srcImage.getWidth(), srcImage.getHeight(), srcImage.getType());for (int i=0; i<srcImage.getWidth(); i++) {for (int j=0; j<srcImage.getHeight(); j++) {color = srcImage.getRGB(i, j);r = (color >> 16) & 0xff;g = (color >> 8) & 0xff;b = color & 0xff;pixel = colorToRGB(255, 0xff - r, 0xff - g, 0xff - b);destImage.setRGB(i, j, pixel);}}// 反射文件的名字String revertFileName = src.substring(0, src.lastIndexOf(".")) + "-revert.png";// 轉換后的圖片寫文件ImageIO.write(destImage, "png", new File(base + revertFileName));return revertFileName;}/*** 取黑白圖片的特征* @param base* @param fileName* @return* @throws Exception*/public static INDArray getGrayImageFeatures(String base, String fileName) throws Exception {log.info("start getImageFeatures [{}]", base + fileName);// 和訓練模型時一樣的設置ImageRecordReader imageRecordReader = new ImageRecordReader(RESIZE_HEIGHT, RESIZE_WIDTH, 1);FileSplit fileSplit = new FileSplit(new File(base + fileName),NativeImageLoader.ALLOWED_FORMATS);imageRecordReader.initialize(fileSplit);DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(imageRecordReader, 1);dataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1));// 取特征return dataSetIterator.next().getFeatures();}/*** 批量清理文件* @param base 處理文件的目錄* @param fileNames 待清理文件集合*/public static void clear(String base, String...fileNames) {for (String fileName : fileNames) {if (null==fileName) {continue;}File file = new File(base + fileName);if (file.exists()) {file.delete();}}}}
DemoApplication.java
package com.et.dl4j;import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;@SpringBootApplication
public class DemoApplication {public static void main(String[] args) {SpringApplication.run(DemoApplication.class, args);}
}
以上只是一些關鍵代碼,所有代碼請參見下面代碼倉庫
代碼倉庫
- https://github.com/Harries/springboot-demo
4.測試
啟動Spring Boot應用,上傳圖片測試
- 如果用戶輸入的是黑底白字的圖片,只需要將上述流程中的反色處理去掉即可
- 為白底黑字圖片提供專用接口predict-with-white-background
- 為黑底白字圖片提供專用接口predict-with-black-background
5.引用
- 關于我們 - Deeplearning4j: Open-source, Distributed Deep Learning for the JVM
- DL4J實戰之三:經典卷積實例(LeNet-5)_multilayerconfiguration 參數-CSDN博客
- Spring Boot集成DeepLearning4j實現圖片數字識別 | Harries Blog?