Enable signed hosts provied in query parameters

This commit is contained in:
Bolke de Bruin 2022-08-17 19:12:28 +02:00
parent 8bc3e25f83
commit cb8b269478
5 changed files with 132 additions and 15 deletions

View File

@ -27,6 +27,7 @@ const (
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type QueryInfoFunc func(context.Context, string, string) (string, error)
type Config struct {
SessionKey []byte
@ -34,6 +35,8 @@ type Config struct {
SessionStore string
PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc
QueryInfo QueryInfoFunc
QueryTokenIssuer string
EnableUserToken bool
OAuth2Config *oauth2.Config
store sessions.Store
@ -159,39 +162,53 @@ func (c *Config) selectRandomHost() string {
return host
}
func (c *Config) getHost(u *url.URL) (string, error) {
var host string
func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) {
switch c.HostSelection {
case "roundrobin":
host = c.selectRandomHost()
return c.selectRandomHost(), nil
case "signed":
case "unsigned":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer)
if err != nil {
return "", err
}
found := false
for _, check := range c.Hosts {
if check == hosts[0] {
host = hosts[0]
if check == host {
found = true
break
}
}
if !found {
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
log.Printf("Invalid host %s specified in token", hosts[0])
return "", errors.New("invalid host specified in query token")
}
return host, nil
case "unsigned":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
for _, check := range c.Hosts {
if check == hosts[0] {
return hosts[0], nil
}
}
// not found
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
case "any":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
host = hosts[0]
return hosts[0], nil
default:
host = c.selectRandomHost()
return c.selectRandomHost(), nil
}
return host, nil
}
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
@ -205,7 +222,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
}
// determine host to connect to
host, err := c.getHost(r.URL)
host, err := c.getHost(ctx, r.URL)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return

View File

@ -1,12 +1,15 @@
package api
import (
"context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"net/url"
"testing"
)
var (
hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
key = []byte("thisisasessionkeyreplacethisjetzt")
)
func contains(needle string, haystack []string) bool {
@ -19,6 +22,7 @@ func contains(needle string, haystack []string) bool {
}
func TestGetHost(t *testing.T) {
ctx := context.Background()
c := Config{
HostSelection: "roundrobin",
Hosts: hosts,
@ -28,7 +32,7 @@ func TestGetHost(t *testing.T) {
}
vals := u.Query()
host, err := c.getHost(u)
host, err := c.getHost(ctx, u)
if err != nil {
t.Fatalf("#{err}")
}
@ -40,14 +44,14 @@ func TestGetHost(t *testing.T) {
c.HostSelection = "unsigned"
vals.Set("host", "in.valid.host")
u.RawQuery = vals.Encode()
host, err = c.getHost(u)
host, err = c.getHost(ctx, u)
if err == nil {
t.Fatalf("Accepted host %s is not in hosts list", host)
}
vals.Set("host", hosts[0])
u.RawQuery = vals.Encode()
host, err = c.getHost(u)
host, err = c.getHost(ctx, u)
if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
}
@ -55,4 +59,35 @@ func TestGetHost(t *testing.T) {
t.Fatalf("host %s is not equal to input %s", host, hosts[0])
}
// check any
c.HostSelection = "any"
test := "bla.bla.com"
vals.Set("host", test)
u.RawQuery = vals.Encode()
host, err = c.getHost(ctx, u)
if err != nil {
t.Fatalf("%s is not accepted", host)
}
if test != host {
t.Fatalf("Returned host %s is not equal to input host %s", host, test)
}
// check signed
c.HostSelection = "signed"
c.QueryInfo = security.QueryInfo
issuer := "rdpgwtest"
security.QuerySigningKey = key
queryToken, err := security.GenerateQueryToken(ctx, hosts[0], issuer)
if err != nil {
t.Fatalf("cannot generate token")
}
vals.Set("host", queryToken)
u.RawQuery = vals.Encode()
host, err = c.getHost(ctx, u)
if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
}
if host != hosts[0] {
t.Fatalf("%s does not equal %s", host, hosts[0])
}
}

View File

@ -58,6 +58,8 @@ type SecurityConfig struct {
PAATokenSigningKey string `koanf:"paatokensigningkey"`
UserTokenEncryptionKey string `koanf:"usertokenencryptionkey"`
UserTokenSigningKey string `koanf:"usertokensigningkey"`
QueryTokenSigningKey string `koanf:"querytokensigningkey"`
QueryTokenIssuer string `koanf:"querytokenissuer"`
VerifyClientIp bool `koanf:"verifyclientip"`
EnableUserToken bool `koanf:"enableusertoken"`
}
@ -176,6 +178,10 @@ func Load(configFile string) Configuration {
log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
}
return Conf
}

View File

@ -40,6 +40,7 @@ func main() {
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey)
// set oidc config
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
@ -74,6 +75,8 @@ func main() {
OIDCTokenVerifier: verifier,
PAATokenGenerator: security.GeneratePAAToken,
UserTokenGenerator: security.GenerateUserToken,
QueryInfo: security.QueryInfo,
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
EnableUserToken: conf.Security.EnableUserToken,
SessionKey: []byte(conf.Server.SessionKey),
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),

View File

@ -19,6 +19,7 @@ var (
EncryptionKey []byte
UserSigningKey []byte
UserEncryptionKey []byte
QuerySigningKey []byte
OIDCProvider *oidc.Provider
Oauth2Config oauth2.Config
)
@ -221,6 +222,61 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
return standard, nil
}
func QueryInfo(ctx context.Context, tokenString string, issuer string) (string, error) {
standard := jwt.Claims{}
token, err := jwt.ParseSigned(tokenString)
if err != nil {
log.Printf("Cannot get token %s", err)
return "", errors.New("cannot get token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return "", errors.New("signature validation failure")
}
err = token.Claims(QuerySigningKey, &standard)
if err = token.Claims(QuerySigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return "", errors.New("cannot verify signature")
}
// go-jose doesnt verify the expiry
err = standard.Validate(jwt.Expected{
Issuer: issuer,
Time: time.Now(),
})
if err != nil {
log.Printf("token validation failed due to %s", err)
return "", fmt.Errorf("token validation failed due to %s", err)
}
return standard.Subject, nil
}
// GenerateQueryToken this is a helper function for testing
func GenerateQueryToken(ctx context.Context, query string, issuer string) (string, error) {
if len(QuerySigningKey) < 32 {
return "", errors.New("query token encryption key not long enough or not specified")
}
claims := jwt.Claims{
Subject: query,
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
Issuer: issuer,
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: QuerySigningKey},
(&jose.SignerOptions{}).WithBase64(true))
if err != nil {
log.Printf("Cannot encrypt user token due to %s", err)
return "", err
}
token, err := jwt.Signed(sig).Claims(claims).CompactSerialize()
return token, err
}
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
if !ok {