Chainer 中 NStepBiLSTM 中 xs 的类型是什么?
What is the type of xs in NStepBiLSTM in Chainer?
NStepBiLSTM 的手册说它的前向函数期望 xs 以变量包装序列列表的格式出现。但是我得到一个错误,暗示 xs 应该是一个 np 数组。我错过了什么?
我使用此函数将输入数组转换为数组形状 (n,1) 的变量(数组)列表。
def cut_data(data, batchsize):
q = data.shape[0] // batchsize
data = data[:q*batchsize]
data = data.reshape((batchsize, q))
xs = []
for i in range(q):
a = data[:,i].reshape(batchsize,1)
xs.append(Variable(a))
return xs
但是当我用这样的 xs 调用我的预测器时,我得到这个错误:
<ipython-input-27-cb554613ad71> in __call__(self, xs, ts)
10 batchlen = len(xs)
11 loss = F.sum(F.mean_squared_error(
---> 12 self.predictor(xs), ts, reduce='no')) / batch
13
14 chainer.report({'loss': loss}, self)
<ipython-input-28-ce9434e91153> in __call__(self, x)
12 def __call__(self, x):
13 self.h, self.c, y = self.lstm(self.h,self.c,x)
---> 14 output = self.out(y)
15 return output
16
~\Anaconda2\lib\site-packages\chainer\links\connection\linear.py in __call__(self, x)
127 in_size = functools.reduce(operator.mul, x.shape[1:], 1)
128 self._initialize_params(in_size)
--> 129 return linear.linear(x, self.W, self.b)
~\Anaconda2\lib\site-packages\chainer\functions\connection\linear.py in linear(x, W, b)
165
166 """
--> 167 if x.ndim > 2:
168 x = x.reshape(len(x), -1)
169
AttributeError: 'list' object has no attribute 'ndim'
这是我的简单网络:
class LSTM_RNN(Chain):
def __init__(self, n_hidden, n_input=1, n_out=1):
super(LSTM_RNN, self).__init__()
with self.init_scope():
self.lstm = L.NStepBiLSTM(n_layers=n_hidden, in_size=n_input, out_size=n_out, dropout=0.5)
self.out = L.Linear(n_hidden, n_out)
self.h = None
self.c = None
def __call__(self, x):
self.h, self.c, y = self.lstm(self.h,self.c,x)
output = self.out(y)
return output
def reset_state(self):
self.h = None
self.c = None
错误信息
---> 14 output = self.out(y)
表示你的错误是由self.out()
方法触发的,而不是self.lstm()
根据 the official API reference、L.NStepBiLSTM.run()
returns hy
、cy
和 ys
的元组,其中 ys
是一个列表。
您的代码
def __call__(self, x):
self.h, self.c, y = self.lstm(self.h,self.c,x)
output = self.out(y)
return output
表示y
(在官方文档中称为ys
)直接传递给self.out
,即L.Linear.__call__
。这导致 type-mismatch.
一般来说,ys
中的y
的形状各不相同,因为xs
中的x
可以是不同长度的序列
如果您需要更多帮助,欢迎随时提问!
NStepBiLSTM 的手册说它的前向函数期望 xs 以变量包装序列列表的格式出现。但是我得到一个错误,暗示 xs 应该是一个 np 数组。我错过了什么?
我使用此函数将输入数组转换为数组形状 (n,1) 的变量(数组)列表。
def cut_data(data, batchsize):
q = data.shape[0] // batchsize
data = data[:q*batchsize]
data = data.reshape((batchsize, q))
xs = []
for i in range(q):
a = data[:,i].reshape(batchsize,1)
xs.append(Variable(a))
return xs
但是当我用这样的 xs 调用我的预测器时,我得到这个错误:
<ipython-input-27-cb554613ad71> in __call__(self, xs, ts)
10 batchlen = len(xs)
11 loss = F.sum(F.mean_squared_error(
---> 12 self.predictor(xs), ts, reduce='no')) / batch
13
14 chainer.report({'loss': loss}, self)
<ipython-input-28-ce9434e91153> in __call__(self, x)
12 def __call__(self, x):
13 self.h, self.c, y = self.lstm(self.h,self.c,x)
---> 14 output = self.out(y)
15 return output
16
~\Anaconda2\lib\site-packages\chainer\links\connection\linear.py in __call__(self, x)
127 in_size = functools.reduce(operator.mul, x.shape[1:], 1)
128 self._initialize_params(in_size)
--> 129 return linear.linear(x, self.W, self.b)
~\Anaconda2\lib\site-packages\chainer\functions\connection\linear.py in linear(x, W, b)
165
166 """
--> 167 if x.ndim > 2:
168 x = x.reshape(len(x), -1)
169
AttributeError: 'list' object has no attribute 'ndim'
这是我的简单网络:
class LSTM_RNN(Chain):
def __init__(self, n_hidden, n_input=1, n_out=1):
super(LSTM_RNN, self).__init__()
with self.init_scope():
self.lstm = L.NStepBiLSTM(n_layers=n_hidden, in_size=n_input, out_size=n_out, dropout=0.5)
self.out = L.Linear(n_hidden, n_out)
self.h = None
self.c = None
def __call__(self, x):
self.h, self.c, y = self.lstm(self.h,self.c,x)
output = self.out(y)
return output
def reset_state(self):
self.h = None
self.c = None
错误信息
---> 14 output = self.out(y)
表示你的错误是由self.out()
方法触发的,而不是self.lstm()
根据 the official API reference、L.NStepBiLSTM.run()
returns hy
、cy
和 ys
的元组,其中 ys
是一个列表。
您的代码
def __call__(self, x):
self.h, self.c, y = self.lstm(self.h,self.c,x)
output = self.out(y)
return output
表示y
(在官方文档中称为ys
)直接传递给self.out
,即L.Linear.__call__
。这导致 type-mismatch.
一般来说,ys
中的y
的形状各不相同,因为xs
中的x
可以是不同长度的序列
如果您需要更多帮助,欢迎随时提问!