Tensorflow RNN:两种不同类型的输入
Tensorflow RNN: input of two different types
我想为我的 LSTM RNN 单元提供 2 种类型的输入。我的输入由整数列表组成(即 [5,2,3,4,6,1,0, ...]
)。
但是,每个整数都分为 2 个不同的组,所以我想以 [[5,True],[2,False], [3,False], [4,True], ... ]
的方式标记每个整数。
我见过输入由相同类型的常量组成,输入维度为 2 或更高的情况。但我不确定 2 种不同的类型是否可以构成 1 个输入单位,例如 [5,True]
。如果这不可能,我正在考虑将 True 替换为整数 2,将 False 替换为整数 1,例如 [[5,2], [2,1], ...]
,其中输入维度为 2(不确定这是标记的好方法)。
标记每个整数以使其属于不同组的好方法是什么?
TensorFlow 支持嵌套元组作为 rnn 输入,参见 doc。但是,您需要编写自己的单元格 Class 来处理这种特定类型的输入。在这种情况下,它应该是这样的:
# Define your own cell which accept (integer, bool) input
class YourCell(tf.contrib.rnn.RNNCell):
# override relevant functions of base interface: RNNCell
# state_size, output_size, etc.
# The main body of computation logic goes in this function
def __call__(self, inputs, state, scope=None):
# note inputs variable contains inputs of only one time step
# for example, inputs = (5, True)
interger, boolean = inputs
# your computation
integer_input = [5, 2, 3, 4]
bool_input = [True, False, False, True]
inputs = [integer_input, bool_input]
cell = YourCell()
outputs = tf.nn.dynamic_rnn(inputs, cell)
我想为我的 LSTM RNN 单元提供 2 种类型的输入。我的输入由整数列表组成(即 [5,2,3,4,6,1,0, ...]
)。
但是,每个整数都分为 2 个不同的组,所以我想以 [[5,True],[2,False], [3,False], [4,True], ... ]
的方式标记每个整数。
我见过输入由相同类型的常量组成,输入维度为 2 或更高的情况。但我不确定 2 种不同的类型是否可以构成 1 个输入单位,例如 [5,True]
。如果这不可能,我正在考虑将 True 替换为整数 2,将 False 替换为整数 1,例如 [[5,2], [2,1], ...]
,其中输入维度为 2(不确定这是标记的好方法)。
标记每个整数以使其属于不同组的好方法是什么?
TensorFlow 支持嵌套元组作为 rnn 输入,参见 doc。但是,您需要编写自己的单元格 Class 来处理这种特定类型的输入。在这种情况下,它应该是这样的:
# Define your own cell which accept (integer, bool) input
class YourCell(tf.contrib.rnn.RNNCell):
# override relevant functions of base interface: RNNCell
# state_size, output_size, etc.
# The main body of computation logic goes in this function
def __call__(self, inputs, state, scope=None):
# note inputs variable contains inputs of only one time step
# for example, inputs = (5, True)
interger, boolean = inputs
# your computation
integer_input = [5, 2, 3, 4]
bool_input = [True, False, False, True]
inputs = [integer_input, bool_input]
cell = YourCell()
outputs = tf.nn.dynamic_rnn(inputs, cell)