Allow chaining of checks

This commit is contained in:
Bolke de Bruin 2022-08-25 12:12:21 +02:00
parent 9d2dc57e90
commit 768ee45974
3 changed files with 34 additions and 31 deletions

View File

@ -156,9 +156,9 @@ func main() {
}
if conf.Caps.TokenAuth {
gwConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken
gwConfig.VerifyServerFunc = security.VerifyServerFunc
gwConfig.VerifyServerFunc = security.CheckSession(security.CheckHost)
} else {
gwConfig.VerifyServerFunc = security.BasicVerifyServer
gwConfig.VerifyServerFunc = security.CheckHost
}
gw := protocol.Gateway{
ServerConf: &gwConfig,

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"strings"
)
var (
@ -12,19 +13,20 @@ var (
HostSelection string
)
func BasicVerifyServer(ctx context.Context, host string) (bool, error) {
if HostSelection == "any" {
func CheckHost(ctx context.Context, host string) (bool, error) {
switch HostSelection {
case "any":
return true, nil
}
if HostSelection == "signed" {
// todo get from context
case "signed":
// todo get from context?
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
}
if HostSelection == "roundrobin" || HostSelection == "unsigned" {
case "roundrobin", "unsigned":
log.Printf("Checking host")
username := ctx.Value("preferred_username").(string)
for _, h := range Hosts {
if username != "" {
h = strings.Replace(h, "{{ preferred_username }}", username, 1)
}
if h == host {
return true, nil
}

View File

@ -33,6 +33,27 @@ type customClaims struct {
AccessToken string `json:"accessToken"`
}
func CheckSession(next protocol.VerifyServerFunc) protocol.VerifyServerFunc {
return func(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx)
if s == nil {
return false, errors.New("no valid session info found in context")
}
if s.RemoteServer != host {
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s",
common.GetClientIp(ctx), s.ClientIp)
return false, nil
}
return next(ctx, host)
}
}
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
if tokenString == "" {
log.Printf("no token to parse")
@ -91,26 +112,6 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
return true, nil
}
func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx)
if s == nil {
return false, errors.New("no valid session info found in context")
}
if s.RemoteServer != host {
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s",
common.GetClientIp(ctx), s.ClientIp)
return false, nil
}
return true, nil
}
func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) {
if len(SigningKey) < 32 {
return "", errors.New("token signing key not long enough or not specified")