-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathApp.tsx
65 lines (59 loc) · 1.83 KB
/
App.tsx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import { StatusBar } from 'expo-status-bar';
import { Alert, Button, StyleSheet, Text, View } from 'react-native';
import * as ort from 'onnxruntime-react-native';
import { Asset } from 'expo-asset';
let myModel: ort.InferenceSession;
async function loadModel() {
try {
const assets = await Asset.loadAsync(require('./assets/mnist.ort'));
const modelUri = assets[0].localUri;
if (!modelUri) {
Alert.alert('failed to get model URI', `${assets[0]}`);
} else {
myModel = await ort.InferenceSession.create(modelUri);
Alert.alert(
'model loaded successfully',
`input names: ${myModel.inputNames}, output names: ${myModel.outputNames}`);
}
} catch (e) {
Alert.alert('failed to load model', `${e}`);
throw e;
}
}
async function runModel() {
try {
const inputData = new Float32Array(28 * 28);
const feeds:Record<string, ort.Tensor> = {};
feeds[myModel.inputNames[0]] = new ort.Tensor(inputData, [1, 28, 28]);
const fetches = await myModel.run(feeds);
const output = fetches[myModel.outputNames[0]];
if (!output) {
Alert.alert('failed to get output', `${myModel.outputNames[0]}`);
} else {
Alert.alert(
'model inference successfully',
`output shape: ${output.dims}, output data: ${output.data}`);
}
} catch (e) {
Alert.alert('failed to inference model', `${e}`);
throw e;
}
}
export default function App() {
return (
<View style={styles.container}>
<Text>using ONNX Runtime for React Native</Text>
<Button title='Load model' onPress={loadModel}></Button>
<Button title='Run' onPress={runModel}></Button>
<StatusBar style="auto" />
</View>
);
}
const styles = StyleSheet.create({
container: {
flex: 1,
backgroundColor: '#fff',
alignItems: 'center',
justifyContent: 'center',
},
});