class aimet_tensorflow.quantsim.QuantizationSimModel(session: tf.compat.v1.Session, starting_op_names: List[str], output_op_names: List[str], quant_scheme: Union[str, QuantScheme] = 'tf_enhanced', rounding_mode: str = 'nearest', default_output_bw: int = 8, default_param_bw: int = 8, use_cuda: bool = True, config_file: str = None, default_data_type: QuantizationDataType = QuantizationDataType.int)QuantizationSimModel类方法
export(self, path: str, filename_prefix: str, orig_sess: tf.compat.v1.Session = None)
导包
import tensorflow as tf
# Import the tensorflow quantisim
from aimet_tensorflow import quantsim
from aimet_tensorflow.common import graph_eval
from aimet_tensorflow.utils import graph_saver
from aimet_common.defs import QuantScheme
传入标定或校准数据
def pass_calibration_data(session: tf.Session):
"""
The User of the QuantizationSimModel API is expected to write this function based on their data set.
This is not a working function and is provided only as a guideline.
:param session: Model's session
:return:
"""
# User action required
# The following line of code is an example of how to use the ImageNet data's validation data loader.
# Replace the following line with your own dataset's validation data loader.
data_loader = None # Your Dataset's data loader
# User action required
# For computing the activation encodings, around 1000 unlabelled data samples are required.
# Edit the following 2 lines based on your dataloader's batch size.
# batch_size * max_batch_counter should be 1024
batch_size = 64
max_batch_counter = 16
input_tensor = None # input tensor in session
train_tensor = None # train tensor in session
current_batch_counter = 0
for input_data, _ in data_loader:
feed_dict = {input_tensor: input_data,
train_tensor: False}
session.run([], feed_dict=feed_dict)
current_batch_counter += 1
if current_batch_counter == max_batch_counter:
break
后量化以及微调(即QAT)
def quantize_model():
"""
Create the Quantization Simulation and finetune the model.
:return:
"""
tf.compat.v1.reset_default_graph()
# load graph
sess = graph_saver.load_model_from_meta('models/mnist_save.meta', 'models/mnist_save')
# Create quantsim model to quantize the network using the default 8 bit params/activations
sim = quantsim.QuantizationSimModel(sess, starting_op_names=['reshape_input'], output_op_names=['dense_1/BiasAdd'],
quant_scheme=QuantScheme.post_training_tf_enhanced,
config_file='../../../TrainingExtensions/common/src/python/aimet_common/'
'quantsim_config/default_config.json')
# Compute encodings
sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=None)
# Do some finetuning
# User action required
# The following line of code illustrates that the model is getting finetuned.
# Replace the following train() function with your pipeline's train() function.
train(sim)
量化和微调训练好的模型,以学习编码(即range learning)
def quantization_aware_training_range_learning():
"""
Running Quantize Range Learning Test
"""
tf.reset_default_graph()
# Allocate the generator you wish to use to provide the network with data
parser2 = tf_gen.MnistParser(batch_size=100, data_inputs=['reshape_input'])
generator = tf_gen.TfRecordGenerator(tfrecords=[os.path.join('data', 'mnist', 'validation.tfrecords')],
parser=parser2)
sess = graph_saver.load_model_from_meta('models/mnist_save.meta', 'models/mnist_save')
# Create quantsim model to quantize the network using the default 8 bit params/activations
# quant scheme set to range learning
sim = quantsim.QuantizationSimModel(sess, ['reshape_input'], ['dense_1/BiasAdd'],
quant_scheme=QuantScheme.training_range_learning_with_tf_init)
# Initialize the model with encodings
sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=None)
# Train the model to fine-tune the encodings
g = sim.session.graph
sess = sim.session
with g.as_default():
parser2 = tf_gen.MnistParser(batch_size=100, data_inputs=['reshape_input'])
generator2 = tf_gen.TfRecordGenerator(tfrecords=['data/mnist/validation.tfrecords'], parser=parser2)
cross_entropy = g.get_operation_by_name('xent')
train_step = g.get_operation_by_name("Adam")
# do training: learn weights and architecture simultaneously
x = sim.session.graph.get_tensor_by_name("reshape_input:0")
y = g.get_tensor_by_name("labels:0")
fc1_w = g.get_tensor_by_name("dense_1/MatMul/ReadVariableOp:0")
perf = graph_eval.evaluate_graph(sess, generator2, ['accuracy'], graph_eval.default_eval_func, 1)
print('Quantized performance: ' + str(perf * 100))
ce = g.get_tensor_by_name("xent:0")
train_step = tf.train.AdamOptimizer(1e-3, name="TempAdam").minimize(ce)
graph_eval.initialize_uninitialized_vars(sess)
mnist = input_data.read_data_sets('./data', one_hot=True)
for i in range(100):
batch = mnist.train.next_batch(50)
sess.run([train_step, fc1_w], feed_dict={x: batch[0], y: batch[1]})
if i % 10 == 0:
perf = graph_eval.evaluate_graph(sess, generator2, ['accuracy'], graph_eval.default_eval_func, 1)
print('Quantized performance: ' + str(perf * 100))
# close session
sess.close()