Browse Source

Added unit tested and increased so the code coverage: fix #9

dependabot/npm_and_yarn/web/prismjs-1.21.0
Max Schmitt 8 years ago
parent
commit
19141a78a0
  1. 1
      .gitignore
  2. 10
      handlers/public.go
  3. 38
      handlers/public_test.go
  4. 9
      handlers/test.yaml
  5. 21
      main_test.go
  6. 15
      store/store.go
  7. 59
      store/store_test.go
  8. 11
      store/util.go
  9. 15
      util/config_test.go
  10. 20
      util/private_test.go
  11. 5
      util/test.yaml

1
.gitignore

@ -13,6 +13,7 @@
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
.glide/ .glide/
debug debug
debug.test
*.db *.db
*.lock *.lock

10
handlers/public.go

@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -82,7 +83,7 @@ func (h *Handler) handleCreate(c *gin.Context) {
originURL := h.getURLOrigin(c) originURL := h.getURLOrigin(c)
c.JSON(http.StatusOK, urlUtil{ c.JSON(http.StatusOK, urlUtil{
URL: fmt.Sprintf("%s/%s", originURL, id), URL: fmt.Sprintf("%s/%s", originURL, id),
DeletionURL: fmt.Sprintf("%s/d/%s/%s", originURL, id, url.QueryEscape(delID)), DeletionURL: fmt.Sprintf("%s/d/%s/%s", originURL, id, url.QueryEscape(base64.RawURLEncoding.EncodeToString(delID))),
}) })
} }
@ -97,7 +98,12 @@ func (h *Handler) handleInfo(c *gin.Context) {
c.JSON(http.StatusOK, info) c.JSON(http.StatusOK, info)
} }
func (h *Handler) handleDelete(c *gin.Context) { func (h *Handler) handleDelete(c *gin.Context) {
if err := h.store.DeleteEntry(c.Param("id"), c.Param("hash")); err != nil { givenHmac, err := base64.RawURLEncoding.DecodeString(c.Param("hash"))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("could not decode base64: %v", err)})
return
}
if err := h.store.DeleteEntry(c.Param("id"), givenHmac); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }

38
handlers/public_test.go

@ -239,6 +239,44 @@ func testRedirect(t *testing.T, shortURL, longURL string) {
} }
} }
func TestHandleApplicationInfo(t *testing.T) {
resp, err := http.Get(server.URL + "/api/v1/info")
if err != nil {
t.Fatalf("could not get application info: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status %d; got %d", http.StatusOK, resp.StatusCode)
}
}
func TestHandleDeletion(t *testing.T) {
reqBody, err := json.Marshal(gin.H{
"URL": testURL,
})
if err != nil {
t.Fatalf("could not marshal json: %v", err)
}
respBody := createEntryWithJSON(t, reqBody, "application/json; charset=utf-8", http.StatusOK)
var body urlUtil
if err := json.Unmarshal(respBody, &body); err != nil {
t.Fatal("could not unmarshal create response")
}
resp, err := http.Get(body.DeletionURL)
if err != nil {
t.Fatalf("could not send deletion http request")
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status: %d; got: %d", resp.StatusCode, http.StatusOK)
}
resp, err = http.Get(body.URL)
if err != nil {
t.Fatalf("could not send visit request: %v", err)
}
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected status: %d; got: %d", http.StatusNotFound, resp.StatusCode)
}
}
func TestCloseB(t *testing.T) { func TestCloseB(t *testing.T) {
TestCloseBackend(t) TestCloseBackend(t)
} }

9
handlers/test.yaml

@ -1,3 +1,12 @@
data_dir: ../data data_dir: ../data
enable_debug_mode: true enable_debug_mode: true
shorted_id_length: 4 shorted_id_length: 4
Google:
ClientID: so
ClientSecret: secret
GitHub:
ClientID: so
ClientSecret: secret
Microsoft:
ClientID: so
ClientSecret: secret

21
main_test.go

