如何将嵌套数组输入 SVM 模型
How to feed a nested array into an SVM model
我的问题如下:
我有一个数组,其中包含对应于多个音频文件的特征向量。因此,例如,如果有 10 个音频文件,则此数组的长度为 10。
我有一个特征本身就是一个列表(这个列表包含音频文件特定特征的信息),对于给定的音频文件,特征向量如下所示:
array([0.03861840871664194, 187.72393405210002, 62.59881268743305,
0.2911392405063291,
array([4963.40332031, 3229.98046875, 2691.65039062, 3208.44726562,
4338.94042969, 4220.5078125 , 4166.67480469, 4801.90429688,
5555.56640625, 5910.86425781, 6115.4296875 , 5706.29882812,
4984.93652344, 2756.25 , 1991.82128906, 2551.68457031,
2734.71679688, 2906.98242188, 3143.84765625, 3219.21386719,
3186.9140625 , 3165.38085938, 3068.48144531, 2465.55175781,
2110.25390625, 2508.61816406, 2993.11523438, 3843.67675781,
4715.77148438, 5652.46582031, 5480.20019531, 5792.43164062,
5932.39746094, 6244.62890625, 6072.36328125, 6201.5625 ,
6158.49609375, 6201.5625 , 6233.86230469, 6061.59667969])],
dtype=object)
现在,当我尝试将此数据输入 svm 模型时:
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
X_train, X_val, y_train, y_val = train_test_split(X,y,test_size=0.3)
model = svm.SVC()
model.fit(X_train,y_train)
yt_p = model.predict(X_train)
yv_p = model.predict(X_val)
我收到这个错误ValueError: setting an array element with a sequence.
如何构造我的特征向量以便能够将其提供给支持向量机?
编辑:
这里我以X为例
如果我们有 5 个音频文件,那么 X 将是:
array([[0.017455393927437918, 227.66237105624407, 32.42076654734572,
0.3867924528301887,
array([1851.85546875, 2433.25195312, 3057.71484375, 3079.24804688,
3079.24804688, 3068.48144531, 3046.94824219, 3359.1796875 ,
3908.27636719, 4618.87207031, 4618.87207031, 4521.97265625,
4091.30859375, 3111.54785156, 3100.78125 , 2863.91601562,
1561.15722656, 1119.7265625 , 1065.89355469, 947.4609375 ,
979.76074219, 990.52734375, 990.52734375, 1356.59179688,
2077.95410156, 2993.11523438, 3025.41503906, 3068.48144531,
3079.24804688, 3090.01464844, 3100.78125 , 3111.54785156,
2993.11523438, 3100.78125 , 3079.24804688, 2853.14941406,
1205.859375 , 1281.22558594, 1614.99023438, 2131.78710938,
2325.5859375 , 2034.88769531, 1916.45507812, 1744.18945312,
1851.85546875, 2357.88574219, 2368.65234375, 1916.45507812,
1959.52148438, 1959.52148438, 1754.95605469, 1787.25585938,
2207.15332031])],
[0.03861840871664194, 187.72393405210002, 62.59881268743305,
0.2911392405063291,
array([4963.40332031, 3229.98046875, 2691.65039062, 3208.44726562,
4338.94042969, 4220.5078125 , 4166.67480469, 4801.90429688,
5555.56640625, 5910.86425781, 6115.4296875 , 5706.29882812,
4984.93652344, 2756.25 , 1991.82128906, 2551.68457031,
2734.71679688, 2906.98242188, 3143.84765625, 3219.21386719,
3186.9140625 , 3165.38085938, 3068.48144531, 2465.55175781,
2110.25390625, 2508.61816406, 2993.11523438, 3843.67675781,
4715.77148438, 5652.46582031, 5480.20019531, 5792.43164062,
5932.39746094, 6244.62890625, 6072.36328125, 6201.5625 ,
6158.49609375, 6201.5625 , 6233.86230469, 6061.59667969])],
[0.042435441297643324, 128.81225073038124, 20.912528554426807,
0.313953488372093,
array([4349.70703125, 4242.04101562, 4274.34082031, 4123.60839844,
4457.37304688, 4834.20410156, 4661.93847656, 4306.640625 ,
4231.27441406, 4543.50585938, 4435.83984375, 6201.5625 ,
8817.84667969, 8817.84667969, 742.89550781, 721.36230469,
732.12890625, 732.12890625, 710.59570312, 721.36230469,
925.92773438, 1119.7265625 , 1141.25976562, 1431.95800781,
7762.71972656, 7934.98535156, 7891.91894531, 7332.05566406,
3789.84375 , 2799.31640625, 2831.61621094, 2217.91992188,
581.39648438, 602.9296875 , 2217.91992188, 2228.68652344,
2368.65234375, 2519.38476562, 2863.91601562, 3682.17773438,
3649.87792969, 4188.20800781, 4112.84179688])],
[0.006295381642571726, 130.28309914454434, 5.193614287487564,
0.2411764705882353,
array([7978.05175781, 8010.3515625 , 8118.01757812, 8430.24902344,
8257.98339844, 8451.78222656, 8591.74804688, 8677.88085938,
8796.31347656, 8850.14648438, 8796.31347656, 8925.51269531,
6244.62890625, 344.53125 , 344.53125 , 1614.99023438,
2325.5859375 , 2971.58203125, 3316.11328125, 3617.578125 ,
3294.58007812, 2788.54980469, 2637.81738281, 2702.41699219,
2723.95019531, 3133.08105469, 3413.01269531, 5663.23242188,
5770.8984375 , 5577.09960938, 2228.68652344, 1604.22363281,
1690.35644531, 4123.60839844, 5566.33300781, 5803.19824219,
5749.36523438, 5846.26464844, 6772.19238281, 7073.65722656,
7622.75390625, 7859.61914062, 8236.45019531, 8441.015625 ,
8699.4140625 , 8807.08007812, 8742.48046875, 8667.11425781,
8710.18066406, 8947.04589844, 9140.84472656, 9130.078125 ,
8936.27929688, 8925.51269531, 8947.04589844, 8925.51269531,
9097.77832031, 9205.44433594, 9194.67773438, 9140.84472656,
9162.37792969, 9043.9453125 , 9162.37792969, 9108.54492188,
9183.91113281, 9280.81054688, 9270.04394531, 9108.54492188,
9076.24511719, 9356.17675781, 9226.97753906, 9216.2109375 ,
9248.51074219, 9140.84472656, 9237.74414062, 9334.64355469,
9259.27734375, 9226.97753906, 9216.2109375 , 9108.54492188,
9183.91113281, 9216.2109375 , 9248.51074219, 9259.27734375,
9183.91113281])],
[0.017070271599460656, 171.91660927761163, 26.854424936811768,
0.11188811188811189,
array([4715.77148438, 4629.63867188, 4898.80371094, 5275.63476562,
4941.87011719, 4532.73925781, 4618.87207031, 4995.703125 ,
4705.00488281, 4500.43945312, 4188.20800781, 4371.24023438,
4457.37304688, 4188.20800781, 4909.5703125 , 4877.27050781,
6761.42578125, 7708.88671875, 7719.65332031, 7956.51855469,
8484.08203125, 9033.17871094, 9043.9453125 , 9000.87890625,
9011.64550781, 9011.64550781, 9000.87890625, 9108.54492188,
8817.84667969, 6686.05957031, 1808.7890625 , 1830.32226562,
1851.85546875, 1636.5234375 , 1022.82714844, 1281.22558594,
1927.22167969, 1948.75488281, 1302.75878906, 1399.65820312,
1873.38867188, 1959.52148438, 7245.92285156, 9011.64550781,
9420.77636719, 9549.97558594, 9453.07617188, 9431.54296875,
9410.00976562, 9248.51074219, 9151.61132812, 9194.67773438,
8968.57910156, 8634.81445312, 8268.75 , 7439.72167969,
5501.73339844, 5232.56835938, 5103.36914062, 7052.12402344,
7299.75585938, 7127.49023438, 7192.08984375, 5673.99902344,
5523.26660156, 5986.23046875, 6729.12597656, 6309.22851562,
5135.66894531, 5081.8359375 , 5329.46777344, 5404.83398438])]],
dtype=object)
您可以通过两种方式将包含列表的特征提供给您的模型:
- 将列表视为附加功能
- 使用您认为合适的函数(最小值、中值、平均值、最大值、求和等)将其所有元素映射到一个数字中。
要尝试第一个选项:
# Convert `X` to data frame
X = pd.DataFrame(X)
# Rename columns
X.columns = ['feature_' + str(i + 1) for i in range(X.shape[1])]
# Convert the feature with lists inside to long format
x = X['feature_5'].explode().to_frame()
# Create counter by observation so we can pivot
x['observation_id'] = x.groupby(level=0).cumcount()
# Convert to dataset and rename all columns
x = x.pivot(columns='observation_id', values='feature_5').fillna(0)
x = x.add_prefix('list_element_')
# Drop `feature_5` from X
X.drop(columns='feature_5', axis=1, inplace=True)
# Concatenate X and x together
X = pd.concat([X, x], axis=1)
# Carry on as before
X_train, X_val, y_train, y_val = train_test_split(X,y,test_size=0.3)
model = svm.SVC()
model.fit(X_train,y_train)
第二个选项没有正确答案,只有您可以决定如何做,因为只有您知道列表的含义。但是,如果您想获取每个列表的平均值(例如)并将其用作特征:
# Get the mean of each list
means = [np.mean(array) for array in X[:, 4]]
# Replace the lists with `means`
X[:, 4] = means
然后进行拆分和拟合
我的问题如下:
我有一个数组,其中包含对应于多个音频文件的特征向量。因此,例如,如果有 10 个音频文件,则此数组的长度为 10。
我有一个特征本身就是一个列表(这个列表包含音频文件特定特征的信息),对于给定的音频文件,特征向量如下所示:
array([0.03861840871664194, 187.72393405210002, 62.59881268743305,
0.2911392405063291,
array([4963.40332031, 3229.98046875, 2691.65039062, 3208.44726562,
4338.94042969, 4220.5078125 , 4166.67480469, 4801.90429688,
5555.56640625, 5910.86425781, 6115.4296875 , 5706.29882812,
4984.93652344, 2756.25 , 1991.82128906, 2551.68457031,
2734.71679688, 2906.98242188, 3143.84765625, 3219.21386719,
3186.9140625 , 3165.38085938, 3068.48144531, 2465.55175781,
2110.25390625, 2508.61816406, 2993.11523438, 3843.67675781,
4715.77148438, 5652.46582031, 5480.20019531, 5792.43164062,
5932.39746094, 6244.62890625, 6072.36328125, 6201.5625 ,
6158.49609375, 6201.5625 , 6233.86230469, 6061.59667969])],
dtype=object)
现在,当我尝试将此数据输入 svm 模型时:
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
X_train, X_val, y_train, y_val = train_test_split(X,y,test_size=0.3)
model = svm.SVC()
model.fit(X_train,y_train)
yt_p = model.predict(X_train)
yv_p = model.predict(X_val)
我收到这个错误ValueError: setting an array element with a sequence.
如何构造我的特征向量以便能够将其提供给支持向量机?
编辑:
这里我以X为例
如果我们有 5 个音频文件,那么 X 将是:
array([[0.017455393927437918, 227.66237105624407, 32.42076654734572,
0.3867924528301887,
array([1851.85546875, 2433.25195312, 3057.71484375, 3079.24804688,
3079.24804688, 3068.48144531, 3046.94824219, 3359.1796875 ,
3908.27636719, 4618.87207031, 4618.87207031, 4521.97265625,
4091.30859375, 3111.54785156, 3100.78125 , 2863.91601562,
1561.15722656, 1119.7265625 , 1065.89355469, 947.4609375 ,
979.76074219, 990.52734375, 990.52734375, 1356.59179688,
2077.95410156, 2993.11523438, 3025.41503906, 3068.48144531,
3079.24804688, 3090.01464844, 3100.78125 , 3111.54785156,
2993.11523438, 3100.78125 , 3079.24804688, 2853.14941406,
1205.859375 , 1281.22558594, 1614.99023438, 2131.78710938,
2325.5859375 , 2034.88769531, 1916.45507812, 1744.18945312,
1851.85546875, 2357.88574219, 2368.65234375, 1916.45507812,
1959.52148438, 1959.52148438, 1754.95605469, 1787.25585938,
2207.15332031])],
[0.03861840871664194, 187.72393405210002, 62.59881268743305,
0.2911392405063291,
array([4963.40332031, 3229.98046875, 2691.65039062, 3208.44726562,
4338.94042969, 4220.5078125 , 4166.67480469, 4801.90429688,
5555.56640625, 5910.86425781, 6115.4296875 , 5706.29882812,
4984.93652344, 2756.25 , 1991.82128906, 2551.68457031,
2734.71679688, 2906.98242188, 3143.84765625, 3219.21386719,
3186.9140625 , 3165.38085938, 3068.48144531, 2465.55175781,
2110.25390625, 2508.61816406, 2993.11523438, 3843.67675781,
4715.77148438, 5652.46582031, 5480.20019531, 5792.43164062,
5932.39746094, 6244.62890625, 6072.36328125, 6201.5625 ,
6158.49609375, 6201.5625 , 6233.86230469, 6061.59667969])],
[0.042435441297643324, 128.81225073038124, 20.912528554426807,
0.313953488372093,
array([4349.70703125, 4242.04101562, 4274.34082031, 4123.60839844,
4457.37304688, 4834.20410156, 4661.93847656, 4306.640625 ,
4231.27441406, 4543.50585938, 4435.83984375, 6201.5625 ,
8817.84667969, 8817.84667969, 742.89550781, 721.36230469,
732.12890625, 732.12890625, 710.59570312, 721.36230469,
925.92773438, 1119.7265625 , 1141.25976562, 1431.95800781,
7762.71972656, 7934.98535156, 7891.91894531, 7332.05566406,
3789.84375 , 2799.31640625, 2831.61621094, 2217.91992188,
581.39648438, 602.9296875 , 2217.91992188, 2228.68652344,
2368.65234375, 2519.38476562, 2863.91601562, 3682.17773438,
3649.87792969, 4188.20800781, 4112.84179688])],
[0.006295381642571726, 130.28309914454434, 5.193614287487564,
0.2411764705882353,
array([7978.05175781, 8010.3515625 , 8118.01757812, 8430.24902344,
8257.98339844, 8451.78222656, 8591.74804688, 8677.88085938,
8796.31347656, 8850.14648438, 8796.31347656, 8925.51269531,
6244.62890625, 344.53125 , 344.53125 , 1614.99023438,
2325.5859375 , 2971.58203125, 3316.11328125, 3617.578125 ,
3294.58007812, 2788.54980469, 2637.81738281, 2702.41699219,
2723.95019531, 3133.08105469, 3413.01269531, 5663.23242188,
5770.8984375 , 5577.09960938, 2228.68652344, 1604.22363281,
1690.35644531, 4123.60839844, 5566.33300781, 5803.19824219,
5749.36523438, 5846.26464844, 6772.19238281, 7073.65722656,
7622.75390625, 7859.61914062, 8236.45019531, 8441.015625 ,
8699.4140625 , 8807.08007812, 8742.48046875, 8667.11425781,
8710.18066406, 8947.04589844, 9140.84472656, 9130.078125 ,
8936.27929688, 8925.51269531, 8947.04589844, 8925.51269531,
9097.77832031, 9205.44433594, 9194.67773438, 9140.84472656,
9162.37792969, 9043.9453125 , 9162.37792969, 9108.54492188,
9183.91113281, 9280.81054688, 9270.04394531, 9108.54492188,
9076.24511719, 9356.17675781, 9226.97753906, 9216.2109375 ,
9248.51074219, 9140.84472656, 9237.74414062, 9334.64355469,
9259.27734375, 9226.97753906, 9216.2109375 , 9108.54492188,
9183.91113281, 9216.2109375 , 9248.51074219, 9259.27734375,
9183.91113281])],
[0.017070271599460656, 171.91660927761163, 26.854424936811768,
0.11188811188811189,
array([4715.77148438, 4629.63867188, 4898.80371094, 5275.63476562,
4941.87011719, 4532.73925781, 4618.87207031, 4995.703125 ,
4705.00488281, 4500.43945312, 4188.20800781, 4371.24023438,
4457.37304688, 4188.20800781, 4909.5703125 , 4877.27050781,
6761.42578125, 7708.88671875, 7719.65332031, 7956.51855469,
8484.08203125, 9033.17871094, 9043.9453125 , 9000.87890625,
9011.64550781, 9011.64550781, 9000.87890625, 9108.54492188,
8817.84667969, 6686.05957031, 1808.7890625 , 1830.32226562,
1851.85546875, 1636.5234375 , 1022.82714844, 1281.22558594,
1927.22167969, 1948.75488281, 1302.75878906, 1399.65820312,
1873.38867188, 1959.52148438, 7245.92285156, 9011.64550781,
9420.77636719, 9549.97558594, 9453.07617188, 9431.54296875,
9410.00976562, 9248.51074219, 9151.61132812, 9194.67773438,
8968.57910156, 8634.81445312, 8268.75 , 7439.72167969,
5501.73339844, 5232.56835938, 5103.36914062, 7052.12402344,
7299.75585938, 7127.49023438, 7192.08984375, 5673.99902344,
5523.26660156, 5986.23046875, 6729.12597656, 6309.22851562,
5135.66894531, 5081.8359375 , 5329.46777344, 5404.83398438])]],
dtype=object)
您可以通过两种方式将包含列表的特征提供给您的模型:
- 将列表视为附加功能
- 使用您认为合适的函数(最小值、中值、平均值、最大值、求和等)将其所有元素映射到一个数字中。
要尝试第一个选项:
# Convert `X` to data frame
X = pd.DataFrame(X)
# Rename columns
X.columns = ['feature_' + str(i + 1) for i in range(X.shape[1])]
# Convert the feature with lists inside to long format
x = X['feature_5'].explode().to_frame()
# Create counter by observation so we can pivot
x['observation_id'] = x.groupby(level=0).cumcount()
# Convert to dataset and rename all columns
x = x.pivot(columns='observation_id', values='feature_5').fillna(0)
x = x.add_prefix('list_element_')
# Drop `feature_5` from X
X.drop(columns='feature_5', axis=1, inplace=True)
# Concatenate X and x together
X = pd.concat([X, x], axis=1)
# Carry on as before
X_train, X_val, y_train, y_val = train_test_split(X,y,test_size=0.3)
model = svm.SVC()
model.fit(X_train,y_train)
第二个选项没有正确答案,只有您可以决定如何做,因为只有您知道列表的含义。但是,如果您想获取每个列表的平均值(例如)并将其用作特征:
# Get the mean of each list
means = [np.mean(array) for array in X[:, 4]]
# Replace the lists with `means`
X[:, 4] = means
然后进行拆分和拟合