unordered_map for custom class 在插入相同的键时不会导致错误

unordered_map for custom class does not cause error when inserting the same key

我正在尝试找出关于将 unordered_map 用于自定义 class 的一些要点。下面是我用来练习的代码,我定义了一个简单的 class Line。我很困惑为什么在 main() 中插入 Line2 不会使程序输出 insert failedm 的值对于 Line1Line2 都是 3。请注意,由于我只比较 class Lineoperator== 函数中的第一个值(即 m),因此此代码中的 Line1Line2 应该具有相同的值钥匙。插入一个已经存在的密钥不应该是无效的吗?有人可以向我解释为什么吗?谢谢!

#include<iostream>                                                                                                                                                                                                                                                                                                                                                                                                                                          
#include<unordered_map>                                                                                                                                                                                                                    

using namespace std;                                                                                                                                                                                                                       
class Line {                                                                                                                                                                                                                               
public:                                                                                                                                                                                                                                    
  float m;                                                                                                                                                                                                                                 
  float c;                                                                                                                                                                                                                                 

  Line() {m = 0; c = 0;}                                                                                                                                                                                                                   
  Line(float mInput, float cInput) {m = mInput; c = cInput;}                                                                                                                                                                               
  float getM() const {return m;}                                                                                                                                                                                                           
  float getC() const {return c;}                                                                                                                                                                                                           
  void setM(float mInput) {m = mInput;}                                                                                                                                                                                                    
  void setC(float cInput) {c = cInput;}                                                                                                                                                                                                    

  bool operator==(const Line &anotherLine) const                                                                                                                                                                                           
    {                                                                                                                                                                                                                                      
      return (m == anotherLine.m);                                                                                                                                                                                                         
    }                                                                                                                                                                                                                                      
};                                                                                                                                                                                                                                         

namespace std                                                                                                                                                                                                                              
{                                                                                                                                                                                                                                          
  template <>                                                                                                                                                                                                                              
  struct hash<Line>                                                                                                                                                                                                                        
  {                                                                                                                                                                                                                                        
    size_t operator()(const Line& k) const                                                                                                                                                                                                 
      {                                                                                                                                                                                                                                    
        // Compute individual hash values for two data members and combine them using XOR and bit shifting                                                                                                                                 
        return ((hash<float>()(k.getM()) ^ (hash<float>()(k.getC()) << 1)) >> 1);                                                                                                                                                          
      }                                                                                                                                                                                                                                    
  };                                                                                                                                                                                                                                       
}                                                                                                                                                                                                                                          

int main()                                                                                                                                                                                                                                 
{                                                                                                                                                                                                                                          
  unordered_map<Line, int> t;                                                                                                                                                                                                              

  Line line1 = Line(3.0,4.0);                                                                                                                                                                                                              
  Line line2 = Line(3.0,5.0);                                                                                                                                                                                                              

  t.insert({line1, 1});                                                                                                                                                                                                                                                                                                                                                                                                                                      
  auto x = t.insert({line2, 2});                                                                                                                                                                                                           
  if (x.second == false)                                                                                                                                                                                                                   
    cout << "insert failed" << endl;                                                                                                                                                                                                       

  for(unordered_map<Line, int>::const_iterator it = t.begin(); it != t.end(); it++)                                                                                                                                                        
  {                                                                                                                                                                                                                                        
    Line t = it->first;                                                                                                                                                                                                                    
    cout << t.m << " " << t.c << "\n" ;                                                                                                                                                                                                    
  }                                                                                                                                                                                                                                        

  return 1;                                                                                                                                                                                                                                
}    

您的 hashoperator == 必须满足他们目前违反的一致性要求。当两个对象根据 == 相等时,它们的哈希码 必须 根据 hash 相等。换句话说,虽然不相等的对象可能具有相同的哈希码,但相等的对象必须具有相同的哈希码:

size_t operator()(const Line& k) const  {
    return hash<float>()(k.getM());
}   

由于您只比较一个组件是否相等,而忽略了另一个组件,因此您需要更改哈希函数以使用用于确定相等性的同一组件。

您在哈希中同时使用了 "m" 和 "c" 的值,因此如果 "m" 和 "c" 是相等的,在你的例子中不是这种情况。所以如果你这样做:

Line line1 = Line(3.0,4.0);                                                                                                                                                                                                              
Line line2 = Line(3.0,4.0);                                                                                                                                                                                                              

t.insert({line1, 1});                                                                                                                                                                                                                                                                                                                                                                                                                                      
auto x = t.insert({line2, 2});                                                                                                                                                                                                           
if (x.second == false)                                                                                                                                                                                                                   
  cout << "insert failed" << endl;  

你会看到它会打印 "insert failed"

您始终可以使用自定义函数在插入时比较键:

#include <iostream>
#include <unordered_map>

class Line {
private:
    float m;
    float c;
public:
    Line() { m = 0; c = 0; }
    Line(float mInput, float cInput) { m = mInput; c = cInput; }
    float getM() const { return m; }
    float getC() const { return c; }
};


struct hash
{
    size_t operator()(const Line& k) const 
    {
        return ((std::hash<float>()(k.getM()) ^ (std::hash<float>()(k.getC()) << 1)) >> 1);
    }
};

// custom key comparison
struct cmpKey
{
    bool operator() (Line const &l1, Line const &l2) const
    {
        return l1.getM() == l2.getM();
    }
};


int main()
{ 

    std::unordered_map<Line, int, hash, cmpKey> mymap; // with custom key comparisom

    Line line1 = Line(3.0, 4.0);
    Line line2 = Line(4.0, 5.0);
    Line line3 = Line(4.0, 4.0);

    auto x = mymap.insert({ line1, 1 });
    std::cout << std::boolalpha << "element inserted: " << x.second << std::endl;
    x = mymap.insert({ line2, 2 });
    std::cout << std::boolalpha << "element inserted: " << x.second << std::endl;
    x = mymap.insert({ line3, 3 });
    std::cout << std::boolalpha << "element inserted: " << x.second << std::endl;

    return 0;
}

打印:

element inserted: true
element inserted: true
element inserted: false