Browse Source

Refactoring VerifyToken and its underlying functions

master
levotrea 6 years ago
committed by GitHub
parent
commit
533900cd23
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      Gopkg.lock
  2. 2
      Gopkg.toml
  3. 11
      api/keycloak_client.go
  4. 7
      integration/integration_test.go
  5. 28
      toolbox/issuer.go
  6. 13
      toolbox/issuer_test.go

6
Gopkg.lock

@ -2,12 +2,12 @@
[[projects]] [[projects]]
digest = "1:fcbbb02d897902b3460476fcc6723f21a1674c0b6bd4a5579de11d53983d3d70" digest = "1:2c7abdf1364ef3db046635ca5bb2cc26838ccb45a2b6608849e7b7fc6493423d"
name = "github.com/cloudtrust/common-service" name = "github.com/cloudtrust/common-service"
packages = ["errors"] packages = ["errors"]
pruneopts = "" pruneopts = ""
revision = "0dac96e146315562f548272f90e3f0ccf9ea7ddb" revision = "c9a387c12b76cd979ddf8be40168f3e19f643184"
version = "v2.2.0" version = "v2.2.1"
[[projects]] [[projects]]
digest = "1:bb7f91ab4d1c44a3bb2651c613463c134165bda0282fca891a63b88d1b501997" digest = "1:bb7f91ab4d1c44a3bb2651c613463c134165bda0282fca891a63b88d1b501997"

2
Gopkg.toml

@ -22,7 +22,7 @@
[[constraint]] [[constraint]]
name = "github.com/cloudtrust/common-service" name = "github.com/cloudtrust/common-service"
version = "v2.2.0" version = "2.2.1"
[[constraint]] [[constraint]]
name = "github.com/pkg/errors" name = "github.com/pkg/errors"

