上海做网站的公司有哪些,网站开发报告步骤分析,小程序,太原企业网站制作公司很多时候嵌入式或者新硬件需要纯净的权重模型和激活值#xff08;运行时中间值#xff09;#xff0c;本文提供一种最简洁的方法。 假设已经有模型model和pt文件了#xff0c;在当前目录下新建weights文件夹#xff0c;运行这段代码#xff0c;就可以得到模型的权重…很多时候嵌入式或者新硬件需要纯净的权重模型和激活值运行时中间值本文提供一种最简洁的方法。 假设已经有模型model和pt文件了在当前目录下新建weights文件夹运行这段代码就可以得到模型的权重文本形式和二进制形式
model.load_state_dict(state_dict)global_index 0
for name, param in model.named_parameters():print(name, param.size())print(param.data.numpy(),fileopen(fweights/{global_index}-{name}.txt, w))param.data.numpy().tofile(fweights/{global_index}-{name}.bin)global_index 1对于二进制形式的文件可以通过od -t f4 binary file name 查看其对应的浮点数值。f4表示fp32.
打印forward的中间值这么复杂是必要的
global_index 0
def hook_fn(module, input, output):global global_indexmodule_name str(module)module_namemodule_name.replace( , )module_namemodule_name.replace(\n, )# print(name)intermediate_outputs {}# input is a tuple, output is a tensorfor i, inp in enumerate(input):intermediate_outputs[f{global_index}-{module_name}-input-{i}] inpintermediate_outputs[f{global_index}-{module_name}-output] outputmodule_name module_name[0:200] # make sure full path 255print(intermediate_outputs)print(fSize input:,end )if(type(input) tuple):for i, inp in enumerate(input):if type(inp) torch.Tensor:print(f{i}-th Size: {inp.size()}, end, )inp.numpy().tofile(factivations/{global_index}-{module_name}-input-{i}.bin)else:print(f{i}-th : {inp}, end, )elif type(input) torch.Tensor:print(fSize: {input.size()})input.numpy().tofile(factivations/{global_index}-{module_name}-input.bin)print(fSize output: {output.size()})global_index 1output.numpy().tofile(factivations/{global_index}-{module_name}-output.bin)def register_hooks(model):for name, layer in model.named_children():# print(name, layer) # dump all layers, layers.txt# Register the hook to the current layerlayer.register_forward_hook(hook_fn)# Recursively apply the same to all submodulesregister_hooks(layer)register_hooks(model)
其中regster_hooks和以下等价不需要recursive了
def register_hooks(model):for name, layer in model.named_modules():# print(name, layer) # dump all layerslayer.register_forward_hook(hook_fn)其中nn.sequential作为一个整体目前没办法拆开来看其内部的中间值。