如何将嵌套数组输入 SVM 模型

How to feed a nested array into an SVM model

我的问题如下:

我有一个数组,其中包含对应于多个音频文件的特征向量。因此,例如,如果有 10 个音频文件,则此数组的长度为 10。

我有一个特征本身就是一个列表(这个列表包含音频文件特定特征的信息),对于给定的音频文件,特征向量如下所示:

array([0.03861840871664194, 187.72393405210002, 62.59881268743305,
       0.2911392405063291,
       array([4963.40332031, 3229.98046875, 2691.65039062, 3208.44726562,
       4338.94042969, 4220.5078125 , 4166.67480469, 4801.90429688,
       5555.56640625, 5910.86425781, 6115.4296875 , 5706.29882812,
       4984.93652344, 2756.25      , 1991.82128906, 2551.68457031,
       2734.71679688, 2906.98242188, 3143.84765625, 3219.21386719,
       3186.9140625 , 3165.38085938, 3068.48144531, 2465.55175781,
       2110.25390625, 2508.61816406, 2993.11523438, 3843.67675781,
       4715.77148438, 5652.46582031, 5480.20019531, 5792.43164062,
       5932.39746094, 6244.62890625, 6072.36328125, 6201.5625    ,
       6158.49609375, 6201.5625    , 6233.86230469, 6061.59667969])],
      dtype=object)

现在,当我尝试将此数据输入 svm 模型时:

from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt


X_train, X_val, y_train, y_val = train_test_split(X,y,test_size=0.3)

model  = svm.SVC()
model.fit(X_train,y_train)

yt_p = model.predict(X_train)
yv_p = model.predict(X_val)

我收到这个错误ValueError: setting an array element with a sequence.

如何构造我的特征向量以便能够将其提供给支持向量机?

编辑:

这里我以X为例

如果我们有 5 个音频文件,那么 X 将是:

