斑马化模型 Pytorch模型导出ONXX以及ONXX推理
将一个把马变成斑马的网络导出为ONXX格式,并使用ONXX格式进行推理
GitHub: Horse-to-Zebra
1. 模型导出
import torch.onnx
from Zebraizer import ResNetGenerator
def Convert_ONNX(model, dummy_input=None, output_path=None):
model.eval()
torch.onnx.export(
model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
f=output_path, # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['modelInput'], # the model's input names
output_names=['modelOutput'], # the model's output names
dynamic_axes={
'modelInput': {
0: 'batch_size',
2: 'height',
3: 'width'
}, # 可变长度轴
'modelOutput': {
0: 'batch_size',
2: 'height',
3: 'width'
}
} # 可变长度轴)
)
print(" ")
print('Model has been converted to ONNX')
if __name__ == "__main__":
netG = ResNetGenerator()
model_path = "D:\workspace\dlwpt-code\data\p1ch2\horse2zebra_0.4.0.pth" # 加载模型权重
output_path = "D:\workspace\dlwpt-code\zevraizer\horse2zebra_0.4.0.onnx"
model_data = torch.load(model_path)
netG.load_state_dict(model_data)
torch_model = netG
dummy_input = torch.randn([1, 3, 250, 250], requires_grad=False)
Convert_ONNX(model=netG, dummy_input=dummy_input, output_path=output_path)
2. ONXX 推理
这里的一些维度的转换,使用torch 和 numpy 直接的互操作是更方便的,但是还是尽可能的使用numpy 来实现
import numpy as np
import onnxruntime
from PIL import Image
from torchvision import transforms
onxx_path = r"D:\workspace\dlwpt-code\zevraizer\horse2zebra_0.4.0.onnx"
# Load the ONNX model
ort_session = onnxruntime.InferenceSession(
onxx_path,
providers=[
'CUDAExecutionProvider', # 需要CuDNN
'CPUExecutionProvider'
]
)
# Load the input image
preprocess = transforms.Compose([
transforms.Resize([550, 550]), # 非必须,模型图片输入的宽和高已经动态化
# transforms.ToTensor()
]
)
img = Image.open("data/p1ch2/horse.jpg")
img_t = preprocess(img)
# 使用numpy array
img_t = np.array(img_t).astype(np.float32) # H,W,C
img_t = img_t.transpose(2, 0, 1) # H,W,C -> C,H,W
batch_t = np.expand_dims(img_t, axis=0) # B,C,H,W
batch_t = batch_t / 255.0 # 归一化
# Run the model on the input image
ort_inputs = {ort_session.get_inputs()[0].name: batch_t}
ort_outs = ort_session.run(None, ort_inputs)
# Postprocess the output
img_out_t2 = ort_outs[0][0] # [(B,C,H,W),....] -> C,H,W
img_out_t = np.transpose(img_out_t2, (1, 2, 0)) # C,H,W -> H,W,C
out_t = (img_out_t.squeeze() + 1.0) / 2.0 # [-1,1] -> [0,1]
# Display the result
out_img = transforms.ToPILImage()(out_t)
out_img.show()