11
api/keycloak_client.go

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"regexp" "regexp"
@ -35,11 +34,11 @@ type AccountClient struct {
} }
// New returns a keycloak client. // New returns a keycloak client.
func New(config keycloak.Config, keyContextIssuerDomain interface{}) (*Client, error) { func New(config keycloak.Config) (*Client, error) {
var issuerMgr toolbox.IssuerManager var issuerMgr toolbox.IssuerManager
{ {
var err error var err error
issuerMgr, err = toolbox.NewIssuerManager(config, keyContextIssuerDomain) issuerMgr, err = toolbox.NewIssuerManager(config)
if err != nil { if err != nil {
return nil, errors.Wrap(err, keycloak.MsgErrCannotParse+"."+keycloak.TokenProviderURL) return nil, errors.Wrap(err, keycloak.MsgErrCannotParse+"."+keycloak.TokenProviderURL)
} }
@ -120,13 +119,13 @@ func (c *Client) GetToken(realm string, username string, password string) (strin
} }
// VerifyToken verifies a token. It returns an error it is malformed, expired,... // VerifyToken verifies a token. It returns an error it is malformed, expired,...
func (c *Client) VerifyToken(ctx context.Context, realmName string, accessToken string) error { func (c *Client) VerifyToken(issuer string, realmName string, accessToken string) error {
issuer, err := c.issuerManager.GetIssuer(ctx) oidcVerifierProvider, err := c.issuerManager.GetOidcVerifierProvider(issuer)
if err != nil { if err != nil {
return err return err
} }
verifier, err := issuer.GetOidcVerifier(realmName) verifier, err := oidcVerifierProvider.GetOidcVerifier(realmName)
if err != nil { if err != nil {
return err return err
} }

7
integration/integration_test.go

@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"strings" "strings"
@ -18,13 +17,11 @@ const (
tstRealm = "__internal" tstRealm = "__internal"
reqRealm = "master" reqRealm = "master"
user = "version" user = "version"
keyContextIssuerDomain keyContext = iota
) )
func main() { func main() {
var conf = getKeycloakConfig() var conf = getKeycloakConfig()
var client, err = api.New(*conf, keyContextIssuerDomain) var client, err = api.New(*conf)
if err != nil { if err != nil {
log.Fatalf("could not create keycloak client: %v", err) log.Fatalf("could not create keycloak client: %v", err)
} }
@ -35,7 +32,7 @@ func main() {
log.Fatalf("could not get access token: %v", err) log.Fatalf("could not get access token: %v", err)
} }
err = client.VerifyToken(context.Background(), "master", accessToken) err = client.VerifyToken("issuer", "master", accessToken)
if err != nil { if err != nil {
log.Fatalf("could not validate access token: %v", err) log.Fatalf("could not validate access token: %v", err)
} }

28
toolbox/issuer.go

@ -1,7 +1,6 @@
package toolbox package toolbox
import ( import (
"context"
"errors" "errors"
"net/url" "net/url"
"regexp" "regexp"
@ -13,12 +12,11 @@ import (
// IssuerManager provides URL according to a given context // IssuerManager provides URL according to a given context
type IssuerManager interface { type IssuerManager interface {
GetIssuer(ctx context.Context) (OidcVerifierProvider, error) GetOidcVerifierProvider(issuer string) (OidcVerifierProvider, error)
} }
type issuerManager struct { type issuerManager struct {
domainToIssuer map[string]OidcVerifierProvider domainToVerifier map[string]OidcVerifierProvider
keyContextIssuerDomain interface{}
} }
func getProtocolAndDomain(URL string) string { func getProtocolAndDomain(URL string) string {
@ -32,7 +30,7 @@ func getProtocolAndDomain(URL string) string {
} }
// NewIssuerManager creates a new URLProvider // NewIssuerManager creates a new URLProvider
func NewIssuerManager(config keycloak.Config, keyContextIssuerDomain interface{}) (IssuerManager, error) { func NewIssuerManager(config keycloak.Config) (IssuerManager, error) {
URLs := config.AddrTokenProvider URLs := config.AddrTokenProvider
// Use default values when clients are not initializing these values // Use default values when clients are not initializing these values
cacheTTL := config.CacheTTL cacheTTL := config.CacheTTL
@ -44,29 +42,25 @@ func NewIssuerManager(config keycloak.Config, keyContextIssuerDomain interface{}
errTolerance = time.Minute errTolerance = time.Minute
} }
var domainToIssuer = make(map[string]OidcVerifierProvider) var domainToVerifier = make(map[string]OidcVerifierProvider)
for _, value := range strings.Split(URLs, " ") { for _, value := range strings.Split(URLs, " ") {
uToken, err := url.Parse(value) uToken, err := url.Parse(value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
issuer := NewVerifierCache(uToken, cacheTTL, errTolerance) verifier := NewVerifierCache(uToken, cacheTTL, errTolerance)
domainToIssuer[getProtocolAndDomain(value)] = issuer domainToVerifier[getProtocolAndDomain(value)] = verifier
} }
return &issuerManager{ return &issuerManager{
domainToIssuer: domainToIssuer, domainToVerifier: domainToVerifier,
keyContextIssuerDomain: keyContextIssuerDomain,
}, nil }, nil
} }
func (im *issuerManager) GetIssuer(ctx context.Context) (OidcVerifierProvider, error) { func (im *issuerManager) GetOidcVerifierProvider(issuer string) (OidcVerifierProvider, error) {
if rawValue := ctx.Value(im.keyContextIssuerDomain); rawValue != nil { issuerDomain := getProtocolAndDomain(issuer)
// The issuer domain has been found in the context if verifier, ok := im.domainToVerifier[issuerDomain]; ok {
issuerDomain := getProtocolAndDomain(rawValue.(string)) return verifier, nil
if issuer, ok := im.domainToIssuer[issuerDomain]; ok {
return issuer, nil
}
} }
return nil, errors.New("Unknown issuer") return nil, errors.New("Unknown issuer")
} }

13
toolbox/issuer_test.go

@ -1,7 +1,6 @@
package toolbox package toolbox
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
@ -23,7 +22,7 @@ func TestGetProtocolAndDomain(t *testing.T) {
func TestNewIssuerManager(t *testing.T) { func TestNewIssuerManager(t *testing.T) {
t.Run("Invalid URL", func(t *testing.T) { t.Run("Invalid URL", func(t *testing.T) {
_, err := NewIssuerManager(keycloak.Config{AddrTokenProvider: ":"}, keyContextIssuerDomain) _, err := NewIssuerManager(keycloak.Config{AddrTokenProvider: ":"})
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
@ -32,18 +31,18 @@ func TestNewIssuerManager(t *testing.T) {
otherDomainPath := "http://other.domain.com:2120/" otherDomainPath := "http://other.domain.com:2120/"
allDomains := fmt.Sprintf("%s %s %s", defaultPath, myDomainPath, otherDomainPath) allDomains := fmt.Sprintf("%s %s %s", defaultPath, myDomainPath, otherDomainPath)
prov, err := NewIssuerManager(keycloak.Config{AddrTokenProvider: allDomains}, keyContextIssuerDomain) prov, err := NewIssuerManager(keycloak.Config{AddrTokenProvider: allDomains})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, prov) assert.NotNil(t, prov)
// No issuer provided with context // No issuer provided with context
issuerNoContext, _ := prov.GetIssuer(context.Background()) issuerNoContext, _ := prov.GetOidcVerifierProvider("")
// Unrecognized issuer provided in context // Unrecognized issuer provided in context
issuerDefault, _ := prov.GetIssuer(context.WithValue(context.Background(), keyContextIssuerDomain, "http://unknown.issuer.com/one/path")) issuerDefault, _ := prov.GetOidcVerifierProvider("http://unknown.issuer.com/one/path")
// Case insensitive // Case insensitive
issuerMyDomain, _ := prov.GetIssuer(context.WithValue(context.Background(), keyContextIssuerDomain, "http://MY.DOMAIN.COM/issuer")) issuerMyDomain, _ := prov.GetOidcVerifierProvider("http://MY.DOMAIN.COM/issuer")
// Other domain // Other domain
issuerOtherDomain, _ := prov.GetIssuer(context.WithValue(context.Background(), keyContextIssuerDomain, "http://other.domain.com:2120/any/thing/here")) issuerOtherDomain, _ := prov.GetOidcVerifierProvider("http://other.domain.com:2120/any/thing/here")
assert.Equal(t, issuerNoContext, issuerDefault) assert.Equal(t, issuerNoContext, issuerDefault)
assert.NotEqual(t, issuerNoContext, issuerMyDomain) assert.NotEqual(t, issuerNoContext, issuerMyDomain)

Loading…
Cancel
Save