Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_train.py 6.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  1. """Pruebas unitarias para train.py.
  2. Verifica funciones de carga de datos, construcción del pipeline, cálculo de métricas,
  3. guardado de la matriz de confusión y gráfica de métricas.
  4. Dependencias:
  5. - pytest: Para ejecutar pruebas.
  6. - pandas: Para manipulación de datos.
  7. - sklearn: Para pipeline, modelos y métricas.
  8. - unittest.mock: Para simular configuraciones.
  9. - pathlib: Para manejo de rutas.
  10. """
  11. import sys
  12. from pathlib import Path
  13. from unittest.mock import Mock, patch
  14. # Añadir el directorio raíz al sys.path
  15. ROOT_DIR = Path(__file__).parent.parent
  16. sys.path.insert(0, str(ROOT_DIR / "src"))
  17. import pytest
  18. import pandas as pd
  19. import numpy as np
  20. from omegaconf import OmegaConf
  21. from sklearn.pipeline import Pipeline
  22. from sklearn.preprocessing import StandardScaler
  23. from sklearn.ensemble import RandomForestClassifier
  24. from src.train import load_data, build_pipeline, save_confusion_matrix, save_metrics_bar
  25. from src.utils import compute_metrics
  26. import matplotlib
  27. matplotlib.use('Agg') # Usar backend no interactivo para evitar errores de Tkinter
  28. # Ruta base para datos simulados
  29. BASE_PATH = ROOT_DIR / "tests" / "data"
  30. @pytest.fixture
  31. def config(tmp_path):
  32. """Crea una configuración simulada para pruebas."""
  33. search_space = OmegaConf.create({"n_estimators": [100, 200], "max_depth": [None, 10, 20]})
  34. cv_config = Mock(folds=5, scoring="f1_macro")
  35. cv_config.__str__ = lambda self: "f1_macro"
  36. params_mock = Mock(random_state=42) # Mock para params con atributo random_state
  37. params_mock.__getitem__ = lambda self, key: {"random_state": 42}[key] # Simula dict para **params
  38. params_mock.keys = lambda: ["random_state"] # Necesario para **params
  39. return Mock(
  40. processed=Mock(
  41. X_train=Mock(path=str(BASE_PATH / "processed" / "X_train.csv")),
  42. X_test=Mock(path=str(BASE_PATH / "processed" / "X_test.csv")),
  43. y_train=Mock(path=str(BASE_PATH / "processed" / "y_train.csv")),
  44. y_test=Mock(path=str(BASE_PATH / "processed" / "y_test.csv"))
  45. ),
  46. model_config=Mock(
  47. _name="model_1",
  48. params=params_mock,
  49. search_space=search_space,
  50. cv=cv_config,
  51. optimization=Mock(n_iter=10),
  52. metrics=["accuracy", "f1_macro", "f1_per_class"]
  53. ),
  54. model=Mock(dir=str(tmp_path / "models"), name="rf_model.pkl", params_name="params.json"),
  55. process=Mock(
  56. features=["Salario_Mensual", "Deuda_Pendiente", "Edad", "debt_to_income"],
  57. target="Puntaje_Credito",
  58. target_classes=["Good", "Standard"]
  59. ),
  60. mlflow=Mock(tracking_uri="https://dagshub.com/JorgeDataScientist/MLOps_CreditScore.mlflow")
  61. )
  62. @pytest.fixture
  63. def processed_data():
  64. """Carga datos procesados simulados y filtra clases."""
  65. X_train = pd.read_csv(BASE_PATH / "processed" / "X_train.csv")
  66. X_test = pd.read_csv(BASE_PATH / "processed" / "X_test.csv")
  67. y_train = pd.read_csv(BASE_PATH / "processed" / "y_train.csv")
  68. y_test = pd.read_csv(BASE_PATH / "processed" / "y_test.csv")
  69. valid_classes = ["Good", "Standard"]
  70. y_train = y_train[y_train["Puntaje_Credito"].isin(valid_classes)]["Puntaje_Credito"].values
  71. y_test = y_test[y_test["Puntaje_Credito"].isin(valid_classes)]["Puntaje_Credito"].values
  72. train_mask = pd.read_csv(BASE_PATH / "processed" / "y_train.csv")["Puntaje_Credito"].isin(valid_classes)
  73. test_mask = pd.read_csv(BASE_PATH / "processed" / "y_test.csv")["Puntaje_Credito"].isin(valid_classes)
  74. X_train = X_train[train_mask]
  75. X_test = X_test[test_mask]
  76. return X_train, X_test, y_train, y_test
  77. def test_load_data(config, tmp_path):
  78. """Verifica que load_data carga datos correctamente."""
  79. with patch("src.train.get_original_cwd", return_value=str(tmp_path)):
  80. X_train, X_test, y_train, y_test = load_data(config)
  81. assert isinstance(X_train, pd.DataFrame)
  82. assert isinstance(X_test, pd.DataFrame)
  83. assert isinstance(y_train, np.ndarray)
  84. assert isinstance(y_test, np.ndarray)
  85. assert len(y_train) == len(X_train)
  86. assert len(y_test) == len(X_test)
  87. def test_build_pipeline(config):
  88. """Verifica que build_pipeline crea un Pipeline correcto."""
  89. pipeline = build_pipeline(config)
  90. assert isinstance(pipeline, Pipeline)
  91. assert isinstance(pipeline.named_steps["scaler"], StandardScaler)
  92. assert isinstance(pipeline.named_steps["model"], RandomForestClassifier)
  93. def test_compute_metrics(processed_data, config):
  94. """Verifica que compute_metrics calcula métricas correctamente."""
  95. X_train, X_test, y_train, y_test = processed_data
  96. model = RandomForestClassifier(random_state=42).fit(X_train, y_train)
  97. y_pred = model.predict(X_test)
  98. metrics = compute_metrics(y_test, y_pred, config)
  99. assert isinstance(metrics, dict)
  100. assert "test_accuracy" in metrics
  101. assert "test_f1_macro" in metrics
  102. assert all(key in metrics for key in ["test_f1_good", "test_f1_standard"])
  103. def test_save_confusion_matrix(tmp_path, processed_data, config):
  104. """Verifica que save_confusion_matrix guarda la gráfica correctamente."""
  105. X_train, X_test, y_train, y_test = processed_data
  106. model = RandomForestClassifier(random_state=42).fit(X_train, y_train)
  107. y_pred = model.predict(X_test)
  108. config.model_config._name = "model_1"
  109. with patch("src.train.Path", return_value=tmp_path):
  110. with patch("src.train.get_original_cwd", return_value=str(tmp_path)):
  111. save_confusion_matrix(y_test, y_pred, config)
  112. assert (tmp_path / "graphics" / "model_1" / "confusion_matrix.png").exists()
  113. def test_save_metrics_bar(tmp_path, processed_data, config):
  114. """Verifica que save_metrics_bar guarda la gráfica correctamente."""
  115. X_train, X_test, y_train, y_test = processed_data
  116. model = RandomForestClassifier(random_state=42).fit(X_train, y_train)
  117. y_pred = model.predict(X_test)
  118. metrics = compute_metrics(y_test, y_pred, config)
  119. config.model_config._name = "model_1"
  120. with patch("src.train.Path", return_value=tmp_path):
  121. with patch("src.train.get_original_cwd", return_value=str(tmp_path)):
  122. save_metrics_bar(metrics, config)
  123. assert (tmp_path / "graphics" / "model_1" / "metrics_bar.png").exists()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...