如何重载 `__eq__` 来比较 pandas DataFrames 和 Series?
How do I overload `__eq__` to compare pandas DataFrames and Series?
为清楚起见,我将从我的代码中摘录并使用通用名称。我有一个 class Foo()
将 DataFrame 存储到属性。
import pandas as pd
import pandas.util.testing as pdt
class Foo():
def __init__(self, bar):
self.bar = bar # dict of dicts
self.df = pd.DataFrame(bar) # pandas object
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result
但是,当我尝试比较 Foo
的两个实例时,我得到了一个与比较两个 DataFrame 的歧义相关的异常(如果没有 'df' 中的键,比较应该可以正常工作 Foo.__dict__
).
d1 = {'A' : pd.Series([1, 2], index=['a', 'b']),
'B' : pd.Series([1, 2], index=['a', 'b'])}
d2 = d1.copy()
foo1 = Foo(d1)
foo2 = Foo(d2)
foo1.bar # dict
foo1.df # pandas DataFrame
foo1 == foo2 # ValueError
[Out] ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
幸运的是,pandas 具有用于断言两个 DataFrame 或 Series 是否为真的实用函数。如果可能的话,我想使用这个函数的比较操作。
pdt.assert_frame_equal(pd.DataFrame(d1), pd.DataFrame(d2)) # no raises
有几个选项可以解决两个 Foo
实例的比较:
- 比较
__dict__
的副本,其中 new_dict
缺少 df 键
- 从
__dict__
中删除 df 键(不理想)
- 不比较
__dict__
,但只有元组中包含的部分
- 重载
__eq__
以促进 pandas DataFrame 比较
最后一个选项似乎是 long-运行 中最稳健的,但我不确定最佳方法。最后,我想重构 __eq__
以比较 Foo.__dict__
中的所有项目,包括 DataFrames(和系列)。 关于如何完成此操作的任何想法?
来自这些线程的解决方案
Comparing two pandas dataframes for differences
Pandas DataFrames with NaNs equality comparison
def df_equal(self):
try:
assert_frame_equal(csvdata, csvdata_old)
return True
except:
return False
对于数据框字典:
def df_equal(df1, df2):
try:
assert_frame_equal(df1, df2)
return True
except:
return False
def __eq__(self, other):
if self.df.keys() != other.keys():
return False
for k in self.df.keys():
if not df_equal(self.df[k], other[k]):
return False
return True
下面的代码似乎完全满足了我原来的问题。它同时处理 pandas DataFrames
和 Series
。欢迎简化。
这里的诀窍是 __eq__
已实现以分别比较 __dict__
和 pandas 对象。最后比较每个的真实性并返回结果。这里有一些有趣和被利用的东西,and
returns 第二个值如果第一个值是 True
.
使用错误处理和外部比较函数的想法受到@ate50eggs 提交的答案的启发。非常感谢。
import pandas as pd
import pandas.util.testing as pdt
def ndframe_equal(ndf1, ndf2):
try:
if isinstance(ndf1, pd.DataFrame) and isinstance(ndf2, pd.DataFrame):
pdt.assert_frame_equal(ndf1, ndf2)
#print('DataFrame check:', type(ndf1), type(ndf2))
elif isinstance(ndf1, pd.Series) and isinstance(ndf2, pd.Series):
pdt.assert_series_equal(ndf1, ndf2)
#print('Series check:', type(ndf1), type(ndf2))
return True
except (ValueError, AssertionError, AttributeError):
return False
class Foo(object):
def __init__(self, bar):
self.bar = bar
try:
self.ndf = pd.DataFrame(bar)
except(ValueError):
self.ndf = pd.Series(bar)
def __eq__(self, other):
if isinstance(other, self.__class__):
# Auto check attrs if assigned to DataFrames/Series, then add to list
blacklisted = [attr for attr in self.__dict__ if
isinstance(getattr(self, attr), pd.DataFrame)
or isinstance(getattr(self, attr), pd.Series)]
# Check DataFrames and Series
for attr in blacklisted:
ndf_eq = ndframe_equal(getattr(self, attr),
getattr(other, attr))
# Ignore pandas objects; check rest of __dict__ and build new dicts
self._dict = {
key: value
for key, value in self.__dict__.items()
if key not in blacklisted}
other._dict = {
key: value
for key, value in other.__dict__.items()
if key not in blacklisted}
return ndf_eq and self._dict == other._dict # order is important
return NotImplemented
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result
正在 DataFrames
.
上测试后面的代码
# Data for DataFrames
d1 = {'A' : pd.Series([1, 2], index=['a', 'b']),
'B' : pd.Series([1, 2], index=['a', 'b'])}
d2 = d1.copy()
d3 = {'A' : pd.Series([1, 2], index=['abc', 'b']),
'B' : pd.Series([9, 0], index=['abc', 'b'])}
# Test DataFrames
foo1 = Foo(d1)
foo2 = Foo(d2)
foo1.bar # dict of Series
foo1.ndf # pandas DataFrame
foo1 == foo2 # triggers _dict
#foo1.__dict__['_dict']
#foo1._dict
foo1 == foo2 # True
foo1 != foo2 # False
not foo1 == foo2 # False
not foo1 != foo2 # True
foo2 = Foo(d3)
foo1 == foo2 # False
foo1 != foo2 # True
not foo1 == foo2 # True
not foo1 != foo2 # False
最后测试另一个常见的 pandas 对象,Series
。
# Data for Series
s1 = {'a' : 0., 'b' : 1., 'c' : 2.}
s2 = s1.copy()
s3 = {'a' : 0., 'b' : 4, 'c' : 5}
# Test Series
foo3 = Foo(s1)
foo4 = Foo(s2)
foo3.bar # dict
foo4.ndf # pandas Series
foo3 == foo4 # True
foo3 != foo4 # False
not foo3 == foo4 # False
not foo3 != foo4 # True
foo4 = Foo(s3)
foo3 == foo4 # False
foo3 != foo4 # True
not foo3 == foo4 # True
not foo3 != foo4 # False
为清楚起见,我将从我的代码中摘录并使用通用名称。我有一个 class Foo()
将 DataFrame 存储到属性。
import pandas as pd
import pandas.util.testing as pdt
class Foo():
def __init__(self, bar):
self.bar = bar # dict of dicts
self.df = pd.DataFrame(bar) # pandas object
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result
但是,当我尝试比较 Foo
的两个实例时,我得到了一个与比较两个 DataFrame 的歧义相关的异常(如果没有 'df' 中的键,比较应该可以正常工作 Foo.__dict__
).
d1 = {'A' : pd.Series([1, 2], index=['a', 'b']),
'B' : pd.Series([1, 2], index=['a', 'b'])}
d2 = d1.copy()
foo1 = Foo(d1)
foo2 = Foo(d2)
foo1.bar # dict
foo1.df # pandas DataFrame
foo1 == foo2 # ValueError
[Out] ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
幸运的是,pandas 具有用于断言两个 DataFrame 或 Series 是否为真的实用函数。如果可能的话,我想使用这个函数的比较操作。
pdt.assert_frame_equal(pd.DataFrame(d1), pd.DataFrame(d2)) # no raises
有几个选项可以解决两个 Foo
实例的比较:
- 比较
__dict__
的副本,其中new_dict
缺少 df 键 - 从
__dict__
中删除 df 键(不理想) - 不比较
__dict__
,但只有元组中包含的部分 - 重载
__eq__
以促进 pandas DataFrame 比较
最后一个选项似乎是 long-运行 中最稳健的,但我不确定最佳方法。最后,我想重构 __eq__
以比较 Foo.__dict__
中的所有项目,包括 DataFrames(和系列)。 关于如何完成此操作的任何想法?
来自这些线程的解决方案
Comparing two pandas dataframes for differences
Pandas DataFrames with NaNs equality comparison
def df_equal(self):
try:
assert_frame_equal(csvdata, csvdata_old)
return True
except:
return False
对于数据框字典:
def df_equal(df1, df2):
try:
assert_frame_equal(df1, df2)
return True
except:
return False
def __eq__(self, other):
if self.df.keys() != other.keys():
return False
for k in self.df.keys():
if not df_equal(self.df[k], other[k]):
return False
return True
下面的代码似乎完全满足了我原来的问题。它同时处理 pandas DataFrames
和 Series
。欢迎简化。
这里的诀窍是 __eq__
已实现以分别比较 __dict__
和 pandas 对象。最后比较每个的真实性并返回结果。这里有一些有趣和被利用的东西,and
returns 第二个值如果第一个值是 True
.
使用错误处理和外部比较函数的想法受到@ate50eggs 提交的答案的启发。非常感谢。
import pandas as pd
import pandas.util.testing as pdt
def ndframe_equal(ndf1, ndf2):
try:
if isinstance(ndf1, pd.DataFrame) and isinstance(ndf2, pd.DataFrame):
pdt.assert_frame_equal(ndf1, ndf2)
#print('DataFrame check:', type(ndf1), type(ndf2))
elif isinstance(ndf1, pd.Series) and isinstance(ndf2, pd.Series):
pdt.assert_series_equal(ndf1, ndf2)
#print('Series check:', type(ndf1), type(ndf2))
return True
except (ValueError, AssertionError, AttributeError):
return False
class Foo(object):
def __init__(self, bar):
self.bar = bar
try:
self.ndf = pd.DataFrame(bar)
except(ValueError):
self.ndf = pd.Series(bar)
def __eq__(self, other):
if isinstance(other, self.__class__):
# Auto check attrs if assigned to DataFrames/Series, then add to list
blacklisted = [attr for attr in self.__dict__ if
isinstance(getattr(self, attr), pd.DataFrame)
or isinstance(getattr(self, attr), pd.Series)]
# Check DataFrames and Series
for attr in blacklisted:
ndf_eq = ndframe_equal(getattr(self, attr),
getattr(other, attr))
# Ignore pandas objects; check rest of __dict__ and build new dicts
self._dict = {
key: value
for key, value in self.__dict__.items()
if key not in blacklisted}
other._dict = {
key: value
for key, value in other.__dict__.items()
if key not in blacklisted}
return ndf_eq and self._dict == other._dict # order is important
return NotImplemented
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result
正在 DataFrames
.
# Data for DataFrames
d1 = {'A' : pd.Series([1, 2], index=['a', 'b']),
'B' : pd.Series([1, 2], index=['a', 'b'])}
d2 = d1.copy()
d3 = {'A' : pd.Series([1, 2], index=['abc', 'b']),
'B' : pd.Series([9, 0], index=['abc', 'b'])}
# Test DataFrames
foo1 = Foo(d1)
foo2 = Foo(d2)
foo1.bar # dict of Series
foo1.ndf # pandas DataFrame
foo1 == foo2 # triggers _dict
#foo1.__dict__['_dict']
#foo1._dict
foo1 == foo2 # True
foo1 != foo2 # False
not foo1 == foo2 # False
not foo1 != foo2 # True
foo2 = Foo(d3)
foo1 == foo2 # False
foo1 != foo2 # True
not foo1 == foo2 # True
not foo1 != foo2 # False
最后测试另一个常见的 pandas 对象,Series
。
# Data for Series
s1 = {'a' : 0., 'b' : 1., 'c' : 2.}
s2 = s1.copy()
s3 = {'a' : 0., 'b' : 4, 'c' : 5}
# Test Series
foo3 = Foo(s1)
foo4 = Foo(s2)
foo3.bar # dict
foo4.ndf # pandas Series
foo3 == foo4 # True
foo3 != foo4 # False
not foo3 == foo4 # False
not foo3 != foo4 # True
foo4 = Foo(s3)
foo3 == foo4 # False
foo3 != foo4 # True
not foo3 == foo4 # True
not foo3 != foo4 # False