shtaxxx日記

コンピュータアーキテクチャについて研究している研究者の日記や技術紹介

KLダイバージェンスを用いた異なる量子化パラメータの比較

NVIDIAのINT8量子化では、 KL Divergenceを使って活性値の分布をより正確に近似できる量子化パラメータ(最大値と最小値)を探索しているらしいので、それを実装してみた。

量子化前後でヒストグラムのビンの幅が異なるので、そのままではKLダイバージェンスの計算ができない。そこで、量子化後のヒストグラム量子化前のビンの幅を使ったヒストグラムに変換して、KLダイバージェンスを計算することにした。

以下の例だと、q1はq2はそれぞれ(-20.0, 20.0)と(-7.0, 7.0)という異なる最小値・最大値でクリップするようにしており、16個のビンでヒストグラムを作る(=4ビット量子化)場合にどちらが良いかをKL-divで判定する。 この場合、q1よりもq2の方がKL-divが小さく、元の分布を正確に表現できることがわかる。

from __future__ import absolute_import
from __future__ import print_function

import matplotlib.pyplot as plt
import numpy as np

N = 1000 * 1000
loc = 0
scale = 2
epsilon = 0.00001

ref_num_bins = 1024
q_num_bins = 16

dist = np.random.normal(loc, scale, N)
q1_dist = np.clip(dist, -20.0, 20.0)
q2_dist = np.clip(dist, -7.0, 7.0)

ref_hist, ref_bins = np.histogram(dist, bins=ref_num_bins, density=True)
q1_hist, q1_bins = np.histogram(q1_dist, bins=q_num_bins, density=True)
q2_hist, q2_bins = np.histogram(q2_dist, bins=q_num_bins, density=True)


def to_hist_with_orig_bins(targ_hist, targ_bins, orig_hist, orig_bins):
    targ_v = 0.0
    targ_i = 0
    targ_bin = targ_bins[0]
    ret_hist = np.zeros_like(orig_hist)

    for i, orig_bin in enumerate(orig_bins[:-1]):
        if targ_bin <= orig_bin:
            if targ_i < len(targ_bins) - 1:
                targ_v = targ_hist[targ_i]
                targ_i += 1
                targ_bin = targ_bins[targ_i]
            else:
                targ_v = 0.0
                targ_bin = orig_bin.max() + 1.0

        ret_hist[i] = targ_v

    return ret_hist


c_q1_hist = to_hist_with_orig_bins(q1_hist, q1_bins, ref_hist, ref_bins)
c_q2_hist = to_hist_with_orig_bins(q2_hist, q2_bins, ref_hist, ref_bins)

pad_ref_bins = np.pad(ref_bins, [1, 0], 'constant')
sumd = np.sum((ref_bins - pad_ref_bins[:-1])[1:])


ref_hist = (ref_hist + epsilon) / (1.0 + epsilon * sumd)
c_q1_hist = (c_q1_hist + epsilon) / (1.0 + epsilon * sumd)
c_q2_hist = (c_q2_hist + epsilon) / (1.0 + epsilon * sumd)

kl_ref = np.sum(ref_hist * np.log(ref_hist / ref_hist))
kl_c_q1 = np.sum(ref_hist * np.log(ref_hist / c_q1_hist))
kl_c_q2 = np.sum(ref_hist * np.log(ref_hist / c_q2_hist))


def to_labels(bins):
    labels = []
    for i in range(len(bins) - 1):
        labels.append((bins[i] + bins[i + 1]) / 2)
    return labels


ref_labels = to_labels(ref_bins)
q1_labels = to_labels(q1_bins)
q2_labels = to_labels(q2_bins)

plt.figure(figsize=(10, 5))

#plt.bar(ref_labels, ref_hist, label='ref')

plt.plot(ref_labels, ref_hist, label='ref')
plt.plot(q1_labels, q1_hist, label='q1')
plt.plot(q2_labels, q2_hist, label='q2')

plt.plot(ref_labels, c_q1_hist, label='q1 KL=%f' % kl_c_q1)
plt.plot(ref_labels, c_q2_hist, label='q2 KL=%f' % kl_c_q2)

plt.legend(title='histogram', loc='best')
plt.grid()

# plt.show()
plt.savefig('out.png')

https://gist.github.com/shtaxxx/6ca20df2cb7933291fdb9cb02ccf2088

f:id:sxhxtxa:20190422121826p:plain

参考

https://paper.hatenadiary.jp/entry/2018/10/13/164539