I trained an object detection model using python yolo. i tested the .pt model, good result also the .onnx same acccurate results. but when i use the same .onnx model in react native cli it has random results and wrong.
The model is in main/assets, and it’s imported correctly, but the output is always different than the python testing results.
import * as ort from 'onnxruntime-react-native';
import {Platform} from 'react-native';
import * as RNFS from 'react-native-fs';
import ImageResizer from 'react-native-image-resizer';
import {toByteArray} from 'base64-js';
const IMAGE_SIZE = 640;
const CHANNELS = 3;
export type DetectionResult = {
className: string;
confidence: number;
bbox: {
x: number;
y: number;
width: number;
height: number;
};
};
class ImageDetectionService {
private static instance: ImageDetectionService;
private session: ort.InferenceSession | null = null;
private classes: string[];
private isInitialized: boolean;
private modelPath: string | null;
private constructor() {
this.session = null;
this.classes = [
'classe1',
'classe2',
'classe3',
'classe4',
];
this.modelPath = null;
this.isInitialized = false;
}
static getInstance(): ImageDetectionService {
if (!ImageDetectionService.instance) {
ImageDetectionService.instance = new ImageDetectionService();
}
return ImageDetectionService.instance;
}
private async loadModelFromAssets(): Promise<string> {
if (Platform.OS === 'android') {
const tempDirPath = `${RNFS.CachesDirectoryPath}/models`;
const tempModelPath = `${tempDirPath}/best.onnx`;
try {
await RNFS.mkdir(tempDirPath);
const tempModelExists = await RNFS.exists(tempModelPath);
if (!tempModelExists) {
await RNFS.copyFileAssets('best.onnx', tempModelPath);
}
return `file://${tempModelPath}`;
} catch (error) {
console.error('Error copying model file:', error);
throw error;
}
} else {
return 'best.onnx';
}
}
// Add to initialize method after creating session
async initialize(): Promise<void> {
try {
this.modelPath = await this.loadModelFromAssets();
console.log('Loading model from path:', this.modelPath);
this.session = await ort.InferenceSession.create(this.modelPath);
// Debug model information
if (this.session) {
const inputNames = await this.session.inputNames;
const outputNames = await this.session.outputNames;
console.log('Model input names:', inputNames);
console.log('Model output names:', outputNames);
}
this.isInitialized = true;
console.log('ONNX model loaded successfully');
} catch (error) {
console.error('Error initializing ONNX model:', error);
throw error;
}
}
private async preprocessImage(imageUri: string): Promise<Float32Array> {
try {
const resizedImage = await ImageResizer.createResizedImage(
imageUri,
IMAGE_SIZE,
IMAGE_SIZE,
'JPEG',
100,
0,
undefined,
false,
{onlyScaleDown: true},
);
const base64Data = await RNFS.readFile(resizedImage.uri, 'base64');
const rawImageData = toByteArray(base64Data);
const float32Data = new Float32Array(IMAGE_SIZE * IMAGE_SIZE * CHANNELS);
const MEAN = [0.485, 0.456, 0.406];
const STD = [0.229, 0.224, 0.225];
for (let i = 0; i < rawImageData.length; i += 3) {
const r = rawImageData[i] / 255.0;
const g = rawImageData[i + 1] / 255.0;
const b = rawImageData[i + 2] / 255.0;
float32Data[i] = (r - MEAN[0]) / STD[0];
float32Data[i + 1] = (g - MEAN[1]) / STD[1];
float32Data[i + 2] = (b - MEAN[2]) / STD[2];
}
const nchwData = new Float32Array(1 * CHANNELS * IMAGE_SIZE * IMAGE_SIZE);
for (let c = 0; c < CHANNELS; c++) {
for (let h = 0; h < IMAGE_SIZE; h++) {
for (let w = 0; w < IMAGE_SIZE; w++) {
const srcIdx = h * IMAGE_SIZE * CHANNELS + w * CHANNELS + c;
const dstIdx = c * IMAGE_SIZE * IMAGE_SIZE + h * IMAGE_SIZE + w;
nchwData[dstIdx] = float32Data[srcIdx];
}
}
}
return nchwData;
} catch (error) {
console.error('Error preprocessing image:', error);
throw error;
}
}
async detectObjects(imagePath: string): Promise<DetectionResult[]> {
if (!this.isInitialized || !this.session) {
throw new Error('Model not initialized. Call initialize() first.');
}
try {
const inputTensor = await this.preprocessImage(imagePath);
const feeds = {
images: new ort.Tensor('float32', inputTensor, [
1,
3,
IMAGE_SIZE,
IMAGE_SIZE,
]),
};
const results = await this.session.run(feeds);
return this.processResults(results);
} catch (error) {
console.error('Error running detection:', error);
throw error;
}
}
private calculateIoU(box1: any, box2: any): number {
const x1 = Math.max(box1.x, box2.x);
const y1 = Math.max(box1.y, box2.y);
const x2 = Math.min(box1.x + box1.width, box2.x + box2.width);
const y2 = Math.min(box1.y + box1.height, box2.y + box2.height);
if (x2 < x1 || y2 < y1) return 0;
const intersection = (x2 - x1) * (y2 - y1);
const area1 = box1.width * box1.height;
const area2 = box2.width * box2.height;
const union = area1 + area2 - intersection;
return intersection / union;
}
private nonMaxSuppression(
detections: DetectionResult[],
iouThreshold: number = 0.5,
): DetectionResult[] {
// Sort by confidence
const sorted = [...detections].sort((a, b) => b.confidence - a.confidence);
const selected: DetectionResult[] = [];
for (const detection of sorted) {
let shouldSelect = true;
for (const selectedDetection of selected) {
const iou = this.calculateIoU(detection.bbox, selectedDetection.bbox);
if (iou > iouThreshold) {
shouldSelect = false;
break;
}
}
if (shouldSelect) {
selected.push(detection);
}
}
return selected;
}
private processResults(
results: Record<string, ort.Tensor>,
): DetectionResult[] {
const outputTensor = results['output0'];
if (!outputTensor) {
console.error('Output tensor is undefined');
return [];
}
const data = outputTensor.data as Float32Array;
const rawDetections: DetectionResult[] = [];
const numClasses = this.classes.length;
// The output shape is (1, 8, 8400)
// 8 represents: [x, y, w, h, confidence, class1_score, class2_score, class3_score, class4_score]
const CONFIDENCE_THRESHOLD = 0.3;
const NUM_BOXES = 8400;
const BOX_DATA = 8;
for (let box = 0; box < NUM_BOXES; box++) {
const baseOffset = box * BOX_DATA;
// Get box coordinates
const x = data[baseOffset + 0];
const y = data[baseOffset + 1];
const width = data[baseOffset + 2];
const height = data[baseOffset + 3];
// Get confidence score
const confidence = data[baseOffset + 4];
if (confidence > CONFIDENCE_THRESHOLD) {
// Find class with highest probability
let maxClassProb = 0;
let classId = 0;
for (let c = 0; c < numClasses; c++) {
const classProb = data[baseOffset + 5 + c];
if (classProb > maxClassProb) {
maxClassProb = classProb;
classId = c;
}
}
const detection: DetectionResult = {
bbox: {
x: x,
y: y,
width: width,
height: height,
},
className: this.classes[classId],
confidence: confidence,
};
rawDetections.push(detection);
}
}
console.log(this.nonMaxSuppression(rawDetections));
// Apply NMS to filter overlapping boxes
return this.nonMaxSuppression(rawDetections);
}
}
export default ImageDetectionService;