将 python 中的 4 个列表与列表理解相交

intersect 4 lists in python with list comprehension

有没有办法处理以下列表理解: 我有 4 个列表:t1、x1(t1)、t2、x2(t2)。 t1 和 t2 的长度不同,x1 和 x2 的长度也不同。我想将 t1 和 t2 相交的 x1 和 x2 的值相加,对于那些 t2 和 t1 不相交的值,只需将它们的值附加到两个新列表 x 和 t 中即可。

t1 = [0.0, 20.0, 40.0, 50.0, 60.0, 80.0]
x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
t2 = [40.0, 50.0, 80.0]
x2 = [7.0, 8.0, 9.0]

所以我的新 t 和 x 是:

t = [0.0, 20.0, 40.0, 50.0, 60.0, 80.0]
x = [1.0, 2.0, 10.0, 12.0, 5.0, 15.0]

这是带有双循环的代码。虽然到目前为止它只是加起来。仍然需要为不相交的 t1、t2、x1 和 x2 附加这些值:

x = [] 
t = []
for y in range(len(t1)):
    for z in range(len(t2)):
        if t1[y] == t2[z]:
           t.append(t1[y])
           x.append(x1[y] + x2[z])

到目前为止,如果 x1 和 x2 对应的 t 值匹配,则这适用于求和它们的值:

t2 = [0.0, 20.0, 40.0, 50.0, 60.0, 80.0]
x2 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
t1 = [40.0, 50.0, 80.0]
x1 = [7.0, 8.0, 9.0]

t = [i for i in t1 for j in t2 if i==j]
x = [sum(i) for i in [(x1[i], x2[j]) for i, k in enumerate(t1) for j, l in enumerate(t2) if t1[i]==t2[j]]]
print(t)
print(x)

输出:

t = [40.0, 50.0, 80.0]
x = [10.0, 12.0, 15.0]

I have 4 lists: t1, x1(t1), t2, x2(t2). t1 and t2 have different length, and so do x1 and x2. I want to add up the values of x1 and x2 in which t1 and t2 intersect and for those values where t2 and t1 don't intersect just append their values in two new lists x and t.

t1 = [0.0, 20.0, 40.0, 50.0, 60.0, 80.0]
x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
t2 = [40.0, 50.0, 80.0]
x2 = [7.0, 8.0, 9.0]

首先,为您拥有 "key" 和 "value" 对的数据创建字典,而不是试图破坏周围的公共索引:

data1 = dict(zip(t1, x1))
data2 = dict(zip(t2, x2))

现在,通过组合它们的键,从这两个中创建一个字典:

data3 = {key: data1.get(key, 0) + data2.get(key, 0)  for key in set(list(data1.keys()) + list(data2.keys()))}

还有你的数据。 如果你真的需要那么列表:

t = data3.keys()
x = data3.values()
t1 = [0.0, 20.0, 40.0, 50.0, 60.0, 80.0]
x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
t2 = [40.0, 50.0, 80.0]
x2 = [7.0, 8.0, 9.0]

t_product = list(itertools.product(t1,t2))
intersec_indexes = [(t1.index(x),t2.index(y)) for x,y in t_product if x==y]
longer_x = x1 if len(x1)>len(x2) else x2

result_x = []
for index, i in enumerate(longer_x):
    index_x1 = intersec_indexes[0][0]
    index_x2 = intersec_indexes[0][1] 
    if index == index_x1:
        result_x.append(x1[index_x1] + x2[index_x2])
        intersec_indexes.pop(0)
    else:
        result_x.append(x1[index])

result_t = []
result_t = list(set(t1+t2))
result_t.sort()

print(result_t)
print(result_x)

结果:

[0.0, 20.0, 40.0, 50.0, 60.0, 80.0]
[1.0, 2.0, 10.0, 12.0, 5.0, 15.0]

您可以使用 heapq.merge 和 "zip" 一起计算时间,然后 itertools.groupby 找到巧合。这些都是线性复杂度运算,因此应该可以很好地扩展:

import heapq, itertools, operator
t, x = zip(*((k, sum(map(operator.itemgetter(1), v))) for k, v in itertools.groupby(heapq.merge(zip(t1, x1), zip(t2, x2)), operator.itemgetter(0))))
t
# (0.0, 20.0, 40.0, 50.0, 60.0, 80.0)
x
# (1.0, 2.0, 10.0, 12.0, 5.0, 15.0)

一步一步:

merged = heapq.merge(zip(t1, x1), zip(t2, x2))
# make list for printing
merged = list(merged)
merged
# [(0.0, 1.0), (20.0, 2.0), (40.0, 3.0), (40.0, 7.0), (50.0, 4.0), (50.0, 8.0), (60.0, 5.0), (80.0, 6.0), (80.0, 9.0)]
grouped = itertools.groupby(merged, operator.itemgetter(0))
# make printable
grouped = [(k, list(v)) for k, v in grouped]
grouped
# [(0.0, [(0.0, 1.0)]), (20.0, [(20.0, 2.0)]), (40.0, [(40.0, 3.0), (40.0, 7.0)]), (50.0, [(50.0, 4.0), (50.0, 8.0)]), (60.0, [(60.0, 5.0)]), (80.0, [(80.0, 6.0), (80.0, 9.0)])]
t, x = zip(*((k, sum(map(operator.itemgetter(1), v))) for k, v in grouped))