Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the possibility to connect to a specific database #61

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/connection-int.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void connection_handle_error(ConnectionObject *conn) {
}

int connection_run_without_results(ConnectionObject *conn, const char *query) {
int status = mg_session_run(conn->session, query, NULL, NULL, NULL, NULL);
int status = mg_session_run(conn->session, query, NULL, conn->extras, NULL, NULL);
if (status != 0) {
connection_handle_error(conn);
return -1;
Expand Down Expand Up @@ -87,7 +87,7 @@ int connection_run(ConnectionObject *conn, const char *query, PyObject *params,

const mg_list *mg_columns;
int status =
mg_session_run(conn->session, query, mg_params, NULL, &mg_columns, NULL);
mg_session_run(conn->session, query, mg_params, conn->extras, &mg_columns, NULL);
mg_map_destroy(mg_params);

if (status != 0) {
Expand Down
42 changes: 38 additions & 4 deletions src/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

static void connection_dealloc(ConnectionObject *conn) {
mg_session_destroy(conn->session);
mg_map_destroy(conn->extras);
Py_TYPE(conn)->tp_free(conn);
}

Expand All @@ -37,11 +38,36 @@ static int execute_trust_callback(const char *hostname, const char *ip_address,
return !status;
}

static mg_map *database_to_extras(const char *database) {
assert(databases);

mg_map *map = NULL;

map = mg_map_make_empty(1U);
if (!map) {
PyErr_SetString(PyExc_RuntimeError, "failed to create a mg_map");
goto cleanup;
}

mg_string* key = mg_string_make("db");
mg_value* value = mg_value_make_string(database);

if (mg_map_insert_unsafe2(map, key, value) != 0) {
mg_string_destroy(key);
abort();
}
return map;

cleanup:
mg_map_destroy(map);
return NULL;
}

static int connection_init(ConnectionObject *conn, PyObject *args,
PyObject *kwargs) {
static char *kwlist[] = {"host", "address", "port", "username",
"password", "client_name", "sslmode", "sslcert",
"sslkey", "trust_callback", "lazy", NULL};
"sslkey", "trust_callback", "lazy", "database", NULL};

const char *host = NULL;
const char *address = NULL;
Expand All @@ -54,11 +80,12 @@ static int connection_init(ConnectionObject *conn, PyObject *args,
const char *sslkey = NULL;
PyObject *trust_callback = NULL;
int lazy = 0;

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$ssisssissOp", kwlist, &host,
const char *database = NULL;

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$ssisssissOps", kwlist, &host,
&address, &port, &username, &password,
&client_name, &sslmode_int, &sslcert,
&sslkey, &trust_callback, &lazy)) {
&sslkey, &trust_callback, &lazy, &database)) {
return -1;
}

Expand Down Expand Up @@ -124,12 +151,17 @@ static int connection_init(ConnectionObject *conn, PyObject *args,
conn->status = CONN_STATUS_READY;
conn->lazy = 0;
conn->autocommit = 0;
conn->extras = NULL;

if (lazy) {
conn->lazy = 1;
conn->autocommit = 1;
}

if (database) {
conn->extras = database_to_extras(database);
}

return 0;
}

Expand Down Expand Up @@ -180,6 +212,8 @@ static PyObject *connection_close(ConnectionObject *conn, PyObject *args) {
// rollback any open transactions.
mg_session_destroy(conn->session);
conn->session = NULL;
mg_map_destroy(conn->extras);
conn->extras = NULL;
conn->status = CONN_STATUS_CLOSED;

Py_RETURN_NONE;
Expand Down
1 change: 1 addition & 0 deletions src/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ typedef struct ConnectionObject {
int status;
int autocommit;
int lazy;
mg_map *extras;
} ConnectionObject;
// clang-format on

Expand Down
8 changes: 6 additions & 2 deletions src/mgclientmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ static PyObject *mgclient_connect(PyObject *self, PyObject *args,
PyDoc_STRVAR(mgclient_connect_doc,
"connect(host=None, address=None, port=None, username=None, password=None,\n\
client_name=None, sslmode=mgclient.MG_SSLMODE_DISABLE,\n\
sslcert=None, sslkey=None, trust_callback=None, lazy=False)\n\
sslcert=None, sslkey=None, trust_callback=None, lazy=False, database=None)\n\
--\n\
\n\
Makes a new connection to the database server and returns a\n\
Expand Down Expand Up @@ -271,7 +271,11 @@ Currently recognized parameters are:\n\
\n\
* :obj:`lazy`\n\
\n\
If this is set to ``True``, a lazy connection is made. Default is ``False``.");
If this is set to ``True``, a lazy connection is made. Default is ``False``.\n\
\n\
* :obj:`database`\n\
\n\
If set, all queries executed will target the defined database. Default is ``None``.");
// clang-format on

static PyMethodDef mgclient_methods[] = {
Expand Down
148 changes: 148 additions & 0 deletions test/test_multi_tenancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) 2016-2020 Memgraph Ltd. [https://memgraph.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mgclient
import pytest
import tempfile

from common import start_memgraph, MEMGRAPH_PORT


def assert_db(cursor, db_name):
cursor.execute("SHOW DATABASE")
assert cursor.fetchall() == [(db_name, )]

def assert_data(cursor, db):
cursor.execute('MATCH (n:Node) RETURN n.db')
cursor.fetchall() == [(db,)]


@pytest.fixture(scope="function")
def memgraph_server():
# memgraph = start_memgraph()
yield "127.0.0.1", MEMGRAPH_PORT

# memgraph.kill()
# memgraph.wait()

# def test_connect_database_fail(memgraph_server):
# host, port = memgraph_server
# # Connected to a non existent database
# conn = mgclient.connect(
# host=host,
# port=port,
# lazy=True,
# database="does not exist")
# cursor = conn.cursor()
# with pytest.raises(mgclient.DatabaseError):
# cursor.execute("MATCH(n) RETURN n;")

# def test_connect_database(memgraph_server):
# host, port = memgraph_server
# conn = mgclient.connect(host=host, port=port, lazy=True)
# cursor = conn.cursor()

# #setup
# assert_db(cursor, "memgraph")
# cursor.execute('CREATE (:Node{db:"memgraph"})')
# cursor.fetchall()

# cursor.execute("CREATE DATABASE db1")
# cursor.fetchall()
# cursor.execute("USE DATABASE db1")
# cursor.fetchall()
# assert_db(cursor, "db1")
# cursor.execute('CREATE (:Node{db:"db1"})')
# cursor.fetchall()

# cursor.execute("CREATE DATABASE db2")
# cursor.fetchall()
# cursor.execute("USE DATABASE db2")
# cursor.fetchall()
# assert_db(cursor, "db2")
# cursor.execute('CREATE (:Node{db:"db2"})')
# cursor.fetchall()

# #connection tests
# #default
# conn = mgclient.connect(host=host, port=port, lazy=True)
# cursor = conn.cursor()
# assert_db(cursor, "memgraph")
# assert_data(cursor, "memgraph")

# #memgraph
# conn = mgclient.connect(host=host, port=port, lazy=True, database="memgraph")
# cursor = conn.cursor()
# assert_db(cursor, "memgraph")
# assert_data(cursor, "memgraph")

# #db1
# conn = mgclient.connect(host=host, port=port, lazy=True, database="db1")
# cursor = conn.cursor()
# assert_db(cursor, "db1")
# assert_data(cursor, "db1")

# #db2
# conn = mgclient.connect(host=host, port=port, lazy=True, database="db2")
# cursor = conn.cursor()
# assert_db(cursor, "db2")
# assert_data(cursor, "db2")

def test_connect_database_and_block(memgraph_server):
host, port = memgraph_server
conn = mgclient.connect(host=host, port=port, lazy=True)
cursor = conn.cursor()

#setup
assert_db(cursor, "memgraph")
cursor.execute("CREATE DATABASE db1")
cursor.fetchall()
cursor.execute("CREATE DATABASE db2")
cursor.fetchall()

#connection tests
#default <- should allow db switching
conn = mgclient.connect(host=host, port=port, lazy=True)
cursor = conn.cursor()
assert_db(cursor, "memgraph")
cursor.execute("USE DATABASE db1;")
cursor.fetchall()
assert_db(cursor, "db1")
cursor.execute("USE DATABASE db2;")
cursor.fetchall()
assert_db(cursor, "db2")

#memgraph
conn = mgclient.connect(host=host, port=port, lazy=True, database="memgraph")
cursor = conn.cursor()
assert_db(cursor, "memgraph")
with pytest.raises(mgclient.DatabaseError):
cursor.execute("USE DATABASE db2;")
print(cursor.fetchall())

#db1
conn = mgclient.connect(host=host, port=port, lazy=True, database="db1")
cursor = conn.cursor()
assert_db(cursor, "db1")
with pytest.raises(mgclient.DatabaseError):
cursor.execute("USE DATABASE db2;")
cursor.fetchall()

#db2
conn = mgclient.connect(host=host, port=port, lazy=True, database="db2")
cursor = conn.cursor()
assert_db(cursor, "db2")
with pytest.raises(mgclient.DatabaseError):
cursor.execute("USE DATABASE memgraph;")
cursor.fetchall()