Skip to content

Commit

Permalink
Variable: fix equal when names equal
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Aug 28, 2020
1 parent b445bbf commit 798316a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
19 changes: 13 additions & 6 deletions Orange/data/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,18 +215,25 @@ def test_hash(self):

def test_hash_eq(self):
a = ContinuousVariable("a")
a1 = ContinuousVariable("a")
b1 = ContinuousVariable("b", compute_value=Identity(a))
b2 = ContinuousVariable("b2", compute_value=Identity(b1))
b3 = ContinuousVariable("b")
self.assertEqual(a, b2)
self.assertEqual(b1, b2)
self.assertEqual(a, b1)
c1 = ContinuousVariable("c", compute_value=Identity(a))
c2 = ContinuousVariable("c", compute_value=Identity(a))
self.assertNotEqual(a, b2)
self.assertNotEqual(b1, b2)
self.assertNotEqual(a, b1)
self.assertNotEqual(b1, b3)
self.assertEqual(a, a1)
self.assertEqual(c1, c2)

self.assertEqual(hash(a), hash(b2))
self.assertEqual(hash(b1), hash(b2))
self.assertEqual(hash(a), hash(b1))
self.assertNotEqual(hash(a), hash(b2))
self.assertNotEqual(hash(b1), hash(b2))
self.assertNotEqual(hash(a), hash(b1))
self.assertNotEqual(hash(b1), hash(b3))
self.assertEqual(hash(a), hash(a1))
self.assertEqual(hash(c1), hash(c2))


def variabletest(varcls):
Expand Down
9 changes: 6 additions & 3 deletions Orange/data/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,15 @@ def __eq__(self, other):
var1 = self._get_identical_source(self)
var2 = self._get_identical_source(other)
# pylint: disable=protected-access
return var1.name == var2.name \
and var1._compute_value == var2._compute_value
return (
self.name == other.name
and var1.name == var2.name
and var1._compute_value == var2._compute_value
)

def __hash__(self):
var = self._get_identical_source(self)
return hash((var.name, type(self), var._compute_value))
return hash((self.name, var.name, type(self), var._compute_value))

@staticmethod
def _get_identical_source(var):
Expand Down

0 comments on commit 798316a

Please sign in to comment.