You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
181 lines
4.6 KiB
181 lines
4.6 KiB
package handlers
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
jwt "github.com/dgrijalva/jwt-go"
|
|
"github.com/maxibanki/golang-url-shortener/store"
|
|
"github.com/maxibanki/golang-url-shortener/util"
|
|
"github.com/pkg/errors"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/spf13/viper"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
const (
|
|
testingDBName = "main.db"
|
|
)
|
|
|
|
var (
|
|
secret []byte
|
|
server *httptest.Server
|
|
closeServer func() error
|
|
handler *Handler
|
|
testingClaimData = jwtClaims{
|
|
jwt.StandardClaims{
|
|
ExpiresAt: time.Now().Add(time.Hour * 24 * 365).Unix(),
|
|
},
|
|
"google",
|
|
"sub sub sub",
|
|
"name",
|
|
"url",
|
|
}
|
|
tokenString string
|
|
)
|
|
|
|
func TestCreateBackend(t *testing.T) {
|
|
secret = util.GetPrivateKey()
|
|
viper.SetConfigName("config")
|
|
viper.AddConfigPath("../")
|
|
util.SetConfigDefaults()
|
|
err := viper.ReadInConfig()
|
|
if err != nil {
|
|
t.Fatalf("could not reload config file: %v", err)
|
|
}
|
|
if err := util.CheckForDatadir(); err != nil {
|
|
t.Fatalf("could not reload config file: %v", err)
|
|
}
|
|
store, err := store.New(logrus.New())
|
|
if err != nil {
|
|
t.Fatalf("could not create store: %v", err)
|
|
}
|
|
handler, err := New(*store, logrus.New(), true)
|
|
if err != nil {
|
|
t.Fatalf("could not create handler: %v", err)
|
|
}
|
|
server = httptest.NewServer(handler.engine)
|
|
closeServer = func() error {
|
|
server.Close()
|
|
if err := handler.CloseStore(); err != nil {
|
|
return errors.Wrap(err, "could not close store")
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func TestHandleGoogleRedirect(t *testing.T) {
|
|
client := &http.Client{
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
}, // don't follow redirects
|
|
}
|
|
resp, err := client.Get(server.URL + "/api/v1/login")
|
|
if err != nil {
|
|
t.Fatalf("could not get login request: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusTemporaryRedirect {
|
|
t.Fatalf("expected status code: %d; got: %d", http.StatusTemporaryRedirect, resp.StatusCode)
|
|
}
|
|
location := resp.Header.Get("Location")
|
|
if !strings.HasPrefix(location, google.Endpoint.AuthURL) {
|
|
t.Fatalf("redirect is not correct, got: %s", location)
|
|
}
|
|
}
|
|
|
|
func TestCreateNewJWT(t *testing.T) {
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, testingClaimData)
|
|
var err error
|
|
tokenString, err = token.SignedString(secret)
|
|
if err != nil {
|
|
t.Fatalf("could not sign token: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestForbiddenReqest(t *testing.T) {
|
|
resp, err := http.Post(server.URL+"/api/v1/protected/create", "application/json", nil)
|
|
if err != nil {
|
|
t.Fatalf("could not execute get request: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("incorrect status code: %d; got: %d", resp.StatusCode, http.StatusForbidden)
|
|
}
|
|
}
|
|
|
|
func TestInvalidToken(t *testing.T) {
|
|
req, err := http.NewRequest("POST", server.URL+"/api/v1/protected/create", nil)
|
|
if err != nil {
|
|
t.Fatalf("could not create request %v", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "incorrect one")
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("could not execute post request: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("incorrect status code: %d; got: %d", resp.StatusCode, http.StatusForbidden)
|
|
}
|
|
}
|
|
|
|
func TestCheckToken(t *testing.T) {
|
|
body, err := json.Marshal(map[string]string{
|
|
"Token": tokenString,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("could not post to the backend: %v", err)
|
|
}
|
|
resp, err := http.Post(server.URL+"/api/v1/check", "application/json", bytes.NewBuffer(body))
|
|
if err != nil {
|
|
t.Fatalf("could not execute get request: %v", err)
|
|
}
|
|
var data checkResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
|
t.Fatalf("could not decode json: %v", err)
|
|
}
|
|
tt := []struct {
|
|
name string
|
|
currentValue string
|
|
expectedValue string
|
|
}{
|
|
{
|
|
name: "ID",
|
|
currentValue: data.ID,
|
|
expectedValue: testingClaimData.OAuthID,
|
|
},
|
|
{
|
|
name: "Name",
|
|
currentValue: data.Name,
|
|
expectedValue: testingClaimData.OAuthName,
|
|
},
|
|
{
|
|
name: "Picture",
|
|
currentValue: data.Picture,
|
|
expectedValue: testingClaimData.OAuthPicture,
|
|
},
|
|
{
|
|
name: "Provider",
|
|
currentValue: data.Provider,
|
|
expectedValue: testingClaimData.OAuthProvider,
|
|
},
|
|
}
|
|
for _, tc := range tt {
|
|
t.Run(fmt.Sprintf("Checking: %s", tc.name), func(t *testing.T) {
|
|
if tc.currentValue != tc.expectedValue {
|
|
t.Fatalf("incorrect jwt value: %s; expected: %s", tc.expectedValue, tc.currentValue)
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|
|
func TestCloseBackend(t *testing.T) {
|
|
if err := closeServer(); err != nil {
|
|
t.Fatalf("could not close server: %v", err)
|
|
}
|
|
}
|
|
|