committed by
rpo
5 changed files with 311 additions and 35 deletions
@ -0,0 +1,86 @@ |
|||||
|
package keycloak |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"fmt" |
||||
|
"net/url" |
||||
|
"time" |
||||
|
|
||||
|
oidc "github.com/coreos/go-oidc" |
||||
|
"github.com/pkg/errors" |
||||
|
) |
||||
|
|
||||
|
// OidcVerifierProvider is an interface for a provider of OidcVerifier instances
|
||||
|
type OidcVerifierProvider interface { |
||||
|
GetOidcVerifier(realm string) (OidcVerifier, error) |
||||
|
} |
||||
|
|
||||
|
// OidcVerifier is an interface for OIDC token verifiers
|
||||
|
type OidcVerifier interface { |
||||
|
Verify(accessToken string) error |
||||
|
} |
||||
|
|
||||
|
type verifierCache struct { |
||||
|
duration time.Duration |
||||
|
errorTolerance time.Duration |
||||
|
tokenProviderURL *url.URL |
||||
|
verifiers map[string]cachedVerifier |
||||
|
} |
||||
|
|
||||
|
type cachedVerifier struct { |
||||
|
verifier *oidc.IDTokenVerifier |
||||
|
createdAt time.Time |
||||
|
expireAt time.Time |
||||
|
invalidateOnErrorAt time.Time |
||||
|
} |
||||
|
|
||||
|
// NewVerifierCache create an instance of OIDC verifier cache
|
||||
|
func NewVerifierCache(tokenProviderURL *url.URL, timeToLive time.Duration, errorTolerance time.Duration) OidcVerifierProvider { |
||||
|
return &verifierCache{ |
||||
|
duration: timeToLive, |
||||
|
errorTolerance: errorTolerance, |
||||
|
tokenProviderURL: tokenProviderURL, |
||||
|
verifiers: make(map[string]cachedVerifier), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (vc *verifierCache) GetOidcVerifier(realm string) (OidcVerifier, error) { |
||||
|
v, ok := vc.verifiers[realm] |
||||
|
if ok && v.isValid() { |
||||
|
return &v, nil |
||||
|
} |
||||
|
var oidcProvider *oidc.Provider |
||||
|
{ |
||||
|
var err error |
||||
|
var issuer = fmt.Sprintf("%s/auth/realms/%s", vc.tokenProviderURL.String(), realm) |
||||
|
oidcProvider, err = oidc.NewProvider(context.Background(), issuer) |
||||
|
if err != nil { |
||||
|
return nil, errors.Wrap(err, "could not create oidc provider") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
ov := oidcProvider.Verifier(&oidc.Config{SkipClientIDCheck: true}) |
||||
|
res := cachedVerifier{ |
||||
|
createdAt: time.Now(), |
||||
|
expireAt: time.Now().Add(vc.duration), |
||||
|
invalidateOnErrorAt: time.Now().Add(vc.errorTolerance), |
||||
|
verifier: ov, |
||||
|
} |
||||
|
vc.verifiers[realm] = res |
||||
|
|
||||
|
return &res, nil |
||||
|
} |
||||
|
|
||||
|
func (cv *cachedVerifier) isValid() bool { |
||||
|
return time.Now().Before(cv.expireAt) |
||||
|
} |
||||
|
|
||||
|
func (cv *cachedVerifier) Verify(accessToken string) error { |
||||
|
_, err := cv.verifier.Verify(context.Background(), accessToken) |
||||
|
if err != nil && time.Now().After(cv.invalidateOnErrorAt) { |
||||
|
// An error occured and current time is after invalidateOnErrorAt
|
||||
|
// Let's make this verifier expire
|
||||
|
cv.expireAt = cv.createdAt |
||||
|
} |
||||
|
return err |
||||
|
} |
||||
@ -0,0 +1,114 @@ |
|||||
|
package keycloak |
||||
|
|
||||
|
//go:generate mockgen -destination=./mock/authmanager.go -package=mock -mock_names=AuthorizationManager=AuthorizationManager github.com/cloudtrust/common-service/security AuthorizationManager
|
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"encoding/json" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"net/url" |
||||
|
"strings" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
http_transport "github.com/go-kit/kit/transport/http" |
||||
|
"github.com/gorilla/mux" |
||||
|
"github.com/stretchr/testify/assert" |
||||
|
) |
||||
|
|
||||
|
func decodeRequest(_ context.Context, req *http.Request) (interface{}, error) { |
||||
|
res := map[string]string{"realm": mux.Vars(req)["realm"], "host": req.Host} |
||||
|
return res, nil |
||||
|
} |
||||
|
|
||||
|
func encodeReply(_ context.Context, w http.ResponseWriter, rep interface{}) error { |
||||
|
if rep == nil { |
||||
|
w.WriteHeader(404) |
||||
|
return nil |
||||
|
} |
||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8") |
||||
|
w.WriteHeader(200) |
||||
|
|
||||
|
var json, err = json.Marshal(rep) |
||||
|
if err == nil { |
||||
|
w.Write(json) |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func errorHandler(_ context.Context, _ error, w http.ResponseWriter) { |
||||
|
w.WriteHeader(500) |
||||
|
} |
||||
|
|
||||
|
func endpoint(_ context.Context, request interface{}) (response interface{}, err error) { |
||||
|
var query = request.(map[string]string) |
||||
|
if !strings.Contains(query["realm"], "realm") { |
||||
|
return nil, nil |
||||
|
} |
||||
|
return map[string]string{ |
||||
|
"issuer": "http://" + query["host"] + "/auth/realms/" + query["realm"], |
||||
|
"authorization_endpoint": "", |
||||
|
"token_endpoint": "", |
||||
|
"jwks_uri": "", |
||||
|
"userinfo_endpoint": "", |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
func TestGetOidcVerifier(t *testing.T) { |
||||
|
verifierHandler := http_transport.NewServer(endpoint, decodeRequest, encodeReply, http_transport.ServerErrorEncoder(errorHandler)) |
||||
|
|
||||
|
r := mux.NewRouter() |
||||
|
r.Handle("/auth/realms/{realm}/.well-known/openid-configuration", verifierHandler) |
||||
|
|
||||
|
ts := httptest.NewServer(r) |
||||
|
defer ts.Close() |
||||
|
|
||||
|
url, _ := url.Parse(ts.URL) |
||||
|
|
||||
|
{ |
||||
|
// First test with a verifier which hardly expires
|
||||
|
verifier := NewVerifierCache(url, time.Minute, 10*time.Minute) |
||||
|
|
||||
|
{ |
||||
|
// Unknown realm: can't get verifier
|
||||
|
_, err := verifier.GetOidcVerifier("unknown") |
||||
|
assert.NotNil(t, err) |
||||
|
} |
||||
|
|
||||
|
v1, e := verifier.GetOidcVerifier("realm1") |
||||
|
assert.Nil(t, e) |
||||
|
{ |
||||
|
// Ask for the same realm before its verifier expires
|
||||
|
v2, _ := verifier.GetOidcVerifier("realm1") |
||||
|
assert.Equal(t, v1, v2) |
||||
|
} |
||||
|
{ |
||||
|
// Ask for a different verifier
|
||||
|
v3, _ := verifier.GetOidcVerifier("realm2") |
||||
|
assert.NotEqual(t, v1, v3) |
||||
|
} |
||||
|
|
||||
|
time.Sleep(100 * time.Millisecond) |
||||
|
assert.NotNil(t, v1.Verify("abcdef")) |
||||
|
} |
||||
|
|
||||
|
{ |
||||
|
// Now, test with a verifier which quickly expires on error
|
||||
|
verifier := NewVerifierCache(url, time.Minute, time.Millisecond) |
||||
|
v1, _ := verifier.GetOidcVerifier("realm1") |
||||
|
time.Sleep(100 * time.Millisecond) |
||||
|
{ |
||||
|
// Ask for the same realm before its verifier expires
|
||||
|
v2, _ := verifier.GetOidcVerifier("realm1") |
||||
|
assert.Equal(t, v1, v2) |
||||
|
} |
||||
|
{ |
||||
|
// Verify an invalid token
|
||||
|
assert.NotNil(t, v1.Verify("abcdef")) |
||||
|
// Ask for the same realm before its verifier expires but after an error occured
|
||||
|
v2, _ := verifier.GetOidcVerifier("realm1") |
||||
|
assert.NotEqual(t, v1, v2) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
Loading…
Reference in new issue