Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#868 Draw fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-fix_draw
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
  1. import torch
  2. import torch.nn as nn
  3. class PixelShuffle(nn.Module):
  4. """
  5. Equivalent to nn.PixelShuffle.
  6. nn.PixelShuffle module is translated to `DepthToSpace` layer in ONNX, some compilation frameworks (i.e tflite),
  7. doesn't support this layer. In that case this module should be used, it's translated to
  8. reshape / transpose / reshape operations in ONNX graph.
  9. """
  10. def __init__(self, upscale_factor: int):
  11. super().__init__()
  12. self.scale = upscale_factor
  13. def forward(self, x: torch.Tensor):
  14. b, c, h, w = x.size()
  15. x = x.reshape(b, torch.div(c, self.scale * self.scale, rounding_mode="trunc"), self.scale, self.scale, h, w)
  16. x = x.permute(0, 1, 4, 2, 5, 3)
  17. x = x.reshape(b, torch.div(c, self.scale * self.scale, rounding_mode="trunc"), h * self.scale, w * self.scale)
  18. return x
Discard
Tip!

Press p or to see the previous file or, n or to see the next file