165 lines
4.9 KiB
Go
165 lines
4.9 KiB
Go
package ntlm
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"errors"
|
|
"github.com/bolkedebruin/rdpgw/cmd/auth/database"
|
|
"github.com/bolkedebruin/rdpgw/shared/auth"
|
|
"github.com/patrickmn/go-cache"
|
|
"github.com/m7913d/go-ntlm/ntlm"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
cacheExpiration = time.Minute
|
|
cleanupInterval = time.Minute * 5
|
|
)
|
|
|
|
type NTLMAuth struct {
|
|
contextCache *cache.Cache
|
|
|
|
// Information about the server, returned to the client during authentication
|
|
ServerName string // e.g. EXAMPLE1
|
|
DomainName string // e.g. EXAMPLE
|
|
DnsServerName string // e.g. example1.example.com
|
|
DnsDomainName string // e.g. example.com
|
|
DnsTreeName string // e.g. example.com
|
|
|
|
Database database.Database
|
|
}
|
|
|
|
func NewNTLMAuth (database database.Database) (*NTLMAuth) {
|
|
return &NTLMAuth{
|
|
contextCache: cache.New(cacheExpiration, cleanupInterval),
|
|
Database: database,
|
|
}
|
|
}
|
|
|
|
func (h *NTLMAuth) Authenticate(message *auth.NtlmRequest) (*auth.NtlmResponse, error) {
|
|
r := &auth.NtlmResponse{}
|
|
r.Authenticated = false
|
|
|
|
if message.Session == "" {
|
|
return r, errors.New("Invalid (empty) session specified")
|
|
}
|
|
|
|
if message.NtlmMessage == "" {
|
|
return r, errors.New("Empty NTLM message specified")
|
|
}
|
|
|
|
c := h.getContext(message.Session)
|
|
err := c.Authenticate(message.NtlmMessage, r)
|
|
|
|
if err != nil || r.Authenticated {
|
|
h.removeContext(message.Session)
|
|
}
|
|
|
|
return r, err
|
|
}
|
|
|
|
func (h *NTLMAuth) getContext (session string) (*ntlmContext) {
|
|
if c_, found := h.contextCache.Get(session); found {
|
|
if c, ok := c_.(*ntlmContext); ok {
|
|
return c
|
|
}
|
|
}
|
|
c := new(ntlmContext)
|
|
c.h = h
|
|
h.contextCache.Set(session, c, cache.DefaultExpiration)
|
|
return c
|
|
}
|
|
|
|
func (h *NTLMAuth) removeContext (session string) {
|
|
h.contextCache.Delete(session)
|
|
}
|
|
|
|
type ntlmContext struct {
|
|
session ntlm.ServerSession
|
|
h *NTLMAuth
|
|
}
|
|
|
|
func (c *ntlmContext) Authenticate(authorisationEncoded string, r *auth.NtlmResponse) (error) {
|
|
authorisation, err := base64.StdEncoding.DecodeString(authorisationEncoded)
|
|
if err != nil {
|
|
return errors.New(fmt.Sprintf("Failed to decode NTLM Authorisation header: %s", err))
|
|
}
|
|
|
|
nm, err := ntlm.ParseNegotiateMessage(authorisation)
|
|
if err == nil {
|
|
return c.negotiate(nm, r)
|
|
}
|
|
if (nm != nil && nm.MessageType == 1) {
|
|
return errors.New(fmt.Sprintf("Failed to parse NTLM Authorisation header: %s", err))
|
|
} else if c.session == nil {
|
|
return errors.New(fmt.Sprintf("New NTLM auth sequence should start with negotioate request"))
|
|
}
|
|
|
|
am, err := ntlm.ParseAuthenticateMessage(authorisation, 2)
|
|
if err == nil {
|
|
return c.authenticate(am, r)
|
|
}
|
|
|
|
return errors.New(fmt.Sprintf("Failed to parse NTLM Authorisation header: %s", err))
|
|
}
|
|
|
|
func (c *ntlmContext) negotiate(nm *ntlm.NegotiateMessage, r *auth.NtlmResponse) (error) {
|
|
session, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode)
|
|
|
|
if err != nil {
|
|
c.session = nil;
|
|
return errors.New(fmt.Sprintf("Failed to create NTLM server session: %s", err))
|
|
}
|
|
|
|
c.session = session
|
|
c.session.SetRequireNtHash(true)
|
|
c.session.SetDomainName(c.h.DomainName)
|
|
c.session.SetComputerName(c.h.ServerName)
|
|
c.session.SetDnsDomainName(c.h.DnsDomainName)
|
|
c.session.SetDnsComputerName(c.h.DnsServerName)
|
|
c.session.SetDnsTreeName(c.h.DnsTreeName)
|
|
|
|
err = c.session.ProcessNegotiateMessage(nm)
|
|
if err != nil {
|
|
return errors.New(fmt.Sprintf("Failed to process NTLM negotiate message: %s", err))
|
|
}
|
|
|
|
cm, err := c.session.GenerateChallengeMessage()
|
|
if err != nil {
|
|
return errors.New(fmt.Sprintf("Failed to generate NTLM challenge message: %s", err))
|
|
}
|
|
|
|
r.NtlmMessage = base64.StdEncoding.EncodeToString(cm.Bytes())
|
|
return nil
|
|
}
|
|
|
|
func (c *ntlmContext) authenticate(am *ntlm.AuthenticateMessage, r *auth.NtlmResponse) (error) {
|
|
if c.session == nil {
|
|
return errors.New(fmt.Sprintf("NTLM Authenticate requires active session: first call negotioate"))
|
|
}
|
|
|
|
username := am.UserName.String()
|
|
log.Printf("NTLM: Trying to validate user: %s", username)
|
|
|
|
password := c.h.Database.GetPassword(username)
|
|
if password == "" {
|
|
log.Printf("NTLM: unknown username specified: %s", username)
|
|
return nil
|
|
}
|
|
|
|
log.Printf("NTLM: Successfully retrieved password for user: %s", username)
|
|
c.session.SetUserInfo(username, password, "")
|
|
|
|
err := c.session.ProcessAuthenticateMessage(am)
|
|
if err != nil {
|
|
log.Printf("Failed to process NTLM authenticate message: %s", err)
|
|
return nil
|
|
}
|
|
|
|
r.Authenticated = true
|
|
r.Username = username
|
|
log.Printf("NTLM: User %s authenticated successfully", username)
|
|
return nil
|
|
}
|