Refactor identity and http routing

This commit is contained in:
Bolke de Bruin 2022-10-18 09:36:41 +02:00
parent b42c3cd3cc
commit db98550455
18 changed files with 402 additions and 199 deletions

View File

@ -48,7 +48,7 @@ type ServerConfig struct {
SendBuf int `koanf:"sendbuf"`
ReceiveBuf int `koanf:"receivebuf"`
Tls string `koanf:"tls"`
Authentication string `koanf:"authentication"`
Authentication []string `koanf:"authentication"`
AuthSocket string `koanf:"authsocket"`
}
@ -206,15 +206,15 @@ func Load(configFile string) Configuration {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
}
if Conf.Server.Authentication == "local" && Conf.Server.Tls == "disable" {
if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" {
log.Fatalf("basicauth=local and tls=disable are mutually exclusive")
}
if !Conf.Caps.TokenAuth && Conf.Server.Authentication == "openid" {
if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() {
log.Fatalf("openid is configured but tokenauth disabled")
}
if Conf.Server.Authentication == AuthenticationKerberos && Conf.Kerberos.Keytab == "" {
if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" {
log.Fatalf("kerberos is configured but no keytab was specified")
}
@ -226,3 +226,24 @@ func Load(configFile string) Configuration {
return Conf
}
func (s *ServerConfig) OpenIDEnabled() bool {
return s.matchAuth("openid")
}
func (s *ServerConfig) KerberosEnabled() bool {
return s.matchAuth("kerberos")
}
func (s *ServerConfig) BasicAuthEnabled() bool {
return s.matchAuth("local")
}
func (s *ServerConfig) matchAuth(needle string) bool {
for _, q := range s.Authentication {
if q == needle {
return true
}
}
return false
}

View File

@ -0,0 +1,57 @@
package identity
import (
"context"
"net/http"
"time"
)
const (
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
AttrRemoteAddr = "remoteAddr"
AttrClientIp = "clientIp"
AttrProxies = "proxyAddresses"
AttrAccessToken = "accessToken" // todo remove for security reasons
)
type Identity interface {
UserName() string
SetUserName(string)
DisplayName() string
SetDisplayName(string)
Domain() string
SetDomain(string)
Authenticated() bool
SetAuthenticated(bool)
AuthTime() time.Time
SetAuthTime(time2 time.Time)
SessionId() string
SetAttribute(string, interface{})
GetAttribute(string) interface{}
Attributes() map[string]interface{}
DelAttribute(string)
Email() string
SetEmail(string)
Expiry() time.Time
SetExpiry(time.Time)
Marshal() ([]byte, error)
Unmarshal([]byte) error
}
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, CTXKey, id)
return r.WithContext(ctx)
}
func FromRequestCtx(r *http.Request) Identity {
return FromCtx(r.Context())
}
func FromCtx(ctx context.Context) Identity {
if id, ok := ctx.Value(CTXKey).(Identity); ok {
return id
}
return nil
}

View File

@ -0,0 +1,28 @@
package identity
import (
"log"
"testing"
)
func TestMarshalling(t *testing.T) {
u := NewUser()
u.SetUserName("ANAME")
u.SetAuthenticated(true)
u.SetDomain("DOMAIN")
c := NewUser()
data, err := u.Marshal()
if err != nil {
log.Fatalf("Cannot marshal %s", err)
}
err = c.Unmarshal(data)
if err != nil {
t.Fatalf("Error while unmarshalling: %s", err)
}
if u.UserName() != c.UserName() || u.Authenticated() != c.Authenticated() || u.Domain() != c.Domain() {
t.Fatalf("identities not equal: %+v != %+v", u, c)
}
}

View File

