Files
cloudpods/pkg/util/ldaputils/ldaputils.go
2025-08-26 13:28:32 +08:00

257 lines
5.9 KiB
Go

// Copyright 2019 Yunion
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ldaputils
import (
"crypto/tls"
"encoding/hex"
"fmt"
"strings"
"github.com/go-ldap/ldap/v3"
"yunion.io/x/log"
"yunion.io/x/pkg/errors"
)
var (
ErrUserNotFound = errors.Error("not found")
ErrUserDuplicate = errors.Error("user id duplicate")
ErrUserBadCredential = errors.Error("bad credential")
binaryAttributes = []string{
"objectGUID",
"objectSid",
}
)
type SLDAPClient struct {
url string
account string
password string
baseDN string
isDebug bool
conn *ldap.Conn
}
func NewLDAPClient(url, account, password string, baseDN string, isDebug bool) *SLDAPClient {
return &SLDAPClient{
url: url,
account: account,
password: password,
baseDN: baseDN,
isDebug: isDebug,
}
}
func (cli *SLDAPClient) Connect() error {
var err error
cli.conn, err = ldap.DialURL(cli.url, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
return errors.Wrap(err, "DiaURL")
}
return cli.bind()
}
func (cli *SLDAPClient) bind() error {
if len(cli.account) > 0 {
err := cli.conn.Bind(cli.account, cli.password)
if err != nil {
return errors.Wrap(err, "Bind")
}
}
return nil
}
func (cli *SLDAPClient) Close() {
if cli.conn != nil {
cli.conn.Close()
cli.conn = nil
}
}
func (cli *SLDAPClient) Authenticate(baseDN string, objClass string, uidAttr string, uname string, passwd string, filter string, fields []string, queryScope int) (*ldap.Entry, error) {
attrMap := make(map[string]string)
attrMap[uidAttr] = uname
var retEntry *ldap.Entry
total, err := cli.Search(
baseDN, objClass, attrMap, filter, fields, queryScope,
2, 2,
func(offset uint32, entry *ldap.Entry) error {
retEntry = entry
return nil
},
)
if err != nil {
return nil, errors.Wrap(err, "Search")
}
if total == 0 {
return nil, ErrUserNotFound
}
if total > 1 {
return nil, ErrUserDuplicate
}
defer cli.bind()
err = cli.conn.Bind(retEntry.DN, passwd)
if err != nil {
return nil, ErrUserBadCredential
}
return retEntry, nil
}
const (
StopSearch = errors.Error("stop dap search")
)
// support pagination
// reference: https://zerokspot.com/weblog/2018/03/07/paging-in-gopkg-ldap/
func (cli *SLDAPClient) Search(
base string, objClass string,
condition map[string]string,
filter string,
fields []string,
queryScope int,
pageSize uint32, limit uint32,
entryFunc func(offset uint32, entry *ldap.Entry) error,
) (uint32, error) {
searches := strings.Builder{}
if len(condition) == 0 && len(objClass) == 0 {
searches.WriteString("(objectClass=*)")
}
if len(objClass) > 0 {
searches.WriteString("(objectClass=")
searches.WriteString(objClass)
searches.WriteString(")")
}
for k, v := range condition {
searches.WriteString("(")
searches.WriteString(k)
searches.WriteString("=")
if isBinaryAttr(k) {
v = toBinarySearchString(v)
}
searches.WriteString(v)
searches.WriteString(")")
}
if len(filter) > 0 && strings.HasPrefix(filter, "(") && strings.HasSuffix(filter, ")") {
searches.WriteString(filter)
}
searchStr := fmt.Sprintf("(&%s)", searches.String())
if len(base) == 0 {
base = cli.baseDN
}
if queryScope != ldap.ScopeWholeSubtree && queryScope != ldap.ScopeSingleLevel && queryScope != ldap.ScopeBaseObject {
queryScope = ldap.ScopeWholeSubtree
}
log.Debugf("ldapSearch: %s", searchStr)
offset := uint32(0)
paging := ldap.NewControlPaging(pageSize)
for {
searchRequest := ldap.NewSearchRequest(
base, // The base dn to search
queryScope, ldap.NeverDerefAliases, 0, 0, false,
searchStr,
fields, // A list attributes to retrieve
[]ldap.Control{paging},
)
searchResult, err := cli.conn.Search(searchRequest)
if err != nil {
return offset, errors.Wrap(err, "Seearch")
}
pageTotal := len(searchResult.Entries)
for i := 0; i < pageTotal; i++ {
entry := searchResult.Entries[i]
err := entryFunc(offset, entry)
if err != nil && errors.Cause(err) != StopSearch {
return offset, errors.Wrapf(err, "process entry fail at %d-%d-%d", pageTotal, i, offset)
}
offset += 1
if (err != nil && errors.Cause(err) == StopSearch) || (limit > 0 && offset >= limit) {
// stop, offset exceeds limit or receive StopSearch error
return offset, nil
}
}
resultCtrl := ldap.FindControl(searchResult.Controls, paging.GetControlType())
if resultCtrl == nil {
break
}
if pagingCtrl, ok := resultCtrl.(*ldap.ControlPaging); ok {
if len(pagingCtrl.Cookie) == 0 {
break
}
paging.SetCookie(pagingCtrl.Cookie)
}
}
return offset, nil
}
func isBinaryAttr(attrName string) bool {
for _, attr := range binaryAttributes {
if strings.EqualFold(attr, attrName) {
return true
}
}
return false
}
func toBinary(val string) string {
ret, err := hex.DecodeString(val)
if err != nil {
return val
} else {
return string(ret)
}
}
func toBinarySearchString(val string) string {
ret := strings.Builder{}
for i := 0; i < len(val); i += 2 {
ret.WriteString(`\`)
ret.WriteString(val[i : i+2])
}
return ret.String()
}
func toHex(val string) string {
return hex.EncodeToString([]byte(val))
}
func GetAttributeValue(e *ldap.Entry, key string) string {
val := e.GetAttributeValue(key)
if isBinaryAttr(key) {
val = toHex(val)
}
return val
}
func GetAttributeValues(e *ldap.Entry, key string) []string {
vals := e.GetAttributeValues(key)
if isBinaryAttr(key) {
for i := range vals {
vals[i] = toHex(vals[i])
}
}
return vals
}