diff --git a/handlers/auth.go b/handlers/auth.go index 574914e..33e0a7d 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -20,7 +20,6 @@ type jwtClaims struct { OAuthID string OAuthName string OAuthPicture string - OAuthEmail string } type oAuthUser struct { @@ -36,6 +35,13 @@ type oAuthUser struct { Hd string `json:"hd"` } +type checkResponse struct { + ID string + Name string + Picture string + Provider string +} + func (h *Handler) initOAuth() { h.oAuthConf = &oauth2.Config{ ClientID: h.config.OAuth.Google.ClientID, @@ -64,7 +70,7 @@ func (h *Handler) authMiddleware(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ - "error": "'authorization' header not set", + "error": "'Authorization' header not set", }) return } @@ -98,12 +104,11 @@ func (h *Handler) handleGoogleCheck(c *gin.Context) { return h.config.Secret, nil }) if claims, ok := token.Claims.(*jwtClaims); ok && token.Valid { - c.JSON(http.StatusOK, gin.H{ - "ID": claims.OAuthID, - "Email": claims.OAuthEmail, - "Name": claims.OAuthName, - "Picture": claims.OAuthPicture, - "Provider": claims.OAuthProvider, + c.JSON(http.StatusOK, checkResponse{ + ID: claims.OAuthID, + Name: claims.OAuthName, + Picture: claims.OAuthPicture, + Provider: claims.OAuthProvider, }) } else { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -150,7 +155,6 @@ func (h *Handler) handleGoogleCallback(c *gin.Context) { user.Sub, user.Name, user.Picture, - user.Email, }) tokenString, err := token.SignedString(h.config.Secret) diff --git a/handlers/auth_test.go b/handlers/auth_test.go new file mode 100644 index 0000000..22a517c --- /dev/null +++ b/handlers/auth_test.go @@ -0,0 +1,154 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + jwt "github.com/dgrijalva/jwt-go" + "github.com/maxibanki/golang-url-shortener/config" + "github.com/maxibanki/golang-url-shortener/store" + "github.com/pkg/errors" + "golang.org/x/oauth2/google" +) + +const ( + testingDBName = "main.db" +) + +var ( + secret = []byte("our really great secret") + server *httptest.Server + closeServer func() error + handler *Handler + testingClaimData = jwtClaims{ + jwt.StandardClaims{ + ExpiresAt: time.Now().Add(time.Hour * 24 * 365).Unix(), + }, + "google", + "sub sub sub", + "name", + "url", + } + tokenString string +) + +func TestCreateBackend(t *testing.T) { + store, err := store.New(config.Store{ + DBPath: testingDBName, + ShortedIDLength: 4, + }) + if err != nil { + t.Fatalf("could not create store: %v", err) + } + handler, err := New(config.Handlers{ + ListenAddr: ":8080", + Secret: secret, + BaseURL: "http://127.0.0.1", + }, *store) + if err != nil { + t.Fatalf("could not create handler: %v", err) + } + handler.DoNotCheckConfigViaGet = true + server = httptest.NewServer(handler.engine) + closeServer = func() error { + server.Close() + if err := handler.CloseStore(); err != nil { + return errors.Wrap(err, "could not close store") + } + if err := os.Remove(testingDBName); err != nil { + return errors.Wrap(err, "could not remove testing db") + } + return nil + } +} + +func TestHandleGoogleRedirect(t *testing.T) { + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, // don't follow redirects + } + resp, err := client.Get(server.URL + "/api/v1/login") + if err != nil { + t.Fatalf("could not get login request: %v", err) + } + if resp.StatusCode != http.StatusTemporaryRedirect { + t.Fatalf("expected status code: %d; got: %d", http.StatusTemporaryRedirect, resp.StatusCode) + } + location := resp.Header.Get("Location") + if !strings.HasPrefix(location, google.Endpoint.AuthURL) { + t.Fatalf("redirect is not correct, got: %s", location) + } +} + +func TestCreateNewJWT(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, testingClaimData) + var err error + tokenString, err = token.SignedString(secret) + if err != nil { + t.Fatalf("could not sign token: %v", err) + } +} + +func TestCheckToken(t *testing.T) { + body, err := json.Marshal(map[string]string{ + "Token": tokenString, + }) + if err != nil { + t.Fatalf("could not post to the backend: %v", err) + } + resp, err := http.Post(server.URL+"/api/v1/check", "application/json", bytes.NewBuffer(body)) + if err != nil { + t.Fatalf("could not execute get request: %v", err) + } + var data checkResponse + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + t.Fatalf("could not decode json: %v", err) + } + tt := []struct { + name string + currentValue string + expectedValue string + }{ + { + name: "ID", + currentValue: data.ID, + expectedValue: testingClaimData.OAuthID, + }, + { + name: "Name", + currentValue: data.Name, + expectedValue: testingClaimData.OAuthName, + }, + { + name: "Picture", + currentValue: data.Picture, + expectedValue: testingClaimData.OAuthPicture, + }, + { + name: "Provider", + currentValue: data.Provider, + expectedValue: testingClaimData.OAuthProvider, + }, + } + for _, tc := range tt { + t.Run(fmt.Sprintf("Checking: %s", tc.name), func(t *testing.T) { + if tc.currentValue != tc.expectedValue { + t.Fatalf("incorrect jwt value: %s; expected: %s", tc.expectedValue, tc.currentValue) + } + }) + } + +} +func TestCloseBackend(t *testing.T) { + if err := closeServer(); err != nil { + t.Fatalf("could not close server: %v", err) + } +} diff --git a/handlers/handlers.go b/handlers/handlers.go index d21ac6d..5f4c9e5 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -1,4 +1,4 @@ -// Package handlers provides the http functionality +// Package handlers provides the http functionality for the URL Shortener //go:generate esc -o static.go -pkg handlers -prefix ../static/build ../static/build //go:generate esc -o tmpls/tmpls.go -pkg tmpls -include ^*\.tmpl -prefix tmpls tmpls package handlers diff --git a/handlers/handlers_test.go b/handlers/handlers_test.go index c83510c..cc1931d 100644 --- a/handlers/handlers_test.go +++ b/handlers/handlers_test.go @@ -5,30 +5,30 @@ import ( "encoding/json" "io/ioutil" "net/http" - "net/http/httptest" "net/url" - "os" "strings" "testing" "github.com/gin-gonic/gin" - "github.com/maxibanki/golang-url-shortener/config" "github.com/maxibanki/golang-url-shortener/store" - "github.com/pkg/errors" ) const ( - testingDBName = "main.db" - testURL = "https://www.google.de/" + testURL = "https://www.google.de/" ) -var server *httptest.Server +// var server *httptest.Server + +func TestCreateB(t *testing.T) { + TestCreateBackend(t) +} func TestCreateEntry(t *testing.T) { tt := []struct { name string ignoreResponse bool contentType string + authToken string response gin.H requestBody URLUtil statusCode int @@ -59,11 +59,6 @@ func TestCreateEntry(t *testing.T) { ignoreResponse: true, }, } - cleanup, err := getBackend() - if err != nil { - t.Fatalf("could not create backend: %v", err) - } - defer cleanup() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { var reqBody []byte @@ -87,7 +82,7 @@ func TestCreateEntry(t *testing.T) { return } var parsed URLUtil - if err = json.Unmarshal(respBody, &parsed); err != nil { + if err := json.Unmarshal(respBody, &parsed); err != nil { t.Fatalf("could not unmarshal data: %v", err) } t.Run("test if shorted URL is correct", func(t *testing.T) { @@ -98,12 +93,6 @@ func TestCreateEntry(t *testing.T) { } func TestHandleInfo(t *testing.T) { - cleanup, err := getBackend() - if err != nil { - t.Fatalf("could not create backend: %v", err) - } - defer cleanup() - t.Run("check existing entry", func(t *testing.T) { reqBody, err := json.Marshal(store.Entry{ URL: testURL, @@ -124,7 +113,13 @@ func TestHandleInfo(t *testing.T) { if err != nil { t.Fatalf("could not marshal the body: %v", err) } - resp, err := http.Post(server.URL+"/api/v1/info", "application/json; charset=utf-8", bytes.NewBuffer(body)) + req, err := http.NewRequest("POST", server.URL+"/api/v1/protected/info", bytes.NewBuffer(body)) + if err != nil { + t.Fatalf("could not create request %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", tokenString) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("could not post to the backend: %v", err) } @@ -140,7 +135,13 @@ func TestHandleInfo(t *testing.T) { } }) t.Run("invalid body", func(t *testing.T) { - resp, err := http.Post(server.URL+"/api/v1/info", "appplication/json", bytes.NewBuffer(nil)) + req, err := http.NewRequest("POST", server.URL+"/api/v1/protected/info", bytes.NewBuffer(nil)) + if err != nil { + t.Fatalf("could not create request %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", tokenString) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("could not post to the backend: %v", err) } @@ -153,17 +154,20 @@ func TestHandleInfo(t *testing.T) { } body = bytes.TrimSpace(body) raw := makeJSON(t, gin.H{ - "error": "Key: '.ID' Error:Field validation for 'ID' failed on the 'required' tag", + "error": "EOF", }) if string(body) != raw { t.Fatalf("body is not the expected one: %s", body) } }) t.Run("no ID provided", func(t *testing.T) { + req, err := http.NewRequest("POST", server.URL+"/api/v1/protected/info", bytes.NewBufferString("{}")) if err != nil { - t.Fatalf("could not marshal the body: %v", err) + t.Fatalf("could not create request %v", err) } - resp, err := http.Post(server.URL+"/api/v1/info", "appplication/json", bytes.NewBufferString("{}")) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", tokenString) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("could not post to the backend: %v", err) } @@ -193,9 +197,15 @@ func makeJSON(t *testing.T, data interface{}) string { } func createEntryWithJSON(t *testing.T, reqBody []byte, contentType string, statusCode int) []byte { - resp, err := http.Post(server.URL+"/api/v1/create", "application/json", bytes.NewBuffer(reqBody)) + req, err := http.NewRequest("POST", server.URL+"/api/v1/protected/create", bytes.NewBuffer(reqBody)) + if err != nil { + t.Fatalf("could not create request %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", tokenString) + resp, err := http.DefaultClient.Do(req) if err != nil { - t.Fatalf("could not post to backend %v", err) + t.Fatalf("could not do request: %v", err) } respBody, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -234,27 +244,6 @@ func testRedirect(t *testing.T, shortURL, longURL string) { } } -func getBackend() (func(), error) { - store, err := store.New(config.Store{ - DBPath: testingDBName, - ShortedIDLength: 4, - }) - if err != nil { - return nil, errors.Wrap(err, "could not create store") - } - handler, err := New(config.Handlers{ - ListenAddr: ":8080", - Secret: []byte("our really great secret"), - BaseURL: "http://127.0.0.1", - }, *store) - if err != nil { - return nil, errors.Wrap(err, "could not create handler") - } - handler.DoNotCheckConfigViaGet = true - server = httptest.NewServer(handler.engine) - return func() { - server.Close() - handler.CloseStore() - os.Remove(testingDBName) - }, nil +func TestCloseB(t *testing.T) { + TestCloseBackend(t) } diff --git a/handlers/public_test.go b/handlers/public_test.go new file mode 100644 index 0000000..d2e1d98 --- /dev/null +++ b/handlers/public_test.go @@ -0,0 +1,2 @@ +package handlers +