@ -0,0 +1,21 @@
package main
import (
"net/http"
"testing"
"time"
"github.com/spf13/viper"
)
func TestInitShortener(t *testing.T) {
close, err := initShortener()
if err != nil {
t.Fatalf("could not init shortener: %v", err)
}
time.Sleep(1) // Give the http server a second to boot up
if err := http.ListenAndServe(viper.GetString("listen_addr"), nil); err == nil {
t.Fatal("port is not in use")
}
close()
}

15
store/store.go

@ -4,7 +4,6 @@ package store
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/sha512" "crypto/sha512"
"encoding/base64"
"encoding/json" "encoding/json"
"path/filepath" "path/filepath"
"time" "time"
@ -135,34 +134,30 @@ func (s *Store) GetEntryByIDRaw(id string) ([]byte, error) {
} }
// CreateEntry creates a new record and returns his short id // CreateEntry creates a new record and returns his short id
func (s *Store) CreateEntry(entry Entry, givenID string) (string, string, error) { func (s *Store) CreateEntry(entry Entry, givenID string) (string, []byte, error) {
if !govalidator.IsURL(entry.Public.URL) { if !govalidator.IsURL(entry.Public.URL) {
return "", "", ErrNoValidURL return "", nil, ErrNoValidURL
} }
// try it 10 times to make a short URL // try it 10 times to make a short URL
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
id, delID, err := s.createEntry(entry, givenID) id, delID, err := s.createEntry(entry, givenID)
if err != nil && givenID != "" { if err != nil && givenID != "" {
return "", "", err return "", nil, err
} else if err != nil { } else if err != nil {
logrus.Debugf("Could not create entry: %v", err) logrus.Debugf("Could not create entry: %v", err)
continue continue
} }
return id, delID, nil return id, delID, nil
} }
return "", "", ErrGeneratingIDFailed return "", nil, ErrGeneratingIDFailed
} }
// DeleteEntry deletes an Entry fully from the DB // DeleteEntry deletes an Entry fully from the DB
func (s *Store) DeleteEntry(id, hash string) error { func (s *Store) DeleteEntry(id string, givenHmac []byte) error {
mac := hmac.New(sha512.New, util.GetPrivateKey()) mac := hmac.New(sha512.New, util.GetPrivateKey())
if _, err := mac.Write([]byte(id)); err != nil { if _, err := mac.Write([]byte(id)); err != nil {
return errors.Wrap(err, "could not write hmac") return errors.Wrap(err, "could not write hmac")
} }
givenHmac, err := base64.RawURLEncoding.DecodeString(hash)
if err != nil {
return errors.Wrap(err, "could not decode base64")
}
if !hmac.Equal(mac.Sum(nil), givenHmac) { if !hmac.Equal(mac.Sum(nil), givenHmac) {
return errors.New("hmac verification failed") return errors.New("hmac verification failed")
} }

59
store/store_test.go

@ -120,6 +120,65 @@ func TestIncreaseVisitCounter(t *testing.T) {
} }
} }
func TestDelete(t *testing.T) {
viper.Set("shorted_id_length", 4)
store, err := New()
if err != nil {
t.Fatalf("could not create store: %v", err)
}
defer cleanup(store)
entryID, delHMac, err := store.CreateEntry(Entry{
Public: EntryPublicData{
URL: "https://golang.org/",
},
}, "")
if err != nil {
t.Fatalf("could not create entry: %v", err)
}
if err := store.DeleteEntry(entryID, delHMac); err != nil {
t.Fatalf("could not delete entry: %v", err)
}
if _, err := store.GetEntryByID(entryID); err != ErrNoEntryFound {
t.Fatalf("unexpected error: %v", err)
}
}
func TestGetURLAndIncrease(t *testing.T) {
viper.Set("shorted_id_length", 4)
store, err := New()
if err != nil {
t.Fatalf("could not create store: %v", err)
}
defer cleanup(store)
const url = "https://golang.org/"
entryID, _, err := store.CreateEntry(Entry{
Public: EntryPublicData{
URL: url,
},
}, "")
if err != nil {
t.Fatalf("could not create entry: %v", err)
}
entryOne, err := store.GetEntryByID(entryID)
if err != nil {
t.Fatalf("could not get entry: %v", err)
}
entryURL, err := store.GetURLAndIncrease(entryID)
if err != nil {
t.Fatalf("could not get URL and increase the visitor counter: %v", err)
}
if entryURL != url {
t.Fatalf("url is not the expected one")
}
entryTwo, err := store.GetEntryByID(entryID)
if err != nil {
t.Fatalf("could not get entry: %v", err)
}
if entryOne.Public.VisitCount+1 != entryTwo.Public.VisitCount {
t.Fatalf("visitor count does not increase")
}
}
func cleanup(s *Store) { func cleanup(s *Store) {
s.Close() s.Close()
os.Remove(testingDBName) os.Remove(testingDBName)

11
store/util.go

@ -4,7 +4,6 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/sha512" "crypto/sha512"
"encoding/base64"
"encoding/json" "encoding/json"
"math/big" "math/big"
"time" "time"
@ -31,23 +30,23 @@ func (s *Store) createEntryRaw(key, value []byte) error {
// createEntry creates a new entry with a randomly generated id. If on is present // createEntry creates a new entry with a randomly generated id. If on is present
// then the given ID is used // then the given ID is used
func (s *Store) createEntry(entry Entry, entryID string) (string, string, error) { func (s *Store) createEntry(entry Entry, entryID string) (string, []byte, error) {
var err error var err error
if entryID == "" { if entryID == "" {
if entryID, err = generateRandomString(s.idLength); err != nil { if entryID, err = generateRandomString(s.idLength); err != nil {
return "", "", errors.Wrap(err, "could not generate random string") return "", nil, errors.Wrap(err, "could not generate random string")
} }
} }
entry.Public.CreatedOn = time.Now() entry.Public.CreatedOn = time.Now()
rawEntry, err := json.Marshal(entry) rawEntry, err := json.Marshal(entry)
if err != nil { if err != nil {
return "", "", err return "", nil, err
} }
mac := hmac.New(sha512.New, util.GetPrivateKey()) mac := hmac.New(sha512.New, util.GetPrivateKey())
if _, err := mac.Write([]byte(entryID)); err != nil { if _, err := mac.Write([]byte(entryID)); err != nil {
return "", "", errors.Wrap(err, "could not write hmac") return "", nil, errors.Wrap(err, "could not write hmac")
} }
return entryID, base64.RawURLEncoding.EncodeToString(mac.Sum(nil)), s.createEntryRaw([]byte(entryID), rawEntry) return entryID, mac.Sum(nil), s.createEntryRaw([]byte(entryID), rawEntry)
} }
// generateRandomString generates a random string with an predefined length // generateRandomString generates a random string with an predefined length

15
util/config_test.go

@ -0,0 +1,15 @@
package util
import (
"testing"
"github.com/spf13/viper"
)
func TestReadInConfig(t *testing.T) {
DoNotSetConfigName = true
viper.SetConfigFile("test.yaml")
if err := ReadInConfig(); err != nil {
t.Fatalf("could not read in config file: %v", err)
}
}

20
util/private_test.go

@ -0,0 +1,20 @@
package util
import (
"os"
"testing"
)
func TestCheckforPrivateKey(t *testing.T) {
TestReadInConfig(t)
privateKey = nil
if err := CheckForPrivateKey(); err != nil {
t.Fatalf("could not check for private key: %v", err)
}
if GetPrivateKey() == nil {
t.Fatalf("private key is nil")
}
if err := os.RemoveAll(GetDataDir()); err != nil {
t.Fatalf("could not remove data dir: %v", err)
}
}

5
util/test.yaml

@ -0,0 +1,5 @@
listen_addr: ':8080'
base_url: 'http://localhost:3000'
data_dir: ./data
enable_debug_mode: true
shorted_id_length: 4
Loading…
Cancel
Save