Pytorch 的 gather、sequeeze 和 unsqueeze 到 Tensorflow Keras

Pytorch's gather, sequeeze and unsqueeze to Tensorflow Keras

我正在将代码从 pytorch 迁移到 tensorflow,在计算损失的函数中,我有下面一行需要迁移到 tensorflow。

state_action_values = net(t_states_features).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)

我找到了 tf.gathertf.gather_nd,但我不确定哪个更合适以及如何使用,另外 unsqueeze 的替代方案可能是 tf.expand_dims?

为了更清楚地了解该行的结果,我使用 print 语句将其分成多个部分。

  print("net result")
  state_action_values = net(t_states_features)
  print(state_action_values)
  print("gather result")
  state_action_values = state_action_values.gather(1, actions_v.unsqueeze(-1))
  print(state_action_values)
  print("last squeeze")
  state_action_values = state_action_values.squeeze(-1)
net result
tensor([[ 45.6878, -14.9495,  59.3737],
        [ 33.5737, -10.4617,  39.0078],
        [ 67.7197, -22.8818,  85.7977],
        [ 94.7701, -33.2053, 120.5519],
        [     nan,      nan,      nan],
        [ 84.7324, -29.2101, 108.0821],
        [ 67.7193, -22.7702,  86.9558],
        [113.6835, -38.7149, 142.6167],
        [ 61.9260, -20.1968,  79.8010],
        [ 51.6152, -17.7391,  66.0719],
        [ 73.6565, -21.5699,  98.9463],
        [ 84.0761, -26.5016, 107.6888],
        [ 60.9459, -20.1257,  76.4105],
        [103.2883, -35.4035, 130.4503],
        [ 37.1156, -13.5180,  47.1067],
        [     nan,      nan,      nan],
        [ 55.6286, -18.5239,  71.9837],
        [ 55.3858, -18.7892,  71.1197],
        [ 50.2419, -17.2959,  66.7059],
        [ 82.5715, -30.0302, 108.4984],
        [ -0.8662,  -1.1861,   1.6033],
        [112.4620, -38.6416, 142.4556],
        [ 57.8702, -19.8080,  74.7656],
        [ 45.8418, -15.7436,  57.3367],
        [ 81.6596, -27.5002, 104.6002],
        [ 57.1507, -21.8001,  67.7933],
        [ 35.0414, -11.8199,  47.6573],
        [ 67.7085, -23.1017,  85.4623],
        [ 40.6284, -12.4578,  58.9603],
        [ 68.6394, -23.1481,  87.0832],
        [ 27.0549,  -8.6635,  34.0150],
        [ 25.4071,  -8.5511,  34.0285],
        [ 62.9161, -22.1693,  78.7965],
        [ 85.4505, -28.1487, 108.6252],
        [ 67.6665, -23.2376,  85.7117],
        [ 60.7806, -20.2784,  77.1022],
        [ 66.5209, -21.5674,  88.5561],
        [ 61.6637, -20.9891,  72.3873],
        [ 45.1634, -15.4678,  61.4886],
        [ 66.8119, -23.1250,  85.6189],
        [     nan,      nan,      nan],
        [ 67.8166, -24.8342,  84.6706],
        [ 86.2114, -29.5941, 107.8025],
        [ 66.2716, -23.3309,  83.9700],
        [101.2122, -35.3554, 127.4772],
        [ 61.0749, -19.4720,  78.5588],
        [ 50.4058, -16.1262,  63.1010],
        [ 27.7543,  -9.3767,  35.7448],
        [ 67.7810, -23.4962,  83.6030],
        [ 35.0103, -11.7238,  44.7983],
        [ 55.7402, -19.0223,  70.3627],
        [ 67.9733, -22.0783,  85.1893],
        [ 60.5253, -20.3157,  79.7312],
        [ 67.2404, -21.5205,  81.4499],
        [ 57.9502, -20.7747,  70.9109],
        [ 87.6536, -31.4256, 112.6491],
        [ 90.3668, -30.7755, 116.6192],
        [ 59.0660, -19.6988,  75.0723],
        [ 50.0969, -17.4135,  62.6556],
        [ 28.8703,  -9.0950,  34.5749],
        [ 68.4053, -22.0715,  88.2302],
        [ 69.1397, -21.4236,  84.7833],
        [ 23.8506,  -8.1834,  30.8318],
        [ 58.4296, -20.2432,  73.8116],
        [ 87.5317, -29.0606, 110.0389],
        [     nan,      nan,      nan],
        [ 88.6387, -30.6154, 112.4239],
        [ 51.6089, -16.1073,  66.2757],
        [ 94.3989, -32.1473, 119.0358],
        [ 82.7449, -30.7778, 102.8537],
        [ 74.3067, -26.6585,  98.2536],
        [ 77.0881, -26.5706,  98.3553],
        [ 28.5688,  -9.2949,  41.1165],
        [ 86.1560, -26.9364, 107.0244],
        [ 41.8914, -16.9703,  57.3840],
        [ 88.8886, -29.7008, 108.2697],
        [ 61.1243, -20.7566,  77.2257],
        [ 85.1174, -28.7558, 107.3853],
        [ 81.7256, -27.9047, 104.5006],
        [ 51.2663, -16.5880,  67.1428],
        [ 46.9150, -12.7457,  61.3240],
        [ 36.1758, -12.9769,  47.7178],
        [ 85.5846, -29.4141, 107.9649],
        [ 59.9424, -20.8349,  75.3359],
        [ 62.6516, -22.1235,  81.6903],
        [104.7664, -34.5876, 129.9478],
        [ 64.4671, -23.3980,  83.9093],
        [ 69.6928, -23.6567,  89.6024],
        [ 60.4407, -19.6136,  75.9350],
        [ 33.4921, -10.3434,  44.9537],
        [ 57.9112, -19.4174,  74.3050],
        [ 24.8262,  -9.3637,  30.1057],
        [ 85.3776, -28.9097, 110.1310],
        [ 63.8175, -22.3843,  81.0308],
        [ 34.6040, -12.3217,  46.0356],
        [ 88.3740, -29.5049, 110.2897],
        [ 66.8196, -22.5860,  85.5386],
        [ 58.9767, -22.0601,  78.7086],
        [ 83.2090, -26.3499, 113.5105],
        [ 54.8450, -17.7980,  68.1161],
        [     nan,      nan,      nan],
        [ 85.0846, -29.2494, 107.6780],
        [ 76.9251, -26.2295,  98.4755],
        [ 98.2907, -32.8878, 124.9192],
        [ 91.1387, -30.8262, 115.3978],
        [ 73.1062, -24.9450,  90.0967],
        [ 27.6564,  -8.6114,  35.4470],
        [ 71.8508, -25.1529,  95.5165],
        [ 69.7275, -20.1357,  86.9620],
        [ 67.0907, -21.9245,  84.8853],
        [ 77.3163, -25.5980,  92.7700],
        [ 63.0082, -21.0345,  78.7311],
        [ 68.0553, -22.4280,  84.8031],
        [  5.8148,  -2.3171,   8.0620],
        [103.3399, -35.1769, 130.7801],
        [ 54.8769, -18.6822,  70.4657],
        [ 58.4446, -18.9764,  75.5509],
        [ 91.0071, -31.2706, 112.6401],
        [ 84.6577, -29.2644, 104.6046],
        [ 45.4887, -15.8309,  59.0498],
        [ 56.3384, -18.9264,  78.8834],
        [ 63.5109, -21.3169,  81.5144],
        [ 79.4635, -29.8681, 100.5056],
        [ 27.6559, -10.0517,  35.6012],
        [ 76.3909, -24.1689,  93.6133],
        [ 34.3802, -11.5272,  45.8650],
        [ 60.3553, -20.1693,  76.5371],
        [ 56.0590, -18.6468,  69.8981]], grad_fn=<AddmmBackward0>)
