Skip to content

Commit

Permalink
feat: LDAP 添加连接池支持 (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
RoninZc authored Jul 24, 2022
1 parent 019f7fd commit 396546d
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 21 deletions.
2 changes: 2 additions & 0 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ email:
ldap:
# ldap服务器地址
url: ldap://localhost:389
# ladp最大连接数设置
max-conn: 10
# ldap服务器基础DN
base-dn: "dc=eryajf,dc=net"
# ldap管理员DN
Expand Down
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ type RateLimitConfig struct {

type LdapConfig struct {
Url string `mapstructure:"url" json:"url"`
MaxConn int `mapstructure:"max-conn" json:"maxConn"`
BaseDN string `mapstructure:"base-dn" json:"baseDN"`
AdminDN string `mapstructure:"admin-dn" json:"adminDN"`
AdminPass string `mapstructure:"admin-pass" json:"adminPass"`
Expand Down
30 changes: 27 additions & 3 deletions public/client/openldap/openldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ func GetAllDepts() (ret []*Dept, err error) {
[]string{}, // Here are the attributes returned by the query, provided as an array. If empty, all attributes are returned
nil,
)

// 获取 LDAP 连接
conn, err := common.GetLDAPConn()
defer common.PutLADPConn(conn)
if err != nil {
return nil, err
}

// Search through ldap built-in search
sr, err := common.LDAP.Search(searchRequest)
sr, err := conn.Search(searchRequest)
if err != nil {
return ret, err
}
Expand Down Expand Up @@ -81,8 +89,16 @@ func GetAllUsers() (ret []*User, err error) {
[]string{}, // Here are the attributes returned by the query, provided as an array. If empty, all attributes are returned
nil,
)

// 获取 LDAP 连接
conn, err := common.GetLDAPConn()
defer common.PutLADPConn(conn)
if err != nil {
return nil, err
}

// Search through ldap built-in search
sr, err := common.LDAP.Search(searchRequest)
sr, err := conn.Search(searchRequest)
if err != nil {
return ret, err
}
Expand Down Expand Up @@ -128,8 +144,16 @@ func GetUserDeptIds(udn string) (ret []string, err error) {
[]string{}, // Here are the attributes returned by the query, provided as an array. If empty, all attributes are returned
nil,
)

// 获取 LDAP 连接
conn, err := common.GetLDAPConn()
defer common.PutLADPConn(conn)
if err != nil {
return nil, err
}

// Search through ldap built-in search
sr, err := common.LDAP.Search(searchRequest)
sr, err := conn.Search(searchRequest)
if err != nil {
return ret, err
}
Expand Down
130 changes: 124 additions & 6 deletions public/common/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,51 @@ package common

import (
"fmt"
"log"
"math/rand"
"net"
"sync"
"time"

"github.com/eryajf/go-ldap-admin/config"

ldap "github.com/go-ldap/ldap/v3"
)

// 全局ldap数据库变量
var LDAP *ldap.Conn
var ldapPool *LdapConnPool
var ldapInit = false
var ldapInitOne sync.Once

// Init 初始化连接
func InitLDAP() {
if ldapInit {
return
}

ldapInitOne.Do(func() {
ldapInit = true
})

// Dail有两个参数 network, address, 返回 (*Conn, error)
ldap, err := ldap.DialURL(config.Conf.Ldap.Url, ldap.DialWithDialer(&net.Dialer{Timeout: 5 * time.Second}))
ldapConn, err := ldap.DialURL(config.Conf.Ldap.Url, ldap.DialWithDialer(&net.Dialer{Timeout: 5 * time.Second}))
if err != nil {
Log.Panicf("初始化ldap连接异常: %v", err)
panic(fmt.Errorf("初始化ldap连接异常: %v", err))
}
err = ldap.Bind(config.Conf.Ldap.AdminDN, config.Conf.Ldap.AdminPass)
err = ldapConn.Bind(config.Conf.Ldap.AdminDN, config.Conf.Ldap.AdminPass)
if err != nil {
Log.Panicf("绑定admin账号异常: %v", err)
panic(fmt.Errorf("绑定admin账号异常: %v", err))
}

// 全局LDAP赋值
LDAP = ldap
// 全局变量赋值
ldapPool = &LdapConnPool{
conns: make([]*ldap.Conn, 0),
reqConns: make(map[uint64]chan *ldap.Conn),
openConn: 0,
maxOpen: config.Conf.Ldap.MaxConn,
}
PutLADPConn(ldapConn)

// 隐藏密码
showDsn := fmt.Sprintf(
Expand All @@ -39,3 +57,103 @@ func InitLDAP() {

Log.Info("初始化ldap完成! dsn: ", showDsn)
}

// GetLDAPConn 获取 LDAP 连接
func GetLDAPConn() (*ldap.Conn, error) {
return ldapPool.GetConnection()
}

// PutLDAPConn 放回 LDAP 连接
func PutLADPConn(conn *ldap.Conn) {
ldapPool.PutConnection(conn)
}

type LdapConnPool struct {
mu sync.Mutex
conns []*ldap.Conn
reqConns map[uint64]chan *ldap.Conn
openConn int
maxOpen int
}

// 获取一个 ladp Conn
func (lcp *LdapConnPool) GetConnection() (*ldap.Conn, error) {
lcp.mu.Lock()
// 判断当前连接池内是否存在连接
connNum := len(lcp.conns)
if connNum > 0 {
lcp.openConn++
conn := lcp.conns[0]
copy(lcp.conns, lcp.conns[1:])
lcp.conns = lcp.conns[:connNum-1]

lcp.mu.Unlock()
// 发现连接已经 close 重新获取连接
if conn.IsClosing() {
return initLDAPConn()
}
return conn, nil
}

// 当现有连接池为空时,并且当前超过最大连接限制
if lcp.maxOpen != 0 && lcp.openConn > lcp.maxOpen {
// 创建一个等待队列
req := make(chan *ldap.Conn, 1)
reqKey := lcp.nextRequestKeyLocked()
lcp.reqConns[reqKey] = req
lcp.mu.Unlock()

// 等待请求归还
return <-req, nil
} else {
lcp.openConn++
lcp.mu.Unlock()
return initLDAPConn()
}
}

func (lcp *LdapConnPool) PutConnection(conn *ldap.Conn) {
log.Println("放回了一个连接")
lcp.mu.Lock()
defer lcp.mu.Unlock()

// 先判断是否存在等待的队列
if num := len(lcp.reqConns); num > 0 {
var req chan *ldap.Conn
var reqKey uint64
for reqKey, req = range lcp.reqConns {
break
}
delete(lcp.reqConns, reqKey)
req <- conn
return
} else {
lcp.openConn--
if !conn.IsClosing() {
lcp.conns = append(lcp.conns, conn)
}
}
}

// 获取下一个请求令牌
func (lcp *LdapConnPool) nextRequestKeyLocked() uint64 {
for {
reqKey := rand.Uint64()
if _, ok := lcp.reqConns[reqKey]; !ok {
return reqKey
}
}
}

// 获取 ladp 连接
func initLDAPConn() (*ldap.Conn, error) {
ldap, err := ldap.DialURL(config.Conf.Ldap.Url, ldap.DialWithDialer(&net.Dialer{Timeout: 5 * time.Second}))
if err != nil {
return nil, err
}
err = ldap.Bind(config.Conf.Ldap.AdminDN, config.Conf.Ldap.AdminPass)
if err != nil {
return nil, err
}
return ldap, err
}
51 changes: 45 additions & 6 deletions service/ildap/group_ildap.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,36 @@ func (x GroupService) Add(g *model.Group) error { //organizationalUnit
add.Attribute(g.GroupType, []string{g.GroupName})
add.Attribute("description", []string{g.Remark})

return common.LDAP.Add(add)
// 获取 LDAP 连接
conn, err := common.GetLDAPConn()
defer common.PutLADPConn(conn)
if err != nil {
return err
}

return conn.Add(add)
}

// UpdateGroup 更新一个分组
func (x GroupService) Update(oldGroup, newGroup *model.Group) error {
modify := ldap.NewModifyRequest(oldGroup.GroupDN, nil)
modify.Replace("description", []string{newGroup.Remark})
err := common.LDAP.Modify(modify)

// 获取 LDAP 连接
conn, err := common.GetLDAPConn()
defer common.PutLADPConn(conn)
if err != nil {
return err
}

err = conn.Modify(modify)
if err != nil {
return err
}
// 如果配置文件允许修改分组名称,且分组名称发生了变化,那么执行修改分组名称
if config.Conf.Ldap.GroupNameModify && newGroup.GroupName != oldGroup.GroupName {
modify := ldap.NewModifyDNRequest(oldGroup.GroupDN, newGroup.GroupDN, true, "")
err := common.LDAP.ModifyDN(modify)
err := conn.ModifyDN(modify)
if err != nil {
return err
}
Expand All @@ -53,7 +68,15 @@ func (x GroupService) Update(oldGroup, newGroup *model.Group) error {
// Delete 删除资源
func (x GroupService) Delete(gdn string) error {
del := ldap.NewDelRequest(gdn, nil)
return common.LDAP.Del(del)

// 获取 LDAP 连接
conn, err := common.GetLDAPConn()
defer common.PutLADPConn(conn)
if err != nil {
return err
}

return conn.Del(del)
}

// AddUserToGroup 添加用户到分组
Expand All @@ -64,12 +87,28 @@ func (x GroupService) AddUserToGroup(dn, udn string) error {
}
newmr := ldap.NewModifyRequest(dn, nil)
newmr.Add("uniqueMember", []string{udn})
return common.LDAP.Modify(newmr)

// 获取 LDAP 连接
conn, err := common.GetLDAPConn()
defer common.PutLADPConn(conn)
if err != nil {
return err
}

return conn.Modify(newmr)
}

// DelUserFromGroup 将用户从分组删除
func (x GroupService) RemoveUserFromGroup(gdn, udn string) error {
newmr := ldap.NewModifyRequest(gdn, nil)
newmr.Delete("uniqueMember", []string{udn})
return common.LDAP.Modify(newmr)

// 获取 LDAP 连接
conn, err := common.GetLDAPConn()
defer common.PutLADPConn(conn)
if err != nil {
return err
}

return conn.Modify(newmr)
}
Loading

0 comments on commit 396546d

Please sign in to comment.