GPT 2 - 类型错误:无法根据规则将数组数据从 dtype('O') 强制转换为 dtype('int64'),'safe'



我正在使用gpt2, python 3.9和tensorflow 2.5,当连接到flask (flask在终端中运行)时,我得到以下消息:

TypeError:不能根据规则'safe'将数组数据从dtype('O')强制转换为dtype('int64')

下面是generator.py 中的代码
#!/usr/bin/env python3
import fire
import json
import os
import numpy as np
import tensorflow.compat.v1 as tf
# import model, sample, encoder
from text_generator import model
from text_generator import sample
from text_generator import encoder

class AI:
def generate_text(self, input_text):
model_name = '117M_Trained'
seed = None,
nsamples = 1
batch_size = 1
length = 150
temperature = 1
top_k = 40
top_p = 1
models_dir = 'models'
self.response = ''
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
if batch_size is None:
batch_size = 1
assert nsamples % batch_size == 0
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
cur_path = os.path.dirname(__file__) + '/models' + '/' + model_name
with open(cur_path + '/hparams.json') as f:
hparams.override_from_dict(json.load(f))
if length is None:
length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k, top_p=top_p
)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(cur_path)
saver.restore(sess, ckpt)
context_tokens = enc.encode(input_text)
generated = 0
for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
})[:, len(context_tokens):]
for i in range(batch_size):
generated += 1
text = enc.decode(out[i])
self.response = text
return self.response

ai = AI()
text = ai.generate_text('How are you?')
print(text)

任何帮助都是赞赏🙏ps我还在下面添加了整个追溯

* Serving Flask app 'text_generator' (lazy loading)
* Environment: development
* Debug mode: on
2021-09-14 19:58:08.687907: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):
File "_mt19937.pyx", line 178, in numpy.random._mt19937.MT19937._legacy_seeding
TypeError: 'tuple' object cannot be interpreted as an integer
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/dusandev/miniconda3/bin/flask", line 8, in <module>
sys.exit(main())
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 990, in main
cli.main(args=sys.argv[1:])
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 596, in main
return super().main(*args, **kwargs)
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 1062, in main
rv = self.invoke(ctx)
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 1668, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 1404, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 763, in invoke
return __callback(*args, **kwargs)
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/decorators.py", line 84, in new_func
return ctx.invoke(f, obj, *args, **kwargs)
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 763, in invoke
return __callback(*args, **kwargs)
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 845, in run_command
app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 321, in __init__
self._load_unlocked()
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 346, in _load_unlocked
self._app = rv = self.loader()
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 402, in load_app
app = locate_app(self, import_name, name)
File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 256, in locate_app
__import__(module_name)
File "/Users/dusandev/Desktop/AI/text_generator/__init__.py", line 2, in <module>
from .routes import generator
File "/Users/dusandev/Desktop/AI/text_generator/routes.py", line 2, in <module>
from .generator import ai
File "/Users/dusandev/Desktop/AI/text_generator/generator.py", line 74, in <module>
text = ai.generate_text('How are you?')
File "/Users/dusandev/Desktop/AI/text_generator/generator.py", line 46, in generate_text
np.random.seed(seed)
File "mtrand.pyx", line 244, in numpy.random.mtrand.RandomState.seed
File "_mt19937.pyx", line 166, in numpy.random._mt19937.MT19937._legacy_seeding
File "_mt19937.pyx", line 186, in numpy.random._mt19937.MT19937._legacy_seeding
TypeError: Cannot cast array data from dtype('O') to dtype('int64') according to the rule 'safe'

问题出在代码中的None,行。这导致元组(None,)作为np.random.seed(seed)的输入。它接受整数,但你发送的是元组。

相关内容

最新更新