Skip to content

Commit

Permalink
Merge pull request #23 from shy1st/master
Browse files Browse the repository at this point in the history
Optimize to load policy by reducing the number of database connections.
  • Loading branch information
hsluoyz authored Jan 13, 2021
2 parents 2a21149 + 1e9f3a4 commit 1a5eb49
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 64 deletions.
129 changes: 67 additions & 62 deletions src/main/java/org/casbin/adapter/JDBCAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.casbin.jcasbin.persist.Helper;

import javax.sql.DataSource;
import java.math.BigDecimal;
import java.sql.*;
import java.util.*;

Expand All @@ -41,8 +40,9 @@ class CasbinRule {
* It can load policy from JDBC supported database or save policy to it.
*/
public class JDBCAdapter implements Adapter {
private DataSource dataSource = null;
private DataSource dataSource;
private final int batchSize = 1000;
private Connection conn;

/**
* JDBCAdapter is the constructor for JDBCAdapter.
Expand All @@ -67,62 +67,62 @@ public JDBCAdapter(DataSource dataSource) throws Exception {
}

private void migrate() throws SQLException {
try (Connection conn = dataSource.getConnection(); Statement stmt = conn.createStatement()) {
String sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY auto_increment, ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))";
String productName = conn.getMetaData().getDatabaseProductName();

switch (productName) {
case "Oracle":
sql = "declare begin execute immediate 'CREATE TABLE CASBIN_RULE(id NUMBER(5, 0) not NULL primary key, ptype VARCHAR(100) not NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))'; " +
"exception when others then " +
"if SQLCODE = -955 then " +
"null; " +
"else raise; " +
"end if; " +
"end;";
break;
case "Microsoft SQL Server":
sql = "IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='casbin_rule' and xtype='U') CREATE TABLE casbin_rule(id int NOT NULL primary key identity(1, 1), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))";
break;
case "PostgreSQL":
sql = "CREATE SEQUENCE IF NOT EXISTS CASBIN_SEQUENCE START 1;";
break;
}
conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
String sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY auto_increment, ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))";
String productName = conn.getMetaData().getDatabaseProductName();

switch (productName) {
case "Oracle":
sql = "declare begin execute immediate 'CREATE TABLE CASBIN_RULE(id NUMBER(5, 0) not NULL primary key, ptype VARCHAR(100) not NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))'; " +
"exception when others then " +
"if SQLCODE = -955 then " +
"null; " +
"else raise; " +
"end if; " +
"end;";
break;
case "Microsoft SQL Server":
sql = "IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='casbin_rule' and xtype='U') CREATE TABLE casbin_rule(id int NOT NULL primary key identity(1, 1), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))";
break;
case "PostgreSQL":
sql = "CREATE SEQUENCE IF NOT EXISTS CASBIN_SEQUENCE START 1;";
break;
}

stmt.executeUpdate(sql);
if (productName.equals("Oracle")) {
sql = "declare " +
"V_NUM number;" +
"BEGIN " +
"V_NUM := 0; " +
"select count(0) into V_NUM from user_sequences where sequence_name = 'CASBIN_SEQUENCE';" +
"if V_NUM > 0 then " +
"null;" +
"else " +
"execute immediate 'CREATE SEQUENCE casbin_sequence increment by 1 start with 1 nomaxvalue nocycle nocache';" +
"end if;END;";
stmt.executeUpdate(sql);
sql = "declare " +
"V_NUM number;" +
"BEGIN " +
"V_NUM := 0;" +
"select count(0) into V_NUM from user_triggers where trigger_name = 'CASBIN_ID_AUTOINCREMENT';" +
"if V_NUM > 0 then " +
"null;" +
"else " +
"execute immediate 'create trigger casbin_id_autoincrement before "+
" insert on CASBIN_RULE for each row "+
" when (new.id is null) "+
" begin "+
" select casbin_sequence.nextval into:new.id from dual;"+
" end;';" +
"end if;" +
"END;";
stmt.executeUpdate(sql);
} else if (productName.equals("PostgreSQL")) {
sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY default nextval('CASBIN_SEQUENCE'::regclass), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))";
stmt.executeUpdate(sql);
if (productName.equals("Oracle")) {
sql = "declare " +
"V_NUM number;" +
"BEGIN " +
"V_NUM := 0; " +
"select count(0) into V_NUM from user_sequences where sequence_name = 'CASBIN_SEQUENCE';" +
"if V_NUM > 0 then " +
"null;" +
"else " +
"execute immediate 'CREATE SEQUENCE casbin_sequence increment by 1 start with 1 nomaxvalue nocycle nocache';" +
"end if;END;";
stmt.executeUpdate(sql);
sql = "declare " +
"V_NUM number;" +
"BEGIN " +
"V_NUM := 0;" +
"select count(0) into V_NUM from user_triggers where trigger_name = 'CASBIN_ID_AUTOINCREMENT';" +
"if V_NUM > 0 then " +
"null;" +
"else " +
"execute immediate 'create trigger casbin_id_autoincrement before "+
" insert on CASBIN_RULE for each row "+
" when (new.id is null) "+
" begin "+
" select casbin_sequence.nextval into:new.id from dual;"+
" end;';" +
"end if;" +
"END;";
stmt.executeUpdate(sql);
} else if (productName.equals("PostgreSQL")) {
sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY default nextval('CASBIN_SEQUENCE'::regclass), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))";
stmt.executeUpdate(sql);
}
}
}

Expand Down Expand Up @@ -155,8 +155,7 @@ private void loadPolicyLine(CasbinRule line, Model model) {
*/
@Override
public void loadPolicy(Model model) {
try (Connection conn = dataSource.getConnection()) {
Statement stmt = conn.createStatement();
try (Statement stmt = conn.createStatement()) {
ResultSet rSet = stmt.executeQuery("SELECT * FROM casbin_rule");
ResultSetMetaData rData = rSet.getMetaData();
while (rSet.next()) {
Expand Down Expand Up @@ -221,7 +220,7 @@ public void savePolicy(Model model) {
String cleanSql = "delete from casbin_rule";
String addSql = "INSERT INTO casbin_rule (ptype,v0,v1,v2,v3,v4,v5) VALUES(?,?,?,?,?,?,?)";

try (Connection conn = dataSource.getConnection()) {
try {
conn.setAutoCommit(false);

int count = 0;
Expand Down Expand Up @@ -298,7 +297,7 @@ public void addPolicy(String sec, String ptype, List<String> rule) {

String sql = "INSERT INTO casbin_rule (ptype,v0,v1,v2,v3,v4,v5) VALUES(?,?,?,?,?,?,?)";

try (Connection conn = dataSource.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) {
try(PreparedStatement ps = conn.prepareStatement(sql)) {
CasbinRule line = savePolicyLine(ptype, rule);

ps.setString(1, line.ptype);
Expand All @@ -309,7 +308,6 @@ public void addPolicy(String sec, String ptype, List<String> rule) {
ps.setString(6, line.v4);
ps.setString(7, line.v5);
ps.addBatch();

ps.executeBatch();
} catch (SQLException e) {
e.printStackTrace();
Expand Down Expand Up @@ -340,7 +338,7 @@ public void removeFilteredPolicy(String sec, String ptype, int fieldIndex, Strin
columnIndex++;
}

try (Connection conn = dataSource.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) {
try (PreparedStatement ps = conn.prepareStatement(sql)) {
ps.setString(1, ptype);
for (int j = 0; j < values.size(); j++) {
ps.setString(j + 2, values.get(j));
Expand All @@ -354,4 +352,11 @@ public void removeFilteredPolicy(String sec, String ptype, int fieldIndex, Strin
throw new Error(e);
}
}

/**
* Close the Connection.
*/
public void close() throws SQLException {
conn.close();
}
}
3 changes: 1 addition & 2 deletions src/main/java/org/casbin/adapter/JDBCDataSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.sql.SQLFeatureNotSupportedException;
import java.util.logging.Logger;


public class JDBCDataSource implements DataSource {
private String driver;
private String url;
Expand Down Expand Up @@ -64,7 +63,7 @@ public PrintWriter getLogWriter() throws SQLException {
}

@Override
public void setLogWriter(PrintWriter out) throws SQLException {
public void setLogWriter(PrintWriter out) throws SQLException{

}

Expand Down
25 changes: 25 additions & 0 deletions src/test/java/org/casbin/adapter/JDBCAdapterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import org.junit.Test;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -43,6 +44,14 @@ public void testMySQLAdapter() {
}

testAdapter(adapters);

adapters.forEach(adapter -> {
try {
adapter.close();
} catch (SQLException sqlException) {
sqlException.printStackTrace();
}
});
}

@Test
Expand All @@ -59,6 +68,14 @@ public void testPgAdapter() {
}

testAdapter(adapters);

adapters.forEach(adapter -> {
try {
adapter.close();
} catch (SQLException sqlException) {
sqlException.printStackTrace();
}
});
}

@Test
Expand All @@ -75,5 +92,13 @@ public void testSQLServerAdapter() {
}

testAdapter(adapters);

adapters.forEach(adapter -> {
try {
adapter.close();
} catch (SQLException sqlException) {
sqlException.printStackTrace();
}
});
}
}

0 comments on commit 1a5eb49

Please sign in to comment.