1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
- #!/usr/bin/env python3
- """
- Test script for the Iris Model API
- """
- import requests
- import json
- import time
- import subprocess
- import sys
- from multiprocessing import Process
- def start_server():
- """Start the FastAPI server in background"""
- subprocess.run([sys.executable, "app.py"])
- def test_api():
- """Test the API endpoints"""
- base_url = "http://localhost:8000"
- # Wait for server to start
- print("โณ Waiting for server to start...")
- time.sleep(3)
- try:
- # Test root endpoint
- print("๐งช Testing root endpoint...")
- response = requests.get(f"{base_url}/")
- print(f"โ
Root endpoint: {response.json()}")
- # Test prediction endpoint with different iris samples
- test_cases = [
- {
- "features": [5.1, 3.5, 1.4, 0.2],
- "expected_class": "Setosa",
- "description": "Typical Setosa sample",
- },
- {
- "features": [6.2, 2.9, 4.3, 1.3],
- "expected_class": "Versicolor",
- "description": "Typical Versicolor sample",
- },
- {
- "features": [7.3, 2.9, 6.3, 1.8],
- "expected_class": "Virginica",
- "description": "Typical Virginica sample",
- },
- ]
- class_names = ["Setosa", "Versicolor", "Virginica"]
- print("\n๐งช Testing prediction endpoint...")
- for i, test_case in enumerate(test_cases, 1):
- response = requests.post(
- f"{base_url}/predict",
- json={"features": test_case["features"]},
- headers={"Content-Type": "application/json"},
- )
- if response.status_code == 200:
- result = response.json()
- predicted_class = class_names[result["prediction"]]
- print(f"โ
Test {i}: {test_case['description']}")
- print(f" Features: {test_case['features']}")
- print(f" Predicted: {predicted_class} (class {result['prediction']})")
- print(f" Expected: {test_case['expected_class']}")
- print()
- else:
- print(f"โ Test {i} failed: {response.status_code} - {response.text}")
- # Test error handling
- print("๐งช Testing error handling...")
- response = requests.post(
- f"{base_url}/predict",
- json={"features": [1, 2]}, # Wrong number of features
- headers={"Content-Type": "application/json"},
- )
- if response.status_code != 200:
- print("โ
Error handling works correctly")
- else:
- print("โ ๏ธ API should reject invalid input")
- print("\n๐ All tests completed!")
- except requests.exceptions.ConnectionError:
- print("โ Could not connect to API. Make sure the server is running.")
- except Exception as e:
- print(f"โ Test failed: {str(e)}")
- if __name__ == "__main__":
- print("๐ Starting API Test Suite")
- print("=" * 40)
- # Check if server is already running
- try:
- response = requests.get("http://localhost:8000/", timeout=1)
- print("โ
Server is already running")
- test_api()
- except requests.exceptions.ConnectionError:
- print("๐ Starting server...")
- # Start server in background process
- server_process = Process(target=start_server)
- server_process.start()
- try:
- test_api()
- finally:
- print("๐ Stopping server...")
- server_process.terminate()
- server_process.join()
- print("\nโจ Test suite finished!")
|