模型是MobileNet v3 small + 微调分类器,pyTorch训练后导出onnx模型
pom:
<dependencies>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.17.3</version>
</dependency>
</dependencies>
<dependencies>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.17.3</version>
</dependency>
</dependencies>
<dependencies> <dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.17.3</version> </dependency> </dependencies>
代码:
package org.example;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
public class OnnxModelInference {
private OrtSession session;
public OnnxModelInference(String modelPath) throws Exception {
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
byte[] modelArray = Files.readAllBytes(Paths.get(modelPath));
session = env.createSession(modelArray, sessionOptions);
}
public float[][][][] preprocessImage(String imagePath) throws IOException {
// read img
BufferedImage image = ImageIO.read(new File(imagePath));
// resize to 224
BufferedImage resizedImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB);
Graphics2D g2d = resizedImage.createGraphics();
g2d.setColor(Color.WHITE); // 或根据需要选择背景颜色
g2d.fillRect(0, 0, 224, 224);
g2d.drawImage(image, 0, 0, 224, 224, null);
g2d.dispose();
// ImageNet norm
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};
float[][][][] floatArray = new float[1][3][224][224];
for (int y = 0; y < 224; y++) {
for (int x = 0; x < 224; x++) {
Color color = new Color(resizedImage.getRGB(x, y));
floatArray[0][0][y][x] = (color.getRed() / 255.0f - mean[0]) / std[0]; // R
floatArray[0][1][y][x] = (color.getGreen() / 255.0f - mean[1]) / std[1]; // G
floatArray[0][2][y][x] = (color.getBlue() / 255.0f - mean[2]) / std[2]; // B
}
}
return floatArray;
}
public String predict(String imagePath) throws Exception {
float[][][][] inputData = preprocessImage(imagePath);
OnnxTensor inputTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputData);
HashMap<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input.1", inputTensor); // "input"是模型输入节点的名称,请根据实际情况调整
// num of class names
String[] classes = {"c1", "c2"....};
// predict
try (OrtSession.Result results = session.run(inputs)) {
float[][] output = (float[][]) results.get(0).getValue();
int maxIndex = 0;
for (int i = 1; i < output[0].length; i++) {
maxIndex = output[0][i] > output[0][maxIndex] ? i : maxIndex;
}
return classes[maxIndex]; // i -> class name
}
}
public static void main(String[] args) {
try {
OnnxModelInference inferencer = new OnnxModelInference("path/tray.onnx");
// traverse dir with recursive
String dir = "path/img_test/";
long start = System.currentTimeMillis();
long cnt = 0;
for (File file : new File(dir).listFiles()) {
if (file.isDirectory()) {
for (File subFile : file.listFiles()) {
String result = inferencer.predict(subFile.getAbsolutePath());
System.out.println(subFile.getAbsolutePath() + " : " + result);
cnt++;
}
}
}
long end = System.currentTimeMillis();
System.out.println("Cost Avg: " + (end - start) / cnt + "ms");
} catch (Exception e) {
e.printStackTrace();
}
}
}
package org.example;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
public class OnnxModelInference {
private OrtSession session;
public OnnxModelInference(String modelPath) throws Exception {
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
byte[] modelArray = Files.readAllBytes(Paths.get(modelPath));
session = env.createSession(modelArray, sessionOptions);
}
public float[][][][] preprocessImage(String imagePath) throws IOException {
// read img
BufferedImage image = ImageIO.read(new File(imagePath));
// resize to 224
BufferedImage resizedImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB);
Graphics2D g2d = resizedImage.createGraphics();
g2d.setColor(Color.WHITE); // 或根据需要选择背景颜色
g2d.fillRect(0, 0, 224, 224);
g2d.drawImage(image, 0, 0, 224, 224, null);
g2d.dispose();
// ImageNet norm
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};
float[][][][] floatArray = new float[1][3][224][224];
for (int y = 0; y < 224; y++) {
for (int x = 0; x < 224; x++) {
Color color = new Color(resizedImage.getRGB(x, y));
floatArray[0][0][y][x] = (color.getRed() / 255.0f - mean[0]) / std[0]; // R
floatArray[0][1][y][x] = (color.getGreen() / 255.0f - mean[1]) / std[1]; // G
floatArray[0][2][y][x] = (color.getBlue() / 255.0f - mean[2]) / std[2]; // B
}
}
return floatArray;
}
public String predict(String imagePath) throws Exception {
float[][][][] inputData = preprocessImage(imagePath);
OnnxTensor inputTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputData);
HashMap<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input.1", inputTensor); // "input"是模型输入节点的名称,请根据实际情况调整
// num of class names
String[] classes = {"c1", "c2"....};
// predict
try (OrtSession.Result results = session.run(inputs)) {
float[][] output = (float[][]) results.get(0).getValue();
int maxIndex = 0;
for (int i = 1; i < output[0].length; i++) {
maxIndex = output[0][i] > output[0][maxIndex] ? i : maxIndex;
}
return classes[maxIndex]; // i -> class name
}
}
public static void main(String[] args) {
try {
OnnxModelInference inferencer = new OnnxModelInference("path/tray.onnx");
// traverse dir with recursive
String dir = "path/img_test/";
long start = System.currentTimeMillis();
long cnt = 0;
for (File file : new File(dir).listFiles()) {
if (file.isDirectory()) {
for (File subFile : file.listFiles()) {
String result = inferencer.predict(subFile.getAbsolutePath());
System.out.println(subFile.getAbsolutePath() + " : " + result);
cnt++;
}
}
}
long end = System.currentTimeMillis();
System.out.println("Cost Avg: " + (end - start) / cnt + "ms");
} catch (Exception e) {
e.printStackTrace();
}
}
}
package org.example; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtSession; import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.HashMap; public class OnnxModelInference { private OrtSession session; public OnnxModelInference(String modelPath) throws Exception { OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); byte[] modelArray = Files.readAllBytes(Paths.get(modelPath)); session = env.createSession(modelArray, sessionOptions); } public float[][][][] preprocessImage(String imagePath) throws IOException { // read img BufferedImage image = ImageIO.read(new File(imagePath)); // resize to 224 BufferedImage resizedImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB); Graphics2D g2d = resizedImage.createGraphics(); g2d.setColor(Color.WHITE); // 或根据需要选择背景颜色 g2d.fillRect(0, 0, 224, 224); g2d.drawImage(image, 0, 0, 224, 224, null); g2d.dispose(); // ImageNet norm float[] mean = {0.485f, 0.456f, 0.406f}; float[] std = {0.229f, 0.224f, 0.225f}; float[][][][] floatArray = new float[1][3][224][224]; for (int y = 0; y < 224; y++) { for (int x = 0; x < 224; x++) { Color color = new Color(resizedImage.getRGB(x, y)); floatArray[0][0][y][x] = (color.getRed() / 255.0f - mean[0]) / std[0]; // R floatArray[0][1][y][x] = (color.getGreen() / 255.0f - mean[1]) / std[1]; // G floatArray[0][2][y][x] = (color.getBlue() / 255.0f - mean[2]) / std[2]; // B } } return floatArray; } public String predict(String imagePath) throws Exception { float[][][][] inputData = preprocessImage(imagePath); OnnxTensor inputTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputData); HashMap<String, OnnxTensor> inputs = new HashMap<>(); inputs.put("input.1", inputTensor); // "input"是模型输入节点的名称,请根据实际情况调整 // num of class names String[] classes = {"c1", "c2"....}; // predict try (OrtSession.Result results = session.run(inputs)) { float[][] output = (float[][]) results.get(0).getValue(); int maxIndex = 0; for (int i = 1; i < output[0].length; i++) { maxIndex = output[0][i] > output[0][maxIndex] ? i : maxIndex; } return classes[maxIndex]; // i -> class name } } public static void main(String[] args) { try { OnnxModelInference inferencer = new OnnxModelInference("path/tray.onnx"); // traverse dir with recursive String dir = "path/img_test/"; long start = System.currentTimeMillis(); long cnt = 0; for (File file : new File(dir).listFiles()) { if (file.isDirectory()) { for (File subFile : file.listFiles()) { String result = inferencer.predict(subFile.getAbsolutePath()); System.out.println(subFile.getAbsolutePath() + " : " + result); cnt++; } } } long end = System.currentTimeMillis(); System.out.println("Cost Avg: " + (end - start) / cnt + "ms"); } catch (Exception e) { e.printStackTrace(); } } }
速度还挺快,在我的i7-12700,纯cpu是15ms / 张