From e5ac34903d270a8ad524bf52df3482e1c6fd48b0 Mon Sep 17 00:00:00 2001 From: Max Schmitt Date: Wed, 15 Nov 2017 17:55:16 +0100 Subject: [PATCH] Cleaned up the oAuth stuff #23 --- .travis.yml | 2 +- descriptor.json => build/bintray.json | 0 handlers/auth.go | 155 ++++++-------------------- handlers/auth/auth.go | 100 +++++++++++++++++ handlers/auth/github.go | 2 + handlers/auth/google.go | 66 +++++++++++ handlers/auth_test.go | 40 ++----- handlers/handlers.go | 6 +- handlers/public.go | 5 +- handlers/public_test.go | 1 - handlers/utils.go | 8 -- main.go | 5 +- static/src/index.js | 2 +- util/util.go | 2 + 14 files changed, 218 insertions(+), 176 deletions(-) rename descriptor.json => build/bintray.json (100%) create mode 100644 handlers/auth/auth.go create mode 100644 handlers/auth/github.go create mode 100644 handlers/auth/google.go create mode 100644 util/util.go diff --git a/.travis.yml b/.travis.yml index 376243c..8db0e04 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,7 +15,7 @@ script: deploy: provider: bintray user: maxibanki - file: "descriptor.json" + file: "build/descriptor.json" key: secure: ErqvSFIlL3d9XuMj+T+hO6xZqll8Ubx0DEdHD6NJKi7sH7Be3b3/vUoPdAjdFOP70DhaccbncGCTPZ9hsNKdqYxZmuKx3WWwH4H4U5YdDIViXtH6+B5KdAmvdZIynaj+THQAbVAhr+QyvcqotNySPd3Ac1HCg2YAcUHme6y3FsiRJ79To80JWxTSR1G/oObmeoDn8R18wmH1gHl8KQ7ltC537Osb/H34bJ/hY94hRe8IEmoQE4yz/EP44kGXRb/F87i92y1mO081ZS1I1hs5Kbom43YoItqSVbJP/abPMyCsGDv2FGXaGqk5IVC1k+01pcAjqxCzMvXC272itc0E8OEWqE4qONN+m2S9tyALyOaUZ7j5meWLHQj49Rzo7XIWh1PvvEMovdl/wk/Oc9f0ZywPuvoRht5ZebgXbPWAMMNywwy0GKM4nU0DCyFm23mlzPh4iklo12gEUzq3YLc18RhAZuy4timeevrDCuJMQeQ3sqcQBKCQ+rdOxzVCKKl2sGpNaTJEYaHGT9KLCEGBLmvaB58RKgmGN6IIEwpxSm2SGoirfnQsr+DP+kaSvWPr6R/pZAhO1JzO+azaXvfr+hL2SMX6U7j5+SDmFGIFDwxok7ny1QUTQXKlNzA/ks9/vufe30hrTkph/MfEvM5mYVbfgAn5zZ0v+dJ2wCoe1go= notifications: diff --git a/descriptor.json b/build/bintray.json similarity index 100% rename from descriptor.json rename to build/bintray.json diff --git a/handlers/auth.go b/handlers/auth.go index e2b3e4f..fd8050f 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -1,12 +1,10 @@ package handlers import ( - "encoding/json" "fmt" - "io/ioutil" "net/http" - "time" + "github.com/maxibanki/golang-url-shortener/handlers/auth" "github.com/maxibanki/golang-url-shortener/util" "github.com/sirupsen/logrus" "github.com/spf13/viper" @@ -15,84 +13,47 @@ import ( "github.com/gin-gonic/contrib/sessions" "github.com/gin-gonic/gin" "github.com/pkg/errors" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" ) -type jwtClaims struct { - jwt.StandardClaims - OAuthProvider string - OAuthID string - OAuthName string - OAuthPicture string -} - -type oAuthUser struct { - Sub string `json:"sub"` - Name string `json:"name"` - Picture string `json:"picture"` -} - -type checkResponse struct { - ID string - Name string - Picture string - Provider string -} - func (h *Handler) initOAuth() { - h.oAuthConf = &oauth2.Config{ - ClientID: viper.GetString("oAuth.Google.ClientID"), - ClientSecret: viper.GetString("oAuth.Google.ClientSecret"), - RedirectURL: viper.GetString("http.BaseURL") + "/api/v1/callback", - Scopes: []string{ - "https://www.googleapis.com/auth/userinfo.email", - }, - Endpoint: google.Endpoint, - } h.engine.Use(sessions.Sessions("backend", sessions.NewCookieStore(util.GetPrivateKey()))) - h.engine.GET("/api/v1/login", h.handleGoogleRedirect) - h.engine.GET("/api/v1/callback", h.handleGoogleCallback) + + auth.WithAdapterWrapper(auth.NewGoogleAdapter(viper.GetString("oAuth.Google.ClientID"), viper.GetString("oAuth.Google.ClientSecret"), viper.GetString("http.BaseURL")), h.engine.Group("/api/v1/auth/google")) + h.engine.POST("/api/v1/check", h.handleGoogleCheck) } -func (h *Handler) handleGoogleRedirect(c *gin.Context) { - state := h.randToken() - session := sessions.Default(c) - session.Set("state", state) - session.Save() - c.Redirect(http.StatusTemporaryRedirect, h.oAuthConf.AuthCodeURL(state)) +func (h *Handler) parseJWT(wt string) (*auth.JWTClaims, error) { + token, err := jwt.ParseWithClaims(wt, &auth.JWTClaims{}, func(token *jwt.Token) (interface{}, error) { + return util.GetPrivateKey(), nil + }) + if err != nil { + return nil, fmt.Errorf("could not parse token: %v", err) + } + if !token.Valid { + return nil, errors.New("token is not valid") + } + return token.Claims.(*auth.JWTClaims), nil } func (h *Handler) authMiddleware(c *gin.Context) { authError := func() error { - authHeader := c.GetHeader("Authorization") - if authHeader == "" { + wt := c.GetHeader("Authorization") + if wt == "" { return errors.New("'Authorization' header not set") } - token, err := jwt.ParseWithClaims(authHeader, &jwtClaims{}, func(token *jwt.Token) (interface{}, error) { - return util.GetPrivateKey(), nil - }) + claims, err := h.parseJWT(wt) if err != nil { - return fmt.Errorf("could not parse token: %v", err) - } - if !token.Valid { - return errors.New("token is not valid") + return err } - c.Set("user", token.Claims) + c.Set("user", claims) return nil }() if authError != nil { - if viper.GetBool("General.EnableDebugMode") { - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ - "error": fmt.Sprintf("token is not valid: %v", authError), - }) - logrus.Debugf("Authentication middleware failed: %v\n", authError) - } else { - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ - "error": "authentication failed", - }) - } + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ + "error": "authentication failed", + }) + logrus.Debugf("Authentication middleware failed: %v\n", authError) return } c.Next() @@ -106,69 +67,15 @@ func (h *Handler) handleGoogleCheck(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - token, err := jwt.ParseWithClaims(data.Token, &jwtClaims{}, func(token *jwt.Token) (interface{}, error) { - return util.GetPrivateKey(), nil - }) - if claims, ok := token.Claims.(*jwtClaims); ok && token.Valid { - c.JSON(http.StatusOK, checkResponse{ - ID: claims.OAuthID, - Name: claims.OAuthName, - Picture: claims.OAuthPicture, - Provider: claims.OAuthProvider, - }) - } else { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - } -} - -func (h *Handler) handleGoogleCallback(c *gin.Context) { - session := sessions.Default(c) - retrievedState := session.Get("state") - if retrievedState != c.Query("state") { - c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("invalid session state: %s", retrievedState)}) - return - } - - oAuthToken, err := h.oAuthConf.Exchange(oauth2.NoContext, c.Query("code")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not exchange code: %v", err)}) - return - } - - client := h.oAuthConf.Client(oauth2.NoContext, oAuthToken) - oAuthUserInfoReq, err := client.Get("https://www.googleapis.com/oauth2/v3/userinfo") - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not get user data: %v", err)}) - return - } - defer oAuthUserInfoReq.Body.Close() - data, err := ioutil.ReadAll(oAuthUserInfoReq.Body) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("could not read body: %v", err)}) - return - } - var user oAuthUser - if err = json.Unmarshal(data, &user); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("decoding user info failed: %v", err)}) - return - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims{ - jwt.StandardClaims{ - ExpiresAt: time.Now().Add(time.Hour * 24 * 365).Unix(), - }, - "google", - user.Sub, - user.Name, - user.Picture, - }) - - tokenString, err := token.SignedString(util.GetPrivateKey()) + claims, err := h.parseJWT(data.Token) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("could not sign token: %v", err)}) + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) return } - c.HTML(http.StatusOK, "token.tmpl", gin.H{ - "token": tokenString, + c.JSON(http.StatusOK, gin.H{ + "ID": claims.OAuthID, + "Name": claims.OAuthName, + "Picture": claims.OAuthPicture, + "Provider": claims.OAuthProvider, }) } diff --git a/handlers/auth/auth.go b/handlers/auth/auth.go new file mode 100644 index 0000000..64d9a62 --- /dev/null +++ b/handlers/auth/auth.go @@ -0,0 +1,100 @@ +package auth + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "time" + + jwt "github.com/dgrijalva/jwt-go" + "github.com/gin-gonic/contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/maxibanki/golang-url-shortener/util" + "github.com/pkg/errors" +) + +// Adapter will be implemented by each oAuth provider +type Adapter interface { + GetRedirectURl(state string) string + GetUserData(state, code string) (*user, error) + GetOAuthProviderName() string +} + +type user struct { + ID, Name, Picture string +} + +type JWTClaims struct { + jwt.StandardClaims + OAuthProvider string + OAuthID string + OAuthName string + OAuthPicture string +} + +type AdapterWrapper struct{ Adapter } + +func WithAdapterWrapper(a Adapter, h *gin.RouterGroup) *AdapterWrapper { + aw := &AdapterWrapper{a} + h.GET("/login", aw.HandleLogin) + h.GET("/callback", aw.HandleCallback) + return aw +} + +func (a *AdapterWrapper) HandleLogin(c *gin.Context) { + state := a.randToken() + session := sessions.Default(c) + session.Set("state", state) + session.Save() + c.Redirect(http.StatusTemporaryRedirect, a.GetRedirectURl(state)) +} + +func (a *AdapterWrapper) HandleCallback(c *gin.Context) { + session := sessions.Default(c) + sessionState := session.Get("state") + receivedState := c.Query("state") + if sessionState != receivedState { + c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("invalid session state: %s", sessionState)}) + return + } + + user, err := a.GetUserData(receivedState, c.Query("code")) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + return + } + + token, err := a.newJWT(user, a.GetOAuthProviderName()) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.HTML(http.StatusOK, "token.tmpl", gin.H{ + "token": token, + }) +} + +func (a *AdapterWrapper) newJWT(user *user, provider string) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, JWTClaims{ + jwt.StandardClaims{ + ExpiresAt: time.Now().Add(time.Hour * 24 * 365).Unix(), + }, + provider, + user.ID, + user.Name, + user.Picture, + }) + tokenString, err := token.SignedString(util.GetPrivateKey()) + if err != nil { + return "", errors.Wrap(err, "could not sign token") + } + return tokenString, nil +} + +func (a *AdapterWrapper) randToken() string { + b := make([]byte, 32) + rand.Read(b) + return base64.StdEncoding.EncodeToString(b) +} diff --git a/handlers/auth/github.go b/handlers/auth/github.go new file mode 100644 index 0000000..9a72978 --- /dev/null +++ b/handlers/auth/github.go @@ -0,0 +1,2 @@ +package auth + diff --git a/handlers/auth/google.go b/handlers/auth/google.go new file mode 100644 index 0000000..90faeec --- /dev/null +++ b/handlers/auth/google.go @@ -0,0 +1,66 @@ +package auth + +import ( + "context" + "encoding/json" + "io/ioutil" + + "github.com/pkg/errors" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +type googleAdapter struct { + config *oauth2.Config +} + +func NewGoogleAdapter(clientID, clientSecret, baseURL string) Adapter { + return &googleAdapter{&oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: baseURL + "/api/v1/auth/google/callback", + Scopes: []string{ + "https://www.googleapis.com/auth/userinfo.email", + }, + Endpoint: google.Endpoint, + }} +} + +func (a *googleAdapter) GetRedirectURl(state string) string { + return a.config.AuthCodeURL(state) +} + +func (a *googleAdapter) GetUserData(state, code string) (*user, error) { + oAuthToken, err := a.config.Exchange(context.Background(), code) + if err != nil { + return nil, errors.Wrap(err, "could not exchange code") + } + + client := a.config.Client(context.Background(), oAuthToken) + oAuthUserInfoReq, err := client.Get("https://www.googleapis.com/oauth2/v3/userinfo") + if err != nil { + return nil, errors.Wrap(err, "could not get user data") + } + defer oAuthUserInfoReq.Body.Close() + data, err := ioutil.ReadAll(oAuthUserInfoReq.Body) + if err != nil { + return nil, errors.Wrap(err, "could not read body") + } + var gUser struct { + Sub string `json:"sub"` + Name string `json:"name"` + Picture string `json:"picture"` + } + if err = json.Unmarshal(data, &gUser); err != nil { + return nil, errors.Wrap(err, "decoding user info failed") + } + return &user{ + ID: gUser.Sub, + Name: gUser.Name, + Picture: gUser.Picture, + }, nil +} + +func (a *googleAdapter) GetOAuthProviderName() string { + return "google" +} diff --git a/handlers/auth_test.go b/handlers/auth_test.go index 958ee4c..1e9ef85 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -6,28 +6,23 @@ import ( "fmt" "net/http" "net/http/httptest" - "strings" "testing" "time" jwt "github.com/dgrijalva/jwt-go" + "github.com/gin-gonic/gin" + "github.com/maxibanki/golang-url-shortener/handlers/auth" "github.com/maxibanki/golang-url-shortener/store" "github.com/maxibanki/golang-url-shortener/util" "github.com/pkg/errors" "github.com/spf13/viper" - "golang.org/x/oauth2/google" -) - -const ( - testingDBName = "main.db" ) var ( secret []byte server *httptest.Server closeServer func() error - handler *Handler - testingClaimData = jwtClaims{ + testingClaimData = auth.JWTClaims{ jwt.StandardClaims{ ExpiresAt: time.Now().Add(time.Hour * 24 * 365).Unix(), }, @@ -64,25 +59,6 @@ func TestCreateBackend(t *testing.T) { } } -func TestHandleGoogleRedirect(t *testing.T) { - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, // don't follow redirects - } - resp, err := client.Get(server.URL + "/api/v1/login") - if err != nil { - t.Fatalf("could not get login request: %v", err) - } - if resp.StatusCode != http.StatusTemporaryRedirect { - t.Fatalf("expected status code: %d; got: %d", http.StatusTemporaryRedirect, resp.StatusCode) - } - location := resp.Header.Get("Location") - if !strings.HasPrefix(location, google.Endpoint.AuthURL) { - t.Fatalf("redirect is not correct, got: %s", location) - } -} - func TestCreateNewJWT(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, testingClaimData) var err error @@ -129,7 +105,7 @@ func TestCheckToken(t *testing.T) { if err != nil { t.Fatalf("could not execute get request: %v", err) } - var data checkResponse + var data gin.H if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { t.Fatalf("could not decode json: %v", err) } @@ -140,22 +116,22 @@ func TestCheckToken(t *testing.T) { }{ { name: "ID", - currentValue: data.ID, + currentValue: data["ID"].(string), expectedValue: testingClaimData.OAuthID, }, { name: "Name", - currentValue: data.Name, + currentValue: data["Name"].(string), expectedValue: testingClaimData.OAuthName, }, { name: "Picture", - currentValue: data.Picture, + currentValue: data["Picture"].(string), expectedValue: testingClaimData.OAuthPicture, }, { name: "Provider", - currentValue: data.Provider, + currentValue: data["Provider"].(string), expectedValue: testingClaimData.OAuthProvider, }, } diff --git a/handlers/handlers.go b/handlers/handlers.go index 94fd6f4..7aa1b91 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -15,15 +15,13 @@ import ( "github.com/maxibanki/golang-url-shortener/store" "github.com/maxibanki/golang-url-shortener/util" "github.com/pkg/errors" - "golang.org/x/oauth2" ) // Handler holds the funcs and attributes for the // http communication type Handler struct { - store store.Store - engine *gin.Engine - oAuthConf *oauth2.Config + store store.Store + engine *gin.Engine } // New initializes the http handlers diff --git a/handlers/public.go b/handlers/public.go index 2fa90b4..9545463 100644 --- a/handlers/public.go +++ b/handlers/public.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/maxibanki/golang-url-shortener/handlers/auth" "github.com/maxibanki/golang-url-shortener/store" ) @@ -28,7 +29,7 @@ func (h *Handler) handleLookup(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } - user := c.MustGet("user").(*jwtClaims) + user := c.MustGet("user").(*auth.JWTClaims) if entry.OAuthID != user.OAuthID || entry.OAuthProvider != user.OAuthProvider { c.JSON(http.StatusOK, store.Entry{ Public: store.EntryPublicData{ @@ -67,7 +68,7 @@ func (h *Handler) handleCreate(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - user := c.MustGet("user").(*jwtClaims) + user := c.MustGet("user").(*auth.JWTClaims) id, err := h.store.CreateEntry(store.Entry{ Public: store.EntryPublicData{ URL: data.URL, diff --git a/handlers/public_test.go b/handlers/public_test.go index 6138e2d..72fc268 100644 --- a/handlers/public_test.go +++ b/handlers/public_test.go @@ -24,7 +24,6 @@ func TestCreateEntry(t *testing.T) { name string ignoreResponse bool contentType string - authToken string response gin.H requestBody URLUtil statusCode int diff --git a/handlers/utils.go b/handlers/utils.go index edfafd9..c23fcf0 100644 --- a/handlers/utils.go +++ b/handlers/utils.go @@ -1,8 +1,6 @@ package handlers import ( - "crypto/rand" - "encoding/base64" "fmt" "github.com/gin-gonic/gin" @@ -15,9 +13,3 @@ func (h *Handler) getSchemaAndHost(c *gin.Context) string { } return fmt.Sprintf("%s://%s", protocol, c.Request.Host) } - -func (h *Handler) randToken() string { - b := make([]byte, 32) - rand.Read(b) - return base64.StdEncoding.EncodeToString(b) -} diff --git a/main.go b/main.go index a6b0869..8a62d58 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "log" "os" "os/signal" @@ -48,12 +47,12 @@ func initShortener() (func(), error) { } go func() { if err := handler.Listen(); err != nil { - log.Fatalf("could not listen to http handlers: %v", err) + logrus.Fatalf("could not listen to http handlers: %v", err) } }() return func() { if err = handler.CloseStore(); err != nil { - log.Printf("failed to stop the handlers: %v", err) + logrus.Printf("failed to stop the handlers: %v", err) } }, nil } diff --git a/static/src/index.js b/static/src/index.js index d494300..8f9d5b4 100644 --- a/static/src/index.js +++ b/static/src/index.js @@ -66,7 +66,7 @@ export default class BaseComponent extends Component { wHeight = 500; var wLeft = (window.screen.width / 2) - (wwidth / 2); var wTop = (window.screen.height / 2) - (wHeight / 2); - window.open('/api/v1/login', '', `width=${wwidth}, height=${wHeight}, top=${wTop}, left=${wLeft}`) + window.open('/api/v1/auth/google/login', '', `width=${wwidth}, height=${wHeight}, top=${wTop}, left=${wLeft}`) } handleLogout = () => { diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..056120e --- /dev/null +++ b/util/util.go @@ -0,0 +1,2 @@ +// Package util implements helper functions for the complete Golang URL Shortener app +package util