@ -1,60 +1,12 @@
package common
package identity
import (
"context"
"bytes"
"encoding/gob"
"github.com/google/uuid"
"net/http"
"time"
)
const (
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
AttrRemoteAddr = "remoteAddr"
AttrClientIp = "clientIp"
AttrProxies = "proxyAddresses"
AttrAccessToken = "accessToken" // todo remove for security reasons
)
type Identity interface {
UserName() string
SetUserName(string)
DisplayName() string
SetDisplayName(string)
Domain() string
SetDomain(string)
Authenticated() bool
SetAuthenticated(bool)
AuthTime() time.Time
SetAuthTime(time2 time.Time)
SessionId() string
SetAttribute(string, interface{})
GetAttribute(string) interface{}
Attributes() map[string]interface{}
DelAttribute(string)
Email() string
SetEmail(string)
Expiry() time.Time
SetExpiry(time.Time)
}
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, CTXKey, id)
return r.WithContext(ctx)
}
func FromRequestCtx(r *http.Request) Identity {
return FromCtx(r.Context())
}
func FromCtx(ctx context.Context) Identity {
if id, ok := ctx.Value(CTXKey).(Identity); ok {
return id
}
return nil
}
type User struct {
authenticated bool
domain string
@ -68,6 +20,19 @@ type User struct {
groupMembership map[string]bool
}
type user struct {
Authenticated bool
UserName string
Domain string
DisplayName string
Email string
AuthTime time.Time
SessionId string
Expiry time.Time
Attributes map[string]interface{}
GroupMembership map[string]bool
}
func NewUser() *User {
uuid := uuid.New().String()
return &User{
@ -158,3 +123,48 @@ func (u *User) Expiry() time.Time {
func (u *User) SetExpiry(t time.Time) {
u.expiry = t
}
func (u *User) Marshal() ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
uu := user{
Authenticated: u.authenticated,
UserName: u.userName,
Domain: u.domain,
DisplayName: u.displayName,
Email: u.email,
AuthTime: u.authTime,
SessionId: u.sessionId,
Expiry: u.expiry,
Attributes: u.attributes,
GroupMembership: u.groupMembership,
}
err := enc.Encode(uu)
if err != nil {
return []byte{}, err
}
return buf.Bytes(), nil
}
func (u *User) Unmarshal(b []byte) error {
buf := bytes.NewBuffer(b)
dec := gob.NewDecoder(buf)
var uu user
err := dec.Decode(&uu)
if err != nil {
return err
}
u.sessionId = uu.SessionId
u.userName = uu.UserName
u.domain = uu.Domain
u.displayName = uu.DisplayName
u.email = uu.Email
u.authenticated = uu.Authenticated
u.authTime = uu.AuthTime
u.expiry = uu.Expiry
u.attributes = uu.Attributes
u.groupMembership = uu.GroupMembership
return nil
}

View File

@ -7,14 +7,13 @@ import (
"github.com/bolkedebruin/gokrb5/v8/keytab"
"github.com/bolkedebruin/gokrb5/v8/service"
"github.com/bolkedebruin/gokrb5/v8/spnego"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/kdcproxy"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/web"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/thought-machine/go-flags"
"golang.org/x/crypto/acme/autocert"
@ -26,13 +25,18 @@ import (
"strconv"
)
const (
gatewayEndPoint = "/remoteDesktopGateway/"
kdcProxyEndPoint = "/KdcProxy"
)
var opts struct {
ConfigFile string `short:"c" long:"conf" default:"rdpgw.yaml" description:"config file (yaml)"`
}
var conf config.Configuration
func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC {
func initOIDC(callbackUrl *url.URL) *web.OIDC {
// set oidc config
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
if err != nil {
@ -56,7 +60,6 @@ func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC {
o := web.OIDCConfig{
OAuth2Config: &oauthConfig,
OIDCTokenVerifier: verifier,
SessionStore: store,
}
return o.New()
@ -91,19 +94,13 @@ func main() {
security.Hosts = conf.Server.Hosts
// init session store
sessionConf := web.SessionManagerConf{
SessionKey: []byte(conf.Server.SessionKey),
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
StoreType: conf.Server.SessionStore,
}
store := sessionConf.Init()
web.InitStore([]byte(conf.Server.SessionKey), []byte(conf.Server.SessionEncryptionKey), conf.Server.SessionStore)
// configure web backend
w := &web.Config{
QueryInfo: security.QueryInfo,
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
EnableUserToken: conf.Security.EnableUserToken,
SessionStore: store,
Hosts: conf.Server.Hosts,
HostSelection: conf.Server.HostSelection,
RdpOpts: web.RdpOpts{
@ -128,6 +125,7 @@ func main() {
log.Printf("Starting remote desktop gateway server")
cfg := &tls.Config{}
// configure tls security
if conf.Server.Tls == config.TlsDisable {
log.Printf("TLS disabled - rdp gw connections require tls, make sure to have a terminator")
} else {
@ -174,13 +172,7 @@ func main() {
}
}
server := http.Server{
Addr: ":" + strconv.Itoa(conf.Server.Port),
TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
}
// create the gateway
// gateway confg
gw := protocol.Gateway{
RedirectFlags: protocol.RedirectFlags{
Clipboard: conf.Caps.EnableClipboard,
@ -205,31 +197,72 @@ func main() {
gw.CheckHost = security.CheckHost
}
if conf.Server.Authentication == config.AuthenticationBasic {
h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
http.Handle("/remoteDesktopGateway/", common.EnrichContext(h.BasicAuth(gw.HandleGatewayProtocol)))
} else if conf.Server.Authentication == config.AuthenticationKerberos {
r := mux.NewRouter()
// ensure identity is set in context and get some extra info
r.Use(web.EnrichContext)
// prometheus metrics
r.Handle("/metrics", promhttp.Handler())
// for sso callbacks
r.HandleFunc("/tokeninfo", web.TokenInfo)
// gateway endpoint
rdp := r.PathPrefix(gatewayEndPoint).Subrouter()
// openid
if conf.Server.OpenIDEnabled() {
log.Printf("enabling openid extended authentication")
o := initOIDC(url)
r.Handle("/connect", o.Authenticated(http.HandlerFunc(h.HandleDownload)))
r.HandleFunc("/callback", o.HandleCallback)
// only enable un-auth endpoint for openid only config
if !conf.Server.KerberosEnabled() || !conf.Server.BasicAuthEnabled() {
rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol)
}
}
// for stacking of authentication
auth := web.NewAuthMux()
// basic auth
if conf.Server.BasicAuthEnabled() {
log.Printf("enabling basic authentication")
q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
rdp.Headers("Authorization", "Basic*").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol))
auth.Register(`Basic realm="restricted", charset="UTF-8"`)
}
// spnego / kerberos
if conf.Server.KerberosEnabled() {
log.Printf("enabling kerberos authentication")
keytab, err := keytab.Load(conf.Kerberos.Keytab)
if err != nil {
log.Fatalf("Cannot load keytab: %s", err)
}
http.Handle("/remoteDesktopGateway/", common.EnrichContext(
spnego.SPNEGOKRB5Authenticate(
common.FixKerberosContext(http.HandlerFunc(gw.HandleGatewayProtocol)),
rdp.Headers("Authorization", "Negotiate*").Handler(
spnego.SPNEGOKRB5Authenticate(web.TransposeSPNEGOContext(http.HandlerFunc(gw.HandleGatewayProtocol)),
keytab,
service.Logger(log.Default()))),
)
service.Logger(log.Default())))
// kdcproxy
k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf)
http.HandleFunc("/KdcProxy", k.Handler)
} else {
// openid
oidc := initOIDC(url, store)
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback)))
r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST")
auth.Register("Negotiate")
}
// allow stacking of authentication
rdp.Use(auth.Route)
// setup server
server := http.Server{
Addr: ":" + strconv.Itoa(conf.Server.Port),
Handler: r,
TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
}
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", web.TokenInfo)
if conf.Server.Tls == config.TlsDisable {
err = server.ListenAndServe()

View File

@ -3,7 +3,7 @@ package protocol
import (
"context"
"errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/google/uuid"
"github.com/gorilla/websocket"
@ -61,14 +61,14 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
var t *Tunnel
ctx := r.Context()
id := common.FromRequestCtx(r)
id := identity.FromRequestCtx(r)
connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId)
if !found {
t = &Tunnel{
RDGId: connId,
RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string),
RemoteAddr: id.GetAttribute(identity.AttrRemoteAddr).(string),
User: id,
}
} else {
@ -183,14 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
id := common.FromRequestCtx(r)
id := identity.FromRequestCtx(r)
if r.Method == MethodRDGOUT {
out, err := transport.NewLegacy(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp))
log.Printf("Opening RDGOUT for client %s", id.GetAttribute(identity.AttrClientIp))
t.transportOut = out
out.SendAccept(true)
@ -212,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
t.transportIn = in
c.Set(t.RDGId, t, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp))
log.Printf("Opening RDGIN for client %s", id.GetAttribute(identity.AttrClientIp))
in.SendAccept(false)
// read some initial data
in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp))
log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(identity.AttrClientIp))
handler := NewProcessor(g, t)
RegisterTunnel(t, handler)
defer RemoveTunnel(t)

View File

@ -6,7 +6,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"io"
"log"
"net"
@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error {
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
if p.state != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error {
_, cookie := p.tunnelRequest(pkt)
if p.gw.CheckPAACookie != nil {
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
p.tunnel.Write(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)

View File

@ -1,7 +1,7 @@
package protocol
import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"net"
"time"
@ -27,7 +27,7 @@ type Tunnel struct {
// The obtained client ip address
RemoteAddr string
// User
User common.Identity
User identity.Identity
// rwc is the underlying connection to the remote desktop server.
// It is of the type *net.TCPConn

View File

@ -2,7 +2,7 @@ package security
import (
"context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"testing"
)
@ -18,7 +18,7 @@ var (
)
func TestCheckHost(t *testing.T) {
info.User = common.NewUser()
info.User = identity.NewUser()
info.User.SetUserName("MYNAME")
ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info)

View File

@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v3"
@ -46,10 +46,10 @@ func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
}
// use identity from context rather then set by tunnel
id := common.FromCtx(ctx)
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) {
id := identity.FromCtx(ctx)
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(identity.AttrClientIp) {
log.Printf("Current client ip address %s does not match token client ip %s",
id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr)
id.GetAttribute(identity.AttrClientIp), tunnel.RemoteAddr)
return false, nil
}
return next(ctx, host)
@ -129,11 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
Subject: username,
}
id := common.FromCtx(ctx)
id := identity.FromCtx(ctx)
private := customClaims{
RemoteServer: server,
ClientIP: id.GetAttribute(common.AttrClientIp).(string),
AccessToken: id.GetAttribute(common.AttrAccessToken).(string),
ClientIP: id.GetAttribute(identity.AttrClientIp).(string),
AccessToken: id.GetAttribute(identity.AttrAccessToken).(string),
}
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {

View File

@ -2,7 +2,7 @@ package web
import (
"context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/shared/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@ -53,11 +53,11 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
log.Printf("User %s is not authenticated for this service", username)
} else {
log.Printf("User %s authenticated", username)
id := common.FromRequestCtx(r)
id := identity.FromRequestCtx(r)
id.SetUserName(username)
id.SetAuthenticated(true)
id.SetAuthTime(time.Now())
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
next.ServeHTTP(w, identity.AddToRequestCtx(id, r))
return
}
@ -66,7 +66,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
// username or password is wrong, then set a WWW-Authenticate
// header to inform the client that we expect them to use basic
// authentication and send a 401 Unauthorized response.
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
w.Header().Add("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
}

View File

@ -1,7 +1,7 @@
package common
package web
import (
"context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/jcmturner/goidentity/v6"
"log"
"net"
@ -9,16 +9,22 @@ import (
"strings"
)
const (
CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token"
)
func EnrichContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := FromRequestCtx(r)
if id == nil {
id = NewUser()
id, err := GetSessionIdentity(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if id == nil {
id = identity.NewUser()
if err := SaveSessionIdentity(r, w, id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t",
id.SessionId(), id.UserName(), id.Authenticated())
@ -33,39 +39,30 @@ func EnrichContext(next http.Handler) http.Handler {
if len(ips) > 1 {
proxies = ips[1:]
}
id.SetAttribute(AttrClientIp, clientIp)
id.SetAttribute(AttrProxies, proxies)
id.SetAttribute(identity.AttrClientIp, clientIp)
id.SetAttribute(identity.AttrProxies, proxies)
}
id.SetAttribute(AttrRemoteAddr, r.RemoteAddr)
id.SetAttribute(identity.AttrRemoteAddr, r.RemoteAddr)
if h == "" {
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
id.SetAttribute(AttrClientIp, clientIp)
id.SetAttribute(identity.AttrClientIp, clientIp)
}
next.ServeHTTP(w, AddToRequestCtx(id, r))
next.ServeHTTP(w, identity.AddToRequestCtx(id, r))
})
}
func FixKerberosContext(next http.Handler) http.Handler {
func TransposeSPNEGOContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gid := goidentity.FromHTTPRequestContext(r)
if gid != nil {
id := FromRequestCtx(r)
id := identity.FromRequestCtx(r)
id.SetUserName(gid.UserName())
id.SetAuthenticated(gid.Authenticated())
id.SetDomain(gid.Domain())
id.SetAuthTime(gid.AuthTime())
r = AddToRequestCtx(id, r)
r = identity.AddToRequestCtx(id, r)
}
next.ServeHTTP(w, r)
})
}
func GetAccessToken(ctx context.Context) string {
token, ok := ctx.Value(CtxAccessToken).(string)
if !ok {
log.Printf("cannot get access token from context")
return ""
}
return token
}

View File

@ -3,9 +3,8 @@ package web
import (
"encoding/hex"
"encoding/json"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"math/rand"
@ -14,23 +13,20 @@ import (
)
const (
CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5
sessionKeyAuthenticated = "authenticated"
oidcKeyUserName = "preferred_username"
CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5
oidcKeyUserName = "preferred_username"
)
type OIDC struct {
oAuth2Config *oauth2.Config
oidcTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache
sessionStore sessions.Store
}
type OIDCConfig struct {
OAuth2Config *oauth2.Config
OIDCTokenVerifier *oidc.IDTokenVerifier
SessionStore sessions.Store
}
func (c *OIDCConfig) New() *OIDC {
@ -38,7 +34,6 @@ func (c *OIDCConfig) New() *OIDC {
oAuth2Config: c.OAuth2Config,
oidcTokenVerifier: c.OIDCTokenVerifier,
stateStore: cache.New(CacheExpiration, CleanupInterval),
sessionStore: c.SessionStore,
}
}
@ -85,22 +80,13 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
return
}
session, err := h.sessionStore.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
id := common.FromRequestCtx(r)
id := identity.FromRequestCtx(r)
id.SetUserName(data[oidcKeyUserName].(string))
id.SetAuthenticated(true)
id.SetAuthTime(time.Now())
id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken)
id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken)
session.Options.MaxAge = MaxAge
session.Values[common.CTXKey] = id
if err = session.Save(r, w); err != nil {
if err = SaveSessionIdentity(r, w, id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
@ -109,14 +95,9 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := h.sessionStore.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
id := identity.FromRequestCtx(r)
id := session.Values[common.CTXKey].(common.Identity)
if id == nil {
if !id.Authenticated() {
seed := make([]byte, 16)
rand.Read(seed)
state := hex.EncodeToString(seed)
@ -126,6 +107,6 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
}
// replace the identity with the one from the sessions
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
next.ServeHTTP(w, r)
})
}

31
cmd/rdpgw/web/router.go Normal file
View File

@ -0,0 +1,31 @@
package web
import (
"net/http"
)
type AuthMux struct {
headers []string
}
func NewAuthMux() *AuthMux {
return &AuthMux{}
}
func (a *AuthMux) Register(s string) {
a.headers = append(a.headers, s)
}
func (a *AuthMux) Route(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := r.Header.Get("Authorization")
if h == "" {
for _, s := range a.headers {
w.Header().Add("WWW-Authenticate", s)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
}
next.ServeHTTP(w, r)
})
}

View File

@ -1,30 +1,75 @@
package web
import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/gorilla/sessions"
"log"
"net/http"
"os"
)
type SessionManagerConf struct {
SessionKey []byte
SessionEncryptionKey []byte
StoreType string
}
const (
rdpGwSession = "RDPGWSESSION"
MaxAge = 120
identityKey = "RDPGWID"
)
func (c *SessionManagerConf) Init() sessions.Store {
if len(c.SessionKey) < 32 {
var sessionStore sessions.Store
func InitStore(sessionKey []byte, encryptionKey []byte, storeType string) {
if len(sessionKey) < 32 {
log.Fatal("Session key too small")
}
if len(c.SessionEncryptionKey) < 32 {
if len(encryptionKey) < 32 {
log.Fatal("Session key too small")
}
if c.StoreType == "file" {
if storeType == "file" {
log.Println("Filesystem is used as session storage")
return sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey)
sessionStore = sessions.NewFilesystemStore(os.TempDir(), sessionKey, encryptionKey)
} else {
log.Println("Cookies are used as session storage")
return sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
sessionStore = sessions.NewCookieStore(sessionKey, encryptionKey)
}
}
func GetSession(r *http.Request) (*sessions.Session, error) {
session, err := sessionStore.Get(r, rdpGwSession)
if err != nil {
return nil, err
}
return session, nil
}
func GetSessionIdentity(r *http.Request) (identity.Identity, error) {
s, err := GetSession(r)
if err != nil {
return nil, err
}
idData := s.Values[identityKey]
if idData == nil {
return nil, nil
}
id := identity.NewUser()
id.Unmarshal(idData.([]byte))
return id, nil
}
func SaveSessionIdentity(r *http.Request, w http.ResponseWriter, id identity.Identity) error {
session, err := GetSession(r)
if err != nil {
return err
}
session.Options.MaxAge = MaxAge
idData, err := id.Marshal()
if err != nil {
return err
}
session.Values[identityKey] = idData
return sessionStore.Save(r, w, session)
}

View File

@ -5,7 +5,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/gorilla/sessions"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"log"
"math/rand"
"net/http"
@ -14,17 +14,11 @@ import (
"time"
)
const (
RdpGwSession = "RDPGWSESSION"
MaxAge = 120
)
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type QueryInfoFunc func(context.Context, string, string) (string, error)
type Config struct {
SessionStore sessions.Store
PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc
QueryInfo QueryInfoFunc
@ -46,7 +40,6 @@ type RdpOpts struct {
}
type Handler struct {
sessionStore sessions.Store
paaTokenGenerator TokenGeneratorFunc
enableUserToken bool
userTokenGenerator UserTokenGeneratorFunc
@ -63,7 +56,6 @@ func (c *Config) NewHandler() *Handler {
log.Fatal("Not enough hosts to connect to specified")
}
return &Handler{
sessionStore: c.SessionStore,
paaTokenGenerator: c.PAATokenGenerator,
enableUserToken: c.EnableUserToken,
userTokenGenerator: c.UserTokenGenerator,
@ -132,13 +124,13 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
}
func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
id := identity.FromRequestCtx(r)
ctx := r.Context()
userName, ok := ctx.Value("preferred_username").(string)
opts := h.rdpOpts
if !ok {
log.Printf("preferred_username not found in context")
if !id.Authenticated() {
log.Printf("unauthenticated user %s", id.UserName())
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
return
}
@ -149,13 +141,13 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
host = strings.Replace(host, "{{ preferred_username }}", id.UserName(), 1)
// split the username into user and domain
var user = userName
var user = id.UserName()
var domain = opts.DefaultDomain
if opts.SplitUserDomain {
creds := strings.SplitN(userName, "@", 2)
creds := strings.SplitN(id.UserName(), "@", 2)
user = creds[0]
if len(creds) > 1 {
domain = creds[1]
@ -203,6 +195,8 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
rdp.Connection.GatewayHostname = h.gatewayAddress.Host
rdp.Connection.GatewayCredentialsSource = SourceCookie
rdp.Connection.GatewayAccessToken = token
rdp.Connection.GatewayCredentialMethod = 1
rdp.Connection.GatewayUsageMethod = 1
rdp.Session.NetworkAutodetect = opts.NetworkAutoDetect != 0
rdp.Session.BandwidthAutodetect = opts.BandwidthAutoDetect != 0
rdp.Session.ConnectionType = opts.ConnectionType

View File

@ -2,6 +2,7 @@ package web
import (
"context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"net/http"
"net/http/httptest"
@ -113,9 +114,13 @@ func TestHandler_HandleDownload(t *testing.T) {
}
rr := httptest.NewRecorder()
id := identity.NewUser()
id.SetUserName(testuser)
id.SetAuthenticated(true)
req = identity.AddToRequestCtx(id, req)
ctx := req.Context()
ctx = context.WithValue(ctx, "preferred_username", testuser)
req = req.WithContext(ctx)
u, _ := url.Parse(gateway)
c := Config{

1
go.mod
View File

@ -8,6 +8,7 @@ require (
github.com/fatih/structs v1.1.0
github.com/go-jose/go-jose/v3 v3.0.0
github.com/google/uuid v1.1.2
github.com/gorilla/mux v1.8.0
github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.5.0
github.com/jcmturner/gofork v1.7.6