Skip to content

斑马化模型 Pytorch模型导出ONXX以及ONXX推理

将一个把马变成斑马的网络导出为ONXX格式,并使用ONXX格式进行推理

GitHub: Horse-to-Zebra

1. 模型导出

import torch.onnx  

from Zebraizer import ResNetGenerator  


def Convert_ONNX(model, dummy_input=None, output_path=None):  
    model.eval()  
    torch.onnx.export(  
        model,  # model being run  
        dummy_input,  # model input (or a tuple for multiple inputs)  
        f=output_path,  # where to save the model  
        export_params=True,  # store the trained parameter weights inside the model file  
        opset_version=10,  # the ONNX version to export the model to  
        do_constant_folding=True,  # whether to execute constant folding for optimization  
        input_names=['modelInput'],  # the model's input names  
        output_names=['modelOutput'],  # the model's output names  
        dynamic_axes={  
            'modelInput': {  
                0: 'batch_size',  
                2: 'height',  
                3: 'width'  
            },  # 可变长度轴  
            'modelOutput': {  
                0: 'batch_size',  
                2: 'height',  
                3: 'width'  
            }  
        }  # 可变长度轴)  
    )  
    print(" ")  
    print('Model has been converted to ONNX')  


if __name__ == "__main__":  
    netG = ResNetGenerator()  
    model_path = "D:\workspace\dlwpt-code\data\p1ch2\horse2zebra_0.4.0.pth"  # 加载模型权重  
    output_path = "D:\workspace\dlwpt-code\zevraizer\horse2zebra_0.4.0.onnx"  
    model_data = torch.load(model_path)  
    netG.load_state_dict(model_data)  
    torch_model = netG  
    dummy_input = torch.randn([1, 3, 250, 250], requires_grad=False)  
    Convert_ONNX(model=netG, dummy_input=dummy_input, output_path=output_path)

2. ONXX 推理

这里的一些维度的转换,使用torch 和 numpy 直接的互操作是更方便的,但是还是尽可能的使用numpy 来实现

import numpy as np  
import onnxruntime  
from PIL import Image  
from torchvision import transforms  

onxx_path = r"D:\workspace\dlwpt-code\zevraizer\horse2zebra_0.4.0.onnx"  

# Load the ONNX model  
ort_session = onnxruntime.InferenceSession(  
    onxx_path,  
    providers=[  
        'CUDAExecutionProvider',  # 需要CuDNN  
        'CPUExecutionProvider'  
    ]  
)  

# Load the input image  
preprocess = transforms.Compose([  
    transforms.Resize([550, 550]),  # 非必须,模型图片输入的宽和高已经动态化  
    # transforms.ToTensor()  
]  
)  

img = Image.open("data/p1ch2/horse.jpg")  

img_t = preprocess(img)  
# 使用numpy array  
img_t = np.array(img_t).astype(np.float32)  # H,W,C  
img_t = img_t.transpose(2, 0, 1)  # H,W,C -> C,H,W  
batch_t = np.expand_dims(img_t, axis=0)  # B,C,H,W  
batch_t = batch_t / 255.0  # 归一化  

# Run the model on the input image  
ort_inputs = {ort_session.get_inputs()[0].name: batch_t}  

ort_outs = ort_session.run(None, ort_inputs)  

# Postprocess the output  
img_out_t2 = ort_outs[0][0]  # [(B,C,H,W),....]  -> C,H,W  
img_out_t = np.transpose(img_out_t2, (1, 2, 0))  # C,H,W -> H,W,C  
out_t = (img_out_t.squeeze() + 1.0) / 2.0  # [-1,1] -> [0,1]  

# Display the result  
out_img = transforms.ToPILImage()(out_t)  
out_img.show()