array([[0.017455393927437918, 227.66237105624407, 32.42076654734572,
        0.3867924528301887,
        array([1851.85546875, 2433.25195312, 3057.71484375, 3079.24804688,
       3079.24804688, 3068.48144531, 3046.94824219, 3359.1796875 ,
       3908.27636719, 4618.87207031, 4618.87207031, 4521.97265625,
       4091.30859375, 3111.54785156, 3100.78125   , 2863.91601562,
       1561.15722656, 1119.7265625 , 1065.89355469,  947.4609375 ,
        979.76074219,  990.52734375,  990.52734375, 1356.59179688,
       2077.95410156, 2993.11523438, 3025.41503906, 3068.48144531,
       3079.24804688, 3090.01464844, 3100.78125   , 3111.54785156,
       2993.11523438, 3100.78125   , 3079.24804688, 2853.14941406,
       1205.859375  , 1281.22558594, 1614.99023438, 2131.78710938,
       2325.5859375 , 2034.88769531, 1916.45507812, 1744.18945312,
       1851.85546875, 2357.88574219, 2368.65234375, 1916.45507812,
       1959.52148438, 1959.52148438, 1754.95605469, 1787.25585938,
       2207.15332031])],
       [0.03861840871664194, 187.72393405210002, 62.59881268743305,
        0.2911392405063291,
        array([4963.40332031, 3229.98046875, 2691.65039062, 3208.44726562,
       4338.94042969, 4220.5078125 , 4166.67480469, 4801.90429688,
       5555.56640625, 5910.86425781, 6115.4296875 , 5706.29882812,
       4984.93652344, 2756.25      , 1991.82128906, 2551.68457031,
       2734.71679688, 2906.98242188, 3143.84765625, 3219.21386719,
       3186.9140625 , 3165.38085938, 3068.48144531, 2465.55175781,
       2110.25390625, 2508.61816406, 2993.11523438, 3843.67675781,
       4715.77148438, 5652.46582031, 5480.20019531, 5792.43164062,
       5932.39746094, 6244.62890625, 6072.36328125, 6201.5625    ,
       6158.49609375, 6201.5625    , 6233.86230469, 6061.59667969])],
       [0.042435441297643324, 128.81225073038124, 20.912528554426807,
        0.313953488372093,
        array([4349.70703125, 4242.04101562, 4274.34082031, 4123.60839844,
       4457.37304688, 4834.20410156, 4661.93847656, 4306.640625  ,
       4231.27441406, 4543.50585938, 4435.83984375, 6201.5625    ,
       8817.84667969, 8817.84667969,  742.89550781,  721.36230469,
        732.12890625,  732.12890625,  710.59570312,  721.36230469,
        925.92773438, 1119.7265625 , 1141.25976562, 1431.95800781,
       7762.71972656, 7934.98535156, 7891.91894531, 7332.05566406,
       3789.84375   , 2799.31640625, 2831.61621094, 2217.91992188,
        581.39648438,  602.9296875 , 2217.91992188, 2228.68652344,
       2368.65234375, 2519.38476562, 2863.91601562, 3682.17773438,
       3649.87792969, 4188.20800781, 4112.84179688])],
       [0.006295381642571726, 130.28309914454434, 5.193614287487564,
        0.2411764705882353,
        array([7978.05175781, 8010.3515625 , 8118.01757812, 8430.24902344,
       8257.98339844, 8451.78222656, 8591.74804688, 8677.88085938,
       8796.31347656, 8850.14648438, 8796.31347656, 8925.51269531,
       6244.62890625,  344.53125   ,  344.53125   , 1614.99023438,
       2325.5859375 , 2971.58203125, 3316.11328125, 3617.578125  ,
       3294.58007812, 2788.54980469, 2637.81738281, 2702.41699219,
       2723.95019531, 3133.08105469, 3413.01269531, 5663.23242188,
       5770.8984375 , 5577.09960938, 2228.68652344, 1604.22363281,
       1690.35644531, 4123.60839844, 5566.33300781, 5803.19824219,
       5749.36523438, 5846.26464844, 6772.19238281, 7073.65722656,
       7622.75390625, 7859.61914062, 8236.45019531, 8441.015625  ,
       8699.4140625 , 8807.08007812, 8742.48046875, 8667.11425781,
       8710.18066406, 8947.04589844, 9140.84472656, 9130.078125  ,
       8936.27929688, 8925.51269531, 8947.04589844, 8925.51269531,
       9097.77832031, 9205.44433594, 9194.67773438, 9140.84472656,
       9162.37792969, 9043.9453125 , 9162.37792969, 9108.54492188,
       9183.91113281, 9280.81054688, 9270.04394531, 9108.54492188,
       9076.24511719, 9356.17675781, 9226.97753906, 9216.2109375 ,
       9248.51074219, 9140.84472656, 9237.74414062, 9334.64355469,
       9259.27734375, 9226.97753906, 9216.2109375 , 9108.54492188,
       9183.91113281, 9216.2109375 , 9248.51074219, 9259.27734375,
       9183.91113281])],
       [0.017070271599460656, 171.91660927761163, 26.854424936811768,
        0.11188811188811189,
        array([4715.77148438, 4629.63867188, 4898.80371094, 5275.63476562,
       4941.87011719, 4532.73925781, 4618.87207031, 4995.703125  ,
       4705.00488281, 4500.43945312, 4188.20800781, 4371.24023438,
       4457.37304688, 4188.20800781, 4909.5703125 , 4877.27050781,
       6761.42578125, 7708.88671875, 7719.65332031, 7956.51855469,
       8484.08203125, 9033.17871094, 9043.9453125 , 9000.87890625,
       9011.64550781, 9011.64550781, 9000.87890625, 9108.54492188,
       8817.84667969, 6686.05957031, 1808.7890625 , 1830.32226562,
       1851.85546875, 1636.5234375 , 1022.82714844, 1281.22558594,
       1927.22167969, 1948.75488281, 1302.75878906, 1399.65820312,
       1873.38867188, 1959.52148438, 7245.92285156, 9011.64550781,
       9420.77636719, 9549.97558594, 9453.07617188, 9431.54296875,
       9410.00976562, 9248.51074219, 9151.61132812, 9194.67773438,
       8968.57910156, 8634.81445312, 8268.75      , 7439.72167969,
       5501.73339844, 5232.56835938, 5103.36914062, 7052.12402344,
       7299.75585938, 7127.49023438, 7192.08984375, 5673.99902344,
       5523.26660156, 5986.23046875, 6729.12597656, 6309.22851562,
       5135.66894531, 5081.8359375 , 5329.46777344, 5404.83398438])]],
      dtype=object)

您可以通过两种方式将包含列表的特征提供给您的模型:

  1. 将列表视为附加功能
  2. 使用您认为合适的函数(最小值、中值、平均值、最大值、求和等)将其所有元素映射到一个数字中。

要尝试第一个选项:

# Convert `X` to data frame
X = pd.DataFrame(X)

# Rename columns
X.columns = ['feature_' + str(i + 1) for i in range(X.shape[1])]

# Convert the feature with lists inside to long format
x = X['feature_5'].explode().to_frame()

# Create counter by observation so we can pivot
x['observation_id'] = x.groupby(level=0).cumcount()

# Convert to dataset and rename all columns
x = x.pivot(columns='observation_id', values='feature_5').fillna(0)
x = x.add_prefix('list_element_')

# Drop `feature_5` from X
X.drop(columns='feature_5', axis=1, inplace=True)

# Concatenate X and x together
X = pd.concat([X, x], axis=1)

# Carry on as before
X_train, X_val, y_train, y_val = train_test_split(X,y,test_size=0.3)
model  = svm.SVC()
model.fit(X_train,y_train)

第二个选项没有正确答案,只有您可以决定如何做,因为只有您知道列表的含义。但是,如果您想获取每个列表的平均值(例如)并将其用作特征:

# Get the mean of each list
means = [np.mean(array) for array in X[:, 4]]

# Replace the lists with `means`
X[:, 4] = means

然后进行拆分和拟合