模型是MobileNet v3 small + 微调分类器,pyTorch训练后导出onnx模型
pom:
<dependencies> <dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.17.3</version> </dependency> </dependencies>
代码:
package org.example; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtSession; import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.HashMap; public class OnnxModelInference { private OrtSession session; public OnnxModelInference(String modelPath) throws Exception { OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); byte[] modelArray = Files.readAllBytes(Paths.get(modelPath)); session = env.createSession(modelArray, sessionOptions); } public float[][][][] preprocessImage(String imagePath) throws IOException { // read img BufferedImage image = ImageIO.read(new File(imagePath)); // resize to 224 BufferedImage resizedImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB); Graphics2D g2d = resizedImage.createGraphics(); g2d.setColor(Color.WHITE); // 或根据需要选择背景颜色 g2d.fillRect(0, 0, 224, 224); g2d.drawImage(image, 0, 0, 224, 224, null); g2d.dispose(); // ImageNet norm float[] mean = {0.485f, 0.456f, 0.406f}; float[] std = {0.229f, 0.224f, 0.225f}; float[][][][] floatArray = new float[1][3][224][224]; for (int y = 0; y < 224; y++) { for (int x = 0; x < 224; x++) { Color color = new Color(resizedImage.getRGB(x, y)); floatArray[0][0][y][x] = (color.getRed() / 255.0f - mean[0]) / std[0]; // R floatArray[0][1][y][x] = (color.getGreen() / 255.0f - mean[1]) / std[1]; // G floatArray[0][2][y][x] = (color.getBlue() / 255.0f - mean[2]) / std[2]; // B } } return floatArray; } public String predict(String imagePath) throws Exception { float[][][][] inputData = preprocessImage(imagePath); OnnxTensor inputTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputData); HashMap<String, OnnxTensor> inputs = new HashMap<>(); inputs.put("input.1", inputTensor); // "input"是模型输入节点的名称,请根据实际情况调整 // num of class names String[] classes = {"c1", "c2"....}; // predict try (OrtSession.Result results = session.run(inputs)) { float[][] output = (float[][]) results.get(0).getValue(); int maxIndex = 0; for (int i = 1; i < output[0].length; i++) { maxIndex = output[0][i] > output[0][maxIndex] ? i : maxIndex; } return classes[maxIndex]; // i -> class name } } public static void main(String[] args) { try { OnnxModelInference inferencer = new OnnxModelInference("path/tray.onnx"); // traverse dir with recursive String dir = "path/img_test/"; long start = System.currentTimeMillis(); long cnt = 0; for (File file : new File(dir).listFiles()) { if (file.isDirectory()) { for (File subFile : file.listFiles()) { String result = inferencer.predict(subFile.getAbsolutePath()); System.out.println(subFile.getAbsolutePath() + " : " + result); cnt++; } } } long end = System.currentTimeMillis(); System.out.println("Cost Avg: " + (end - start) / cnt + "ms"); } catch (Exception e) { e.printStackTrace(); } } }
速度还挺快,在我的i7-12700,纯cpu是15ms / 张