Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

options.py 5.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  1. import os
  2. import time
  3. import argparse
  4. import torch
  5. def get_options(args=None):
  6. parser = argparse.ArgumentParser(
  7. description="Attention based model for solving the Travelling Salesman Problem with Reinforcement Learning")
  8. # Data
  9. parser.add_argument('--problem', default='tsp', help="The problem to solve, default 'tsp'")
  10. parser.add_argument('--graph_size', type=int, default=20, help="The size of the problem graph")
  11. parser.add_argument('--batch_size', type=int, default=512, help='Number of instances per batch during training')
  12. parser.add_argument('--epoch_size', type=int, default=1280000, help='Number of instances per epoch during training')
  13. parser.add_argument('--val_size', type=int, default=10000,
  14. help='Number of instances used for reporting validation performance')
  15. parser.add_argument('--val_dataset', type=str, default=None, help='Dataset file to use for validation')
  16. # Model
  17. parser.add_argument('--model', default='attention', help="Model, 'attention' (default) or 'pointer'")
  18. parser.add_argument('--embedding_dim', type=int, default=128, help='Dimension of input embedding')
  19. parser.add_argument('--hidden_dim', type=int, default=128, help='Dimension of hidden layers in Enc/Dec')
  20. parser.add_argument('--n_encode_layers', type=int, default=3,
  21. help='Number of layers in the encoder/critic network')
  22. parser.add_argument('--tanh_clipping', type=float, default=10.,
  23. help='Clip the parameters to within +- this value using tanh. '
  24. 'Set to 0 to not perform any clipping.')
  25. parser.add_argument('--normalization', default='batch', help="Normalization type, 'batch' (default) or 'instance'")
  26. # Training
  27. parser.add_argument('--lr_model', type=float, default=1e-4, help="Set the learning rate for the actor network")
  28. parser.add_argument('--lr_critic', type=float, default=1e-4, help="Set the learning rate for the critic network")
  29. parser.add_argument('--lr_decay', type=float, default=1.0, help='Learning rate decay per epoch')
  30. parser.add_argument('--eval_only', action='store_true', help='Set this value to only evaluate model')
  31. parser.add_argument('--n_epochs', type=int, default=100, help='The number of epochs to train')
  32. parser.add_argument('--seed', type=int, default=1234, help='Random seed to use')
  33. parser.add_argument('--max_grad_norm', type=float, default=1.0,
  34. help='Maximum L2 norm for gradient clipping, default 1.0 (0 to disable clipping)')
  35. parser.add_argument('--no_cuda', action='store_true', help='Disable CUDA')
  36. parser.add_argument('--exp_beta', type=float, default=0.8,
  37. help='Exponential moving average baseline decay (default 0.8)')
  38. parser.add_argument('--baseline', default=None,
  39. help="Baseline to use: 'rollout', 'critic' or 'exponential'. Defaults to no baseline.")
  40. parser.add_argument('--bl_alpha', type=float, default=0.05,
  41. help='Significance in the t-test for updating rollout baseline')
  42. parser.add_argument('--bl_warmup_epochs', type=int, default=None,
  43. help='Number of epochs to warmup the baseline, default None means 1 for rollout (exponential '
  44. 'used for warmup phase), 0 otherwise. Can only be used with rollout baseline.')
  45. parser.add_argument('--eval_batch_size', type=int, default=1024,
  46. help="Batch size to use during (baseline) evaluation")
  47. parser.add_argument('--checkpoint_encoder', action='store_true',
  48. help='Set to decrease memory usage by checkpointing encoder')
  49. parser.add_argument('--shrink_size', type=int, default=None,
  50. help='Shrink the batch size if at least this many instances in the batch are finished'
  51. ' to save memory (default None means no shrinking)')
  52. parser.add_argument('--data_distribution', type=str, default=None,
  53. help='Data distribution to use during training, defaults and options depend on problem.')
  54. # Misc
  55. parser.add_argument('--log_step', type=int, default=50, help='Log info every log_step steps')
  56. parser.add_argument('--log_dir', default='logs', help='Directory to write TensorBoard information to')
  57. parser.add_argument('--run_name', default='run', help='Name to identify the run')
  58. parser.add_argument('--output_dir', default='outputs', help='Directory to write output models to')
  59. parser.add_argument('--epoch_start', type=int, default=0,
  60. help='Start at epoch # (relevant for learning rate decay)')
  61. parser.add_argument('--checkpoint_epochs', type=int, default=1,
  62. help='Save checkpoint every n epochs (default 1), 0 to save no checkpoints')
  63. parser.add_argument('--load_path', help='Path to load model parameters and optimizer state from')
  64. parser.add_argument('--resume', help='Resume from previous checkpoint file')
  65. parser.add_argument('--no_tensorboard', action='store_true', help='Disable logging TensorBoard files')
  66. parser.add_argument('--no_progress_bar', action='store_true', help='Disable progress bar')
  67. opts = parser.parse_args(args)
  68. opts.use_cuda = torch.cuda.is_available() and not opts.no_cuda
  69. opts.run_name = "{}_{}".format(opts.run_name, time.strftime("%Y%m%dT%H%M%S"))
  70. opts.save_dir = os.path.join(
  71. opts.output_dir,
  72. "{}_{}".format(opts.problem, opts.graph_size),
  73. opts.run_name
  74. )
  75. if opts.bl_warmup_epochs is None:
  76. opts.bl_warmup_epochs = 1 if opts.baseline == 'rollout' else 0
  77. assert (opts.bl_warmup_epochs == 0) or (opts.baseline == 'rollout')
  78. assert opts.epoch_size % opts.batch_size == 0, "Epoch size must be integer multiple of batch size!"
  79. return opts
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...