Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_api.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  1. #!/usr/bin/env python3
  2. """
  3. Test script for the Iris Model API
  4. """
  5. import requests
  6. import json
  7. import time
  8. import subprocess
  9. import sys
  10. from multiprocessing import Process
  11. def start_server():
  12. """Start the FastAPI server in background"""
  13. subprocess.run([sys.executable, "app.py"])
  14. def test_api():
  15. """Test the API endpoints"""
  16. base_url = "http://localhost:8000"
  17. # Wait for server to start
  18. print("โณ Waiting for server to start...")
  19. time.sleep(3)
  20. try:
  21. # Test root endpoint
  22. print("๐Ÿงช Testing root endpoint...")
  23. response = requests.get(f"{base_url}/")
  24. print(f"โœ… Root endpoint: {response.json()}")
  25. # Test prediction endpoint with different iris samples
  26. test_cases = [
  27. {
  28. "features": [5.1, 3.5, 1.4, 0.2],
  29. "expected_class": "Setosa",
  30. "description": "Typical Setosa sample",
  31. },
  32. {
  33. "features": [6.2, 2.9, 4.3, 1.3],
  34. "expected_class": "Versicolor",
  35. "description": "Typical Versicolor sample",
  36. },
  37. {
  38. "features": [7.3, 2.9, 6.3, 1.8],
  39. "expected_class": "Virginica",
  40. "description": "Typical Virginica sample",
  41. },
  42. ]
  43. class_names = ["Setosa", "Versicolor", "Virginica"]
  44. print("\n๐Ÿงช Testing prediction endpoint...")
  45. for i, test_case in enumerate(test_cases, 1):
  46. response = requests.post(
  47. f"{base_url}/predict",
  48. json={"features": test_case["features"]},
  49. headers={"Content-Type": "application/json"},
  50. )
  51. if response.status_code == 200:
  52. result = response.json()
  53. predicted_class = class_names[result["prediction"]]
  54. print(f"โœ… Test {i}: {test_case['description']}")
  55. print(f" Features: {test_case['features']}")
  56. print(f" Predicted: {predicted_class} (class {result['prediction']})")
  57. print(f" Expected: {test_case['expected_class']}")
  58. print()
  59. else:
  60. print(f"โŒ Test {i} failed: {response.status_code} - {response.text}")
  61. # Test error handling
  62. print("๐Ÿงช Testing error handling...")
  63. response = requests.post(
  64. f"{base_url}/predict",
  65. json={"features": [1, 2]}, # Wrong number of features
  66. headers={"Content-Type": "application/json"},
  67. )
  68. if response.status_code != 200:
  69. print("โœ… Error handling works correctly")
  70. else:
  71. print("โš ๏ธ API should reject invalid input")
  72. print("\n๐ŸŽ‰ All tests completed!")
  73. except requests.exceptions.ConnectionError:
  74. print("โŒ Could not connect to API. Make sure the server is running.")
  75. except Exception as e:
  76. print(f"โŒ Test failed: {str(e)}")
  77. if __name__ == "__main__":
  78. print("๐Ÿš€ Starting API Test Suite")
  79. print("=" * 40)
  80. # Check if server is already running
  81. try:
  82. response = requests.get("http://localhost:8000/", timeout=1)
  83. print("โœ… Server is already running")
  84. test_api()
  85. except requests.exceptions.ConnectionError:
  86. print("๐Ÿ”„ Starting server...")
  87. # Start server in background process
  88. server_process = Process(target=start_server)
  89. server_process.start()
  90. try:
  91. test_api()
  92. finally:
  93. print("๐Ÿ›‘ Stopping server...")
  94. server_process.terminate()
  95. server_process.join()
  96. print("\nโœจ Test suite finished!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...