Pytorch 的 gather、sequeeze 和 unsqueeze 到 Tensorflow Keras
Pytorch's gather, sequeeze and unsqueeze to Tensorflow Keras
我正在将代码从 pytorch 迁移到 tensorflow,在计算损失的函数中,我有下面一行需要迁移到 tensorflow。
state_action_values = net(t_states_features).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
我找到了 tf.gather
和 tf.gather_nd
,但我不确定哪个更合适以及如何使用,另外 unsqueeze 的替代方案可能是 tf.expand_dims
?
为了更清楚地了解该行的结果,我使用 print 语句将其分成多个部分。
print("net result")
state_action_values = net(t_states_features)
print(state_action_values)
print("gather result")
state_action_values = state_action_values.gather(1, actions_v.unsqueeze(-1))
print(state_action_values)
print("last squeeze")
state_action_values = state_action_values.squeeze(-1)
net result
tensor([[ 45.6878, -14.9495, 59.3737],
[ 33.5737, -10.4617, 39.0078],
[ 67.7197, -22.8818, 85.7977],
[ 94.7701, -33.2053, 120.5519],
[ nan, nan, nan],
[ 84.7324, -29.2101, 108.0821],
[ 67.7193, -22.7702, 86.9558],
[113.6835, -38.7149, 142.6167],
[ 61.9260, -20.1968, 79.8010],
[ 51.6152, -17.7391, 66.0719],
[ 73.6565, -21.5699, 98.9463],
[ 84.0761, -26.5016, 107.6888],
[ 60.9459, -20.1257, 76.4105],
[103.2883, -35.4035, 130.4503],
[ 37.1156, -13.5180, 47.1067],
[ nan, nan, nan],
[ 55.6286, -18.5239, 71.9837],
[ 55.3858, -18.7892, 71.1197],
[ 50.2419, -17.2959, 66.7059],
[ 82.5715, -30.0302, 108.4984],
[ -0.8662, -1.1861, 1.6033],
[112.4620, -38.6416, 142.4556],
[ 57.8702, -19.8080, 74.7656],
[ 45.8418, -15.7436, 57.3367],
[ 81.6596, -27.5002, 104.6002],
[ 57.1507, -21.8001, 67.7933],
[ 35.0414, -11.8199, 47.6573],
[ 67.7085, -23.1017, 85.4623],
[ 40.6284, -12.4578, 58.9603],
[ 68.6394, -23.1481, 87.0832],
[ 27.0549, -8.6635, 34.0150],
[ 25.4071, -8.5511, 34.0285],
[ 62.9161, -22.1693, 78.7965],
[ 85.4505, -28.1487, 108.6252],
[ 67.6665, -23.2376, 85.7117],
[ 60.7806, -20.2784, 77.1022],
[ 66.5209, -21.5674, 88.5561],
[ 61.6637, -20.9891, 72.3873],
[ 45.1634, -15.4678, 61.4886],
[ 66.8119, -23.1250, 85.6189],
[ nan, nan, nan],
[ 67.8166, -24.8342, 84.6706],
[ 86.2114, -29.5941, 107.8025],
[ 66.2716, -23.3309, 83.9700],
[101.2122, -35.3554, 127.4772],
[ 61.0749, -19.4720, 78.5588],
[ 50.4058, -16.1262, 63.1010],
[ 27.7543, -9.3767, 35.7448],
[ 67.7810, -23.4962, 83.6030],
[ 35.0103, -11.7238, 44.7983],
[ 55.7402, -19.0223, 70.3627],
[ 67.9733, -22.0783, 85.1893],
[ 60.5253, -20.3157, 79.7312],
[ 67.2404, -21.5205, 81.4499],
[ 57.9502, -20.7747, 70.9109],
[ 87.6536, -31.4256, 112.6491],
[ 90.3668, -30.7755, 116.6192],
[ 59.0660, -19.6988, 75.0723],
[ 50.0969, -17.4135, 62.6556],
[ 28.8703, -9.0950, 34.5749],
[ 68.4053, -22.0715, 88.2302],
[ 69.1397, -21.4236, 84.7833],
[ 23.8506, -8.1834, 30.8318],
[ 58.4296, -20.2432, 73.8116],
[ 87.5317, -29.0606, 110.0389],
[ nan, nan, nan],
[ 88.6387, -30.6154, 112.4239],
[ 51.6089, -16.1073, 66.2757],
[ 94.3989, -32.1473, 119.0358],
[ 82.7449, -30.7778, 102.8537],
[ 74.3067, -26.6585, 98.2536],
[ 77.0881, -26.5706, 98.3553],
[ 28.5688, -9.2949, 41.1165],
[ 86.1560, -26.9364, 107.0244],
[ 41.8914, -16.9703, 57.3840],
[ 88.8886, -29.7008, 108.2697],
[ 61.1243, -20.7566, 77.2257],
[ 85.1174, -28.7558, 107.3853],
[ 81.7256, -27.9047, 104.5006],
[ 51.2663, -16.5880, 67.1428],
[ 46.9150, -12.7457, 61.3240],
[ 36.1758, -12.9769, 47.7178],
[ 85.5846, -29.4141, 107.9649],
[ 59.9424, -20.8349, 75.3359],
[ 62.6516, -22.1235, 81.6903],
[104.7664, -34.5876, 129.9478],
[ 64.4671, -23.3980, 83.9093],
[ 69.6928, -23.6567, 89.6024],
[ 60.4407, -19.6136, 75.9350],
[ 33.4921, -10.3434, 44.9537],
[ 57.9112, -19.4174, 74.3050],
[ 24.8262, -9.3637, 30.1057],
[ 85.3776, -28.9097, 110.1310],
[ 63.8175, -22.3843, 81.0308],
[ 34.6040, -12.3217, 46.0356],
[ 88.3740, -29.5049, 110.2897],
[ 66.8196, -22.5860, 85.5386],
[ 58.9767, -22.0601, 78.7086],
[ 83.2090, -26.3499, 113.5105],
[ 54.8450, -17.7980, 68.1161],
[ nan, nan, nan],
[ 85.0846, -29.2494, 107.6780],
[ 76.9251, -26.2295, 98.4755],
[ 98.2907, -32.8878, 124.9192],
[ 91.1387, -30.8262, 115.3978],
[ 73.1062, -24.9450, 90.0967],
[ 27.6564, -8.6114, 35.4470],
[ 71.8508, -25.1529, 95.5165],
[ 69.7275, -20.1357, 86.9620],
[ 67.0907, -21.9245, 84.8853],
[ 77.3163, -25.5980, 92.7700],
[ 63.0082, -21.0345, 78.7311],
[ 68.0553, -22.4280, 84.8031],
[ 5.8148, -2.3171, 8.0620],
[103.3399, -35.1769, 130.7801],
[ 54.8769, -18.6822, 70.4657],
[ 58.4446, -18.9764, 75.5509],
[ 91.0071, -31.2706, 112.6401],
[ 84.6577, -29.2644, 104.6046],
[ 45.4887, -15.8309, 59.0498],
[ 56.3384, -18.9264, 78.8834],
[ 63.5109, -21.3169, 81.5144],
[ 79.4635, -29.8681, 100.5056],
[ 27.6559, -10.0517, 35.6012],
[ 76.3909, -24.1689, 93.6133],
[ 34.3802, -11.5272, 45.8650],
[ 60.3553, -20.1693, 76.5371],
[ 56.0590, -18.6468, 69.8981]], grad_fn=<AddmmBackward0>)
gather result
tensor([[ 59.3737],
[-10.4617],
[ 67.7197],
[ 94.7701],
[ nan],
[-29.2101],
[ 67.7193],
[-38.7149],
[-20.1968],
[ 66.0719],
[ 98.9463],
[107.6888],
[-20.1257],
[-35.4035],
[ 47.1067],
[ nan],
[ 55.6286],
[-18.7892],
[ 66.7059],
[-30.0302],
[ 1.6033],
[112.4620],
[ 74.7656],
[-15.7436],
[ 81.6596],
[-21.8001],
[ 35.0414],
[-23.1017],
[ 40.6284],
[ 68.6394],
[ 34.0150],
[ 34.0285],
[ 78.7965],
[-28.1487],
[ 67.6665],
[-20.2784],
[-21.5674],
[ 72.3873],
[-15.4678],
[ 85.6189],
[ nan],
[-24.8342],
[-29.5941],
[-23.3309],
[101.2122],
[-19.4720],
[-16.1262],
[ -9.3767],
[-23.4962],
[-11.7238],
[ 70.3627],
[-22.0783],
[-20.3157],
[ 67.2404],
[-20.7747],
[112.6491],
[-30.7755],
[-19.6988],
[ 50.0969],
[ 34.5749],
[ 88.2302],
[-21.4236],
[ -8.1834],
[ 73.8116],
[110.0389],
[ nan],
[112.4239],
[-16.1073],
[-32.1473],
[-30.7778],
[ 98.2536],
[ 98.3553],
[ 28.5688],
[107.0244],
[-16.9703],
[-29.7008],
[ 77.2257],
[-28.7558],
[-27.9047],
[ 67.1428],
[-12.7457],
[ 47.7178],
[-29.4141],
[ 59.9424],
[-22.1235],
[129.9478],
[-23.3980],
[-23.6567],
[ 75.9350],
[-10.3434],
[-19.4174],
[ 30.1057],
[ 85.3776],
[ 63.8175],
[ 46.0356],
[-29.5049],
[-22.5860],
[-22.0601],
[113.5105],
[-17.7980],
[ nan],
[-29.2494],
[ 76.9251],
[-32.8878],
[115.3978],
[-24.9450],
[ 35.4470],
[ 95.5165],
[ 86.9620],
[-21.9245],
[-25.5980],
[ 78.7311],
[-22.4280],
[ 5.8148],
[103.3399],
[ 70.4657],
[ 58.4446],
[ 91.0071],
[104.6046],
[ 45.4887],
[-18.9264],
[ 63.5109],
[ 79.4635],
[-10.0517],
[ 76.3909],
[ 34.3802],
[-20.1693],
[-18.6468]], grad_fn=<GatherBackward0>)
last squeeze
tensor([ 59.3737, -10.4617, 67.7197, 94.7701, nan, -29.2101, 67.7193,
-38.7149, -20.1968, 66.0719, 98.9463, 107.6888, -20.1257, -35.4035,
47.1067, nan, 55.6286, -18.7892, 66.7059, -30.0302, 1.6033,
112.4620, 74.7656, -15.7436, 81.6596, -21.8001, 35.0414, -23.1017,
40.6284, 68.6394, 34.0150, 34.0285, 78.7965, -28.1487, 67.6665,
-20.2784, -21.5674, 72.3873, -15.4678, 85.6189, nan, -24.8342,
-29.5941, -23.3309, 101.2122, -19.4720, -16.1262, -9.3767, -23.4962,
-11.7238, 70.3627, -22.0783, -20.3157, 67.2404, -20.7747, 112.6491,
-30.7755, -19.6988, 50.0969, 34.5749, 88.2302, -21.4236, -8.1834,
73.8116, 110.0389, nan, 112.4239, -16.1073, -32.1473, -30.7778,
98.2536, 98.3553, 28.5688, 107.0244, -16.9703, -29.7008, 77.2257,
-28.7558, -27.9047, 67.1428, -12.7457, 47.7178, -29.4141, 59.9424,
-22.1235, 129.9478, -23.3980, -23.6567, 75.9350, -10.3434, -19.4174,
30.1057, 85.3776, 63.8175, 46.0356, -29.5049, -22.5860, -22.0601,
113.5105, -17.7980, nan, -29.2494, 76.9251, -32.8878, 115.3978,
-24.9450, 35.4470, 95.5165, 86.9620, -21.9245, -25.5980, 78.7311,
-22.4280, 5.8148, 103.3399, 70.4657, 58.4446, 91.0071, 104.6046,
45.4887, -18.9264, 63.5109, 79.4635, -10.0517, 76.3909, 34.3802,
-20.1693, -18.6468], grad_fn=<SqueezeBackward1>)
编辑 1:打印 actions_v
actions_v
tensor([2, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 2, 0, 2, 0, 1,
2, 0, 2, 1, 1, 0, 2, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1,
0, 2, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1, 2, 0, 0, 1, 1, 2, 0, 0, 2, 0, 0,
1, 1, 2, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 2, 0, 2, 0, 1, 1, 2, 1, 2, 2,
2, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2, 1, 1, 0, 1, 0, 1, 2,
2, 1, 0, 2, 0, 0, 2, 1])
gather_nd
采用与输入张量具有相同维度的输入,并将输出具有这些索引值的张量(这就是您想要的)。
gather
将输出切片(但你可以随心所欲地给出索引形状,输出张量将只是一堆根据索引形状构造的切片)这不是你想要的想要。
所以你应该首先使索引与初始矩阵的维度相匹配:
indices = tf.transpose(tf.stack((tf.range(tf.shape(state_action_values)[0]),actions_v)))
然后gather_nd
:
state_action_values = tf.gather_nd(state_action_values,indices)
凯文
我正在将代码从 pytorch 迁移到 tensorflow,在计算损失的函数中,我有下面一行需要迁移到 tensorflow。
state_action_values = net(t_states_features).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
我找到了 tf.gather
和 tf.gather_nd
,但我不确定哪个更合适以及如何使用,另外 unsqueeze 的替代方案可能是 tf.expand_dims
?
为了更清楚地了解该行的结果,我使用 print 语句将其分成多个部分。
print("net result")
state_action_values = net(t_states_features)
print(state_action_values)
print("gather result")
state_action_values = state_action_values.gather(1, actions_v.unsqueeze(-1))
print(state_action_values)
print("last squeeze")
state_action_values = state_action_values.squeeze(-1)
net result
tensor([[ 45.6878, -14.9495, 59.3737],
[ 33.5737, -10.4617, 39.0078],
[ 67.7197, -22.8818, 85.7977],
[ 94.7701, -33.2053, 120.5519],
[ nan, nan, nan],
[ 84.7324, -29.2101, 108.0821],
[ 67.7193, -22.7702, 86.9558],
[113.6835, -38.7149, 142.6167],
[ 61.9260, -20.1968, 79.8010],
[ 51.6152, -17.7391, 66.0719],
[ 73.6565, -21.5699, 98.9463],
[ 84.0761, -26.5016, 107.6888],
[ 60.9459, -20.1257, 76.4105],
[103.2883, -35.4035, 130.4503],
[ 37.1156, -13.5180, 47.1067],
[ nan, nan, nan],
[ 55.6286, -18.5239, 71.9837],
[ 55.3858, -18.7892, 71.1197],
[ 50.2419, -17.2959, 66.7059],
[ 82.5715, -30.0302, 108.4984],
[ -0.8662, -1.1861, 1.6033],
[112.4620, -38.6416, 142.4556],
[ 57.8702, -19.8080, 74.7656],
[ 45.8418, -15.7436, 57.3367],
[ 81.6596, -27.5002, 104.6002],
[ 57.1507, -21.8001, 67.7933],
[ 35.0414, -11.8199, 47.6573],
[ 67.7085, -23.1017, 85.4623],
[ 40.6284, -12.4578, 58.9603],
[ 68.6394, -23.1481, 87.0832],
[ 27.0549, -8.6635, 34.0150],
[ 25.4071, -8.5511, 34.0285],
[ 62.9161, -22.1693, 78.7965],
[ 85.4505, -28.1487, 108.6252],
[ 67.6665, -23.2376, 85.7117],
[ 60.7806, -20.2784, 77.1022],
[ 66.5209, -21.5674, 88.5561],
[ 61.6637, -20.9891, 72.3873],
[ 45.1634, -15.4678, 61.4886],
[ 66.8119, -23.1250, 85.6189],
[ nan, nan, nan],
[ 67.8166, -24.8342, 84.6706],
[ 86.2114, -29.5941, 107.8025],
[ 66.2716, -23.3309, 83.9700],
[101.2122, -35.3554, 127.4772],
[ 61.0749, -19.4720, 78.5588],
[ 50.4058, -16.1262, 63.1010],
[ 27.7543, -9.3767, 35.7448],
[ 67.7810, -23.4962, 83.6030],
[ 35.0103, -11.7238, 44.7983],
[ 55.7402, -19.0223, 70.3627],
[ 67.9733, -22.0783, 85.1893],
[ 60.5253, -20.3157, 79.7312],
[ 67.2404, -21.5205, 81.4499],
[ 57.9502, -20.7747, 70.9109],
[ 87.6536, -31.4256, 112.6491],
[ 90.3668, -30.7755, 116.6192],
[ 59.0660, -19.6988, 75.0723],
[ 50.0969, -17.4135, 62.6556],
[ 28.8703, -9.0950, 34.5749],
[ 68.4053, -22.0715, 88.2302],
[ 69.1397, -21.4236, 84.7833],
[ 23.8506, -8.1834, 30.8318],
[ 58.4296, -20.2432, 73.8116],
[ 87.5317, -29.0606, 110.0389],
[ nan, nan, nan],
[ 88.6387, -30.6154, 112.4239],
[ 51.6089, -16.1073, 66.2757],
[ 94.3989, -32.1473, 119.0358],
[ 82.7449, -30.7778, 102.8537],
[ 74.3067, -26.6585, 98.2536],
[ 77.0881, -26.5706, 98.3553],
[ 28.5688, -9.2949, 41.1165],
[ 86.1560, -26.9364, 107.0244],
[ 41.8914, -16.9703, 57.3840],
[ 88.8886, -29.7008, 108.2697],
[ 61.1243, -20.7566, 77.2257],
[ 85.1174, -28.7558, 107.3853],
[ 81.7256, -27.9047, 104.5006],
[ 51.2663, -16.5880, 67.1428],
[ 46.9150, -12.7457, 61.3240],
[ 36.1758, -12.9769, 47.7178],
[ 85.5846, -29.4141, 107.9649],
[ 59.9424, -20.8349, 75.3359],
[ 62.6516, -22.1235, 81.6903],
[104.7664, -34.5876, 129.9478],
[ 64.4671, -23.3980, 83.9093],
[ 69.6928, -23.6567, 89.6024],
[ 60.4407, -19.6136, 75.9350],
[ 33.4921, -10.3434, 44.9537],
[ 57.9112, -19.4174, 74.3050],
[ 24.8262, -9.3637, 30.1057],
[ 85.3776, -28.9097, 110.1310],
[ 63.8175, -22.3843, 81.0308],
[ 34.6040, -12.3217, 46.0356],
[ 88.3740, -29.5049, 110.2897],
[ 66.8196, -22.5860, 85.5386],
[ 58.9767, -22.0601, 78.7086],
[ 83.2090, -26.3499, 113.5105],
[ 54.8450, -17.7980, 68.1161],
[ nan, nan, nan],
[ 85.0846, -29.2494, 107.6780],
[ 76.9251, -26.2295, 98.4755],
[ 98.2907, -32.8878, 124.9192],
[ 91.1387, -30.8262, 115.3978],
[ 73.1062, -24.9450, 90.0967],
[ 27.6564, -8.6114, 35.4470],
[ 71.8508, -25.1529, 95.5165],
[ 69.7275, -20.1357, 86.9620],
[ 67.0907, -21.9245, 84.8853],
[ 77.3163, -25.5980, 92.7700],
[ 63.0082, -21.0345, 78.7311],
[ 68.0553, -22.4280, 84.8031],
[ 5.8148, -2.3171, 8.0620],
[103.3399, -35.1769, 130.7801],
[ 54.8769, -18.6822, 70.4657],
[ 58.4446, -18.9764, 75.5509],
[ 91.0071, -31.2706, 112.6401],
[ 84.6577, -29.2644, 104.6046],
[ 45.4887, -15.8309, 59.0498],
[ 56.3384, -18.9264, 78.8834],
[ 63.5109, -21.3169, 81.5144],
[ 79.4635, -29.8681, 100.5056],
[ 27.6559, -10.0517, 35.6012],
[ 76.3909, -24.1689, 93.6133],
[ 34.3802, -11.5272, 45.8650],
[ 60.3553, -20.1693, 76.5371],
[ 56.0590, -18.6468, 69.8981]], grad_fn=<AddmmBackward0>)
gather result
tensor([[ 59.3737],
[-10.4617],
[ 67.7197],
[ 94.7701],
[ nan],
[-29.2101],
[ 67.7193],
[-38.7149],
[-20.1968],
[ 66.0719],
[ 98.9463],
[107.6888],
[-20.1257],
[-35.4035],
[ 47.1067],
[ nan],
[ 55.6286],
[-18.7892],
[ 66.7059],
[-30.0302],
[ 1.6033],
[112.4620],
[ 74.7656],
[-15.7436],
[ 81.6596],
[-21.8001],
[ 35.0414],
[-23.1017],
[ 40.6284],
[ 68.6394],
[ 34.0150],
[ 34.0285],
[ 78.7965],
[-28.1487],
[ 67.6665],
[-20.2784],
[-21.5674],
[ 72.3873],
[-15.4678],
[ 85.6189],
[ nan],
[-24.8342],
[-29.5941],
[-23.3309],
[101.2122],
[-19.4720],
[-16.1262],
[ -9.3767],
[-23.4962],
[-11.7238],
[ 70.3627],
[-22.0783],
[-20.3157],
[ 67.2404],
[-20.7747],
[112.6491],
[-30.7755],
[-19.6988],
[ 50.0969],
[ 34.5749],
[ 88.2302],
[-21.4236],
[ -8.1834],
[ 73.8116],
[110.0389],
[ nan],
[112.4239],
[-16.1073],
[-32.1473],
[-30.7778],
[ 98.2536],
[ 98.3553],
[ 28.5688],
[107.0244],
[-16.9703],
[-29.7008],
[ 77.2257],
[-28.7558],
[-27.9047],
[ 67.1428],
[-12.7457],
[ 47.7178],
[-29.4141],
[ 59.9424],
[-22.1235],
[129.9478],
[-23.3980],
[-23.6567],
[ 75.9350],
[-10.3434],
[-19.4174],
[ 30.1057],
[ 85.3776],
[ 63.8175],
[ 46.0356],
[-29.5049],
[-22.5860],
[-22.0601],
[113.5105],
[-17.7980],
[ nan],
[-29.2494],
[ 76.9251],
[-32.8878],
[115.3978],
[-24.9450],
[ 35.4470],
[ 95.5165],
[ 86.9620],
[-21.9245],
[-25.5980],
[ 78.7311],
[-22.4280],
[ 5.8148],
[103.3399],
[ 70.4657],
[ 58.4446],
[ 91.0071],
[104.6046],
[ 45.4887],
[-18.9264],
[ 63.5109],
[ 79.4635],
[-10.0517],
[ 76.3909],
[ 34.3802],
[-20.1693],
[-18.6468]], grad_fn=<GatherBackward0>)
last squeeze
tensor([ 59.3737, -10.4617, 67.7197, 94.7701, nan, -29.2101, 67.7193,
-38.7149, -20.1968, 66.0719, 98.9463, 107.6888, -20.1257, -35.4035,
47.1067, nan, 55.6286, -18.7892, 66.7059, -30.0302, 1.6033,
112.4620, 74.7656, -15.7436, 81.6596, -21.8001, 35.0414, -23.1017,
40.6284, 68.6394, 34.0150, 34.0285, 78.7965, -28.1487, 67.6665,
-20.2784, -21.5674, 72.3873, -15.4678, 85.6189, nan, -24.8342,
-29.5941, -23.3309, 101.2122, -19.4720, -16.1262, -9.3767, -23.4962,
-11.7238, 70.3627, -22.0783, -20.3157, 67.2404, -20.7747, 112.6491,
-30.7755, -19.6988, 50.0969, 34.5749, 88.2302, -21.4236, -8.1834,
73.8116, 110.0389, nan, 112.4239, -16.1073, -32.1473, -30.7778,
98.2536, 98.3553, 28.5688, 107.0244, -16.9703, -29.7008, 77.2257,
-28.7558, -27.9047, 67.1428, -12.7457, 47.7178, -29.4141, 59.9424,
-22.1235, 129.9478, -23.3980, -23.6567, 75.9350, -10.3434, -19.4174,
30.1057, 85.3776, 63.8175, 46.0356, -29.5049, -22.5860, -22.0601,
113.5105, -17.7980, nan, -29.2494, 76.9251, -32.8878, 115.3978,
-24.9450, 35.4470, 95.5165, 86.9620, -21.9245, -25.5980, 78.7311,
-22.4280, 5.8148, 103.3399, 70.4657, 58.4446, 91.0071, 104.6046,
45.4887, -18.9264, 63.5109, 79.4635, -10.0517, 76.3909, 34.3802,
-20.1693, -18.6468], grad_fn=<SqueezeBackward1>)
编辑 1:打印 actions_v
actions_v
tensor([2, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 2, 0, 2, 0, 1,
2, 0, 2, 1, 1, 0, 2, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1,
0, 2, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1, 2, 0, 0, 1, 1, 2, 0, 0, 2, 0, 0,
1, 1, 2, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 2, 0, 2, 0, 1, 1, 2, 1, 2, 2,
2, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2, 1, 1, 0, 1, 0, 1, 2,
2, 1, 0, 2, 0, 0, 2, 1])
gather_nd
采用与输入张量具有相同维度的输入,并将输出具有这些索引值的张量(这就是您想要的)。
gather
将输出切片(但你可以随心所欲地给出索引形状,输出张量将只是一堆根据索引形状构造的切片)这不是你想要的想要。
所以你应该首先使索引与初始矩阵的维度相匹配:
indices = tf.transpose(tf.stack((tf.range(tf.shape(state_action_values)[0]),actions_v)))
然后gather_nd
:
state_action_values = tf.gather_nd(state_action_values,indices)
凯文