gather result
tensor([[ 59.3737],
        [-10.4617],
        [ 67.7197],
        [ 94.7701],
        [     nan],
        [-29.2101],
        [ 67.7193],
        [-38.7149],
        [-20.1968],
        [ 66.0719],
        [ 98.9463],
        [107.6888],
        [-20.1257],
        [-35.4035],
        [ 47.1067],
        [     nan],
        [ 55.6286],
        [-18.7892],
        [ 66.7059],
        [-30.0302],
        [  1.6033],
        [112.4620],
        [ 74.7656],
        [-15.7436],
        [ 81.6596],
        [-21.8001],
        [ 35.0414],
        [-23.1017],
        [ 40.6284],
        [ 68.6394],
        [ 34.0150],
        [ 34.0285],
        [ 78.7965],
        [-28.1487],
        [ 67.6665],
        [-20.2784],
        [-21.5674],
        [ 72.3873],
        [-15.4678],
        [ 85.6189],
        [     nan],
        [-24.8342],
        [-29.5941],
        [-23.3309],
        [101.2122],
        [-19.4720],
        [-16.1262],
        [ -9.3767],
        [-23.4962],
        [-11.7238],
        [ 70.3627],
        [-22.0783],
        [-20.3157],
        [ 67.2404],
        [-20.7747],
        [112.6491],
        [-30.7755],
        [-19.6988],
        [ 50.0969],
        [ 34.5749],
        [ 88.2302],
        [-21.4236],
        [ -8.1834],
        [ 73.8116],
        [110.0389],
        [     nan],
        [112.4239],
        [-16.1073],
        [-32.1473],
        [-30.7778],
        [ 98.2536],
        [ 98.3553],
        [ 28.5688],
        [107.0244],
        [-16.9703],
        [-29.7008],
        [ 77.2257],
        [-28.7558],
        [-27.9047],
        [ 67.1428],
        [-12.7457],
        [ 47.7178],
        [-29.4141],
        [ 59.9424],
        [-22.1235],
        [129.9478],
        [-23.3980],
        [-23.6567],
        [ 75.9350],
        [-10.3434],
        [-19.4174],
        [ 30.1057],
        [ 85.3776],
        [ 63.8175],
        [ 46.0356],
        [-29.5049],
        [-22.5860],
        [-22.0601],
        [113.5105],
        [-17.7980],
        [     nan],
        [-29.2494],
        [ 76.9251],
        [-32.8878],
        [115.3978],
        [-24.9450],
        [ 35.4470],
        [ 95.5165],
        [ 86.9620],
        [-21.9245],
        [-25.5980],
        [ 78.7311],
        [-22.4280],
        [  5.8148],
        [103.3399],
        [ 70.4657],
        [ 58.4446],
        [ 91.0071],
        [104.6046],
        [ 45.4887],
        [-18.9264],
        [ 63.5109],
        [ 79.4635],
        [-10.0517],
        [ 76.3909],
        [ 34.3802],
        [-20.1693],
        [-18.6468]], grad_fn=<GatherBackward0>)
