유효한 인덱스를 반환하지 않는 Tensorflow 텍스트 생성 (Tensorflow text generation not returning valid index)


문제 설명

유효한 인덱스를 반환하지 않는 Tensorflow 텍스트 생성 (Tensorflow text generation not returning valid index)

Tensorflow 모델을 학습시켜 텍스트를 생성하려고 합니다. 저는 주로 Tensorflow 웹사이트의 코드를 사용하고 있지만 텍스트를 생성하려고 하면 모델이 word_index에 없는 인덱스를 반환합니다.

텍스트 생성 기능:

6
model = create_model(vocab_size = vocab_size,
  embed_dim=embed_dim,
  rnn_neurons=rnn_neurons,
  batch_size=1)

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

model.build(tf.TensorShape([1, None]))

char_2_index = tokenizer.word_index
index_2_char = {ind:char for char, ind in char_2_index.items()}

def generate_text(model, start_string):

  num_generate = 1000

  input_eval = [char_2_index[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)

  text_generated = []

  temperature = 1.0

  model.reset_states()
  for i in range(num_generate):
      print(text_generated)
      predictions = model(input_eval)

      predictions = tf.squeeze(predictions, 0)

      predictions = predictions / temperature
      print(predictions)
      predicted_id = tf.random.categorical(predictions, num_samples=1)[‑1,0].numpy()
      print(predicted_id)

      input_eval = tf.expand_dims([predicted_id], 0)


      text_generated.append(index_2_char[predicted_id])

  return (start_string + ''.join(text_generated))

오류

KeyError                                  Traceback (most recent call last)
<ipython‑input‑52‑9517558352c4> in <module>()
‑‑‑‑> 1 print(generate_text(model, start_string=u"Is Baby yoda "))

<ipython‑input‑47‑75973c66de6c> in generate_text(model, start_string)
     37 
     38 
‑‑‑> 39       text_generated.append(index_2_char[predicted_id])
     40 
     41   return (start_string + ''.join(text_generated))

KeyError: 133

단어 색인 및 훈련 텍스트에는 대문자와 소문자만 포함됩니다.

편집 자세한 내용은 내 데이터 준비 및 구조입니다.

구조 [['SENTENCE'], ['SENTENCE2']...]

데이터 준비

tokenizer = keras.preprocessing.text.Tokenizer(num_words=209, lower=False, char_level=True, filters='#$%&()*+‑<=>@[\\]^_`{|}~\t\n')
tokenizer.fit_on_texts(df['title'].values)
df['encoded_with_keras'] = tokenizer.texts_to_sequences(df['title'].values)

dataset = df['encoded_with_keras'].values
dataset = tf.keras.preprocessing.sequence.pad_sequences(dataset, padding='post')

dataset = dataset.flatten()

dataset = tf.data.Dataset.from_tensor_slices(dataset)

sequences = dataset.batch(seq_len+1, drop_remainder=True)

def create_seq_targets(seq):
    input_txt = seq[:‑1]
    target_txt = seq[1:]
    return input_txt, target_txt

dataset = sequences.map(create_seq_targets)

dataset = dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True)


참조 솔루션

방법 1:

It seems that vocab_size used in create_model(...) is not equal to the length of index_2_char.

(by GrepThisVladimir Sotnikov)

참조 문서

  1. Tensorflow text generation not returning valid index (CC BY‑SA 2.5/3.0/4.0)

#Python #Artificial-Intelligence #recurrent-neural-network #Machine-Learning #tensorflow2.0






관련 질문

Python - 파일 이름에 특수 문자가 있는 파일의 이름을 바꿀 수 없습니다. (Python - Unable to rename a file with special characters in the file name)

구조화된 배열의 dtype을 변경하면 문자열 데이터가 0이 됩니다. (Changing dtype of structured array zeros out string data)

목록 목록의 효과적인 구현 (Effective implementation of list of lists)

for 루프를 중단하지 않고 if 문을 중지하고 다른 if에 영향을 줍니다. (Stop if statement without breaking for loop and affect other ifs)

기본 숫자를 10 ^ 9 이상으로 늘리면 코드가 작동하지 않습니다. (Code fails to work when i increase the base numbers to anything over 10 ^ 9)

사용자 지정 대화 상자 PyQT5를 닫고 데이터 가져오기 (Close and get data from a custom dialog PyQT5)

Enthought Canopy의 Python: csv 파일 조작 (Python in Enthought Canopy: manipulating csv files)

학생의 이름을 인쇄하려고 하는 것이 잘못된 것은 무엇입니까? (What is wrong with trying to print the name of the student?)

다단계 열 테이블에 부분합 열 추가 (Adding a subtotal column to a multilevel column table)

여러 함수의 변수를 다른 함수로 사용 (Use variables from multiple functions into another function)

리프 텐서의 값을 업데이트하는 적절한 방법은 무엇입니까(예: 경사하강법 업데이트 단계 중) (What's the proper way to update a leaf tensor's values (e.g. during the update step of gradient descent))

Boto3: 조직 단위의 AMI에 시작 권한을 추가하려고 하면 ParamValidationError가 발생합니다. (Boto3: trying to add launch permission to AMI for an organizational unit raises ParamValidationError)







코멘트