Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

load_embedding.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """ Returning embedding of input text """
  17. import logging
  18. import os
  19. import sys
  20. from dataclasses import dataclass, field
  21. from typing import Dict, List, Optional, Tuple
  22. from torch.utils.data.dataloader import DataLoader
  23. import numpy as np
  24. import torch
  25. import h5py
  26. import pdb
  27. from tqdm import tqdm
  28. from transformers import (
  29. HfArgumentParser,
  30. )
  31. # Setup logging
  32. logging.basicConfig(
  33. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  34. datefmt="%m/%d/%Y %H:%M:%S",
  35. level=logging.INFO,
  36. )
  37. logger = logging.getLogger(__name__)
  38. @dataclass
  39. class DataArguments:
  40. """
  41. Arguments pertaining to what data we are going to input our model for training and eval.
  42. """
  43. indexed_path: str = field(
  44. metadata={"help": "indexed h5 file path"}
  45. )
  46. inputtext_path: str = field(
  47. metadata={"help": "The input text file path"}
  48. )
  49. def main():
  50. # We now keep distinct sets of args, for a cleaner separation of concerns.
  51. parser = HfArgumentParser((DataArguments))
  52. data_args = parser.parse_args_into_dataclasses()[0]
  53. with h5py.File(data_args.indexed_path, 'r') as f:
  54. with open(data_args.inputtext_path, 'r') as f_in:
  55. print("The number of keys in h5: {}".format(len(f)))
  56. for i, input in enumerate(f_in):
  57. entity_name = input.strip()
  58. embedding = f[entity_name]['embedding'][:]
  59. print("entity_name = {}".format(entity_name))
  60. print("embedding = {}".format(embedding))
  61. break
  62. def _mp_fn(index):
  63. # For xla_spawn (TPUs)
  64. main()
  65. if __name__ == "__main__":
  66. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...