diff --git a/Gopkg.lock b/Gopkg.lock index 3e683f2..a7ba409 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -2,12 +2,20 @@ [[projects]] - digest = "1:d8ee1b165eb7f4fd9ada718e1e7eeb0bc1fd462592d0bd823df694443f448681" + digest = "1:379d34d9efc755fab444199f007819fe99718640f9ccfbdd3f0430340bb02b07" name = "github.com/coreos/go-oidc" packages = ["."] pruneopts = "" - revision = "1180514eaf4d9f38d0d19eef639a1d695e066e72" - version = "v2.0.0" + revision = "2be1c5b8a260760503f66dc0996e102b683b3ac3" + version = "v2.1.0" + +[[projects]] + digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" [[projects]] digest = "1:0c07a9cb3d3c845439a4fcaae6c8bdd0e7727cbbd3acf1e032e5d4a2dc132306" @@ -18,12 +26,49 @@ version = "v2.0.0" [[projects]] - digest = "1:529d738b7976c3848cae5cf3a8036440166835e389c1f617af701eeb12a0518d" + digest = "1:183b1cb81b770d8033281c5629a4847a2ed7614068bb33c5a9a159d1226b23f0" + name = "github.com/go-kit/kit" + packages = [ + "endpoint", + "log", + "transport", + "transport/http", + ] + pruneopts = "" + revision = "150a65a7ec6156b4b640c1fd55f26fd3d475d656" + version = "v0.9.0" + +[[projects]] + digest = "1:df89444601379b2e1ee82bf8e6b72af9901cbeed4b469fa380a519c89c339310" + name = "github.com/go-logfmt/logfmt" + packages = ["."] + pruneopts = "" + revision = "07c9b44f60d7ffdfb7d8efe1ad539965737836dc" + version = "v0.4.0" + +[[projects]] + digest = "1:b852d2b62be24e445fcdbad9ce3015b44c207815d631230dfce3f14e7803f5bf" name = "github.com/golang/protobuf" packages = ["proto"] pruneopts = "" - revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30" - version = "v1.3.1" + revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7" + version = "v1.3.2" + +[[projects]] + digest = "1:883e2fdbdd0e577187bd8106fec775b1176059af267a7f40eba5308955c67d52" + name = "github.com/gorilla/mux" + packages = ["."] + pruneopts = "" + revision = "00bdffe0f3c77e27d2cf6f5c70232a2d3e4d9c15" + version = "v1.7.3" + +[[projects]] + branch = "master" + digest = "1:1ed9eeebdf24aadfbca57eb50e6455bd1d2474525e0f0d4454de8c8e9bc7ee9a" + name = "github.com/kr/logfmt" + packages = ["."] + pruneopts = "" + revision = "b84e30acd515aadc4b783ad4ff83aff3299bdfe0" [[projects]] digest = "1:1d7e1867c49a6dd9856598ef7c3123604ea3daabf5b83f303ff457bcbc410b1d" @@ -33,6 +78,14 @@ revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" version = "v0.8.1" +[[projects]] + digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + [[projects]] branch = "master" digest = "1:de5481dda0c081b66450e391bbb1a5c4435b13e3c0bbf0133ba1a5baeda7b7af" @@ -52,9 +105,17 @@ revision = "298182f68c66c05229eb03ac171abe6e309ee79a" version = "v1.0.3" +[[projects]] + digest = "1:f7b541897bcde05a04a044c342ddc7425aab7e331f37b47fbb486cd16324b48e" + name = "github.com/stretchr/testify" + packages = ["assert"] + pruneopts = "" + revision = "221dbe5ed46703ee255b1da0dec05086f5035f62" + version = "v1.4.0" + [[projects]] branch = "master" - digest = "1:086760278d762dbb0e9a26e09b57f04c89178c86467d8d94fae47d64c222f328" + digest = "1:a530f8e0c0ee8a3b440f9f0b0e9f4e5d5e47cfe3a581086ce32cd8ba114ddf4f" name = "golang.org/x/crypto" packages = [ "ed25519", @@ -62,11 +123,11 @@ "pbkdf2", ] pruneopts = "" - revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285" + revision = "9756ffdc24725223350eb3266ffb92590d28f278" [[projects]] branch = "master" - digest = "1:31cd6e3c114e17c5f0c9e8b0bcaa3025ab3c221ce36323c7ce1acaa753d0d0aa" + digest = "1:87c06c289123bf8be0a776c57ca40ce075f6c598a905ff2ff8ba40fba0d5d17c" name = "golang.org/x/net" packages = [ "context", @@ -75,7 +136,7 @@ "publicsuffix", ] pruneopts = "" - revision = "da137c7871d730100384dbcf36e6f8fa493aef5b" + revision = "ba9fcec4b297b415637633c5a6e8fa592e4a16c3" [[projects]] branch = "master" @@ -114,7 +175,7 @@ version = "v0.3.2" [[projects]] - digest = "1:47f391ee443f578f01168347818cb234ed819521e49e4d2c8dd2fb80d48ee41a" + digest = "1:0568e577f790e9bd0420521cff50580f9b38165a38f217ce68f55c4bbaa97066" name = "google.golang.org/appengine" packages = [ "internal", @@ -126,8 +187,8 @@ "urlfetch", ] pruneopts = "" - revision = "b2f4a3cf3c67576a2ee09e1fe62656a5086ce880" - version = "v1.6.1" + revision = "5f2a59506353b8d5ba8cbbcd9f3c1f41f1eaf079" + version = "v1.6.2" [[projects]] digest = "1:e3250d192192f02fbb143d50de437cbe967d6be7bd9fad671600942a33269d08" @@ -164,17 +225,29 @@ revision = "730df5f748271903322feb182be83b43ebbbe27d" version = "v2.3.1" +[[projects]] + digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 input-imports = [ "github.com/coreos/go-oidc", "github.com/gbrlsnchs/jwt", + "github.com/go-kit/kit/transport/http", + "github.com/gorilla/mux", "github.com/pkg/errors", "github.com/spf13/pflag", + "github.com/stretchr/testify/assert", "gopkg.in/h2non/gentleman.v2", "gopkg.in/h2non/gentleman.v2/plugin", "gopkg.in/h2non/gentleman.v2/plugins/body", + "gopkg.in/h2non/gentleman.v2/plugins/headers", "gopkg.in/h2non/gentleman.v2/plugins/query", "gopkg.in/h2non/gentleman.v2/plugins/timeout", "gopkg.in/h2non/gentleman.v2/plugins/url", diff --git a/client_role_mappings.go b/client_role_mappings.go index a06cadb..6c448ed 100644 --- a/client_role_mappings.go +++ b/client_role_mappings.go @@ -10,7 +10,7 @@ const ( realmRoleMappingPath = "/auth/admin/realms/:realm/users/:id/role-mappings/realm" ) -// AddClientRoleMapping add client-level roles to the user role mapping. +// AddClientRolesToUserRoleMapping add client-level roles to the user role mapping. func (c *Client) AddClientRolesToUserRoleMapping(accessToken string, realmName, userID, clientID string, roles []RoleRepresentation) error { _, err := c.post(accessToken, nil, url.Path(clientRoleMappingPath), url.Param("realm", realmName), url.Param("id", userID), url.Param("client", clientID), body.JSON(roles)) return err @@ -28,6 +28,7 @@ func (c *Client) DeleteClientRolesFromUserRoleMapping(accessToken string, realmN return c.delete(accessToken, url.Path(clientRoleMappingPath), url.Param("realm", realmName), url.Param("id", userID), url.Param("client", clientID)) } +// GetRealmLevelRoleMappings gets realm level role mappings func (c *Client) GetRealmLevelRoleMappings(accessToken string, realmName, userID string) ([]RoleRepresentation, error) { var resp = []RoleRepresentation{} var err = c.get(accessToken, &resp, url.Path(realmRoleMappingPath), url.Param("realm", realmName), url.Param("id", userID)) diff --git a/keycloak_client.go b/keycloak_client.go index da6926b..51ea9cc 100644 --- a/keycloak_client.go +++ b/keycloak_client.go @@ -1,14 +1,12 @@ package keycloak import ( - "context" "encoding/json" "fmt" "net/http" "net/url" "time" - oidc "github.com/coreos/go-oidc" "github.com/pkg/errors" "gopkg.in/h2non/gentleman.v2" "gopkg.in/h2non/gentleman.v2/plugin" @@ -23,13 +21,15 @@ type Config struct { AddrTokenProvider string AddrAPI string Timeout time.Duration + CacheTTL time.Duration + ErrorTolerance time.Duration } // Client is the keycloak client. type Client struct { - tokenProviderURL *url.URL apiURL *url.URL httpClient *gentleman.Client + verifierProvider OidcVerifierProvider } // HTTPError is returned when an error occured while contacting the keycloak instance. @@ -68,14 +68,26 @@ func New(config Config) (*Client, error) { httpClient = httpClient.Use(timeout.Request(config.Timeout)) } - return &Client{ - tokenProviderURL: uToken, + // Use default values when clients are not initializing these values + cacheTTL := config.CacheTTL + if cacheTTL == 0 { + cacheTTL = 15 * time.Minute + } + errTolerance := config.ErrorTolerance + if errTolerance == 0 { + errTolerance = time.Minute + } + + var client = &Client{ apiURL: uAPI, httpClient: httpClient, - }, nil + verifierProvider: NewVerifierCache(uToken, cacheTTL, errTolerance), + } + + return client, nil } -// getToken returns a valid token from keycloak. +// GetToken returns a valid token from keycloak. func (c *Client) GetToken(realm string, username string, password string) (string, error) { var req *gentleman.Request { @@ -121,22 +133,12 @@ func (c *Client) GetToken(realm string, username string, password string) (strin return accessToken.(string), nil } -// verifyToken token verify a token. It returns an error it is malformed, expired,... +// VerifyToken verifies a token. It returns an error it is malformed, expired,... func (c *Client) VerifyToken(realmName string, accessToken string) error { - var oidcProvider *oidc.Provider - { - var err error - var issuer = fmt.Sprintf("%s/auth/realms/%s", c.tokenProviderURL.String(), realmName) - oidcProvider, err = oidc.NewProvider(context.Background(), issuer) - if err != nil { - return errors.Wrap(err, "could not create oidc provider") - } + verifier, err := c.verifierProvider.GetOidcVerifier(realmName) + if err != nil { + err = verifier.Verify(accessToken) } - - var v = oidcProvider.Verifier(&oidc.Config{SkipClientIDCheck: true}) - - var err error - _, err = v.Verify(context.Background(), accessToken) return err } diff --git a/oidc_verifier.go b/oidc_verifier.go new file mode 100644 index 0000000..1b59853 --- /dev/null +++ b/oidc_verifier.go @@ -0,0 +1,86 @@ +package keycloak + +import ( + "context" + "fmt" + "net/url" + "time" + + oidc "github.com/coreos/go-oidc" + "github.com/pkg/errors" +) + +// OidcVerifierProvider is an interface for a provider of OidcVerifier instances +type OidcVerifierProvider interface { + GetOidcVerifier(realm string) (OidcVerifier, error) +} + +// OidcVerifier is an interface for OIDC token verifiers +type OidcVerifier interface { + Verify(accessToken string) error +} + +type verifierCache struct { + duration time.Duration + errorTolerance time.Duration + tokenProviderURL *url.URL + verifiers map[string]cachedVerifier +} + +type cachedVerifier struct { + verifier *oidc.IDTokenVerifier + createdAt time.Time + expireAt time.Time + invalidateOnErrorAt time.Time +} + +// NewVerifierCache create an instance of OIDC verifier cache +func NewVerifierCache(tokenProviderURL *url.URL, timeToLive time.Duration, errorTolerance time.Duration) OidcVerifierProvider { + return &verifierCache{ + duration: timeToLive, + errorTolerance: errorTolerance, + tokenProviderURL: tokenProviderURL, + verifiers: make(map[string]cachedVerifier), + } +} + +func (vc *verifierCache) GetOidcVerifier(realm string) (OidcVerifier, error) { + v, ok := vc.verifiers[realm] + if ok && v.isValid() { + return &v, nil + } + var oidcProvider *oidc.Provider + { + var err error + var issuer = fmt.Sprintf("%s/auth/realms/%s", vc.tokenProviderURL.String(), realm) + oidcProvider, err = oidc.NewProvider(context.Background(), issuer) + if err != nil { + return nil, errors.Wrap(err, "could not create oidc provider") + } + } + + ov := oidcProvider.Verifier(&oidc.Config{SkipClientIDCheck: true}) + res := cachedVerifier{ + createdAt: time.Now(), + expireAt: time.Now().Add(vc.duration), + invalidateOnErrorAt: time.Now().Add(vc.errorTolerance), + verifier: ov, + } + vc.verifiers[realm] = res + + return &res, nil +} + +func (cv *cachedVerifier) isValid() bool { + return time.Now().Before(cv.expireAt) +} + +func (cv *cachedVerifier) Verify(accessToken string) error { + _, err := cv.verifier.Verify(context.Background(), accessToken) + if err != nil && time.Now().After(cv.invalidateOnErrorAt) { + // An error occured and current time is after invalidateOnErrorAt + // Let's make this verifier expire + cv.expireAt = cv.createdAt + } + return err +} diff --git a/oidc_verifier_test.go b/oidc_verifier_test.go new file mode 100644 index 0000000..bf64f43 --- /dev/null +++ b/oidc_verifier_test.go @@ -0,0 +1,114 @@ +package keycloak + +//go:generate mockgen -destination=./mock/authmanager.go -package=mock -mock_names=AuthorizationManager=AuthorizationManager github.com/cloudtrust/common-service/security AuthorizationManager + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + http_transport "github.com/go-kit/kit/transport/http" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" +) + +func decodeRequest(_ context.Context, req *http.Request) (interface{}, error) { + res := map[string]string{"realm": mux.Vars(req)["realm"], "host": req.Host} + return res, nil +} + +func encodeReply(_ context.Context, w http.ResponseWriter, rep interface{}) error { + if rep == nil { + w.WriteHeader(404) + return nil + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + + var json, err = json.Marshal(rep) + if err == nil { + w.Write(json) + } + return nil +} + +func errorHandler(_ context.Context, _ error, w http.ResponseWriter) { + w.WriteHeader(500) +} + +func endpoint(_ context.Context, request interface{}) (response interface{}, err error) { + var query = request.(map[string]string) + if !strings.Contains(query["realm"], "realm") { + return nil, nil + } + return map[string]string{ + "issuer": "http://" + query["host"] + "/auth/realms/" + query["realm"], + "authorization_endpoint": "", + "token_endpoint": "", + "jwks_uri": "", + "userinfo_endpoint": "", + }, nil +} + +func TestGetOidcVerifier(t *testing.T) { + verifierHandler := http_transport.NewServer(endpoint, decodeRequest, encodeReply, http_transport.ServerErrorEncoder(errorHandler)) + + r := mux.NewRouter() + r.Handle("/auth/realms/{realm}/.well-known/openid-configuration", verifierHandler) + + ts := httptest.NewServer(r) + defer ts.Close() + + url, _ := url.Parse(ts.URL) + + { + // First test with a verifier which hardly expires + verifier := NewVerifierCache(url, time.Minute, 10*time.Minute) + + { + // Unknown realm: can't get verifier + _, err := verifier.GetOidcVerifier("unknown") + assert.NotNil(t, err) + } + + v1, e := verifier.GetOidcVerifier("realm1") + assert.Nil(t, e) + { + // Ask for the same realm before its verifier expires + v2, _ := verifier.GetOidcVerifier("realm1") + assert.Equal(t, v1, v2) + } + { + // Ask for a different verifier + v3, _ := verifier.GetOidcVerifier("realm2") + assert.NotEqual(t, v1, v3) + } + + time.Sleep(100 * time.Millisecond) + assert.NotNil(t, v1.Verify("abcdef")) + } + + { + // Now, test with a verifier which quickly expires on error + verifier := NewVerifierCache(url, time.Minute, time.Millisecond) + v1, _ := verifier.GetOidcVerifier("realm1") + time.Sleep(100 * time.Millisecond) + { + // Ask for the same realm before its verifier expires + v2, _ := verifier.GetOidcVerifier("realm1") + assert.Equal(t, v1, v2) + } + { + // Verify an invalid token + assert.NotNil(t, v1.Verify("abcdef")) + // Ask for the same realm before its verifier expires but after an error occured + v2, _ := verifier.GetOidcVerifier("realm1") + assert.NotEqual(t, v1, v2) + } + } +}