Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

XSegNet.py 3.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  1. import os
  2. import pickle
  3. from functools import partial
  4. from pathlib import Path
  5. import cv2
  6. import numpy as np
  7. from core.interact import interact as io
  8. from core.leras import nn
  9. class XSegNet(object):
  10. VERSION = 1
  11. def __init__ (self, name,
  12. resolution=256,
  13. load_weights=True,
  14. weights_file_root=None,
  15. training=False,
  16. place_model_on_cpu=False,
  17. run_on_cpu=False,
  18. optimizer=None,
  19. data_format="NHWC",
  20. raise_on_no_model_files=False):
  21. self.resolution = resolution
  22. self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent
  23. nn.initialize(data_format=data_format)
  24. tf = nn.tf
  25. model_name = f'{name}_{resolution}'
  26. self.model_filename_list = []
  27. with tf.device ('/CPU:0'):
  28. #Place holders on CPU
  29. self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
  30. self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )
  31. # Initializing model classes
  32. with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name):
  33. self.model = nn.XSeg(3, 32, 1, name=name)
  34. self.model_weights = self.model.get_weights()
  35. if training:
  36. if optimizer is None:
  37. raise ValueError("Optimizer should be provided for training mode.")
  38. self.opt = optimizer
  39. self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
  40. self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
  41. self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
  42. if not training:
  43. with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name):
  44. _, pred = self.model(self.input_t)
  45. def net_run(input_np):
  46. return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
  47. self.net_run = net_run
  48. self.initialized = True
  49. # Loading/initializing all models/optimizers weights
  50. for model, filename in self.model_filename_list:
  51. do_init = not load_weights
  52. if not do_init:
  53. model_file_path = self.weights_file_root / filename
  54. do_init = not model.load_weights( model_file_path )
  55. if do_init:
  56. if raise_on_no_model_files:
  57. raise Exception(f'{model_file_path} does not exists.')
  58. if not training:
  59. self.initialized = False
  60. break
  61. if do_init:
  62. model.init_weights()
  63. def get_resolution(self):
  64. return self.resolution
  65. def flow(self, x, pretrain=False):
  66. return self.model(x, pretrain=pretrain)
  67. def get_weights(self):
  68. return self.model_weights
  69. def save_weights(self):
  70. for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False):
  71. model.save_weights( self.weights_file_root / filename )
  72. def extract (self, input_image):
  73. if not self.initialized:
  74. return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
  75. input_shape_len = len(input_image.shape)
  76. if input_shape_len == 3:
  77. input_image = input_image[None,...]
  78. result = np.clip ( self.net_run(input_image), 0, 1.0 )
  79. result[result < 0.1] = 0 #get rid of noise
  80. if input_shape_len == 3:
  81. result = result[0]
  82. return result
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...