Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
4aa5e0d91e | ||
30bf6c9147 |
@ -15,6 +15,7 @@ from PIL import Image
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
from torch.profiler import profile, record_function, ProfilerActivity
|
||||||
|
|
||||||
def load_image(imfile):
|
def load_image(imfile):
|
||||||
img = np.array(Image.open(imfile)).astype(np.uint8)
|
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||||
@ -44,8 +45,14 @@ def demo(args):
|
|||||||
|
|
||||||
padder = InputPadder(image1.shape, divis_by=32)
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
image1, image2 = padder.pad(image1, image2)
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
with profile(
|
||||||
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
with_stack=True,
|
||||||
|
) as prof:
|
||||||
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
|
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
|
||||||
|
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=10))
|
||||||
|
prof.export_chrome_trace("./trace.json")
|
||||||
|
# disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
|
||||||
disp = disp.cpu().numpy()
|
disp = disp.cpu().numpy()
|
||||||
disp = padder.unpad(disp)
|
disp = padder.unpad(disp)
|
||||||
file_stem = imfile1.split('/')[-2]
|
file_stem = imfile1.split('/')[-2]
|
||||||
@ -81,6 +88,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user