二维数组的二进制搜索算法 Python

Binary search algorithm for 2D array Python

我学会了用一维数组的方式进行二分查找:

def exist(target, array):
  lower = -1 
  upper = len(array)
  while not (lower + 1 == upper):
    mid = (lower + upper)//2
    if target== array[mid]: 
      return True
    elif target< array[mid]: 
      upper = mid 
    else: 
      lower = mid 
  return False

但是,现在我面临着这个包含 employee_id 和 employee_birthyear:

的二维数组 list2
[[2, 1986],
 [4, 1950],
 [6, 1994],
 [9, 2004],
 [12, 1988],
 [13, 1964],
 [16, 1987],
 [18, 1989],
 [19, 1951],
 [20, 1991]]

我想使用上面的纯二进制搜索算法编写一个函数,该算法接受一年(作为整数)和 list2,以及 returns employee_id 的列表匹配 employee_birthyear.

我该怎么做?

这是我想到的:

lst2 = [ j for i in list2 for j in i]

def get_IDs_with_birthyear(year, lst2):
    lower = -1 
  upper = len(lst2)
  while not (lower + 1 == upper): 
    mid = (lower + upper)//2
    if year = lst2[mid]:
      return mid 
  return []

更新: 我尝试对 year 进行排序并进行二进制搜索,但是当同一年有多个 id 时,我无法检索所有 ID。

ids = [] 
def get_IDs_with_birthyear(year, employee_with_birthyear_list):
  data2 = sorted(employee_with_birthyear_list, key=lambda d: d[1])
  years = [d[1] for d in data2]
  id = [d[0] for d in data2]
  
  lower = -1 
  upper = len(years)
  while not (lower + 1 == upper): 
    mid = (lower + upper)//2
    if year == years[mid]:
      ids.append(id[mid])
      return ids
    elif year < years[mid]:
      upper = mid 
    else: 
      lower = mid
  return False 

结果应该是 [101, 201, 1999632, 1999649],但我只得到 [1999632]

result = get_IDs_with_birthyear(1949, ewby)

我的函数做错了什么?

二进制搜索算法需要根据您搜索的关键字对数据进行排序。在这种情况下,您的 2D 列表未按年份排序,因此您需要采用不同的方法(或者您需要按年份而不是员工 ID 对列表进行排序)。

一个选项是构建一个以年份为键的字典:

years = dict()
for emp,birth in employee_birthyear:
    years.setdefault(birth,[]).append(emp)

然后您可以获得任何给定出生年份的员工 ID 列表:

years[1950] # [4]

请注意,您会得到一份员工 ID 列表,因为同一年出生的员工可能不止一名

一旦构建了字典(O(N) 操作而不是 O(NlogN) 排序),按年的所有访问将在 O(1) 中 return。这避免了更改 employee_birthyear 列表中元素顺序的需要,并且比二进制搜索(即 O(logN))

具有更低的复杂性

您可以利用 bisect module 进行二进制搜索,并维护一个排序列表。

要按日期搜索,需要按日期对数据进行排序,并构建二级键(日期)列表。 bisect_left() 然后可用于查找记录的开始 >= 选择的开始日期。

>>> from bisect import bisect_left, bisect_right
>>>
>>> data  = [[2, 1986],[4, 1950],[6, 1994],[9, 2004],[12, 1988],
...          [13, 1964],[16, 1987],[18, 1989],[19, 1951],[20, 1991]]
>>>
>>> # Sort data by date, then construct array of just dates.
>>> data2 = sorted(data, key=lambda d: d[1])
>>> dates = [d[1] for d in data2]
>>>
>>> # Use bisect_left() to locate the start of records >= 1988.
>>> idx = bisect_left(dates, 1988)
>>> data2[idx]     # Use that index with data2 to get first record.
[12, 1988]
>>>
>>> # Iteratively get records between 1988 and 1991.
>>> idx = bisect_left([d[1] for d in data2], 1988)
>>> recs = []
>>> while idx < len(data2) and data2[idx][1] <= 1991:
...     recs.append(data2[idx])
...     idx += 1
...     
>>> recs
[[12, 1988], [18, 1989], [20, 1991]]
>>> 
>>> 

这是一个灵活的解决方案,因为您要查找的日期实际上可能不在任何记录中,因此字典查找将失败。通过获取第一条记录 >= 日期,您可以向前迭代以按升序查找后续记录。您还可以指定一个结束日期并使用 bisect_right() 找到它在数组中的位置,这样您就可以包含感兴趣的记录范围。

>>> # A function that returns records that fall between 
>>> # specific dates, inclusive.
>>> def get_range(sorted_data, year_begin, year_end):
...     dates     = [d[1] for d in sorted_data]
...     idx_start = bisect_left(dates, year_begin)
...     idx_end   = bisect_right(dates, year_end)
...     return sorted_data[idx_start:idx_end]
...     
>>> get_range(data2, 1988, 1994)
[[12, 1988], [18, 1989], [20, 1991], [6, 1994]]
>>>
>>> # Add some more records for 1988 & sort.
>>> data2.extend([[1, 1988], [3, 1988]])
>>> data2.sort(); data2.sort(key=lambda r: r[1])
>>> 
>>> # Retrieve all records with 1988.
>>> get_range(data2, 1988, 1988)
[[1, 1988], [3, 1988], [12, 1988]]

如果数据集比较小,只想获取特定日期的记录,则不需要排序或使用上述任何一种方法。您可以只过滤数据集。

>>> list(filter(lambda d: d[1] == 1988, data2))
[[1, 1988], [3, 1988], [12, 1988]]