Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  1. from functools import partial
  2. from typing import Type, Union, Dict
  3. import torch
  4. from torch import nn
  5. def get_builtin_activation_type(activation: Union[str, None], **kwargs) -> Type:
  6. """
  7. Returns activation class by its name from torch.nn namespace. This function support all modules available from
  8. torch.nn and also their lower-case aliases.
  9. On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).
  10. >>> act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01)
  11. >>> act = act_cls()
  12. Args:
  13. activation: Activation function name (E.g. ReLU). If None will return nn.Identity
  14. **kwargs: Extra arguments to pass to constructor during instantiation (E.g. inplace=True)
  15. Returns:
  16. Type of the activation function that is ready to be instantiated
  17. """
  18. if activation is None:
  19. activation_cls = nn.Identity
  20. else:
  21. lowercase_aliases: Dict[str, str] = dict((k.lower(), k) for k in torch.nn.__dict__.keys())
  22. # Register additional aliases
  23. lowercase_aliases["leaky_relu"] = "LeakyReLU" # LeakyRelu in snake_case
  24. lowercase_aliases["swish"] = "SiLU" # Swish shich is equivalent to SiLU
  25. lowercase_aliases["none"] = "Identity"
  26. if activation in lowercase_aliases:
  27. activation = lowercase_aliases[activation]
  28. if activation not in torch.nn.__dict__:
  29. raise KeyError(f"Requested activation function {activation} is not known")
  30. activation_cls = torch.nn.__dict__[activation]
  31. if len(kwargs):
  32. activation_cls = partial(activation_cls, **kwargs)
  33. return activation_cls
Discard
Tip!

Press p or to see the previous file or, n or to see the next file