RNN:在训练模型后从文本输入中获取预测

RNN: Get prediction from a text input after the model is trained

我是 RNN 的新手,我一直在研究一个小型二元标签分类器。我已经能够得到一个稳定的模型,结果令人满意。

但是,我很难使用该模型对新输入进行分类,我想知道你们是否可以帮助我。请参阅下面的代码以供参考。

非常感谢。

from tensorflow.keras import preprocessing
from sklearn.utils import shuffle
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.models import Model
from tensorflow.keras import models
from tensorflow.keras.layers import LSTM, Activation, Dense, Dropout, Input, 
Embedding
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.preprocessing import sequence, text
from tensorflow.keras.callbacks import EarlyStopping
from matplotlib import pyplot

class tensor_rnn():
def __init__(self, hidden_layers=3):
    self.data_path = 'C:\\Users\cmazz\PycharmProjects\InvestmentAnalysis_2.0\Sentiment\Finance_Articles\'
    # self.corp_paths = corpora_paths
    self.h_layers = hidden_layers
    self.num_words = []
    good = pd.read_csv(self.data_path + 'GoodO.csv')
    good['Polarity'] = 'pos'
    for line in good['Head'].tolist():
        counter = len(line.split())
        self.num_words.append(counter)
    bad = pd.read_csv(self.data_path + 'BadO.csv')
    bad['Polarity'] = 'neg'
    for line in bad['Head'].tolist():
        counter = len(line.split())
        self.num_words.append(counter)
    self.features = pd.concat([good, bad]).reset_index(drop=True)
    self.features = shuffle(self.features)

    self.max_len = len(max(self.features['Head'].tolist()))
    # self.train, self.test = train_test_split(features, test_size=0.33, random_state=42)
    X = self.features['Head']
    Y = self.features['Polarity']
    le = LabelEncoder()
    Y = le.fit_transform(Y)
    Y = Y.reshape(-1, 1)
    self.X_train, self.X_test, self.Y_train, self.Y_test = train_test_split(X, Y, test_size=0.30)
    self.tok = preprocessing.text.Tokenizer(num_words=len(self.num_words))
    self.tok.fit_on_texts(self.X_train)
    sequences = self.tok.texts_to_sequences(self.X_train)
    self.sequences_matrix = preprocessing.sequence.pad_sequences(sequences, maxlen=self.max_len)

def RNN(self):
    inputs = Input(name='inputs', shape=[self.max_len])
    layer = Embedding(len(self.num_words), 30, input_length=self.max_len)(inputs)
    # layer = LSTM(64, return_sequences=True)(layer)
    layer = LSTM(32)(layer)
    layer = Dense(256, name='FC1')(layer)
    layer = Activation('relu')(layer)
    layer = Dropout(0.5)(layer)
    layer = Dense(1, name='out_layer')(layer)
    layer = Activation('sigmoid')(layer)
    model = Model(inputs=inputs, outputs=layer)
    return model

def model_train(self):
    self.model = self.RNN()
    self.model.summary()
    self.model.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])   # RMSprop()

def model_test(self):
    self.history = self.model.fit(self.sequences_matrix, self.Y_train, batch_size=100, epochs=3,
              validation_split=0.30, callbacks=[EarlyStopping(monitor='val_loss', min_delta=0.0001)])
    test_sequences = self.tok.texts_to_sequences(self.X_test)
    test_sequences_matrix = sequence.pad_sequences(test_sequences, maxlen=self.max_len)
    accr = self.model.evaluate(test_sequences_matrix, self.Y_test)
    print('Test set\n  Loss: {:0.3f}\n  Accuracy: {:0.3f}'.format(accr[0], accr[1]))


if __name__ == "__main__":
    a = tensor_rnn()
    a.model_train()
    a.model_test()
    a.model.save('C:\\Users\cmazz\PycharmProjects\'
                              'InvestmentAnalysis_2.0\RNN_Model.h5', 
    include_optimizer=True)
     b = models.load_model('C:\\Users\cmazz\PycharmProjects\'
                              'InvestmentAnalysis_2.0\RNN_Model.h5')
    stringy = ['Fund managers back away from Amazon as they cut FANG exposure']
    prediction = b.predict(np.array(stringy))
    print(prediction)

当我 运行 我的代码出现以下错误:

ValueError: Error when checking input: expected inputs to have shape (39,) but got array with shape (1,)

根据 ValueError 和 prediction = b.predict(np.array(stringy)),我认为您需要标记化您的输入字符串。