金光华网站建设,如何做外贸营销型网站推广,兰州画册设计公司,衡器行业网站建设模板有些地方看的不是透彻#xff0c;后续继续补充#xff01;
继续看张量量化函数#xff0c;代码位于#xff1a;tools\pytorch-quantization\pytorch_quantization\tensor_quant.py
ScaledQuantDescriptor
量化的支持描述符:描述张量应该如何量化。QuantDescriptor和张量…有些地方看的不是透彻后续继续补充
继续看张量量化函数代码位于tools\pytorch-quantization\pytorch_quantization\tensor_quant.py
ScaledQuantDescriptor
量化的支持描述符:描述张量应该如何量化。QuantDescriptor和张量定义了量化张量。
class ScaledQuantDescriptor():def __init__(self, num_bits8, nameNone, **kwargs):if not isinstance(num_bits, int):raise TypeError(num_bits must be an integer, not {}..format(type(num_bits)))if num_bits 0:raise ValueError(num_bits must be 0, not {}..format(num_bits))if num_bits 0:logging.error(num_bits is 0. This will result in the tensor being quantized to all zeros. This mode should only be used for debugging purposes.)self._num_bits num_bitsif not isinstance(name, str) and name is not None:raise TypeError(name must be a string or None, not {}..format(type(name)))self._name nameself._fake_quant kwargs.pop(fake_quant, True)self._axis kwargs.pop(axis, None)if self._axis is not None:logging.debug(Meaning of axis has changed since v2.0. Make sure to update.)self._learn_amax kwargs.pop(learn_amax, False)if self._learn_amax and self._axis is not None:raise TypeError(axis is ignored and must be None when learn_amax is true, got {}..format(type(self._axis)))amax kwargs.pop(amax, None)if amax is not None:if not isinstance(amax, float) and not isinstance(amax, list) and not isinstance(amax, np.ndarray):raise TypeError(amax must be float, list or ndarray, not {}.format(type(amax)))# Make it single precision arrayself._amax np.array(amax, dtypenp.float32)else:self._amax amaxself._scale_amax kwargs.pop(scale_amax, None)self._calib_method kwargs.pop(calib_method, max)self._unsigned kwargs.pop(unsigned, False)self._narrow_range kwargs.pop(narrow_range, False)if kwargs:raise TypeError(Unused keys: {}.format(kwargs.keys()))参数
num_bits:int,量化位数,用于计算比例因子。默认值8。name:看起来很不错
关键字参数 fake_quant布尔值。如果为True则使用fake量化模式。默认为True axisNone, int或整数的tuple轴将利用自己的最大值以计算缩放因子默认None。 如果None默认值则使用per tensor scale。 确保在范围[-rankinput_tensorrank输入_tensor内。 例如对于KCRS权重张量quant_axis0将产生per channel scaling。 amax用户指定的绝对最大范围的float或list/ndarray。如果提供忽略quant_axis并使用它进行量化。如果learn_amax为True将用于初始化可学习的amax。默认None learn_amaxboolean如果为True学习amax。默认为False。 scale_amaxfloat如果提供将amax乘以scale_amax默认无。 calib_methodstring[“max”“histogram”]中的一个校准要使用的指标。除了 max calibration其他都是基于hisogram的。默认值“max”。 unsignedBoolean如果为True则使用无符号。默认为False。
Raises:
TypeError如果传入了不支持的类型。
Read-only properties:
fake_quant:name:learn_amax:scale_amax:axis:calib_method:num_bits:amax:unsigned:
QuantDescriptor定义了张量应该如何量化。预定义的QuantDescriptor张量描述符如下
QuantDescriptor ScaledQuantDescriptor# Predefined descriptors
QUANT_DESC_8BIT_PER_TENSOR QuantDescriptor(num_bits8)
QUANT_DESC_UNSIGNED_8BIT_PER_TENSOR QuantDescriptor(num_bits8, unsignedTrue)
QUANT_DESC_8BIT_CONV1D_WEIGHT_PER_CHANNEL QuantDescriptor(num_bits8, axis(0))
QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL QuantDescriptor(num_bits8, axis(0))
QUANT_DESC_8BIT_CONV3D_WEIGHT_PER_CHANNEL QuantDescriptor(num_bits8, axis(0))
QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW QuantDescriptor(num_bits8, axis(0))
QUANT_DESC_8BIT_CONVTRANSPOSE1D_WEIGHT_PER_CHANNEL QuantDescriptor(num_bits8, axis(0))
QUANT_DESC_8BIT_CONVTRANSPOSE2D_WEIGHT_PER_CHANNEL QuantDescriptor(num_bits8, axis(0))
QUANT_DESC_8BIT_CONVTRANSPOSE3D_WEIGHT_PER_CHANNEL QuantDescriptor(num_bits8, axis(0))如果在QuantDescriptor中给出最amaxTensorQuantizer将使用它进行量化。否则TensorQuantizer将计算amax然后进行量化。amax被计算通过指定的axis轴。注意QuantDescriptor将剩余轴指定与max()轴相反。
例子
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizerquant_desc QuantDescriptor(num_bits4, fake_quantFalse, axis(0), unsignedTrue)接下来看量化函数pytorch_quantization提供3个自定义的张量量化函数算子继承torch.autograd.function实现函数的前向传播、反向传播
TensorQuantFunction
通用的张量量化函数TensorQuantFunction
class TensorQuantFunction(Function):一个输入张量输出一个量化张量。scale的粒度可以从amax的形状来解释forward
在前向过程中对浮点权重和激活进行伪量化并使用这些伪量化的权重和激活来执行层的操作 staticmethoddef forward(ctx, inputs, amax, num_bits8, unsignedFalse, narrow_rangeTrue):ctx.save_for_backward(inputs, amax)outputs, scale _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)# Check if scale overflows FP16if outputs.dtype torch.half and scale.max() 65504:raise ValueError(scale is too large for FP16 with amax{}.format(amax))return outputs, scale.to(inputs.dtype)output_dtype指示量化值是以整数还是浮点形式存储。希望将其存储在浮点中的原因是pytorch函数接受量化值它可能不接受整数输入例如Conv2D。 它使用2num_bits−12^{num\_bits-1}2num_bits−1值例如对于num_bits8使用[-127127]
遵循tensorflow约定传入最大值并用于确定比例而不是直接输入比例。尽管直接输入比例可能更自然。
参数 ctx一个用于向后存储张量的Context对象。 inputsfloat32型张量。 amaxfloat32型张量。输入将在[-amaxamax]范围内量化amax将广播到inputs tensor。 num_bits用于计算缩放因子的整数scale(2num_bits−1−1)/maxscale(2^{num\_bits-1}-1)/maxscale(2num_bits−1−1)/max。默认值8。 output_dtype张量的一种类型。torch.int32或torch.float32。希望存储为floatpytorch函数接受float量化值它可能不接受整数输入。 unsignedboolean使用无符号整数范围。例如对于num_bits8[0255]。默认为False。 narrow_range布尔值。使用对称整数范围进行有符号量化 例如对于num_bits8用[-127127]代替[-128127]。默认为True。
Returns: outputsoutput_dtype类型的张量。 scalefloat32型张量。outputs / scale将对输出张量进行反量化。
Raises:
ValueError:
backward
通过clipping实现直通估计。对于-amaxinputamax梯度直接通过否则梯度为零。 参数
ctx一个上下文对象其中保存了来自forward的张量。grad_outputsoutputs梯度张量。grad_scalescale梯度张量。
Returns:
grad_inputs梯度张量。 staticmethoddef backward(ctx, grad_outputs, grad_scale):Implements straight through estimation with clipping. For -amax input amaxthe gradient passes straight through, otherwise the gradient is zero.Args:ctx: A Context object with saved tensors from forward.grad_outputs: A tensor of gradient of outputs.grad_scale: A tensor of gradient of scale.Returns:grad_inputs: A tensor of gradient.inputs, amax ctx.saved_tensorszero grad_outputs.new_zeros(1) # create a zero tensor with the same type and devicegrad_inputs torch.where(inputs.abs() amax, grad_outputs, zero)return grad_inputs, None, None, None, Nonetensor_quant TensorQuantFunction.apply给TensorQuantFunction.apply赋予一个别名tensor_quant这样可以直接调用tensor_quant进行量化例如
from pytorch_quantization import tensor_quant# Generate random input. With fixed seed 12345, x should be
# tensor([0.9817, 0.8796, 0.9921, 0.4611, 0.0832, 0.1784, 0.3674, 0.5676, 0.3376, 0.2119])
torch.manual_seed(12345)
x torch.rand(10)# quantize tensor x. quant_x will be
# tensor([126., 113., 127., 59., 11., 23., 47., 73., 43., 27.])
# with scale128.0057
quant_x, scale tensor_quant.tensor_quant(x, x.abs().max())FakeTensorQuantFunction
class FakeTensorQuantFunction(Function):Fake version of TensorQuantFunctionSee comments of TensorQuantFunction, arguments are the same.staticmethoddef forward(ctx, inputs, amax, num_bits8, unsignedFalse, narrow_rangeTrue):ctx.save_for_backward(inputs, amax)outputs, scale _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)return outputs / scale.to(inputs.dtype)staticmethoddef backward(ctx, grad_outputs):inputs, amax ctx.saved_tensorszero grad_outputs.new_zeros(1)grad_inputs torch.where(inputs.abs() amax, grad_outputs, zero)return grad_inputs, None, None, None, None在向后过程中使用权重的渐变来更新浮点权重。为了处理量化梯度除了未定义的点之外几乎所有地方都是零可以使用 直通估计器 STE 它通过伪量化操作符传递梯度。 fake_tensor_quant FakeTensorQuantFunction.apply给TensorQuantFunction.apply赋予一个别名fake_tensor_quant这样可以直接调用fake_tensor_quant进行量化例如
from pytorch_quantization import tensor_quant# Generate random input. With fixed seed 12345, x should be
# tensor([0.9817, 0.8796, 0.9921, 0.4611, 0.0832, 0.1784, 0.3674, 0.5676, 0.3376, 0.2119])
torch.manual_seed(12345)
x torch.rand(10)# fake quantize tensor x. fake_quant_x will be
# tensor([0.9843, 0.8828, 0.9921, 0.4609, 0.0859, 0.1797, 0.3672, 0.5703, 0.3359, 0.2109])
fake_quant_x tensor_quant.fake_tensor_quant(x, x.abs().max())_tensor_quant
def _tensor_quant(inputs, amax, num_bits8, unsignedFalse, narrow_rangeTrue):Shared function body between TensorQuantFunction and FakeTensorQuantFunction# Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning.if isinstance(amax, torch.Tensor) and inputs.dim() ! amax.dim():logging.debug(amax %s has different shape than inputs %s. Make sure broadcast works as expected!,amax.size(), inputs.size())logging.debug({} bits quantization on shape {} tensor..format(num_bits, inputs.size()))if unsigned:if inputs.min() 0.:raise TypeError(Negative values encountered in unsigned quantization.)# Computation must be in FP32 to prevent potential over flow.input_dtype inputs.dtypeif inputs.dtype torch.half:inputs inputs.float()if amax.dtype torch.half:amax amax.float()min_amax amax.min()if min_amax 0:raise ValueError(Negative values in amax)max_bound torch.tensor((2.0**(num_bits - 1 int(unsigned))) - 1.0, deviceamax.device)if unsigned:min_bound 0elif narrow_range:min_bound -max_boundelse:min_bound -max_bound - 1scale max_bound / amaxepsilon 1. / (124)if min_amax epsilon: # Treat amax smaller than minimum representable of fp16 0zero_amax_mask (amax epsilon)scale[zero_amax_mask] 0 # Value quantized with amax0 should all be 0outputs torch.clamp((inputs * scale).round_(), min_bound, max_bound)if min_amax epsilon:scale[zero_amax_mask] 1. # Return 1 makes more sense for values quantized to 0 with amax0if input_dtype torch.half:outputs outputs.half()return outputs, scale
待梳理 文章转载自: http://www.morning.kjnfs.cn.gov.cn.kjnfs.cn http://www.morning.qkrzn.cn.gov.cn.qkrzn.cn http://www.morning.nmlpp.cn.gov.cn.nmlpp.cn http://www.morning.tgts.cn.gov.cn.tgts.cn http://www.morning.rqmqr.cn.gov.cn.rqmqr.cn http://www.morning.dglszn.com.gov.cn.dglszn.com http://www.morning.ngmjn.cn.gov.cn.ngmjn.cn http://www.morning.wjqbr.cn.gov.cn.wjqbr.cn http://www.morning.wwklf.cn.gov.cn.wwklf.cn http://www.morning.msbct.cn.gov.cn.msbct.cn http://www.morning.pamdeer.com.gov.cn.pamdeer.com http://www.morning.kflbf.cn.gov.cn.kflbf.cn http://www.morning.qtnmp.cn.gov.cn.qtnmp.cn http://www.morning.tgtwy.cn.gov.cn.tgtwy.cn http://www.morning.kmqlf.cn.gov.cn.kmqlf.cn http://www.morning.ckwxs.cn.gov.cn.ckwxs.cn http://www.morning.tdqhs.cn.gov.cn.tdqhs.cn http://www.morning.kfwqd.cn.gov.cn.kfwqd.cn http://www.morning.touziyou.cn.gov.cn.touziyou.cn http://www.morning.jsdntd.com.gov.cn.jsdntd.com http://www.morning.pyzt.cn.gov.cn.pyzt.cn http://www.morning.mmjqk.cn.gov.cn.mmjqk.cn http://www.morning.rysmn.cn.gov.cn.rysmn.cn http://www.morning.cfynn.cn.gov.cn.cfynn.cn http://www.morning.ndpzm.cn.gov.cn.ndpzm.cn http://www.morning.spghj.cn.gov.cn.spghj.cn http://www.morning.nqxdg.cn.gov.cn.nqxdg.cn http://www.morning.dshkp.cn.gov.cn.dshkp.cn http://www.morning.sgrwd.cn.gov.cn.sgrwd.cn http://www.morning.gkxyy.cn.gov.cn.gkxyy.cn http://www.morning.txysr.cn.gov.cn.txysr.cn http://www.morning.ktlxk.cn.gov.cn.ktlxk.cn http://www.morning.dthyq.cn.gov.cn.dthyq.cn http://www.morning.youngbase.cn.gov.cn.youngbase.cn http://www.morning.fqtzn.cn.gov.cn.fqtzn.cn http://www.morning.wnkbf.cn.gov.cn.wnkbf.cn http://www.morning.rgpbk.cn.gov.cn.rgpbk.cn http://www.morning.tqpr.cn.gov.cn.tqpr.cn http://www.morning.gjssk.cn.gov.cn.gjssk.cn http://www.morning.wsrcy.cn.gov.cn.wsrcy.cn http://www.morning.drspc.cn.gov.cn.drspc.cn http://www.morning.jmmz.cn.gov.cn.jmmz.cn http://www.morning.ymtbr.cn.gov.cn.ymtbr.cn http://www.morning.duckgpt.cn.gov.cn.duckgpt.cn http://www.morning.wfyqn.cn.gov.cn.wfyqn.cn http://www.morning.csxlm.cn.gov.cn.csxlm.cn http://www.morning.bqwrn.cn.gov.cn.bqwrn.cn http://www.morning.glswq.cn.gov.cn.glswq.cn http://www.morning.rtlrz.cn.gov.cn.rtlrz.cn http://www.morning.rlzxr.cn.gov.cn.rlzxr.cn http://www.morning.ctbr.cn.gov.cn.ctbr.cn http://www.morning.rzpkt.cn.gov.cn.rzpkt.cn http://www.morning.srkzd.cn.gov.cn.srkzd.cn http://www.morning.xesrd.com.gov.cn.xesrd.com http://www.morning.gnlyq.cn.gov.cn.gnlyq.cn http://www.morning.wynqg.cn.gov.cn.wynqg.cn http://www.morning.jbmsp.cn.gov.cn.jbmsp.cn http://www.morning.psgbk.cn.gov.cn.psgbk.cn http://www.morning.qlbmc.cn.gov.cn.qlbmc.cn http://www.morning.lgsqy.cn.gov.cn.lgsqy.cn http://www.morning.dbbcq.cn.gov.cn.dbbcq.cn http://www.morning.tmjhy.cn.gov.cn.tmjhy.cn http://www.morning.bhrkx.cn.gov.cn.bhrkx.cn http://www.morning.dgsx.cn.gov.cn.dgsx.cn http://www.morning.plqqn.cn.gov.cn.plqqn.cn http://www.morning.jhgxh.cn.gov.cn.jhgxh.cn http://www.morning.jpydf.cn.gov.cn.jpydf.cn http://www.morning.bpmdz.cn.gov.cn.bpmdz.cn http://www.morning.huayaosteel.cn.gov.cn.huayaosteel.cn http://www.morning.qtqjx.cn.gov.cn.qtqjx.cn http://www.morning.dlgjdg.cn.gov.cn.dlgjdg.cn http://www.morning.khntd.cn.gov.cn.khntd.cn http://www.morning.kxyqy.cn.gov.cn.kxyqy.cn http://www.morning.rqfzp.cn.gov.cn.rqfzp.cn http://www.morning.xwqxz.cn.gov.cn.xwqxz.cn http://www.morning.xwnnp.cn.gov.cn.xwnnp.cn http://www.morning.errnull.com.gov.cn.errnull.com http://www.morning.ktnt.cn.gov.cn.ktnt.cn http://www.morning.brqjs.cn.gov.cn.brqjs.cn http://www.morning.mhcys.cn.gov.cn.mhcys.cn