LeetCode 1707. 数组中元素的最大异或

LeetCode 1707. Maximum XOR With an Element From Array

给你一个由非负整数组成的数组 nums。您还会得到一个查询数组,其中 queries[i] = [xi, mi].

第i个查询的答案是xi与nums中任意一个不超过mi的元素的最大按位异或值。换句话说,对于所有满足 nums[j] <= mi 的 j,答案是 max(nums[j] XOR xi)。如果nums中的所有元素都大于mi,那么答案就是-1。

Return 一个整数数组答案,其中 answer.length == queries.length 和 answer[i] 是第 i 个查询的答案。 这个python解法用了Trie,LeetCode还是显示TLE?

 import operator

 class TrieNode:
      def __init__(self):
          self.left=None
          self.right=None
   
 class Solution:
     def insert(self,head,x):
         curr=head
         for i in range(31,-1,-1):
             val = (x>>i) & 1
             if val==0:
                if not curr.left:
                   curr.left=TrieNode()
                       curr=curr.left
                   else:
                       curr=curr.left
               else:
                   if not curr.right:
                       curr.right=TrieNode()
                       curr=curr.right
                   else:
                       curr=curr.right
       
           
       def maximizeXor(self, nums: List[int], queries: List[List[int]]) -> List[int]:
           res=[-10]*len(queries)
           nums.sort()
           for i in range(len(queries)):
               queries[i].append(i)
           queries.sort(key=operator.itemgetter(1))
           head=TrieNode()
          
           for li in queries:
               max=0
               xi,mi,index=li[0],li[1],li[2]
               m=2**31
               node = head
               pos=0
               if mi<nums[0]:
                   res[index]=-1
                   continue
               for i in range(pos,len(nums)):
                   if mi<nums[i]:
                       pos=i
                       break
                   self.insert(node,nums[i])
               node=head
               for i in range(31,-1,-1):
                   val=(xi>>i)&1
                   if val==0:
                       if node.right:
                           max+=m
                           node=node.right
                       else:
                           node=node.left
                   else:
                       if node.left:
                           max+=m
                           node=node.left
                       else:
                           node=node.right
                   m>>=1
               res[index]=max
           return -1

这里是解决这个问题的替代Trie工具:

[注:1) max(x XOR y for y in A); 2)对MSB位做贪心; 3) 对查询进行排序]

class Trie:
    def __init__(self):
        self.root = {}
    
    def add(self, n):
        p = self.root
        for bitpos in range(31, -1, -1):
            bit = (n >> bitpos) & 1
            if bit not in p:
                p[bit] = {}
            p = p[bit]
    
    def query(self, n):
        p = self.root
        ret = 0
        if not p:
            return -1
        for bitpos in range(31, -1, -1):
            bit = (n >> bitpos) & 1
            inverse = 1 - bit
            if inverse in p:
                p = p[inverse]
                ret |= (1 << bitpos)
            else:
                p = p[bit]
                
        return ret

class Solution:
    def maximizeXor(self, nums: List[int], queries: List[List[int]]) -> List[int]:
        
        n = len(nums)
        trie = Trie()
        q = sorted(enumerate(queries), key = lambda x: x[1][1])
        nums.sort()
        res = [-1] * len(queries)
        i = 0
        for index, (x, m) in q:
            while i < n and nums[i] <= m:
                trie.add(nums[i])
                i += 1
            res[index] = trie.query(x)
        return res

问题是您正在为每个查询构建一个新的 Trie。更糟糕的是,使用线性搜索在 nums 中找到最大值 <= mi。简单地使用

你会过得更好
max((n for n in nums if n <= mi), key=lambda n: n^xi, default=-1) 

这里的解决方案是在一开始就构建 trie,然后使用该 trie 简单地过滤小于 mi 的值:

import math
import bisect

def dump(t, indent=''):
    if t is not None:
        print(indent, "bit=", t.bit, "val=", t.val, "lower=", t.lower)
        dump(t.left, indent + '\tl')
        dump(t.right, indent + '\tr')

class Trie:
    def __init__(self, bit, val, lower):
        self.bit = bit
        self.val = val
        self.lower = lower
        self.left = None
        self.right = None
        
    def solve(self, mi, xi):
        print('-------------------------------------------')
        print(self.bit, "mi(b)=", (mi >> self.bit) & 1, "xi(b)=", (xi >> self.bit) & 1, "mi=", mi, "xi=", xi)
        dump(self)
        
        if self.val is not None:
            # reached a leave of the trie => found matching value
            print("Leaf")
            return self.val
        
        if mi & (1 << self.bit) == 0:
            # the maximum has a zero-bit at this position => all values in the right subtree are > mi
            print("Left forced by max")
            return -1 if self.left is None else self.left.solve(mi, xi)
        
        # pick based on xor-value if possible
        if (xi >> self.bit) & 1 == 0 and self.right is not None and (mi > self.right.lower or mi == ~0):
            print("Right preferred by xi")
            return self.right.solve(mi, xi)
        elif (xi >> self.bit) & 1 == 1 and self.left is not None:
            print("Left preferred by xi")
            return self.left.solve(~0, xi)
        
        # pick whichever is available
        if self.right is not None and (mi > self.right.lower or mi == ~0):
            print("Only right available")
            return self.right.solve(mi, xi)
        elif self.left is not None:
            print("Only left available")
            return self.left.solve(~0, xi)
        else:
            print("None available")
            return -1
        
        
def build_trie(nums):
    nums.sort()
        
    # msb of max(nums)
    max_bit = int(math.log(nums[-1], 2))  # I'll just assume that nums is never empty
    print(max_bit)
        
    def node(start, end, bit, template):
        print(start, end, bit, template, nums[start:end])
        
        if end - start == 1:
            # reached a leaf
            return Trie(0, nums[start], nums[start])
        elif start == end:
            # a partition without values => no Trie-node
            return None
            
        # find pivot for partitioning based on bit-value of specified position (bit)
        part = bisect.bisect_left(nums, template | (1 << bit), start, end)
        print(part)
            
        # build nodes for paritioning       
        res = Trie(bit, None, nums[start])
        res.left = node(start, part, bit - 1, template)
        res.right = node(part, end, bit - 1, template | (1 << bit))
        return res
        
    return node(0, len(nums), max_bit, 0)

class Solution:
    def maximizeXor(self, nums: List[int], queries: List[List[int]]) -> List[int]:
        trie = build_trie(nums)
        return [trie.solve(mi if mi <= nums[-1] else ~0, xi) for xi, mi in queries]

我有点懒,只是用 ~0 表示可以忽略最大值,因为子树中的所有值都小于 mi。基本思想是 ~0 & x == x 对任何整数 x 都成立。不像 那样简单,但能够处理查询流。