Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  1. from monotonic_align import maximum_path
  2. from monotonic_align import mask_from_lens
  3. from monotonic_align.core import maximum_path_c
  4. import numpy as np
  5. import torch
  6. import copy
  7. from torch import nn
  8. import torch.nn.functional as F
  9. import torchaudio
  10. import librosa
  11. import matplotlib.pyplot as plt
  12. from munch import Munch
  13. def maximum_path(neg_cent, mask):
  14. """ Cython optimized version.
  15. neg_cent: [b, t_t, t_s]
  16. mask: [b, t_t, t_s]
  17. """
  18. device = neg_cent.device
  19. dtype = neg_cent.dtype
  20. neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
  21. path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
  22. t_t_max = np.ascontiguousarray(mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32))
  23. t_s_max = np.ascontiguousarray(mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32))
  24. maximum_path_c(path, neg_cent, t_t_max, t_s_max)
  25. return torch.from_numpy(path).to(device=device, dtype=dtype)
  26. def get_data_path_list(train_path=None, val_path=None):
  27. if train_path is None:
  28. train_path = "Data/train_list.txt"
  29. if val_path is None:
  30. val_path = "Data/val_list.txt"
  31. with open(train_path, 'r', encoding='utf-8', errors='ignore') as f:
  32. train_list = f.readlines()
  33. with open(val_path, 'r', encoding='utf-8', errors='ignore') as f:
  34. val_list = f.readlines()
  35. return train_list, val_list
  36. def length_to_mask(lengths):
  37. mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
  38. mask = torch.gt(mask+1, lengths.unsqueeze(1))
  39. return mask
  40. # for norm consistency loss
  41. def log_norm(x, mean=-4, std=4, dim=2):
  42. """
  43. normalized log mel -> mel -> norm -> log(norm)
  44. """
  45. x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
  46. return x
  47. def get_image(arrs):
  48. plt.switch_backend('agg')
  49. fig = plt.figure()
  50. ax = plt.gca()
  51. ax.imshow(arrs)
  52. return fig
  53. def recursive_munch(d):
  54. if isinstance(d, dict):
  55. return Munch((k, recursive_munch(v)) for k, v in d.items())
  56. elif isinstance(d, list):
  57. return [recursive_munch(v) for v in d]
  58. else:
  59. return d
  60. def log_print(message, logger):
  61. logger.info(message)
  62. print(message)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...