Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_evaluate.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
  1. """Pruebas unitarias para evaluate.py.
  2. Verifica funciones de evaluación de modelos, cálculo de métricas y guardado de resultados.
  3. Dependencias:
  4. - pytest: Para ejecutar pruebas.
  5. - pandas: Para manipulación de datos.
  6. - sklearn: Para métricas y reportes.
  7. - joblib: Para cargar modelos simulados.
  8. - unittest.mock: Para simular configuraciones y rutas.
  9. - pathlib: Para manejo de rutas.
  10. """
  11. import pytest
  12. import pandas as pd
  13. import numpy as np
  14. from pathlib import Path
  15. from unittest.mock import Mock, patch, mock_open
  16. from src.evaluate import (
  17. load_data_and_model,
  18. save_metrics,
  19. save_class_report
  20. )
  21. from omegaconf import DictConfig
  22. @pytest.fixture
  23. def config():
  24. """Crea una configuración simulada para pruebas."""
  25. return Mock(
  26. model_config=Mock(_name="rf_base"),
  27. process=Mock(
  28. target_classes=["Good", "Standard", "Poor"],
  29. target="Puntaje_Credito"
  30. )
  31. )
  32. @pytest.fixture
  33. def sample_data():
  34. """Crea datos de prueba simulados."""
  35. X_test = pd.DataFrame({
  36. "Edad": [25, 30],
  37. "Salario_Mensual": [5000.0, 6000.0]
  38. })
  39. y_test = np.array(["Good", "Standard"])
  40. X_train = pd.DataFrame({
  41. "Edad": [20, 28],
  42. "Salario_Mensual": [4500.0, 5500.0]
  43. })
  44. y_train = np.array(["Good", "Poor"])
  45. return X_test, y_test, X_train, y_train
  46. def test_load_data_and_model(tmp_path, config):
  47. """Verifica que load_data_and_model carga datos y modelo correctamente."""
  48. base_path = tmp_path
  49. (base_path / "data/processed").mkdir(parents=True)
  50. (base_path / "models/rf_base").mkdir(parents=True)
  51. X_test = pd.DataFrame({"Edad": [25, 30]})
  52. y_test = pd.DataFrame({"Puntaje_Credito": ["Good", "Standard"]})
  53. X_test.to_csv(base_path / "data/processed/X_test.csv", index=False)
  54. y_test.to_csv(base_path / "data/processed/y_test.csv", index=False)
  55. mock_model = Mock()
  56. with patch("joblib.load", return_value=mock_model), \
  57. patch("hydra.utils.get_original_cwd", return_value=str(base_path)):
  58. X_test_loaded, y_test_loaded, model_loaded = load_data_and_model(config)
  59. assert isinstance(X_test_loaded, pd.DataFrame)
  60. assert isinstance(y_test_loaded, np.ndarray)
  61. assert X_test_loaded.shape == (2, 1)
  62. assert y_test_loaded.tolist() == ["Good", "Standard"]
  63. assert model_loaded == mock_model
  64. def test_save_metrics(tmp_path, config):
  65. """Verifica que save_metrics guarda métricas en un archivo CSV."""
  66. metrics = {"accuracy": 0.85, "f1_macro": 0.80, "roc_auc": 0.90}
  67. base_path = tmp_path
  68. with patch("hydra.utils.get_original_cwd", return_value=str(base_path)):
  69. save_metrics(metrics, config)
  70. metrics_path = base_path / "metrics/rf_base/metrics.csv"
  71. assert metrics_path.exists()
  72. metrics_df = pd.read_csv(metrics_path)
  73. assert metrics_df.shape == (1, 3)
  74. assert metrics_df["accuracy"].iloc[0] == 0.85
  75. assert metrics_df["f1_macro"].iloc[0] == 0.80
  76. assert metrics_df["roc_auc"].iloc[0] == 0.90
  77. def test_save_class_report(tmp_path, sample_data, config):
  78. """Verifica que save_class_report guarda el reporte de clasificación."""
  79. _, y_test, _, _ = sample_data
  80. y_pred = np.array(["Good", "Standard"])
  81. base_path = tmp_path
  82. with patch("hydra.utils.get_original_cwd", return_value=str(base_path)), \
  83. patch("builtins.open", mock_open()) as mocked_file:
  84. save_class_report(y_test, y_pred, config)
  85. report_path = base_path / "metrics/rf_base/class_report_rf_base.txt"
  86. assert report_path.parent.exists()
  87. mocked_file.assert_called_once_with(report_path, "w")
  88. assert config.process.target_classes == ["Good", "Standard", "Poor"]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...