Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. import time
  2. import scipy.misc as m
  3. import numpy as np
  4. import cv2
  5. import torch
  6. import torchvision.utils as vutils
  7. import argparse
  8. from tqdm import *
  9. from model.spade_model import SpadeModel
  10. from opt.configTrain import TrainOptions
  11. from loader.dataset_loader_demo import DatasetLoaderDemo
  12. from fusion.affineFace import *
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('--pose_path', type=str, default='data/poseGuide/imgs', help='path to pose guide images')
  15. parser.add_argument('--ref_path', type=str, default='data/reference/imgs', help='path to appearance/reference images')
  16. parser.add_argument('--pose_lms', type=str, default='data/poseGuide/lms_poseGuide.out', help='path to pose guide landmark file')
  17. parser.add_argument('--ref_lms', type=str, default='data/reference/lms_ref.out', help='path to reference landmark file')
  18. args = parser.parse_args()
  19. if __name__ == '__main__':
  20. trainConfig = TrainOptions()
  21. opt = trainConfig.get_config() # namespace of arguments
  22. # init test dataset
  23. dataset = DatasetLoaderDemo(gaze=(opt.input_nc == 9), imgSize=256)
  24. root = args.pose_path # root to pose guide img
  25. path_Appears = args.pose_lms.format(root) # root to pose guide dir&landmark
  26. dataset.loadBounds([path_Appears], head='{}/'.format(root))
  27. root = args.ref_path # root to reference img
  28. path_Appears = args.ref_lms.format(root) # root to reference dir&landmark
  29. dataset.loadAppears([path_Appears], '{}/'.format(root))
  30. dataset.setAppearRule('sequence')
  31. # dataloader
  32. data_loader = torch.utils.data.DataLoader(dataset=dataset,
  33. batch_size=opt.batch_size,
  34. shuffle=False,
  35. num_workers=12, drop_last=False)
  36. print('dataset size: {}\n'.format(dataset.shape()))
  37. # output sequence: ref1-pose1, ref1-pose2, ref1-pose3, ... ref2-pose1, ref2-pose2, ref2-pose3, ...
  38. boundNew = []
  39. appNew = []
  40. for aa in dataset.appearList:
  41. for bb in dataset.boundList:
  42. boundNew.append(bb)
  43. appNew.append(aa)
  44. dataset.boundList = boundNew
  45. dataset.appearList = appNew
  46. model = SpadeModel(opt) # define model
  47. model.setup(opt) # initilize schedules (if isTrain), load pretrained models
  48. model.set_logger(opt) # set writer to runs/test_res
  49. model.eval()
  50. iter_start_time = time.time()
  51. cnt = 1
  52. with torch.no_grad():
  53. for step, data in tqdm(enumerate(data_loader)):
  54. model.set_input(data) # set device for data
  55. model.forward()
  56. # fusionNet
  57. for i in range(data['img_src'].shape[0]):
  58. img_gen = model.fake_B.cpu().numpy()[i].transpose(1, 2, 0)
  59. img_gen = (img_gen * 0.5 + 0.5) * 255.0
  60. img_gen = img_gen.astype(np.uint8)
  61. img_gen = dataset.gammaTrans(img_gen, 2.0) # model output image, 256*256*3
  62. # cv2.imwrite('output_noFusion/{}.jpg'.format(cnt), img_gen)
  63. lms_gen = data['pt_dst'].cpu().numpy()[i] / 255.0 # [146, 2]
  64. img_ref = data['img_src_np'].cpu().numpy()[i]
  65. lms_ref = data['pt_src'].cpu().numpy()[i] / 255.0
  66. lms_ref_parts, img_ref_parts = affineface_parts(img_ref, lms_ref, lms_gen)
  67. # fusion
  68. fuse_parts, seg_ref_parts, seg_gen = fusion(img_ref_parts, lms_ref_parts, img_gen, lms_gen, 0.1)
  69. fuse_eye, mask_eye, img_eye = lightEye(img_ref, lms_ref, fuse_parts, lms_gen, 0.1)
  70. # res = np.hstack([img_ref, img_pose, img_gen, fuse_eye])
  71. cv2.imwrite('output/{}.jpg'.format(cnt), fuse_eye)
  72. cnt += 1
  73. iter_end_time = time.time()
  74. print('length of dataset:', len(dataset))
  75. print('time per img: ', (iter_end_time - iter_start_time) / len(dataset))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...