KLダイバージェンスを用いた異なる量子化パラメータの比較
NVIDIAのINT8量子化では、 KL Divergenceを使って活性値の分布をより正確に近似できる量子化パラメータ(最大値と最小値)を探索しているらしいので、それを実装してみた。
量子化前後でヒストグラムのビンの幅が異なるので、そのままではKLダイバージェンスの計算ができない。そこで、量子化後のヒストグラムを量子化前のビンの幅を使ったヒストグラムに変換して、KLダイバージェンスを計算することにした。
以下の例だと、q1はq2はそれぞれ(-20.0, 20.0)と(-7.0, 7.0)という異なる最小値・最大値でクリップするようにしており、16個のビンでヒストグラムを作る(=4ビット量子化)場合にどちらが良いかをKL-divで判定する。 この場合、q1よりもq2の方がKL-divが小さく、元の分布を正確に表現できることがわかる。
from __future__ import absolute_import from __future__ import print_function import matplotlib.pyplot as plt import numpy as np N = 1000 * 1000 loc = 0 scale = 2 epsilon = 0.00001 ref_num_bins = 1024 q_num_bins = 16 dist = np.random.normal(loc, scale, N) q1_dist = np.clip(dist, -20.0, 20.0) q2_dist = np.clip(dist, -7.0, 7.0) ref_hist, ref_bins = np.histogram(dist, bins=ref_num_bins, density=True) q1_hist, q1_bins = np.histogram(q1_dist, bins=q_num_bins, density=True) q2_hist, q2_bins = np.histogram(q2_dist, bins=q_num_bins, density=True) def to_hist_with_orig_bins(targ_hist, targ_bins, orig_hist, orig_bins): targ_v = 0.0 targ_i = 0 targ_bin = targ_bins[0] ret_hist = np.zeros_like(orig_hist) for i, orig_bin in enumerate(orig_bins[:-1]): if targ_bin <= orig_bin: if targ_i < len(targ_bins) - 1: targ_v = targ_hist[targ_i] targ_i += 1 targ_bin = targ_bins[targ_i] else: targ_v = 0.0 targ_bin = orig_bin.max() + 1.0 ret_hist[i] = targ_v return ret_hist c_q1_hist = to_hist_with_orig_bins(q1_hist, q1_bins, ref_hist, ref_bins) c_q2_hist = to_hist_with_orig_bins(q2_hist, q2_bins, ref_hist, ref_bins) pad_ref_bins = np.pad(ref_bins, [1, 0], 'constant') sumd = np.sum((ref_bins - pad_ref_bins[:-1])[1:]) ref_hist = (ref_hist + epsilon) / (1.0 + epsilon * sumd) c_q1_hist = (c_q1_hist + epsilon) / (1.0 + epsilon * sumd) c_q2_hist = (c_q2_hist + epsilon) / (1.0 + epsilon * sumd) kl_ref = np.sum(ref_hist * np.log(ref_hist / ref_hist)) kl_c_q1 = np.sum(ref_hist * np.log(ref_hist / c_q1_hist)) kl_c_q2 = np.sum(ref_hist * np.log(ref_hist / c_q2_hist)) def to_labels(bins): labels = [] for i in range(len(bins) - 1): labels.append((bins[i] + bins[i + 1]) / 2) return labels ref_labels = to_labels(ref_bins) q1_labels = to_labels(q1_bins) q2_labels = to_labels(q2_bins) plt.figure(figsize=(10, 5)) #plt.bar(ref_labels, ref_hist, label='ref') plt.plot(ref_labels, ref_hist, label='ref') plt.plot(q1_labels, q1_hist, label='q1') plt.plot(q2_labels, q2_hist, label='q2') plt.plot(ref_labels, c_q1_hist, label='q1 KL=%f' % kl_c_q1) plt.plot(ref_labels, c_q2_hist, label='q2 KL=%f' % kl_c_q2) plt.legend(title='histogram', loc='best') plt.grid() # plt.show() plt.savefig('out.png')
https://gist.github.com/shtaxxx/6ca20df2cb7933291fdb9cb02ccf2088