155 lines
4.1 KiB
Go
155 lines
4.1 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/bolkedebruin/rdpgw/cmd/auth/config"
|
||
"github.com/bolkedebruin/rdpgw/cmd/auth/database"
|
||
"github.com/bolkedebruin/rdpgw/cmd/auth/ntlm"
|
||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||
"github.com/msteinert/pam/v2"
|
||
"github.com/thought-machine/go-flags"
|
||
"google.golang.org/grpc"
|
||
"log"
|
||
"net"
|
||
"os"
|
||
"syscall"
|
||
)
|
||
|
||
const (
|
||
protocol = "unix"
|
||
)
|
||
|
||
var opts struct {
|
||
ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"`
|
||
SocketAddr string `short:"s" long:"socket" default:"/tmp/rdpgw-auth.sock" description:"the location of the socket"`
|
||
ConfigFile string `short:"c" long:"conf" default:"rdpgw-auth.yaml" description:"users config file for NTLM (yaml)"`
|
||
}
|
||
|
||
type AuthServiceImpl struct {
|
||
auth.UnimplementedAuthenticateServer
|
||
|
||
serviceName string
|
||
ntlm *ntlm.NTLMAuth
|
||
}
|
||
|
||
var conf config.Configuration
|
||
var _ auth.AuthenticateServer = (*AuthServiceImpl)(nil)
|
||
|
||
func NewAuthService(serviceName string, database database.Database) auth.AuthenticateServer {
|
||
s := &AuthServiceImpl{
|
||
serviceName: serviceName,
|
||
ntlm: ntlm.NewNTLMAuth(database),
|
||
}
|
||
return s
|
||
}
|
||
|
||
func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPass) (*auth.AuthResponse, error) {
|
||
t, err := pam.StartFunc(s.serviceName, message.Username, func(s pam.Style, msg string) (string, error) {
|
||
switch s {
|
||
case pam.PromptEchoOff:
|
||
return message.Password, nil
|
||
case pam.PromptEchoOn, pam.ErrorMsg, pam.TextInfo:
|
||
return "", nil
|
||
}
|
||
return "", errors.New("unrecognized PAM message style")
|
||
})
|
||
|
||
r := &auth.AuthResponse{}
|
||
r.Authenticated = false
|
||
|
||
if err != nil {
|
||
log.Printf("Error authenticating user: %s due to: %s", message.Username, err)
|
||
r.Error = err.Error()
|
||
return r, err
|
||
}
|
||
defer func() {
|
||
err := t.End()
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "end: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
}()
|
||
if err = t.Authenticate(0); err != nil {
|
||
log.Printf("Authentication for user: %s failed due to: %s", message.Username, err)
|
||
r.Error = err.Error()
|
||
return r, nil
|
||
}
|
||
|
||
if err = t.AcctMgmt(0); err != nil {
|
||
log.Printf("Account authorization for user: %s failed due to %s", message.Username, err)
|
||
r.Error = err.Error()
|
||
return r, nil
|
||
}
|
||
|
||
log.Printf("User: %s authenticated", message.Username)
|
||
r.Authenticated = true
|
||
return r, nil
|
||
}
|
||
|
||
func (s *AuthServiceImpl) NTLM(ctx context.Context, message *auth.NtlmRequest) (*auth.NtlmResponse, error) {
|
||
r, err := s.ntlm.Authenticate(message)
|
||
|
||
if err != nil {
|
||
log.Printf("[%s] NTLM failed: %s", message.Session, err)
|
||
} else if r.Authenticated {
|
||
log.Printf("[%s] User: %s authenticated using NTLM", message.Session, r.Username)
|
||
} else if r.NtlmMessage != "" {
|
||
log.Printf("[%s] Sending NTLM challenge", message.Session)
|
||
}
|
||
|
||
return r, err
|
||
}
|
||
|
||
func main() {
|
||
_, err := flags.Parse(&opts)
|
||
if err != nil {
|
||
var fErr *flags.Error
|
||
if errors.As(err, &fErr) {
|
||
if fErr.Type == flags.ErrHelp {
|
||
fmt.Printf("Acknowledgements:\n")
|
||
fmt.Printf(" - This product includes software developed by the Thomson Reuters Global Resources. (go-ntlm - https://github.com/m7913d/go-ntlm - BSD-4 License)\n")
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
conf = config.Load(opts.ConfigFile)
|
||
|
||
log.Printf("Starting auth server on %s", opts.SocketAddr)
|
||
cleanup := func() {
|
||
if _, err := os.Stat(opts.SocketAddr); err == nil {
|
||
if err := os.RemoveAll(opts.SocketAddr); err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
}
|
||
}
|
||
cleanup()
|
||
|
||
oldUmask := syscall.Umask(0)
|
||
listener, err := net.Listen(protocol, opts.SocketAddr)
|
||
syscall.Umask(oldUmask)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
server := grpc.NewServer()
|
||
|
||
// 根据配置选择使用API认证或本地配置认证
|
||
var db database.Database
|
||
log.Printf("使用API认证,API URL: %s", conf.ApiAuth.ApiUrl)
|
||
log.Printf("使用API认证,API URL: %s", conf.ApiAuth.Enabled)
|
||
if conf.ApiAuth.Enabled && conf.ApiAuth.ApiUrl != "" {
|
||
log.Printf("使用API认证,API URL: %s", conf.ApiAuth.ApiUrl)
|
||
db = database.NewApiDb(conf.ApiAuth.ApiUrl)
|
||
} else {
|
||
log.Printf("使用本地配置文件认证")
|
||
db = database.NewConfig(conf.Users)
|
||
}
|
||
|
||
service := NewAuthService(opts.ServiceName, db)
|
||
auth.RegisterAuthenticateServer(server, service)
|
||
server.Serve(listener)
|
||
}
|