Update demo_imgs.py

This commit is contained in:
HTensor 2023-04-29 23:18:22 +08:00
parent a1cc25351d
commit 30bf6c9147

View File

@ -15,6 +15,7 @@ from PIL import Image
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import os import os
import cv2 import cv2
from torch.profiler import profile, record_function, ProfilerActivity
def load_image(imfile): def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8) img = np.array(Image.open(imfile)).astype(np.uint8)
@ -44,8 +45,14 @@ def demo(args):
padder = InputPadder(image1.shape, divis_by=32) padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2) image1, image2 = padder.pad(image1, image2)
with profile(
disp = model(image1, image2, iters=args.valid_iters, test_mode=True) activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_stack=True,
) as prof:
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=10))
prof.export_chrome_trace("./trace.json")
# disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
disp = disp.cpu().numpy() disp = disp.cpu().numpy()
disp = padder.unpad(disp) disp = padder.unpad(disp)
file_stem = imfile1.split('/')[-2] file_stem = imfile1.split('/')[-2]