Enable signed hosts provied in query parameters
This commit is contained in:
parent
8bc3e25f83
commit
cb8b269478
@ -27,6 +27,7 @@ const (
|
||||
|
||||
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
||||
type QueryInfoFunc func(context.Context, string, string) (string, error)
|
||||
|
||||
type Config struct {
|
||||
SessionKey []byte
|
||||
@ -34,6 +35,8 @@ type Config struct {
|
||||
SessionStore string
|
||||
PAATokenGenerator TokenGeneratorFunc
|
||||
UserTokenGenerator UserTokenGeneratorFunc
|
||||
QueryInfo QueryInfoFunc
|
||||
QueryTokenIssuer string
|
||||
EnableUserToken bool
|
||||
OAuth2Config *oauth2.Config
|
||||
store sessions.Store
|
||||
@ -159,39 +162,53 @@ func (c *Config) selectRandomHost() string {
|
||||
return host
|
||||
}
|
||||
|
||||
func (c *Config) getHost(u *url.URL) (string, error) {
|
||||
var host string
|
||||
func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||
switch c.HostSelection {
|
||||
case "roundrobin":
|
||||
host = c.selectRandomHost()
|
||||
return c.selectRandomHost(), nil
|
||||
case "signed":
|
||||
case "unsigned":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
found := false
|
||||
for _, check := range c.Hosts {
|
||||
if check == hosts[0] {
|
||||
host = hosts[0]
|
||||
if check == host {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Printf("Invalid host %s specified in client request", hosts[0])
|
||||
return "", errors.New("invalid host specified in query parameter")
|
||||
log.Printf("Invalid host %s specified in token", hosts[0])
|
||||
return "", errors.New("invalid host specified in query token")
|
||||
}
|
||||
return host, nil
|
||||
case "unsigned":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
for _, check := range c.Hosts {
|
||||
if check == hosts[0] {
|
||||
return hosts[0], nil
|
||||
}
|
||||
}
|
||||
// not found
|
||||
log.Printf("Invalid host %s specified in client request", hosts[0])
|
||||
return "", errors.New("invalid host specified in query parameter")
|
||||
case "any":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
host = hosts[0]
|
||||
return hosts[0], nil
|
||||
default:
|
||||
host = c.selectRandomHost()
|
||||
return c.selectRandomHost(), nil
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
@ -205,7 +222,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// determine host to connect to
|
||||
host, err := c.getHost(r.URL)
|
||||
host, err := c.getHost(ctx, r.URL)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
|
||||
key = []byte("thisisasessionkeyreplacethisjetzt")
|
||||
)
|
||||
|
||||
func contains(needle string, haystack []string) bool {
|
||||
@ -19,6 +22,7 @@ func contains(needle string, haystack []string) bool {
|
||||
}
|
||||
|
||||
func TestGetHost(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c := Config{
|
||||
HostSelection: "roundrobin",
|
||||
Hosts: hosts,
|
||||
@ -28,7 +32,7 @@ func TestGetHost(t *testing.T) {
|
||||
}
|
||||
vals := u.Query()
|
||||
|
||||
host, err := c.getHost(u)
|
||||
host, err := c.getHost(ctx, u)
|
||||
if err != nil {
|
||||
t.Fatalf("#{err}")
|
||||
}
|
||||
@ -40,14 +44,14 @@ func TestGetHost(t *testing.T) {
|
||||
c.HostSelection = "unsigned"
|
||||
vals.Set("host", "in.valid.host")
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(u)
|
||||
host, err = c.getHost(ctx, u)
|
||||
if err == nil {
|
||||
t.Fatalf("Accepted host %s is not in hosts list", host)
|
||||
}
|
||||
|
||||
vals.Set("host", hosts[0])
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(u)
|
||||
host, err = c.getHost(ctx, u)
|
||||
if err != nil {
|
||||
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||
}
|
||||
@ -55,4 +59,35 @@ func TestGetHost(t *testing.T) {
|
||||
t.Fatalf("host %s is not equal to input %s", host, hosts[0])
|
||||
}
|
||||
|
||||
// check any
|
||||
c.HostSelection = "any"
|
||||
test := "bla.bla.com"
|
||||
vals.Set("host", test)
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(ctx, u)
|
||||
if err != nil {
|
||||
t.Fatalf("%s is not accepted", host)
|
||||
}
|
||||
if test != host {
|
||||
t.Fatalf("Returned host %s is not equal to input host %s", host, test)
|
||||
}
|
||||
|
||||
// check signed
|
||||
c.HostSelection = "signed"
|
||||
c.QueryInfo = security.QueryInfo
|
||||
issuer := "rdpgwtest"
|
||||
security.QuerySigningKey = key
|
||||
queryToken, err := security.GenerateQueryToken(ctx, hosts[0], issuer)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot generate token")
|
||||
}
|
||||
vals.Set("host", queryToken)
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(ctx, u)
|
||||
if err != nil {
|
||||
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||
}
|
||||
if host != hosts[0] {
|
||||
t.Fatalf("%s does not equal %s", host, hosts[0])
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,6 +58,8 @@ type SecurityConfig struct {
|
||||
PAATokenSigningKey string `koanf:"paatokensigningkey"`
|
||||
UserTokenEncryptionKey string `koanf:"usertokenencryptionkey"`
|
||||
UserTokenSigningKey string `koanf:"usertokensigningkey"`
|
||||
QueryTokenSigningKey string `koanf:"querytokensigningkey"`
|
||||
QueryTokenIssuer string `koanf:"querytokenissuer"`
|
||||
VerifyClientIp bool `koanf:"verifyclientip"`
|
||||
EnableUserToken bool `koanf:"enableusertoken"`
|
||||
}
|
||||
@ -176,6 +178,10 @@ func Load(configFile string) Configuration {
|
||||
log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random")
|
||||
}
|
||||
|
||||
if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 {
|
||||
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
|
||||
}
|
||||
|
||||
return Conf
|
||||
|
||||
}
|
||||
|
||||
@ -40,6 +40,7 @@ func main() {
|
||||
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
|
||||
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
|
||||
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
|
||||
security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey)
|
||||
|
||||
// set oidc config
|
||||
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
||||
@ -74,6 +75,8 @@ func main() {
|
||||
OIDCTokenVerifier: verifier,
|
||||
PAATokenGenerator: security.GeneratePAAToken,
|
||||
UserTokenGenerator: security.GenerateUserToken,
|
||||
QueryInfo: security.QueryInfo,
|
||||
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
|
||||
EnableUserToken: conf.Security.EnableUserToken,
|
||||
SessionKey: []byte(conf.Server.SessionKey),
|
||||
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
||||
|
||||
@ -19,6 +19,7 @@ var (
|
||||
EncryptionKey []byte
|
||||
UserSigningKey []byte
|
||||
UserEncryptionKey []byte
|
||||
QuerySigningKey []byte
|
||||
OIDCProvider *oidc.Provider
|
||||
Oauth2Config oauth2.Config
|
||||
)
|
||||
@ -221,6 +222,61 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
|
||||
return standard, nil
|
||||
}
|
||||
|
||||
func QueryInfo(ctx context.Context, tokenString string, issuer string) (string, error) {
|
||||
standard := jwt.Claims{}
|
||||
token, err := jwt.ParseSigned(tokenString)
|
||||
if err != nil {
|
||||
log.Printf("Cannot get token %s", err)
|
||||
return "", errors.New("cannot get token")
|
||||
}
|
||||
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
|
||||
log.Printf("signature validation failure: %s", err)
|
||||
return "", errors.New("signature validation failure")
|
||||
}
|
||||
err = token.Claims(QuerySigningKey, &standard)
|
||||
if err = token.Claims(QuerySigningKey, &standard); err != nil {
|
||||
log.Printf("cannot verify signature %s", err)
|
||||
return "", errors.New("cannot verify signature")
|
||||
}
|
||||
|
||||
// go-jose doesnt verify the expiry
|
||||
err = standard.Validate(jwt.Expected{
|
||||
Issuer: issuer,
|
||||
Time: time.Now(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("token validation failed due to %s", err)
|
||||
return "", fmt.Errorf("token validation failed due to %s", err)
|
||||
}
|
||||
|
||||
return standard.Subject, nil
|
||||
}
|
||||
|
||||
// GenerateQueryToken this is a helper function for testing
|
||||
func GenerateQueryToken(ctx context.Context, query string, issuer string) (string, error) {
|
||||
if len(QuerySigningKey) < 32 {
|
||||
return "", errors.New("query token encryption key not long enough or not specified")
|
||||
}
|
||||
|
||||
claims := jwt.Claims{
|
||||
Subject: query,
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
|
||||
Issuer: issuer,
|
||||
}
|
||||
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: QuerySigningKey},
|
||||
(&jose.SignerOptions{}).WithBase64(true))
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Cannot encrypt user token due to %s", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
token, err := jwt.Signed(sig).Claims(claims).CompactSerialize()
|
||||
return token, err
|
||||
}
|
||||
|
||||
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
||||
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
|
||||
if !ok {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user