Refactor identity and http routing
This commit is contained in:
parent
b42c3cd3cc
commit
db98550455
@ -48,7 +48,7 @@ type ServerConfig struct {
|
||||
SendBuf int `koanf:"sendbuf"`
|
||||
ReceiveBuf int `koanf:"receivebuf"`
|
||||
Tls string `koanf:"tls"`
|
||||
Authentication string `koanf:"authentication"`
|
||||
Authentication []string `koanf:"authentication"`
|
||||
AuthSocket string `koanf:"authsocket"`
|
||||
}
|
||||
|
||||
@ -206,15 +206,15 @@ func Load(configFile string) Configuration {
|
||||
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
|
||||
}
|
||||
|
||||
if Conf.Server.Authentication == "local" && Conf.Server.Tls == "disable" {
|
||||
if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" {
|
||||
log.Fatalf("basicauth=local and tls=disable are mutually exclusive")
|
||||
}
|
||||
|
||||
if !Conf.Caps.TokenAuth && Conf.Server.Authentication == "openid" {
|
||||
if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() {
|
||||
log.Fatalf("openid is configured but tokenauth disabled")
|
||||
}
|
||||
|
||||
if Conf.Server.Authentication == AuthenticationKerberos && Conf.Kerberos.Keytab == "" {
|
||||
if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" {
|
||||
log.Fatalf("kerberos is configured but no keytab was specified")
|
||||
}
|
||||
|
||||
@ -226,3 +226,24 @@ func Load(configFile string) Configuration {
|
||||
return Conf
|
||||
|
||||
}
|
||||
|
||||
func (s *ServerConfig) OpenIDEnabled() bool {
|
||||
return s.matchAuth("openid")
|
||||
}
|
||||
|
||||
func (s *ServerConfig) KerberosEnabled() bool {
|
||||
return s.matchAuth("kerberos")
|
||||
}
|
||||
|
||||
func (s *ServerConfig) BasicAuthEnabled() bool {
|
||||
return s.matchAuth("local")
|
||||
}
|
||||
|
||||
func (s *ServerConfig) matchAuth(needle string) bool {
|
||||
for _, q := range s.Authentication {
|
||||
if q == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
57
cmd/rdpgw/identity/identity.go
Normal file
57
cmd/rdpgw/identity/identity.go
Normal file
@ -0,0 +1,57 @@
|
||||
package identity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
|
||||
|
||||
AttrRemoteAddr = "remoteAddr"
|
||||
AttrClientIp = "clientIp"
|
||||
AttrProxies = "proxyAddresses"
|
||||
AttrAccessToken = "accessToken" // todo remove for security reasons
|
||||
)
|
||||
|
||||
type Identity interface {
|
||||
UserName() string
|
||||
SetUserName(string)
|
||||
DisplayName() string
|
||||
SetDisplayName(string)
|
||||
Domain() string
|
||||
SetDomain(string)
|
||||
Authenticated() bool
|
||||
SetAuthenticated(bool)
|
||||
AuthTime() time.Time
|
||||
SetAuthTime(time2 time.Time)
|
||||
SessionId() string
|
||||
SetAttribute(string, interface{})
|
||||
GetAttribute(string) interface{}
|
||||
Attributes() map[string]interface{}
|
||||
DelAttribute(string)
|
||||
Email() string
|
||||
SetEmail(string)
|
||||
Expiry() time.Time
|
||||
SetExpiry(time.Time)
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal([]byte) error
|
||||
}
|
||||
|
||||
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, CTXKey, id)
|
||||
return r.WithContext(ctx)
|
||||
}
|
||||
|
||||
func FromRequestCtx(r *http.Request) Identity {
|
||||
return FromCtx(r.Context())
|
||||
}
|
||||
|
||||
func FromCtx(ctx context.Context) Identity {
|
||||
if id, ok := ctx.Value(CTXKey).(Identity); ok {
|
||||
return id
|
||||
}
|
||||
return nil
|
||||
}
|
||||
28
cmd/rdpgw/identity/identity_test.go
Normal file
28
cmd/rdpgw/identity/identity_test.go
Normal file
@ -0,0 +1,28 @@
|
||||
package identity
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalling(t *testing.T) {
|
||||
u := NewUser()
|
||||
u.SetUserName("ANAME")
|
||||
u.SetAuthenticated(true)
|
||||
u.SetDomain("DOMAIN")
|
||||
|
||||
c := NewUser()
|
||||
data, err := u.Marshal()
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot marshal %s", err)
|
||||
}
|
||||
|
||||
err = c.Unmarshal(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Error while unmarshalling: %s", err)
|
||||
}
|
||||
|
||||
if u.UserName() != c.UserName() || u.Authenticated() != c.Authenticated() || u.Domain() != c.Domain() {
|
||||
t.Fatalf("identities not equal: %+v != %+v", u, c)
|
||||
}
|
||||
}
|
||||
@ -1,60 +1,12 @@
|
||||
package common
|
||||
package identity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"github.com/google/uuid"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
|
||||
|
||||
AttrRemoteAddr = "remoteAddr"
|
||||
AttrClientIp = "clientIp"
|
||||
AttrProxies = "proxyAddresses"
|
||||
AttrAccessToken = "accessToken" // todo remove for security reasons
|
||||
)
|
||||
|
||||
type Identity interface {
|
||||
UserName() string
|
||||
SetUserName(string)
|
||||
DisplayName() string
|
||||
SetDisplayName(string)
|
||||
Domain() string
|
||||
SetDomain(string)
|
||||
Authenticated() bool
|
||||
SetAuthenticated(bool)
|
||||
AuthTime() time.Time
|
||||
SetAuthTime(time2 time.Time)
|
||||
SessionId() string
|
||||
SetAttribute(string, interface{})
|
||||
GetAttribute(string) interface{}
|
||||
Attributes() map[string]interface{}
|
||||
DelAttribute(string)
|
||||
Email() string
|
||||
SetEmail(string)
|
||||
Expiry() time.Time
|
||||
SetExpiry(time.Time)
|
||||
}
|
||||
|
||||
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, CTXKey, id)
|
||||
return r.WithContext(ctx)
|
||||
}
|
||||
|
||||
func FromRequestCtx(r *http.Request) Identity {
|
||||
return FromCtx(r.Context())
|
||||
}
|
||||
|
||||
func FromCtx(ctx context.Context) Identity {
|
||||
if id, ok := ctx.Value(CTXKey).(Identity); ok {
|
||||
return id
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type User struct {
|
||||
authenticated bool
|
||||
domain string
|
||||
@ -68,6 +20,19 @@ type User struct {
|
||||
groupMembership map[string]bool
|
||||
}
|
||||
|
||||
type user struct {
|
||||
Authenticated bool
|
||||
UserName string
|
||||
Domain string
|
||||
DisplayName string
|
||||
Email string
|
||||
AuthTime time.Time
|
||||
SessionId string
|
||||
Expiry time.Time
|
||||
Attributes map[string]interface{}
|
||||
GroupMembership map[string]bool
|
||||
}
|
||||
|
||||
func NewUser() *User {
|
||||
uuid := uuid.New().String()
|
||||
return &User{
|
||||
@ -158,3 +123,48 @@ func (u *User) Expiry() time.Time {
|
||||
func (u *User) SetExpiry(t time.Time) {
|
||||
u.expiry = t
|
||||
}
|
||||
|
||||
func (u *User) Marshal() ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
uu := user{
|
||||
Authenticated: u.authenticated,
|
||||
UserName: u.userName,
|
||||
Domain: u.domain,
|
||||
DisplayName: u.displayName,
|
||||
Email: u.email,
|
||||
AuthTime: u.authTime,
|
||||
SessionId: u.sessionId,
|
||||
Expiry: u.expiry,
|
||||
Attributes: u.attributes,
|
||||
GroupMembership: u.groupMembership,
|
||||
}
|
||||
err := enc.Encode(uu)
|
||||
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (u *User) Unmarshal(b []byte) error {
|
||||
buf := bytes.NewBuffer(b)
|
||||
dec := gob.NewDecoder(buf)
|
||||
var uu user
|
||||
err := dec.Decode(&uu)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.sessionId = uu.SessionId
|
||||
u.userName = uu.UserName
|
||||
u.domain = uu.Domain
|
||||
u.displayName = uu.DisplayName
|
||||
u.email = uu.Email
|
||||
u.authenticated = uu.Authenticated
|
||||
u.authTime = uu.AuthTime
|
||||
u.expiry = uu.Expiry
|
||||
u.attributes = uu.Attributes
|
||||
u.groupMembership = uu.GroupMembership
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -7,14 +7,13 @@ import (
|
||||
"github.com/bolkedebruin/gokrb5/v8/keytab"
|
||||
"github.com/bolkedebruin/gokrb5/v8/service"
|
||||
"github.com/bolkedebruin/gokrb5/v8/spnego"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/kdcproxy"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/web"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/thought-machine/go-flags"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
@ -26,13 +25,18 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
gatewayEndPoint = "/remoteDesktopGateway/"
|
||||
kdcProxyEndPoint = "/KdcProxy"
|
||||
)
|
||||
|
||||
var opts struct {
|
||||
ConfigFile string `short:"c" long:"conf" default:"rdpgw.yaml" description:"config file (yaml)"`
|
||||
}
|
||||
|
||||
var conf config.Configuration
|
||||
|
||||
func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC {
|
||||
func initOIDC(callbackUrl *url.URL) *web.OIDC {
|
||||
// set oidc config
|
||||
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
||||
if err != nil {
|
||||
@ -56,7 +60,6 @@ func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC {
|
||||
o := web.OIDCConfig{
|
||||
OAuth2Config: &oauthConfig,
|
||||
OIDCTokenVerifier: verifier,
|
||||
SessionStore: store,
|
||||
}
|
||||
|
||||
return o.New()
|
||||
@ -91,19 +94,13 @@ func main() {
|
||||
security.Hosts = conf.Server.Hosts
|
||||
|
||||
// init session store
|
||||
sessionConf := web.SessionManagerConf{
|
||||
SessionKey: []byte(conf.Server.SessionKey),
|
||||
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
||||
StoreType: conf.Server.SessionStore,
|
||||
}
|
||||
store := sessionConf.Init()
|
||||
web.InitStore([]byte(conf.Server.SessionKey), []byte(conf.Server.SessionEncryptionKey), conf.Server.SessionStore)
|
||||
|
||||
// configure web backend
|
||||
w := &web.Config{
|
||||
QueryInfo: security.QueryInfo,
|
||||
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
|
||||
EnableUserToken: conf.Security.EnableUserToken,
|
||||
SessionStore: store,
|
||||
Hosts: conf.Server.Hosts,
|
||||
HostSelection: conf.Server.HostSelection,
|
||||
RdpOpts: web.RdpOpts{
|
||||
@ -128,6 +125,7 @@ func main() {
|
||||
log.Printf("Starting remote desktop gateway server")
|
||||
cfg := &tls.Config{}
|
||||
|
||||
// configure tls security
|
||||
if conf.Server.Tls == config.TlsDisable {
|
||||
log.Printf("TLS disabled - rdp gw connections require tls, make sure to have a terminator")
|
||||
} else {
|
||||
@ -174,13 +172,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
server := http.Server{
|
||||
Addr: ":" + strconv.Itoa(conf.Server.Port),
|
||||
TLSConfig: cfg,
|
||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
|
||||
}
|
||||
|
||||
// create the gateway
|
||||
// gateway confg
|
||||
gw := protocol.Gateway{
|
||||
RedirectFlags: protocol.RedirectFlags{
|
||||
Clipboard: conf.Caps.EnableClipboard,
|
||||
@ -205,31 +197,72 @@ func main() {
|
||||
gw.CheckHost = security.CheckHost
|
||||
}
|
||||
|
||||
if conf.Server.Authentication == config.AuthenticationBasic {
|
||||
h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(h.BasicAuth(gw.HandleGatewayProtocol)))
|
||||
} else if conf.Server.Authentication == config.AuthenticationKerberos {
|
||||
r := mux.NewRouter()
|
||||
|
||||
// ensure identity is set in context and get some extra info
|
||||
r.Use(web.EnrichContext)
|
||||
|
||||
// prometheus metrics
|
||||
r.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
// for sso callbacks
|
||||
r.HandleFunc("/tokeninfo", web.TokenInfo)
|
||||
|
||||
// gateway endpoint
|
||||
rdp := r.PathPrefix(gatewayEndPoint).Subrouter()
|
||||
|
||||
// openid
|
||||
if conf.Server.OpenIDEnabled() {
|
||||
log.Printf("enabling openid extended authentication")
|
||||
o := initOIDC(url)
|
||||
r.Handle("/connect", o.Authenticated(http.HandlerFunc(h.HandleDownload)))
|
||||
r.HandleFunc("/callback", o.HandleCallback)
|
||||
|
||||
// only enable un-auth endpoint for openid only config
|
||||
if !conf.Server.KerberosEnabled() || !conf.Server.BasicAuthEnabled() {
|
||||
rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
// for stacking of authentication
|
||||
auth := web.NewAuthMux()
|
||||
|
||||
// basic auth
|
||||
if conf.Server.BasicAuthEnabled() {
|
||||
log.Printf("enabling basic authentication")
|
||||
q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
|
||||
rdp.Headers("Authorization", "Basic*").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol))
|
||||
auth.Register(`Basic realm="restricted", charset="UTF-8"`)
|
||||
}
|
||||
|
||||
// spnego / kerberos
|
||||
if conf.Server.KerberosEnabled() {
|
||||
log.Printf("enabling kerberos authentication")
|
||||
keytab, err := keytab.Load(conf.Kerberos.Keytab)
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot load keytab: %s", err)
|
||||
}
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(
|
||||
spnego.SPNEGOKRB5Authenticate(
|
||||
common.FixKerberosContext(http.HandlerFunc(gw.HandleGatewayProtocol)),
|
||||
rdp.Headers("Authorization", "Negotiate*").Handler(
|
||||
spnego.SPNEGOKRB5Authenticate(web.TransposeSPNEGOContext(http.HandlerFunc(gw.HandleGatewayProtocol)),
|
||||
keytab,
|
||||
service.Logger(log.Default()))),
|
||||
)
|
||||
service.Logger(log.Default())))
|
||||
|
||||
// kdcproxy
|
||||
k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf)
|
||||
http.HandleFunc("/KdcProxy", k.Handler)
|
||||
} else {
|
||||
// openid
|
||||
oidc := initOIDC(url, store)
|
||||
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||
http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback)))
|
||||
r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST")
|
||||
auth.Register("Negotiate")
|
||||
}
|
||||
|
||||
// allow stacking of authentication
|
||||
rdp.Use(auth.Route)
|
||||
|
||||
// setup server
|
||||
server := http.Server{
|
||||
Addr: ":" + strconv.Itoa(conf.Server.Port),
|
||||
Handler: r,
|
||||
TLSConfig: cfg,
|
||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
|
||||
}
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
http.HandleFunc("/tokeninfo", web.TokenInfo)
|
||||
|
||||
if conf.Server.Tls == config.TlsDisable {
|
||||
err = server.ListenAndServe()
|
||||
|
||||
@ -3,7 +3,7 @@ package protocol
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -61,14 +61,14 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||
var t *Tunnel
|
||||
|
||||
ctx := r.Context()
|
||||
id := common.FromRequestCtx(r)
|
||||
id := identity.FromRequestCtx(r)
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
t = &Tunnel{
|
||||
RDGId: connId,
|
||||
RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string),
|
||||
RemoteAddr: id.GetAttribute(identity.AttrRemoteAddr).(string),
|
||||
User: id,
|
||||
}
|
||||
} else {
|
||||
@ -183,14 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
|
||||
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
|
||||
|
||||
id := common.FromRequestCtx(r)
|
||||
id := identity.FromRequestCtx(r)
|
||||
if r.Method == MethodRDGOUT {
|
||||
out, err := transport.NewLegacy(w)
|
||||
if err != nil {
|
||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp))
|
||||
log.Printf("Opening RDGOUT for client %s", id.GetAttribute(identity.AttrClientIp))
|
||||
|
||||
t.transportOut = out
|
||||
out.SendAccept(true)
|
||||
@ -212,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
|
||||
t.transportIn = in
|
||||
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||
|
||||
log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp))
|
||||
log.Printf("Opening RDGIN for client %s", id.GetAttribute(identity.AttrClientIp))
|
||||
in.SendAccept(false)
|
||||
|
||||
// read some initial data
|
||||
in.Drain()
|
||||
|
||||
log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp))
|
||||
log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(identity.AttrClientIp))
|
||||
handler := NewProcessor(g, t)
|
||||
RegisterTunnel(t, handler)
|
||||
defer RemoveTunnel(t)
|
||||
|
||||
@ -6,7 +6,7 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
|
||||
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
|
||||
if p.state != SERVER_STATE_INITIALIZED {
|
||||
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
|
||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
||||
@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||
_, cookie := p.tunnelRequest(pkt)
|
||||
if p.gw.CheckPAACookie != nil {
|
||||
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
|
||||
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
|
||||
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
p.tunnel.Write(msg)
|
||||
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||
"net"
|
||||
"time"
|
||||
@ -27,7 +27,7 @@ type Tunnel struct {
|
||||
// The obtained client ip address
|
||||
RemoteAddr string
|
||||
// User
|
||||
User common.Identity
|
||||
User identity.Identity
|
||||
|
||||
// rwc is the underlying connection to the remote desktop server.
|
||||
// It is of the type *net.TCPConn
|
||||
|
||||
@ -2,7 +2,7 @@ package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||
"testing"
|
||||
)
|
||||
@ -18,7 +18,7 @@ var (
|
||||
)
|
||||
|
||||
func TestCheckHost(t *testing.T) {
|
||||
info.User = common.NewUser()
|
||||
info.User = identity.NewUser()
|
||||
info.User.SetUserName("MYNAME")
|
||||
|
||||
ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info)
|
||||
|
||||
@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
@ -46,10 +46,10 @@ func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
|
||||
}
|
||||
|
||||
// use identity from context rather then set by tunnel
|
||||
id := common.FromCtx(ctx)
|
||||
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) {
|
||||
id := identity.FromCtx(ctx)
|
||||
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(identity.AttrClientIp) {
|
||||
log.Printf("Current client ip address %s does not match token client ip %s",
|
||||
id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr)
|
||||
id.GetAttribute(identity.AttrClientIp), tunnel.RemoteAddr)
|
||||
return false, nil
|
||||
}
|
||||
return next(ctx, host)
|
||||
@ -129,11 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
|
||||
Subject: username,
|
||||
}
|
||||
|
||||
id := common.FromCtx(ctx)
|
||||
id := identity.FromCtx(ctx)
|
||||
private := customClaims{
|
||||
RemoteServer: server,
|
||||
ClientIP: id.GetAttribute(common.AttrClientIp).(string),
|
||||
AccessToken: id.GetAttribute(common.AttrAccessToken).(string),
|
||||
ClientIP: id.GetAttribute(identity.AttrClientIp).(string),
|
||||
AccessToken: id.GetAttribute(identity.AttrAccessToken).(string),
|
||||
}
|
||||
|
||||
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
|
||||
|
||||
@ -2,7 +2,7 @@ package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@ -53,11 +53,11 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
log.Printf("User %s is not authenticated for this service", username)
|
||||
} else {
|
||||
log.Printf("User %s authenticated", username)
|
||||
id := common.FromRequestCtx(r)
|
||||
id := identity.FromRequestCtx(r)
|
||||
id.SetUserName(username)
|
||||
id.SetAuthenticated(true)
|
||||
id.SetAuthTime(time.Now())
|
||||
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
|
||||
next.ServeHTTP(w, identity.AddToRequestCtx(id, r))
|
||||
return
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
// username or password is wrong, then set a WWW-Authenticate
|
||||
// header to inform the client that we expect them to use basic
|
||||
// authentication and send a 401 Unauthorized response.
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
w.Header().Add("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
package common
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/jcmturner/goidentity/v6"
|
||||
"log"
|
||||
"net"
|
||||
@ -9,16 +9,22 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token"
|
||||
)
|
||||
|
||||
func EnrichContext(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id := FromRequestCtx(r)
|
||||
if id == nil {
|
||||
id = NewUser()
|
||||
id, err := GetSessionIdentity(r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if id == nil {
|
||||
id = identity.NewUser()
|
||||
if err := SaveSessionIdentity(r, w, id); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t",
|
||||
id.SessionId(), id.UserName(), id.Authenticated())
|
||||
|
||||
@ -33,39 +39,30 @@ func EnrichContext(next http.Handler) http.Handler {
|
||||
if len(ips) > 1 {
|
||||
proxies = ips[1:]
|
||||
}
|
||||
id.SetAttribute(AttrClientIp, clientIp)
|
||||
id.SetAttribute(AttrProxies, proxies)
|
||||
id.SetAttribute(identity.AttrClientIp, clientIp)
|
||||
id.SetAttribute(identity.AttrProxies, proxies)
|
||||
}
|
||||
|
||||
id.SetAttribute(AttrRemoteAddr, r.RemoteAddr)
|
||||
id.SetAttribute(identity.AttrRemoteAddr, r.RemoteAddr)
|
||||
if h == "" {
|
||||
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
id.SetAttribute(AttrClientIp, clientIp)
|
||||
id.SetAttribute(identity.AttrClientIp, clientIp)
|
||||
}
|
||||
next.ServeHTTP(w, AddToRequestCtx(id, r))
|
||||
next.ServeHTTP(w, identity.AddToRequestCtx(id, r))
|
||||
})
|
||||
}
|
||||
|
||||
func FixKerberosContext(next http.Handler) http.Handler {
|
||||
func TransposeSPNEGOContext(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gid := goidentity.FromHTTPRequestContext(r)
|
||||
if gid != nil {
|
||||
id := FromRequestCtx(r)
|
||||
id := identity.FromRequestCtx(r)
|
||||
id.SetUserName(gid.UserName())
|
||||
id.SetAuthenticated(gid.Authenticated())
|
||||
id.SetDomain(gid.Domain())
|
||||
id.SetAuthTime(gid.AuthTime())
|
||||
r = AddToRequestCtx(id, r)
|
||||
r = identity.AddToRequestCtx(id, r)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func GetAccessToken(ctx context.Context) string {
|
||||
token, ok := ctx.Value(CtxAccessToken).(string)
|
||||
if !ok {
|
||||
log.Printf("cannot get access token from context")
|
||||
return ""
|
||||
}
|
||||
return token
|
||||
}
|
||||
@ -3,9 +3,8 @@ package web
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
"math/rand"
|
||||
@ -14,23 +13,20 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
CacheExpiration = time.Minute * 2
|
||||
CleanupInterval = time.Minute * 5
|
||||
sessionKeyAuthenticated = "authenticated"
|
||||
oidcKeyUserName = "preferred_username"
|
||||
CacheExpiration = time.Minute * 2
|
||||
CleanupInterval = time.Minute * 5
|
||||
oidcKeyUserName = "preferred_username"
|
||||
)
|
||||
|
||||
type OIDC struct {
|
||||
oAuth2Config *oauth2.Config
|
||||
oidcTokenVerifier *oidc.IDTokenVerifier
|
||||
stateStore *cache.Cache
|
||||
sessionStore sessions.Store
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
OAuth2Config *oauth2.Config
|
||||
OIDCTokenVerifier *oidc.IDTokenVerifier
|
||||
SessionStore sessions.Store
|
||||
}
|
||||
|
||||
func (c *OIDCConfig) New() *OIDC {
|
||||
@ -38,7 +34,6 @@ func (c *OIDCConfig) New() *OIDC {
|
||||
oAuth2Config: c.OAuth2Config,
|
||||
oidcTokenVerifier: c.OIDCTokenVerifier,
|
||||
stateStore: cache.New(CacheExpiration, CleanupInterval),
|
||||
sessionStore: c.SessionStore,
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,22 +80,13 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.sessionStore.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
id := common.FromRequestCtx(r)
|
||||
id := identity.FromRequestCtx(r)
|
||||
id.SetUserName(data[oidcKeyUserName].(string))
|
||||
id.SetAuthenticated(true)
|
||||
id.SetAuthTime(time.Now())
|
||||
id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken)
|
||||
id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken)
|
||||
|
||||
session.Options.MaxAge = MaxAge
|
||||
session.Values[common.CTXKey] = id
|
||||
|
||||
if err = session.Save(r, w); err != nil {
|
||||
if err = SaveSessionIdentity(r, w, id); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@ -109,14 +95,9 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := h.sessionStore.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
id := identity.FromRequestCtx(r)
|
||||
|
||||
id := session.Values[common.CTXKey].(common.Identity)
|
||||
if id == nil {
|
||||
if !id.Authenticated() {
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
state := hex.EncodeToString(seed)
|
||||
@ -126,6 +107,6 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// replace the identity with the one from the sessions
|
||||
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
31
cmd/rdpgw/web/router.go
Normal file
31
cmd/rdpgw/web/router.go
Normal file
@ -0,0 +1,31 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AuthMux struct {
|
||||
headers []string
|
||||
}
|
||||
|
||||
func NewAuthMux() *AuthMux {
|
||||
return &AuthMux{}
|
||||
}
|
||||
|
||||
func (a *AuthMux) Register(s string) {
|
||||
a.headers = append(a.headers, s)
|
||||
}
|
||||
|
||||
func (a *AuthMux) Route(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := r.Header.Get("Authorization")
|
||||
if h == "" {
|
||||
for _, s := range a.headers {
|
||||
w.Header().Add("WWW-Authenticate", s)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@ -1,30 +1,75 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/gorilla/sessions"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
type SessionManagerConf struct {
|
||||
SessionKey []byte
|
||||
SessionEncryptionKey []byte
|
||||
StoreType string
|
||||
}
|
||||
const (
|
||||
rdpGwSession = "RDPGWSESSION"
|
||||
MaxAge = 120
|
||||
identityKey = "RDPGWID"
|
||||
)
|
||||
|
||||
func (c *SessionManagerConf) Init() sessions.Store {
|
||||
if len(c.SessionKey) < 32 {
|
||||
var sessionStore sessions.Store
|
||||
|
||||
func InitStore(sessionKey []byte, encryptionKey []byte, storeType string) {
|
||||
if len(sessionKey) < 32 {
|
||||
log.Fatal("Session key too small")
|
||||
}
|
||||
if len(c.SessionEncryptionKey) < 32 {
|
||||
if len(encryptionKey) < 32 {
|
||||
log.Fatal("Session key too small")
|
||||
}
|
||||
|
||||
if c.StoreType == "file" {
|
||||
if storeType == "file" {
|
||||
log.Println("Filesystem is used as session storage")
|
||||
return sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey)
|
||||
sessionStore = sessions.NewFilesystemStore(os.TempDir(), sessionKey, encryptionKey)
|
||||
} else {
|
||||
log.Println("Cookies are used as session storage")
|
||||
return sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
|
||||
sessionStore = sessions.NewCookieStore(sessionKey, encryptionKey)
|
||||
}
|
||||
}
|
||||
|
||||
func GetSession(r *http.Request) (*sessions.Session, error) {
|
||||
session, err := sessionStore.Get(r, rdpGwSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func GetSessionIdentity(r *http.Request) (identity.Identity, error) {
|
||||
s, err := GetSession(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idData := s.Values[identityKey]
|
||||
if idData == nil {
|
||||
return nil, nil
|
||||
|
||||
}
|
||||
id := identity.NewUser()
|
||||
id.Unmarshal(idData.([]byte))
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func SaveSessionIdentity(r *http.Request, w http.ResponseWriter, id identity.Identity) error {
|
||||
session, err := GetSession(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session.Options.MaxAge = MaxAge
|
||||
|
||||
idData, err := id.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session.Values[identityKey] = idData
|
||||
|
||||
return sessionStore.Save(r, w, session)
|
||||
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
@ -14,17 +14,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
RdpGwSession = "RDPGWSESSION"
|
||||
MaxAge = 120
|
||||
)
|
||||
|
||||
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
||||
type QueryInfoFunc func(context.Context, string, string) (string, error)
|
||||
|
||||
type Config struct {
|
||||
SessionStore sessions.Store
|
||||
PAATokenGenerator TokenGeneratorFunc
|
||||
UserTokenGenerator UserTokenGeneratorFunc
|
||||
QueryInfo QueryInfoFunc
|
||||
@ -46,7 +40,6 @@ type RdpOpts struct {
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
sessionStore sessions.Store
|
||||
paaTokenGenerator TokenGeneratorFunc
|
||||
enableUserToken bool
|
||||
userTokenGenerator UserTokenGeneratorFunc
|
||||
@ -63,7 +56,6 @@ func (c *Config) NewHandler() *Handler {
|
||||
log.Fatal("Not enough hosts to connect to specified")
|
||||
}
|
||||
return &Handler{
|
||||
sessionStore: c.SessionStore,
|
||||
paaTokenGenerator: c.PAATokenGenerator,
|
||||
enableUserToken: c.EnableUserToken,
|
||||
userTokenGenerator: c.UserTokenGenerator,
|
||||
@ -132,13 +124,13 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||
}
|
||||
|
||||
func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
id := identity.FromRequestCtx(r)
|
||||
ctx := r.Context()
|
||||
userName, ok := ctx.Value("preferred_username").(string)
|
||||
|
||||
opts := h.rdpOpts
|
||||
|
||||
if !ok {
|
||||
log.Printf("preferred_username not found in context")
|
||||
if !id.Authenticated() {
|
||||
log.Printf("unauthenticated user %s", id.UserName())
|
||||
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -149,13 +141,13 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
||||
host = strings.Replace(host, "{{ preferred_username }}", id.UserName(), 1)
|
||||
|
||||
// split the username into user and domain
|
||||
var user = userName
|
||||
var user = id.UserName()
|
||||
var domain = opts.DefaultDomain
|
||||
if opts.SplitUserDomain {
|
||||
creds := strings.SplitN(userName, "@", 2)
|
||||
creds := strings.SplitN(id.UserName(), "@", 2)
|
||||
user = creds[0]
|
||||
if len(creds) > 1 {
|
||||
domain = creds[1]
|
||||
@ -203,6 +195,8 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
rdp.Connection.GatewayHostname = h.gatewayAddress.Host
|
||||
rdp.Connection.GatewayCredentialsSource = SourceCookie
|
||||
rdp.Connection.GatewayAccessToken = token
|
||||
rdp.Connection.GatewayCredentialMethod = 1
|
||||
rdp.Connection.GatewayUsageMethod = 1
|
||||
rdp.Session.NetworkAutodetect = opts.NetworkAutoDetect != 0
|
||||
rdp.Session.BandwidthAutodetect = opts.BandwidthAutoDetect != 0
|
||||
rdp.Session.ConnectionType = opts.ConnectionType
|
||||
|
||||
@ -2,6 +2,7 @@ package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -113,9 +114,13 @@ func TestHandler_HandleDownload(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
id := identity.NewUser()
|
||||
|
||||
id.SetUserName(testuser)
|
||||
id.SetAuthenticated(true)
|
||||
|
||||
req = identity.AddToRequestCtx(id, req)
|
||||
ctx := req.Context()
|
||||
ctx = context.WithValue(ctx, "preferred_username", testuser)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
u, _ := url.Parse(gateway)
|
||||
c := Config{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user