diff --git a/src/main/java/org/casbin/adapter/JDBCAdapter.java b/src/main/java/org/casbin/adapter/JDBCAdapter.java index c29aa1b..eac56e9 100644 --- a/src/main/java/org/casbin/adapter/JDBCAdapter.java +++ b/src/main/java/org/casbin/adapter/JDBCAdapter.java @@ -21,7 +21,6 @@ import org.casbin.jcasbin.persist.Helper; import javax.sql.DataSource; -import java.math.BigDecimal; import java.sql.*; import java.util.*; @@ -41,8 +40,9 @@ class CasbinRule { * It can load policy from JDBC supported database or save policy to it. */ public class JDBCAdapter implements Adapter { - private DataSource dataSource = null; + private DataSource dataSource; private final int batchSize = 1000; + private Connection conn; /** * JDBCAdapter is the constructor for JDBCAdapter. @@ -67,62 +67,62 @@ public JDBCAdapter(DataSource dataSource) throws Exception { } private void migrate() throws SQLException { - try (Connection conn = dataSource.getConnection(); Statement stmt = conn.createStatement()) { - String sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY auto_increment, ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))"; - String productName = conn.getMetaData().getDatabaseProductName(); - - switch (productName) { - case "Oracle": - sql = "declare begin execute immediate 'CREATE TABLE CASBIN_RULE(id NUMBER(5, 0) not NULL primary key, ptype VARCHAR(100) not NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))'; " + - "exception when others then " + - "if SQLCODE = -955 then " + - "null; " + - "else raise; " + - "end if; " + - "end;"; - break; - case "Microsoft SQL Server": - sql = "IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='casbin_rule' and xtype='U') CREATE TABLE casbin_rule(id int NOT NULL primary key identity(1, 1), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))"; - break; - case "PostgreSQL": - sql = "CREATE SEQUENCE IF NOT EXISTS CASBIN_SEQUENCE START 1;"; - break; - } + conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + String sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY auto_increment, ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))"; + String productName = conn.getMetaData().getDatabaseProductName(); + + switch (productName) { + case "Oracle": + sql = "declare begin execute immediate 'CREATE TABLE CASBIN_RULE(id NUMBER(5, 0) not NULL primary key, ptype VARCHAR(100) not NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))'; " + + "exception when others then " + + "if SQLCODE = -955 then " + + "null; " + + "else raise; " + + "end if; " + + "end;"; + break; + case "Microsoft SQL Server": + sql = "IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='casbin_rule' and xtype='U') CREATE TABLE casbin_rule(id int NOT NULL primary key identity(1, 1), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))"; + break; + case "PostgreSQL": + sql = "CREATE SEQUENCE IF NOT EXISTS CASBIN_SEQUENCE START 1;"; + break; + } + stmt.executeUpdate(sql); + if (productName.equals("Oracle")) { + sql = "declare " + + "V_NUM number;" + + "BEGIN " + + "V_NUM := 0; " + + "select count(0) into V_NUM from user_sequences where sequence_name = 'CASBIN_SEQUENCE';" + + "if V_NUM > 0 then " + + "null;" + + "else " + + "execute immediate 'CREATE SEQUENCE casbin_sequence increment by 1 start with 1 nomaxvalue nocycle nocache';" + + "end if;END;"; + stmt.executeUpdate(sql); + sql = "declare " + + "V_NUM number;" + + "BEGIN " + + "V_NUM := 0;" + + "select count(0) into V_NUM from user_triggers where trigger_name = 'CASBIN_ID_AUTOINCREMENT';" + + "if V_NUM > 0 then " + + "null;" + + "else " + + "execute immediate 'create trigger casbin_id_autoincrement before "+ + " insert on CASBIN_RULE for each row "+ + " when (new.id is null) "+ + " begin "+ + " select casbin_sequence.nextval into:new.id from dual;"+ + " end;';" + + "end if;" + + "END;"; + stmt.executeUpdate(sql); + } else if (productName.equals("PostgreSQL")) { + sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY default nextval('CASBIN_SEQUENCE'::regclass), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))"; stmt.executeUpdate(sql); - if (productName.equals("Oracle")) { - sql = "declare " + - "V_NUM number;" + - "BEGIN " + - "V_NUM := 0; " + - "select count(0) into V_NUM from user_sequences where sequence_name = 'CASBIN_SEQUENCE';" + - "if V_NUM > 0 then " + - "null;" + - "else " + - "execute immediate 'CREATE SEQUENCE casbin_sequence increment by 1 start with 1 nomaxvalue nocycle nocache';" + - "end if;END;"; - stmt.executeUpdate(sql); - sql = "declare " + - "V_NUM number;" + - "BEGIN " + - "V_NUM := 0;" + - "select count(0) into V_NUM from user_triggers where trigger_name = 'CASBIN_ID_AUTOINCREMENT';" + - "if V_NUM > 0 then " + - "null;" + - "else " + - "execute immediate 'create trigger casbin_id_autoincrement before "+ - " insert on CASBIN_RULE for each row "+ - " when (new.id is null) "+ - " begin "+ - " select casbin_sequence.nextval into:new.id from dual;"+ - " end;';" + - "end if;" + - "END;"; - stmt.executeUpdate(sql); - } else if (productName.equals("PostgreSQL")) { - sql = "CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY default nextval('CASBIN_SEQUENCE'::regclass), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))"; - stmt.executeUpdate(sql); - } } } @@ -155,8 +155,7 @@ private void loadPolicyLine(CasbinRule line, Model model) { */ @Override public void loadPolicy(Model model) { - try (Connection conn = dataSource.getConnection()) { - Statement stmt = conn.createStatement(); + try (Statement stmt = conn.createStatement()) { ResultSet rSet = stmt.executeQuery("SELECT * FROM casbin_rule"); ResultSetMetaData rData = rSet.getMetaData(); while (rSet.next()) { @@ -221,7 +220,7 @@ public void savePolicy(Model model) { String cleanSql = "delete from casbin_rule"; String addSql = "INSERT INTO casbin_rule (ptype,v0,v1,v2,v3,v4,v5) VALUES(?,?,?,?,?,?,?)"; - try (Connection conn = dataSource.getConnection()) { + try { conn.setAutoCommit(false); int count = 0; @@ -298,7 +297,7 @@ public void addPolicy(String sec, String ptype, List rule) { String sql = "INSERT INTO casbin_rule (ptype,v0,v1,v2,v3,v4,v5) VALUES(?,?,?,?,?,?,?)"; - try (Connection conn = dataSource.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { + try(PreparedStatement ps = conn.prepareStatement(sql)) { CasbinRule line = savePolicyLine(ptype, rule); ps.setString(1, line.ptype); @@ -309,7 +308,6 @@ public void addPolicy(String sec, String ptype, List rule) { ps.setString(6, line.v4); ps.setString(7, line.v5); ps.addBatch(); - ps.executeBatch(); } catch (SQLException e) { e.printStackTrace(); @@ -340,7 +338,7 @@ public void removeFilteredPolicy(String sec, String ptype, int fieldIndex, Strin columnIndex++; } - try (Connection conn = dataSource.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { + try (PreparedStatement ps = conn.prepareStatement(sql)) { ps.setString(1, ptype); for (int j = 0; j < values.size(); j++) { ps.setString(j + 2, values.get(j)); @@ -354,4 +352,11 @@ public void removeFilteredPolicy(String sec, String ptype, int fieldIndex, Strin throw new Error(e); } } + + /** + * Close the Connection. + */ + public void close() throws SQLException { + conn.close(); + } } diff --git a/src/main/java/org/casbin/adapter/JDBCDataSource.java b/src/main/java/org/casbin/adapter/JDBCDataSource.java index 7c1e674..69b8984 100644 --- a/src/main/java/org/casbin/adapter/JDBCDataSource.java +++ b/src/main/java/org/casbin/adapter/JDBCDataSource.java @@ -22,7 +22,6 @@ import java.sql.SQLFeatureNotSupportedException; import java.util.logging.Logger; - public class JDBCDataSource implements DataSource { private String driver; private String url; @@ -64,7 +63,7 @@ public PrintWriter getLogWriter() throws SQLException { } @Override - public void setLogWriter(PrintWriter out) throws SQLException { + public void setLogWriter(PrintWriter out) throws SQLException{ } diff --git a/src/test/java/org/casbin/adapter/JDBCAdapterTest.java b/src/test/java/org/casbin/adapter/JDBCAdapterTest.java index 2491ecb..031fcd2 100644 --- a/src/test/java/org/casbin/adapter/JDBCAdapterTest.java +++ b/src/test/java/org/casbin/adapter/JDBCAdapterTest.java @@ -16,6 +16,7 @@ import org.junit.Test; +import java.sql.SQLException; import java.util.ArrayList; import java.util.List; @@ -43,6 +44,14 @@ public void testMySQLAdapter() { } testAdapter(adapters); + + adapters.forEach(adapter -> { + try { + adapter.close(); + } catch (SQLException sqlException) { + sqlException.printStackTrace(); + } + }); } @Test @@ -59,6 +68,14 @@ public void testPgAdapter() { } testAdapter(adapters); + + adapters.forEach(adapter -> { + try { + adapter.close(); + } catch (SQLException sqlException) { + sqlException.printStackTrace(); + } + }); } @Test @@ -75,5 +92,13 @@ public void testSQLServerAdapter() { } testAdapter(adapters); + + adapters.forEach(adapter -> { + try { + adapter.close(); + } catch (SQLException sqlException) { + sqlException.printStackTrace(); + } + }); } }