Quantize ONNX model in python
In this video it shows the steps and code to generate a QInt8 QuantType quantized ONNX model in Python.
It quantized the gpt2-model.onnx which was initially created using the torch.onnx.export API.
For previous code one can refer to: https://programmerworld.co/ai-gen-ai-generative-artificial-intelligence/how-to-create-local-onnx-file-for-text-generation-in-python-using-a-ai-llm-model/
I hope you like this video. For any questions, suggestions or appreciation please contact us at: https://programmerworld.co/contact/ or email at: programmerworld1990@gmail.com
Details:
Python Code:
quantize_model.py
from onnxruntime.quantization import quantize_dynamic, QuantType
# Load your ONNX model
onnx_model_path = "gpt2_model.onnx"
quantized_model_path = "gpt2_model_quantized.onnx"
# Perform dynamic quantization
quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8)
print(f"The quantized model has been saved as '{quantized_model_path}'!")
generate_text_from_onnx.py
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import onnx
from onnxruntime import InferenceSession
# Load the pre-trained GPT-2 model and tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Create dummy input for the model
text = "Write a story about sunflowers"
inputs = tokenizer(text, return_tensors="pt")
dummy_input = inputs['input_ids']
# Export the model to ONNX with dynamic axes
# onnx_file_path = "gpt2_model.onnx"
onnx_file_path = "gpt2_model_quantized.onnx"
torch.onnx.export(model, dummy_input, onnx_file_path,
input_names=['input_ids'],
output_names=['logits'],
dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'logits': {0: 'batch_size', 1: 'sequence_length'}},
opset_version=14)
# Validate the ONNX model
onnx_model = onnx.load(onnx_file_path)
onnx.checker.check_model(onnx_model)
print("The model is valid and has been saved as 'gpt2_model.onnx'!")
# Perform inference using the ONNX model with dynamic input size
session = InferenceSession(onnx_file_path)
# Function to apply top-p (nucleus) sampling
def top_p_sampling(logits, top_p=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for batch_idx in range(sorted_indices.size(0)):
logits[batch_idx, sorted_indices[batch_idx][sorted_indices_to_remove[batch_idx]]] = -float('Inf')
return logits
# Function to generate text iteratively with top-p sampling and repetition penalty
def generate_text(session, tokenizer, initial_input, max_length=1000, top_p=0.9, repetition_penalty=1.2):
input_ids = initial_input
generated_text = ""
for _ in range(max_length):
onnx_inputs = {"input_ids": input_ids.numpy()}
onnx_outputs = session.run(None, onnx_inputs)
logits = onnx_outputs[0]
logits = torch.tensor(logits[:, -1, :]) / repetition_penalty # Apply repetition penalty
filtered_logits = top_p_sampling(logits, top_p=top_p) # Apply top-p sampling
probabilities = torch.nn.functional.softmax(filtered_logits, dim=-1)
next_token_id = torch.multinomial(probabilities, 1).unsqueeze(0)
next_token_id = next_token_id.squeeze(1) # Ensure correct dimensions
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
generated_text = tokenizer.decode(input_ids[0])
# Optionally, break if EOS token is generated
if next_token_id.item() == tokenizer.eos_token_id:
break
return generated_text
# Generate longer text
generated_text = generate_text(session, tokenizer, dummy_input, max_length=100)
print("Generated text:", generated_text)
Screenshots:
data:image/s3,"s3://crabby-images/7a2c9/7a2c9e9a5b2a6bd751203737c3b32fe7ac65c609" alt=""
data:image/s3,"s3://crabby-images/48cbf/48cbff95d93e398f481578a14305a0ddb4fd0dde" alt=""