Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_threshold_layer.cpp 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
  1. #include <vector>
  2. #include "gtest/gtest.h"
  3. #include "caffe/blob.hpp"
  4. #include "caffe/common.hpp"
  5. #include "caffe/filler.hpp"
  6. #include "caffe/layers/threshold_layer.hpp"
  7. #include "caffe/test/test_caffe_main.hpp"
  8. namespace caffe {
  9. template <typename TypeParam>
  10. class ThresholdLayerTest : public MultiDeviceTest<TypeParam> {
  11. typedef typename TypeParam::Dtype Dtype;
  12. protected:
  13. ThresholdLayerTest()
  14. : blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
  15. blob_top_(new Blob<Dtype>()) {
  16. Caffe::set_random_seed(1701);
  17. // fill the values
  18. FillerParameter filler_param;
  19. GaussianFiller<Dtype> filler(filler_param);
  20. filler.Fill(this->blob_bottom_);
  21. blob_bottom_vec_.push_back(blob_bottom_);
  22. blob_top_vec_.push_back(blob_top_);
  23. }
  24. virtual ~ThresholdLayerTest() { delete blob_bottom_; delete blob_top_; }
  25. Blob<Dtype>* const blob_bottom_;
  26. Blob<Dtype>* const blob_top_;
  27. vector<Blob<Dtype>*> blob_bottom_vec_;
  28. vector<Blob<Dtype>*> blob_top_vec_;
  29. };
  30. TYPED_TEST_CASE(ThresholdLayerTest, TestDtypesAndDevices);
  31. TYPED_TEST(ThresholdLayerTest, TestSetup) {
  32. typedef typename TypeParam::Dtype Dtype;
  33. LayerParameter layer_param;
  34. ThresholdLayer<Dtype> layer(layer_param);
  35. layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
  36. EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
  37. EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
  38. EXPECT_EQ(this->blob_top_->height(), this->blob_bottom_->height());
  39. EXPECT_EQ(this->blob_top_->width(), this->blob_bottom_->width());
  40. }
  41. TYPED_TEST(ThresholdLayerTest, Test) {
  42. typedef typename TypeParam::Dtype Dtype;
  43. LayerParameter layer_param;
  44. ThresholdLayer<Dtype> layer(layer_param);
  45. layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
  46. layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
  47. // Now, check values
  48. const Dtype* bottom_data = this->blob_bottom_->cpu_data();
  49. const Dtype* top_data = this->blob_top_->cpu_data();
  50. const Dtype threshold_ = layer_param.threshold_param().threshold();
  51. for (int i = 0; i < this->blob_bottom_->count(); ++i) {
  52. EXPECT_GE(top_data[i], 0.);
  53. EXPECT_LE(top_data[i], 1.);
  54. if (top_data[i] == 0) {
  55. EXPECT_LE(bottom_data[i], threshold_);
  56. }
  57. if (top_data[i] == 1) {
  58. EXPECT_GT(bottom_data[i], threshold_);
  59. }
  60. }
  61. }
  62. TYPED_TEST(ThresholdLayerTest, Test2) {
  63. typedef typename TypeParam::Dtype Dtype;
  64. LayerParameter layer_param;
  65. ThresholdParameter* threshold_param =
  66. layer_param.mutable_threshold_param();
  67. threshold_param->set_threshold(0.5);
  68. ThresholdLayer<Dtype> layer(layer_param);
  69. layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
  70. layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
  71. // Now, check values
  72. const Dtype* bottom_data = this->blob_bottom_->cpu_data();
  73. const Dtype* top_data = this->blob_top_->cpu_data();
  74. const Dtype threshold_ = layer_param.threshold_param().threshold();
  75. EXPECT_FLOAT_EQ(threshold_, 0.5);
  76. for (int i = 0; i < this->blob_bottom_->count(); ++i) {
  77. EXPECT_GE(top_data[i], 0.);
  78. EXPECT_LE(top_data[i], 1.);
  79. if (top_data[i] == 0) {
  80. EXPECT_LE(bottom_data[i], threshold_);
  81. }
  82. if (top_data[i] == 1) {
  83. EXPECT_GT(bottom_data[i], threshold_);
  84. }
  85. }
  86. }
  87. } // namespace caffe
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...