diff --git a/CHANGELOG b/CHANGELOG index b21e252e..f6a5d972 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -5,3 +5,4 @@ [das-query-engine#214] Add retrieve_all_atoms method [#124] Changed count_atoms() to return more accurate numbers [das-query-engine#197] Changed get_all_links() to return a tuple +[#142] Changed add_link() and add_node() to work with get_atom returns diff --git a/hyperon_das_atomdb/database.py b/hyperon_das_atomdb/database.py index dfaca8f5..4be36953 100644 --- a/hyperon_das_atomdb/database.py +++ b/hyperon_das_atomdb/database.py @@ -51,6 +51,16 @@ def _transform_to_target_format( if kwargs.get('targets_document', False): targets_document = [self.get_atom(target) for target in answer['targets']] return answer, targets_document + elif kwargs.get('deep_representation', False): + + def _recursive_targets(targets, **kwargs): + return [self.get_atom(target, **kwargs) for target in targets] + + if 'targets' in answer: + deep_targets = _recursive_targets(answer['targets'], **kwargs) + answer['targets'] = deep_targets + + return answer return answer @@ -66,82 +76,83 @@ def _recursive_link_split(self, params: Dict[str, Any]) -> (str, Any): return (self.link_handle(atom_type, targets), composite_type) def _add_node(self, node_params: Dict[str, Any]) -> Dict[str, Any]: - reserved_parameters = ['_id', 'composite_type_hash', 'named_type'] + reserved_parameters = ['handle', '_id', 'composite_type_hash', 'named_type'] - if any(item in reserved_parameters for item in node_params.keys()): - raise AddNodeException( - message="This is a reserved field name in nodes", - details=str(reserved_parameters), - ) + valid_params = { + key: value for key, value in node_params.items() if key not in reserved_parameters + } + + node_type = valid_params.get('type') + node_name = valid_params.get('name') - node_type = node_params.get('type') - node_name = node_params.get('name') if node_type is None or node_name is None: raise AddNodeException( message='The "name" and "type" fields must be sent', - details=node_params, + details=valid_params, ) handle = self.node_handle(node_type, node_name) + node = { '_id': handle, 'composite_type_hash': ExpressionHasher.named_type_hash(node_type), 'name': node_name, 'named_type': node_type, } - node.update(node_params) + + node.update(valid_params) node.pop('type') + return (handle, node) def _add_link(self, link_params: Dict[str, Any], toplevel: bool = True) -> Dict[str, Any]: reserved_parameters = [ + 'handle', '_id', 'composite_type_hash', - 'is_toplevel', 'composite_type', + 'is_toplevel', 'named_type', 'named_type_hash', 'key_n', ] - if any( - item in reserved_parameters or re.search(AtomDB.key_pattern, item) - for item in link_params.keys() - ): - raise AddLinkException( - message="This is a reserved field name in links", - details=str(reserved_parameters), - ) + valid_params = { + key: value + for key, value in link_params.items() + if key not in reserved_parameters and not re.search(AtomDB.key_pattern, key) + } + + link_type = valid_params.get('type') + targets = valid_params.get('targets') - link_type = link_params.get('type') - targets = link_params.get('targets') if link_type is None or targets is None: raise AddLinkException( message='The "type" and "targets" fields must be sent', - details=link_params, + details=valid_params, ) link_type_hash = ExpressionHasher.named_type_hash(link_type) - targets_hash = [] composite_type = [link_type_hash] composite_type_hash = [link_type_hash] for target in targets: - if 'targets' not in target.keys(): + if not isinstance(target, dict): + raise ValueError('The target must be a dictionary') + if 'targets' not in target: atom = self.add_node(target) - atom_hash = ExpressionHasher.named_type_hash(atom['named_type']) + atom_hash = atom['composite_type_hash'] composite_type.append(atom_hash) else: atom = self.add_link(target, toplevel=False) - composite_type.append(atom['composite_type']) atom_hash = atom['composite_type_hash'] + composite_type.append(atom['composite_type']) composite_type_hash.append(atom_hash) targets_hash.append(atom['_id']) handle = ExpressionHasher.expression_hash(link_type_hash, targets_hash) - arity = len(targets) link = { '_id': handle, 'composite_type_hash': ExpressionHasher.composite_hash(composite_type_hash), @@ -150,10 +161,11 @@ def _add_link(self, link_params: Dict[str, Any], toplevel: bool = True) -> Dict[ 'named_type': link_type, 'named_type_hash': link_type_hash, } - for item in range(arity): + + for item in range(len(targets)): link[f'key_{item}'] = targets_hash[item] - link.update(link_params) + link.update(valid_params) link.pop('type') link.pop('targets') diff --git a/tests/integration/adapters/test_redis_mongo.py b/tests/integration/adapters/test_redis_mongo.py index 6dcd2348..8b5ba1b4 100644 --- a/tests/integration/adapters/test_redis_mongo.py +++ b/tests/integration/adapters/test_redis_mongo.py @@ -957,3 +957,40 @@ def test_retrieve_all_atoms(self, _cleanup): nodes = db.get_all_nodes('Concept') assert len(response) == len(links) + len(nodes) _db_down() + + def test_add_fields_to_atoms(self, _cleanup): + _db_up(Database.REDIS, Database.MONGO) + db = self._connect_db() + self._add_atoms(db) + db.commit() + human = db.node_handle('Concept', 'human') + monkey = db.node_handle('Concept', 'monkey') + link_handle = db.link_handle('Similarity', [human, monkey]) + + node_human = db.get_atom(human) + + assert node_human['handle'] == human + assert node_human['name'] == 'human' + assert node_human['named_type'] == 'Concept' + + node_human['score'] = 0.5 + + db.add_node(node_human) + db.commit() + + assert db.get_atom(human)['score'] == 0.5 + + link_similarity = db.get_atom(link_handle, deep_representation=True) + + assert link_similarity['handle'] == link_handle + assert link_similarity['type'] == 'Similarity' + assert link_similarity['targets'] == [db.get_atom(human), db.get_atom(monkey)] + + link_similarity['score'] = 0.5 + + db.add_link(link_similarity) + db.commit() + + assert db.get_atom(link_handle)['score'] == 0.5 + + _db_down() diff --git a/tests/unit/adapters/test_redis_mongo_db.py b/tests/unit/adapters/test_redis_mongo_db.py index 1733e515..9f96f45b 100644 --- a/tests/unit/adapters/test_redis_mongo_db.py +++ b/tests/unit/adapters/test_redis_mongo_db.py @@ -1826,28 +1826,6 @@ def test_add_link(self, database): added_atoms.clear() - def test_add_node_and_link_with_reserved_parameters(self, database): - with pytest.raises(AddNodeException) as exc: - database.add_node({'type': 'Concept', 'name': 'lion', 'named_type': 'Concept-type'}) - assert exc.value.message == 'This is a reserved field name in nodes' - assert exc.value.details == "['_id', 'composite_type_hash', 'named_type']" - with pytest.raises(AddLinkException) as exc: - database.add_link( - { - 'type': 'Concept', - 'targets': [ - {'type': 'Concept', 'name': 'test-1'}, - {'type': 'Concept', 'name': 'test-2'}, - ], - 'key_1': 'custom_key', - } - ) - assert exc.value.message == 'This is a reserved field name in links' - assert ( - exc.value.details - == "['_id', 'composite_type_hash', 'is_toplevel', 'composite_type', 'named_type', 'named_type_hash', 'key_n']" - ) - def test_get_incoming_links(self, database): h = database.get_node_handle('Concept', 'human') m = database.get_node_handle('Concept', 'monkey')