Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

block_transformers.py 23 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
  1. from abc import ABC, abstractmethod
  2. import numpy as np
  3. from collections import defaultdict
  4. from sklearn.ensemble import BaseEnsemble
  5. from sklearn.ensemble._forest import _generate_unsampled_indices, _generate_sample_indices
  6. from .local_stumps import make_stumps, tree_feature_transform
  7. class BlockPartitionedData:
  8. """
  9. Abstraction for a feature matrix in which the columns are grouped into
  10. blocks.
  11. Parameters
  12. ----------
  13. data_blocks: list of ndarray
  14. Blocks of feature columns
  15. common_block: ndarray
  16. A set of feature columns that should be common to all blocks
  17. """
  18. def __init__(self, data_blocks, common_block=None):
  19. self.n_blocks = len(data_blocks)
  20. self.n_samples = data_blocks[0].shape[0]
  21. self._data_blocks = data_blocks
  22. self._common_block = common_block
  23. self._create_block_indices()
  24. self._means = [np.mean(data_block, axis=0) for data_block in
  25. self._data_blocks]
  26. def get_all_data(self):
  27. """
  28. Returns
  29. -------
  30. all_data: ndarray
  31. Returns the data matrix obtained by concatenating all feature
  32. blocks together
  33. """
  34. if self._common_block is None:
  35. all_data = np.hstack(self._data_blocks)
  36. else:
  37. all_data = np.hstack(self._data_blocks + [self._common_block])
  38. # Common block appended at the end
  39. return all_data
  40. def _create_block_indices(self):
  41. self._block_indices_dict = dict({})
  42. start_index = 0
  43. for k in range(self.n_blocks):
  44. stop_index = start_index + self._data_blocks[k].shape[1]
  45. self._block_indices_dict[k] = list(range(start_index, stop_index))
  46. start_index = stop_index
  47. if self._common_block is None:
  48. self._common_block_indices = []
  49. else:
  50. stop_index = start_index + self._common_block.shape[1]
  51. self._common_block_indices = list(range(start_index, stop_index))
  52. def get_block_indices(self, k):
  53. """
  54. Parameters
  55. ----------
  56. k: int
  57. The index of the feature block desired
  58. Returns
  59. -------
  60. block_indices: list of int
  61. The indices of the features in the desired block
  62. """
  63. block_indices = self._common_block_indices + self._block_indices_dict[k]
  64. return block_indices
  65. def get_block(self, k):
  66. """
  67. Parameters
  68. ----------
  69. k: int
  70. The index of the feature block desired
  71. Returns
  72. -------
  73. block: ndarray
  74. The feature block desired
  75. """
  76. if self._common_block is None:
  77. block = self._data_blocks[k]
  78. else:
  79. block = np.hstack([self._common_block, self._data_blocks[k]])
  80. return block
  81. def get_all_except_block_indices(self, k):
  82. """
  83. Parameters
  84. ----------
  85. k: int
  86. The index of the feature block not desired
  87. Returns
  88. -------
  89. all_except_block_indices: list of int
  90. The indices of the features not in the desired block
  91. """
  92. if k not in self._block_indices_dict.keys():
  93. raise ValueError(f"{k} not a block index.")
  94. all_except_block_indices = []
  95. for block_no, block_indices in self._block_indices_dict.items():
  96. if block_no != k:
  97. all_except_block_indices += block_indices
  98. all_except_block_indices += self._common_block_indices
  99. return all_except_block_indices
  100. def get_all_except_block(self, k):
  101. """
  102. Parameters
  103. ----------
  104. k: int
  105. The index of the feature block not desired
  106. Returns
  107. -------
  108. all_except_block: ndarray
  109. The features not in the desired block
  110. """
  111. all_data = self.get_all_data()
  112. all_except_block_indices = self.get_all_except_block_indices(k)
  113. all_except_block = all_data[:, all_except_block_indices]
  114. return all_except_block
  115. def get_modified_data(self, k, mode="keep_k"):
  116. """
  117. Modify the data by either imputing the mean of each feature in block k
  118. (keep_rest) or imputing the mean of each feature not in block k
  119. (keep_k). Return the full data matrix with the modified data.
  120. Parameters
  121. ----------
  122. k: int
  123. The index of the feature block not to modify
  124. mode: string in {"keep_k", "keep_rest"}
  125. Mode for the method. "keep_k" imputes the mean of each feature not
  126. in block k, "keep_rest" imputes the mean of each feature in block k
  127. Returns
  128. -------
  129. all_data: ndarray
  130. Returns the data matrix obtained by concatenating all feature
  131. blocks together
  132. """
  133. modified_blocks = [np.outer(np.ones(self.n_samples), self._means[i])
  134. for i in range(self.n_blocks)]
  135. if mode == "keep_k":
  136. data_blocks = \
  137. [self._data_blocks[i] if i == k else modified_blocks[i] for
  138. i in range(self.n_blocks)]
  139. elif mode == "keep_rest":
  140. data_blocks = \
  141. [modified_blocks[i] if i == k else self._data_blocks[i] for
  142. i in range(self.n_blocks)]
  143. else:
  144. raise ValueError("Unsupported mode.")
  145. if self._common_block is None:
  146. all_data = np.hstack(data_blocks)
  147. else:
  148. all_data = np.hstack(data_blocks + [self._common_block])
  149. return all_data
  150. def train_test_split(self, train_indices, test_indices):
  151. """
  152. Split the data intro training and test partitions given the
  153. training and test indices. Return the training and test
  154. block partitioned data objects.
  155. Parameters
  156. ----------
  157. train_indices: array-like of shape (n_train_samples,)
  158. The indices corresponding to the training samples
  159. test_indices: array-like of shape (n_test_samples,)
  160. The indices corresponding to the training samples
  161. Returns
  162. -------
  163. train_blocked_data: BlockPartitionedData
  164. Returns the training block partitioned data set
  165. test_blocked_data: BlockPartitionedData
  166. Returns the test block partitioned data set
  167. """
  168. train_blocks = [self.get_block(k)[train_indices, :] for
  169. k in range(self.n_blocks)]
  170. train_blocked_data = BlockPartitionedData(train_blocks)
  171. test_blocks = [self.get_block(k)[test_indices, :] for
  172. k in range(self.n_blocks)]
  173. test_blocked_data = BlockPartitionedData(test_blocks)
  174. return train_blocked_data, test_blocked_data
  175. def __repr__(self):
  176. return self.get_all_data().__repr__()
  177. class BlockTransformerBase(ABC):
  178. """
  179. An interface for block transformers, objects that transform a data matrix
  180. into a BlockPartitionedData object comprising one block of engineered
  181. features for each original feature
  182. """
  183. def __init__(self):
  184. self._centers = {}
  185. self._scales = {}
  186. self.is_fitted = False
  187. def fit(self, X):
  188. """
  189. Fit (or train) the block transformer using the data matrix X.
  190. Parameters
  191. ----------
  192. X: ndarray
  193. The data matrix to be used in training
  194. """
  195. for k in range(X.shape[1]):
  196. self._fit_one_feature(X, k)
  197. self.is_fitted = True
  198. def check_is_fitted(self):
  199. """
  200. Check if the transformer has been fitted. Returns an error if not
  201. previously fitted.
  202. """
  203. if not self.is_fitted:
  204. raise AttributeError("Transformer has not yet been fitted.")
  205. def transform_one_feature(self, X, k, center=True, normalize=False):
  206. """
  207. Obtain a block of engineered features associated with the original
  208. feature with index k using the (previously) fitted transformer.
  209. Parameters
  210. ----------
  211. X: ndarray
  212. The data matrix to be transformed
  213. k: int
  214. Index of feature in X to be transformed
  215. center: bool
  216. Flag for whether to center the transformed data
  217. normalize: bool
  218. Flag for whether to rescale the transformed data to have unit
  219. variance
  220. Returns
  221. -------
  222. data_block: ndarray
  223. The block of engineered features associated with the original
  224. feature with index k.
  225. """
  226. data_block = self._transform_one_feature(X, k)
  227. data_block = self._center_and_normalize(data_block, k, center, normalize)
  228. return data_block
  229. def transform(self, X, center=True, normalize=False):
  230. """
  231. Transform a data matrix into a BlockPartitionedData object comprising
  232. one block for each original feature in X using the (previously) fitted
  233. trasnformer.
  234. Parameters
  235. ----------
  236. X: ndarray
  237. The data matrix to be transformed
  238. center: bool
  239. Flag for whether to center the transformed data
  240. normalize: bool
  241. Flag for whether to rescale the transformed data to have unit
  242. variance
  243. Returns
  244. -------
  245. blocked_data: BlockPartitionedData object
  246. The transformed data
  247. """
  248. self.check_is_fitted()
  249. n_features = X.shape[1]
  250. data_blocks = [self.transform_one_feature(X, k, center, normalize) for
  251. k in range(n_features)]
  252. blocked_data = BlockPartitionedData(data_blocks)
  253. return blocked_data
  254. def fit_transform_one_feature(self, X, k, center=True, normalize=False):
  255. """
  256. Fit the transformer and obtain a block of engineered features associated with
  257. the original feature with index k using this fitted transformer.
  258. Parameters
  259. ----------
  260. X: ndarray
  261. The data matrix to be fitted and transformed
  262. k: int
  263. Index of feature in X to be fitted and transformed
  264. center: bool
  265. Flag for whether to center the transformed data
  266. normalize: bool
  267. Flag for whether to rescale the transformed data to have unit
  268. variance
  269. Returns
  270. -------
  271. data_block: ndarray
  272. The block of engineered features associated with the original
  273. feature with index k.
  274. """
  275. data_block = self._fit_transform_one_feature(X, k)
  276. data_block = self._center_and_normalize(data_block, k, center, normalize)
  277. return data_block
  278. def fit_transform(self, X, center=True, normalize=False):
  279. """
  280. Fit the transformer and transform a data matrix into a BlockPartitionedData
  281. object comprising one block for each original feature in X using this
  282. fitted transformer.
  283. Parameters
  284. ----------
  285. X: ndarray
  286. The data matrix to be transformed
  287. center: bool
  288. Flag for whether to center the transformed data
  289. normalize: bool
  290. Flag for whether to rescale the transformed data to have unit
  291. variance
  292. Returns
  293. -------
  294. blocked_data: BlockPartitionedData object
  295. The transformed data
  296. """
  297. n_features = X.shape[1]
  298. data_blocks = [self.fit_transform_one_feature(X, k, center, normalize) for
  299. k in range(n_features)]
  300. blocked_data = BlockPartitionedData(data_blocks)
  301. self.is_fitted = True
  302. return blocked_data
  303. @abstractmethod
  304. def _fit_one_feature(self, X, k):
  305. pass
  306. @abstractmethod
  307. def _transform_one_feature(self, X, k):
  308. pass
  309. def _fit_transform_one_feature(self, X, k):
  310. self._fit_one_feature(X, k)
  311. return self._transform_one_feature(X, k)
  312. def _center_and_normalize(self, data_block, k, center=True, normalize=False):
  313. if center:
  314. data_block = data_block - self._centers[k]
  315. if normalize:
  316. if any(self._scales[k] == 0):
  317. raise Warning("No recaling done."
  318. "At least one feature is constant.")
  319. else:
  320. data_block = data_block / self._scales[k]
  321. return data_block
  322. class IdentityTransformer(BlockTransformerBase, ABC):
  323. """
  324. Block transformer that creates a block partitioned data object with each
  325. block k containing only the original feature k.
  326. """
  327. def _fit_one_feature(self, X, k):
  328. self._centers[k] = np.mean(X[:, [k]])
  329. self._scales[k] = np.std(X[:, [k]])
  330. def _transform_one_feature(self, X, k):
  331. return X[:, [k]]
  332. class TreeTransformer(BlockTransformerBase, ABC):
  333. """
  334. A block transformer that transforms data using a representation built from
  335. local decision stumps from a tree or tree ensemble. The transformer also
  336. comes with metadata on the local decision stumps and methods that allow for
  337. transformations using sub-representations corresponding to each of the
  338. original features.
  339. Parameters
  340. ----------
  341. estimator: scikit-learn estimator
  342. The scikit-learn tree or tree ensemble estimator object.
  343. data: ndarray
  344. A data matrix that can be used to update the number of samples in each
  345. node of the tree(s) in the supplied estimator object. This affects
  346. the node values of the resulting engineered features.
  347. """
  348. def __init__(self, estimator, data=None):
  349. super().__init__()
  350. self.estimator = estimator
  351. self.oob_seed = self.estimator.random_state
  352. # Check if single tree or tree ensemble
  353. if isinstance(estimator, BaseEnsemble):
  354. tree_models = estimator.estimators_
  355. if data is not None:
  356. # If a data matrix is supplied, use it to update the number
  357. # of samples in each node
  358. for tree_model in tree_models:
  359. _update_n_node_samples(tree_model, data)
  360. else:
  361. tree_models = [estimator]
  362. # Make stumps for each tree
  363. all_stumps = []
  364. for tree_model in tree_models:
  365. tree_stumps = make_stumps(tree_model.tree_)
  366. all_stumps += tree_stumps
  367. # Identify the stumps that split on feature k, for each k
  368. self.stumps = defaultdict(list)
  369. for stump in all_stumps:
  370. self.stumps[stump.feature].append(stump)
  371. self.n_splits = {k: len(stumps) for k, stumps in self.stumps.items()}
  372. def _fit_one_feature(self, X, k):
  373. stump_features = tree_feature_transform(self.stumps[k], X)
  374. self._centers[k] = np.mean(stump_features, axis=0)
  375. self._scales[k] = np.std(stump_features, axis=0)
  376. def _transform_one_feature(self, X, k):
  377. return tree_feature_transform(self.stumps[k], X)
  378. def _fit_transform_one_feature(self, X, k):
  379. stump_features = tree_feature_transform(self.stumps[k], X)
  380. self._centers[k] = np.mean(stump_features, axis=0)
  381. self._scales[k] = np.std(stump_features, axis=0)
  382. return stump_features
  383. class CompositeTransformer(BlockTransformerBase, ABC):
  384. """
  385. A block transformer that is built by concatenating the blocks of the same
  386. index from a list of block transformers.
  387. Parameters
  388. ----------
  389. block_transformer_list: list of BlockTransformer objects
  390. The list of block transformers to combine
  391. rescale_mode: string in {"max", "mean", None}
  392. Flag for the type of rescaling to be done to the blocks from different
  393. base transformers. If "max", divide each block by the max std deviation
  394. of a column within the block. If "mean", divide each block by the mean
  395. std deviation of a column within the block. If None, do not rescale.
  396. drop_features: bool
  397. Flag for whether to return an empty block if that from the first
  398. transformer in the list is trivial.
  399. """
  400. def __init__(self, block_transformer_list, rescale_mode=None, drop_features=True):
  401. super().__init__()
  402. self.block_transformer_list = block_transformer_list
  403. assert len(self.block_transformer_list) > 0, "Need at least one base" \
  404. "transformer."
  405. for transformer in block_transformer_list:
  406. if hasattr(transformer, "oob_seed") and \
  407. transformer.oob_seed is not None:
  408. self.oob_seed = transformer.oob_seed
  409. break
  410. self.rescale_mode = rescale_mode
  411. self.drop_features = drop_features
  412. self._rescale_factors = {}
  413. self._trivial_block_indices = {}
  414. def _fit_one_feature(self, X, k):
  415. data_blocks = []
  416. for block_transformer in self.block_transformer_list:
  417. data_block = block_transformer.fit_transform_one_feature(
  418. X, k, center=False, normalize=False)
  419. data_blocks.append(data_block)
  420. # Handle trivial blocks
  421. self._trivial_block_indices[k] = \
  422. [idx for idx, data_block in enumerate(data_blocks) if
  423. _empty_or_constant(data_block)]
  424. if (0 in self._trivial_block_indices[k] and self.drop_features) or \
  425. (len(self._trivial_block_indices[k]) == len(data_blocks)):
  426. # If first block is trivial and self.drop_features is True,
  427. self._centers[k] = np.array([0])
  428. self._scales[k] = np.array([1])
  429. return
  430. else:
  431. # Remove trivial blocks
  432. for idx in reversed(self._trivial_block_indices[k]):
  433. data_blocks.pop(idx)
  434. self._rescale_factors[k] = _get_rescale_factors(data_blocks, self.rescale_mode)
  435. composite_block = np.hstack(
  436. [data_block / scale_factor for data_block, scale_factor in
  437. zip(data_blocks, self._rescale_factors[k])]
  438. )
  439. self._centers[k] = composite_block.mean(axis=0)
  440. self._scales[k] = composite_block.std(axis=0)
  441. def _transform_one_feature(self, X, k):
  442. data_blocks = []
  443. for block_transformer in self.block_transformer_list:
  444. data_block = block_transformer.transform_one_feature(
  445. X, k, center=False, normalize=False)
  446. data_blocks.append(data_block)
  447. # Handle trivial blocks
  448. if (0 in self._trivial_block_indices[k] and self.drop_features) or \
  449. (len(self._trivial_block_indices[k]) == len(data_blocks)):
  450. # If first block is trivial and self.drop_features is True,
  451. # return empty block
  452. return np.empty((X.shape[0], 0))
  453. else:
  454. # Remove trivial blocks
  455. for idx in reversed(self._trivial_block_indices[k]):
  456. data_blocks.pop(idx)
  457. composite_block = np.hstack(
  458. [data_block / scale_factor for data_block, scale_factor in
  459. zip(data_blocks, self._rescale_factors[k])]
  460. )
  461. return composite_block
  462. def _fit_transform_one_feature(self, X, k):
  463. data_blocks = []
  464. for block_transformer in self.block_transformer_list:
  465. data_block = block_transformer.fit_transform_one_feature(
  466. X, k, center=False, normalize=False)
  467. data_blocks.append(data_block)
  468. # Handle trivial blocks
  469. self._trivial_block_indices[k] = \
  470. [idx for idx, data_block in enumerate(data_blocks) if
  471. _empty_or_constant(data_block)]
  472. if (0 in self._trivial_block_indices[k] and self.drop_features) or \
  473. (len(self._trivial_block_indices[k]) == len(data_blocks)):
  474. # If first block is trivial and self.drop_features is True,
  475. # return empty block
  476. self._centers[k] = np.array([0])
  477. self._scales[k] = np.array([1])
  478. return np.empty((X.shape[0], 0))
  479. else:
  480. # Remove trivial blocks
  481. for idx in reversed(self._trivial_block_indices[k]):
  482. data_blocks.pop(idx)
  483. self._rescale_factors[k] = _get_rescale_factors(data_blocks, self.rescale_mode)
  484. composite_block = np.hstack(
  485. [data_block / scale_factor for data_block, scale_factor in
  486. zip(data_blocks, self._rescale_factors[k])]
  487. )
  488. self._centers[k] = composite_block.mean(axis=0)
  489. self._scales[k] = composite_block.std(axis=0)
  490. return composite_block
  491. class MDIPlusDefaultTransformer(CompositeTransformer, ABC):
  492. """
  493. Default block transformer used in MDI+. For each original feature, this
  494. forms a block comprising the local decision stumps, from a single tree
  495. model, that split on the feature, and appends the original feature.
  496. Parameters
  497. ----------
  498. tree_model: scikit-learn estimator
  499. The scikit-learn tree estimator object.
  500. rescale_mode: string in {"max", "mean", None}
  501. Flag for the type of rescaling to be done to the blocks from different
  502. base transformers. If "max", divide each block by the max std deviation
  503. of a column within the block. If "mean", divide each block by the mean
  504. std deviation of a column within the block. If None, do not rescale.
  505. drop_features: bool
  506. Flag for whether to return an empty block if that from the first
  507. transformer in the list is trivial.
  508. """
  509. def __init__(self, tree_model, rescale_mode="max", drop_features=True):
  510. super().__init__([TreeTransformer(tree_model), IdentityTransformer()],
  511. rescale_mode, drop_features)
  512. def _update_n_node_samples(tree, X):
  513. node_indicators = tree.decision_path(X)
  514. new_n_node_samples = node_indicators.getnnz(axis=0)
  515. for i in range(len(new_n_node_samples)):
  516. tree.tree_.n_node_samples[i] = new_n_node_samples[i]
  517. def _get_rescale_factors(data_blocks, rescale_mode):
  518. if rescale_mode == "max":
  519. scale_factors = np.array([max(data_block.std(axis=0)) for
  520. data_block in data_blocks])
  521. elif rescale_mode == "mean":
  522. scale_factors = np.array([np.mean(data_block.std(axis=0)) for
  523. data_block in data_blocks])
  524. elif rescale_mode is None:
  525. scale_factors = np.ones(len(data_blocks))
  526. else:
  527. raise ValueError("Invalid rescale mode.")
  528. scale_factors = scale_factors / scale_factors[0]
  529. return scale_factors
  530. def _empty_or_constant(data_block):
  531. return data_block.shape[1] == 0 or max(data_block.std(axis=0)) == 0
  532. def _blocked_train_test_split(blocked_data, y, oob_seed):
  533. n_samples = len(y)
  534. train_indices = _generate_sample_indices(oob_seed, n_samples, n_samples)
  535. test_indices = _generate_unsampled_indices(oob_seed, n_samples, n_samples)
  536. train_blocked_data, test_blocked_data = \
  537. blocked_data.train_test_split(train_indices, test_indices)
  538. if y.ndim > 1:
  539. y_train = y[train_indices, :]
  540. y_test = y[test_indices, :]
  541. else:
  542. y_train = y[train_indices]
  543. y_test = y[test_indices]
  544. return train_blocked_data, test_blocked_data, y_train, y_test, train_indices, test_indices
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...