last squeeze
tensor([ 59.3737, -10.4617,  67.7197,  94.7701,      nan, -29.2101,  67.7193,
        -38.7149, -20.1968,  66.0719,  98.9463, 107.6888, -20.1257, -35.4035,
         47.1067,      nan,  55.6286, -18.7892,  66.7059, -30.0302,   1.6033,
        112.4620,  74.7656, -15.7436,  81.6596, -21.8001,  35.0414, -23.1017,
         40.6284,  68.6394,  34.0150,  34.0285,  78.7965, -28.1487,  67.6665,
        -20.2784, -21.5674,  72.3873, -15.4678,  85.6189,      nan, -24.8342,
        -29.5941, -23.3309, 101.2122, -19.4720, -16.1262,  -9.3767, -23.4962,
        -11.7238,  70.3627, -22.0783, -20.3157,  67.2404, -20.7747, 112.6491,
        -30.7755, -19.6988,  50.0969,  34.5749,  88.2302, -21.4236,  -8.1834,
         73.8116, 110.0389,      nan, 112.4239, -16.1073, -32.1473, -30.7778,
         98.2536,  98.3553,  28.5688, 107.0244, -16.9703, -29.7008,  77.2257,
        -28.7558, -27.9047,  67.1428, -12.7457,  47.7178, -29.4141,  59.9424,
        -22.1235, 129.9478, -23.3980, -23.6567,  75.9350, -10.3434, -19.4174,
         30.1057,  85.3776,  63.8175,  46.0356, -29.5049, -22.5860, -22.0601,
        113.5105, -17.7980,      nan, -29.2494,  76.9251, -32.8878, 115.3978,
        -24.9450,  35.4470,  95.5165,  86.9620, -21.9245, -25.5980,  78.7311,
        -22.4280,   5.8148, 103.3399,  70.4657,  58.4446,  91.0071, 104.6046,
         45.4887, -18.9264,  63.5109,  79.4635, -10.0517,  76.3909,  34.3802,
        -20.1693, -18.6468], grad_fn=<SqueezeBackward1>)

编辑 1:打印 actions_v

actions_v
tensor([2, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 2, 0, 2, 0, 1,
        2, 0, 2, 1, 1, 0, 2, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1,
        0, 2, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1, 2, 0, 0, 1, 1, 2, 0, 0, 2, 0, 0,
        1, 1, 2, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 2, 0, 2, 0, 1, 1, 2, 1, 2, 2,
        2, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2, 1, 1, 0, 1, 0, 1, 2,
        2, 1, 0, 2, 0, 0, 2, 1])

gather_nd 采用与输入张量具有相同维度的输入,并将输出具有这些索引值的张量(这就是您想要的)。

gather 将输出切片(但你可以随心所欲地给出索引形状,输出张量将只是一堆根据索引形状构造的切片)这不是你想要的想要。

所以你应该首先使索引与初始矩阵的维度相匹配:

indices = tf.transpose(tf.stack((tf.range(tf.shape(state_action_values)[0]),actions_v)))

然后gather_nd

state_action_values  = tf.gather_nd(state_action_values,indices)

凯文