Run LLM (ONNX) in Android App
Steps:
- Generate Quantized ONNX file for the LLM model (distilbert-base-uncased) – Python
- Validate the ONNX model (optional) – Python
- Get vocab file (for tokenization) for respective LLM model – https://huggingface.co/distilbert/distilbert-base-uncased/raw/main/vocab.txt
- Create the Android App code – Java
I hope you like this video. For any questions, suggestions or appreciation please contact us at: https://programmerworld.co/contact/ or email at: programmerworld1990@gmail.com
Details:
Python codes:
Generate the ONNX model:
from transformers import DistilBertModel, DistilBertTokenizer
import torch
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
# Load pre-trained DistilBert model and tokenizer
model_name = "distilbert-base-uncased"
model = DistilBertModel.from_pretrained(model_name)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
# Create dummy input for the model
text = "Once upon a time"
inputs = tokenizer(text, return_tensors="pt")
dummy_input = inputs['input_ids']
# Export the model to ONNX with dynamic axes
onnx_file_path = "distilbert_model.onnx"
torch.onnx.export(model, dummy_input, onnx_file_path,
input_names=['input_ids'],
output_names=['last_hidden_state'],
dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}},
opset_version=14)
# Quantize the model using dynamic quantization
quantized_model_path = "distilbert_model_quantized.onnx"
quantize_dynamic(onnx_file_path, quantized_model_path, weight_type=QuantType.QUInt8)
print(f"The quantized model has been saved as '{quantized_model_path}'!")
Tokenize the text:
from transformers import DistilBertTokenizer
import json
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
def tokenize_text(input_text):
encoding = tokenizer(input_text, return_tensors="pt")
input_ids = encoding["input_ids"].tolist()
return input_ids
if __name__ == "__main__":
input_text = "Poem about time"
tokenized_output = tokenize_text(input_text)
with open("tokenized_output.json", "w") as f:
json.dump(tokenized_output, f)
Sample Tokenized JSON file:
[[101, 5961, 2055, 2051, 102]]
Validate the model (optional):
import onnxruntime as ort
import json
# Load the quantized ONNX model
onnx_model_path = "distilbert_model_quantized.onnx"
session = ort.InferenceSession(onnx_model_path)
# Read the tokenized input from the JSON file
with open("tokenized_output.json", "r") as f:
input_ids = json.load(f)
# Prepare the input for the model
inputs = {"input_ids": input_ids}
# Run the model and get the output
results = session.run(None, inputs)
# Print the output
print("Output:", results)
Java code:
package com.programmerworld.onnxmodelruninandroid;
import static android.content.ContentValues.TAG;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.EditText;
import android.widget.TextView;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
private OnnxModelRunner modelRunner;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
}
public void buttonGenerate(View view){
// Initialize the ONNX model runner
modelRunner = new OnnxModelRunner(this);
// Get the input prompt
EditText editText = findViewById(R.id.editText);
String inputText = editText.getText().toString(); //Example prompt: "Once upon a time";
// Run inference
String outputText = modelRunner.runInference(inputText);
// Display the output
TextView textView = findViewById(R.id.textView);
if (outputText != null && !outputText.isEmpty()) {
textView.setText(outputText);
Log.i(TAG, outputText);
} else {
textView.setText("Inference failed or returned empty output.");
}
// Close the model runner
modelRunner.close();
}
}
package com.programmerworld.onnxmodelruninandroid;
// How to run an AI LLM model locally in your Android App?
import android.content.Context;
import android.content.res.AssetManager;
import android.util.Log;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.LongBuffer;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
public class OnnxModelRunner {
private static final String TAG = "OnnxModelRunner";
private OrtEnvironment env;
private OrtSession session;
private Context context;
public OnnxModelRunner(Context context) {
this.context = context; // Assign the context
try {
// Initialize the ONNX Runtime environment
env = OrtEnvironment.getEnvironment();
Log.d(TAG, "ONNX Runtime environment initialized.");
// Copy the ONNX model from assets to internal storage
String modelPath = copyAssetToInternalStorage(context, "distilbert_model_quantized.onnx");
Log.d(TAG, "Model path: " + modelPath);
session = env.createSession(modelPath);
Log.d(TAG, "ONNX model session created.");
} catch (OrtException | IOException e) {
Log.e(TAG, "Error initializing ONNX model: " + e.getMessage());
session = null;
}
}
private String copyAssetToInternalStorage(Context context, String assetFileName) throws IOException {
AssetManager assetManager = context.getAssets();
InputStream inputStream = assetManager.open(assetFileName);
File outFile = new File(context.getFilesDir(), assetFileName);
FileOutputStream outputStream = new FileOutputStream(outFile);
byte[] buffer = new byte[1024];
int read;
while ((read = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, read);
}
inputStream.close();
outputStream.close();
if (!outFile.exists()) {
Log.e(TAG, "Model file does not exist after copying!");
return null;
}
Log.d(TAG, "Model copied successfully to: " + outFile.getAbsolutePath());
return outFile.getAbsolutePath();
}
public String runInference(String inputText) {
if (session == null) {
Log.e(TAG, "ONNX session is not initialized.");
return null;
}
try {
// Read the tokenized output from JSON file
Tokenizer tokenizer = new Tokenizer(context);
long[] inputIds = tokenizer.tokenize(inputText);
// Create input tensor
OnnxTensor inputTensor = prepareInput(inputIds);
// Run the model
Map<String, OnnxTensor> inputMap = Collections.singletonMap("input_ids", inputTensor);
Result result = session.run(inputMap);
// Get the output tensor and decode it
Optional<OnnxValue> optionalOutput = result.get("last_hidden_state");
if (optionalOutput.isPresent()) {
OnnxTensor outputTensor = (OnnxTensor) optionalOutput.get();
// Convert output to meaningful text using the Tokenizer
return decodeOutput(outputTensor, tokenizer);
} else {
Log.e(TAG, "No output returned from the model.");
return null;
}
} catch (OrtException | IOException e) {
Log.e(TAG, "Error running ONNX model inference: " + e.getMessage());
return null;
}
}
private OnnxTensor prepareInput(long[] inputIds) throws OrtException {
// Create input tensor from tokenized input IDs
long[] shape = new long[]{1, inputIds.length}; // Batch size of 1
return OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds), shape);
}
private String decodeOutput(OnnxTensor outputTensor, Tokenizer tokenizer) throws OrtException, IOException {
// Decode the output tensor to a string
float[][][] outputArray = (float[][][]) outputTensor.getValue();
StringBuilder sb = new StringBuilder();
for (float[][] sequence : outputArray) {
for (float[] tokenProbs : sequence) {
// Step 1: Find argmax (index of highest probability)
int maxIndex = 0;
float maxValue = tokenProbs[0];
for (int i = 1; i < tokenProbs.length; i++) {
if (tokenProbs[i] > maxValue) {
maxValue = tokenProbs[i];
maxIndex = i;
}
}
// Step 2: Convert token ID to text using tokenizer
String tokenText = tokenizer.decode(maxIndex);
// Step 3: Append token text
sb.append(tokenText).append(" "); // Add space between tokens
}
}
return sb.toString();
}
public void close() {
try {
if (session != null) {
session.close();
}
if (env != null) {
env.close();
}
} catch (OrtException e) {
Log.e(TAG, "Error closing ONNX model: " + e.getMessage());
}
}
}
package com.programmerworld.onnxmodelruninandroid;
import android.content.Context;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;
public class Tokenizer {
private Map<Integer, String> idToToken = new HashMap<>();
private Map<String, Integer> tokenToId = new HashMap<>();
public Tokenizer(Context context) {
loadVocab(context);
}
private void loadVocab(Context context) {
try {
InputStream is = context.getAssets().open("vocab.txt");
BufferedReader reader = new BufferedReader(new InputStreamReader(is));
String line;
int index = 0;
while ((line = reader.readLine()) != null) {
idToToken.put(index, line);
tokenToId.put(line, index);
index++;
}
reader.close();
} catch (Exception e) {
e.printStackTrace();
}
}
// Tokenize input text to token IDs
public long[] tokenize(String inputText) {
String[] words = inputText.toLowerCase().split("\\s+"); // Basic whitespace tokenization
long[] inputIds = new long[words.length + 2]; // [CLS] + tokens + [SEP]
inputIds[0] = tokenToId.getOrDefault("[CLS]", 101); // Add [CLS] token
for (int i = 0; i < words.length; i++) {
inputIds[i + 1] = tokenToId.getOrDefault(words[i], tokenToId.getOrDefault("[UNK]", 100)); // Map words to token IDs
}
inputIds[inputIds.length - 1] = tokenToId.getOrDefault("[SEP]", 102); // Add [SEP] token
return inputIds;
}
// Decode token ID to word
public String decode(int tokenId) {
return idToToken.getOrDefault(tokenId, "[UNK]");
}
}
Manifest file:
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.INTERNET"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<application
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.OnnxModelRunInAndroid"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>
Gradle File:
plugins {
alias(libs.plugins.android.application)
}
android {
namespace = "com.programmerworld.onnxmodelruninandroid"
compileSdk = 35
defaultConfig {
applicationId = "com.programmerworld.onnxmodelruninandroid"
minSdk = 33
targetSdk = 35
versionCode = 1
versionName = "1.0"
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
}
}
dependencies {
implementation(libs.appcompat)
implementation(libs.material)
implementation(libs.activity)
implementation(libs.constraintlayout)
testImplementation(libs.junit)
androidTestImplementation(libs.ext.junit)
androidTestImplementation(libs.espresso.core)
implementation("com.microsoft.onnxruntime:onnxruntime-android:1.17.1")
}
Layout XML file:
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:id="@+id/main"
tools:context=".MainActivity">
<TextView
android:id="@+id/textView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Output will be displayed here"
android:textSize="18sp"
android:layout_centerInParent="true" />
<EditText
android:id="@+id/editText"
android:layout_width="388dp"
android:layout_height="118dp"
android:layout_marginLeft="10dp"
android:layout_marginTop="80dp"
android:ems="10"
android:inputType="text"
android:text="Once upon a time ..." />
<Button
android:id="@+id/button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginLeft="150dp"
android:layout_marginTop="250dp"
android:onClick="buttonGenerate"
android:text="Generate" />
</RelativeLayout>
Screenshots:


Model validation in Python:

Output: [array([[[-0.3429286 , -0.07686502, -0.06207633, ..., 0.02597048,
0.18410149, 0.37661746],
[ 0.0625054 , 0.00804514, -0.24020125, ..., -0.10774948,
-0.07214016, 0.27932528],
[ 0.08012569, 0.2088192 , 0.07876888, ..., 0.05914421,
-0.21747732, 0.01284895],
[-0.0507772 , -0.13391055, -0.3511802 , ..., 0.03803144,
0.3148152 , -0.4408192 ],
[ 0.98205227, 0.2172539 , -0.20831339, ..., 0.05268615,
-0.7091678 , -0.29968455]]], dtype=float32)]