import cv2 import numpy as np import onnxruntime as ort def load_model(model_path): """加载ONNX模型""" session = ort.InferenceSession(model_path) return session def preprocess_image(image_path): image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (224, 224)) image = image.astype(np.float32) / 255.0 # 归一化至0-1之间 # 标准化参数 mean_val = np.array([0.485, 0.456, 0.406], dtype=np.float32) std_val = np.array([0.229, 0.224, 0.225], dtype=np.float32) # 标准化图片 image = (image - mean_val) / std_val # 调整图片形状为 [1, 3, 224, 224] image = np.transpose(image, (2, 0, 1)) # 通道、高度、宽度 image = np.expand_dims(image, axis=0) # 批次大小 return image def predict(session, input_image): """对输入图片进行预测""" input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name result = session.run([output_name], {input_name: input_image}) return result def main(model_path, image_path): # 加载模型 session = load_model(model_path) # 读取和预处理图片 img = preprocess_image(image_path) # 做预测 prediction = predict(session, img) # 取得预测的类别 predicted_class_index = np.argmax(prediction[0]) class_names = ['class1', 'class2'] if predicted_class_index < len(class_names): predicted_class_name = class_names[predicted_class_index] print(f"Predicted class: {predicted_class_name}") else: print("Predicted class index is out of bounds of the class_names array") model_path = 'your_path/model.onnx' image_path = 'your_path/test/class1/03a79bee1d5ac929f9b18b4b223aad04.jpg' main(model_path, image_path)