1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
- #!/usr/bin/env python3
- import fire
- import json
- import os
- import numpy as np
- import tensorflow as tf
- import model, sample, encoder
- def interact_model(
- model_name='124M',
- seed=None,
- nsamples=1,
- batch_size=1,
- length=None,
- temperature=1,
- top_k=0,
- top_p=1,
- models_dir='models',
- ):
- """
- Interactively run the model
- :model_name=124M : String, which model to use
- :seed=None : Integer seed for random number generators, fix seed to reproduce
- results
- :nsamples=1 : Number of samples to return total
- :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples.
- :length=None : Number of tokens in generated text, if None (default), is
- determined by model hyperparameters
- :temperature=1 : Float value controlling randomness in boltzmann
- distribution. Lower temperature results in less random completions. As the
- temperature approaches zero, the model will become deterministic and
- repetitive. Higher temperature results in more random completions.
- :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
- considered for each step (token), resulting in deterministic completions,
- while 40 means 40 words are considered at each step. 0 (default) is a
- special setting meaning no restrictions. 40 generally is a good value.
- :models_dir : path to parent folder containing model subfolders
- (i.e. contains the <model_name> folder)
- """
- models_dir = os.path.expanduser(os.path.expandvars(models_dir))
- if batch_size is None:
- batch_size = 1
- assert nsamples % batch_size == 0
- enc = encoder.get_encoder(model_name, models_dir)
- hparams = model.default_hparams()
- with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
- hparams.override_from_dict(json.load(f))
- if length is None:
- length = hparams.n_ctx // 2
- elif length > hparams.n_ctx:
- raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
- with tf.Session(graph=tf.Graph()) as sess:
- context = tf.placeholder(tf.int32, [batch_size, None])
- np.random.seed(seed)
- tf.set_random_seed(seed)
- output = sample.sample_sequence(
- hparams=hparams, length=length,
- context=context,
- batch_size=batch_size,
- temperature=temperature, top_k=top_k, top_p=top_p
- )
- saver = tf.train.Saver()
- ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
- saver.restore(sess, ckpt)
- while True:
- raw_text = input("Model prompt >>> ")
- while not raw_text:
- print('Prompt should not be empty!')
- raw_text = input("Model prompt >>> ")
- context_tokens = enc.encode(raw_text)
- generated = 0
- for _ in range(nsamples // batch_size):
- out = sess.run(output, feed_dict={
- context: [context_tokens for _ in range(batch_size)]
- })[:, len(context_tokens):]
- for i in range(batch_size):
- generated += 1
- text = enc.decode(out[i])
- print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
- print(text)
- print("=" * 80)
- if __name__ == '__main__':
- fire.Fire(interact_model)
|