Browse Source

refactored main code

dependabot/npm_and_yarn/web/prismjs-1.21.0
Max Schmitt 8 years ago
parent
commit
7b75be21e1
  1. 3
      handlers/auth.go
  2. 3
      handlers/auth_test.go
  3. 7
      handlers/handlers.go
  4. 15
      handlers/public.go
  5. 10
      handlers/public_test.go
  6. 15
      handlers/utils.go
  7. 2
      main.go
  8. 11
      store/store.go
  9. 23
      store/util.go
  10. 10
      util/config.go
  11. 18
      util/private.go

3
handlers/auth.go

@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"fmt"
"net/http" "net/http"
"github.com/maxibanki/golang-url-shortener/handlers/auth" "github.com/maxibanki/golang-url-shortener/handlers/auth"
@ -28,7 +27,7 @@ func (h *Handler) parseJWT(wt string) (*auth.JWTClaims, error) {
return util.GetPrivateKey(), nil return util.GetPrivateKey(), nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("could not parse token: %v", err) return nil, errors.Wrap(err, "could not parse token")
} }
if !token.Valid { if !token.Valid {
return nil, errors.New("token is not valid") return nil, errors.New("token is not valid")

3
handlers/auth_test.go

@ -45,7 +45,8 @@ func TestCreateBackend(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("could not create store: %v", err) t.Fatalf("could not create store: %v", err)
} }
handler, err := New(*store, true) DoNotPrivateKeyChecking = true
handler, err := New(*store)
if err != nil { if err != nil {
t.Fatalf("could not create handler: %v", err) t.Fatalf("could not create handler: %v", err)
} }

7
handlers/handlers.go

@ -24,8 +24,11 @@ type Handler struct {
engine *gin.Engine engine *gin.Engine
} }
// DoNotPrivateKeyChecking is used for testing
var DoNotPrivateKeyChecking = false
// New initializes the http handlers // New initializes the http handlers
func New(store store.Store, testing bool) (*Handler, error) { func New(store store.Store) (*Handler, error) {
if !viper.GetBool("General.EnableDebugMode") { if !viper.GetBool("General.EnableDebugMode") {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
@ -36,7 +39,7 @@ func New(store store.Store, testing bool) (*Handler, error) {
if err := h.setHandlers(); err != nil { if err := h.setHandlers(); err != nil {
return nil, errors.Wrap(err, "could not set handlers") return nil, errors.Wrap(err, "could not set handlers")
} }
if !testing { if !DoNotPrivateKeyChecking {
if err := util.CheckForPrivateKey(); err != nil { if err := util.CheckForPrivateKey(); err != nil {
return nil, errors.Wrap(err, "could not check for privat key") return nil, errors.Wrap(err, "could not check for privat key")
} }

15
handlers/public.go

@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"fmt"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -8,9 +9,9 @@ import (
"github.com/maxibanki/golang-url-shortener/store" "github.com/maxibanki/golang-url-shortener/store"
) )
// URLUtil is used to help in- and outgoing requests for json // urlUtil is used to help in- and outgoing requests for json
// un- and marshalling // un- and marshalling
type URLUtil struct { type urlUtil struct {
URL string `binding:"required"` URL string `binding:"required"`
ID string ID string
} }
@ -63,7 +64,7 @@ func (h *Handler) handleAccess(c *gin.Context) {
// handleCreate handles requests to create an entry // handleCreate handles requests to create an entry
func (h *Handler) handleCreate(c *gin.Context) { func (h *Handler) handleCreate(c *gin.Context) {
var data URLUtil var data urlUtil
if err := c.ShouldBind(&data); err != nil { if err := c.ShouldBind(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@ -84,3 +85,11 @@ func (h *Handler) handleCreate(c *gin.Context) {
data.URL = h.getSchemaAndHost(c) + "/" + id data.URL = h.getSchemaAndHost(c) + "/" + id
c.JSON(http.StatusOK, data) c.JSON(http.StatusOK, data)
} }
func (h *Handler) getSchemaAndHost(c *gin.Context) string {
protocol := "http"
if c.Request.TLS != nil {
protocol = "https"
}
return fmt.Sprintf("%s://%s", protocol, c.Request.Host)
}

10
handlers/public_test.go

@ -25,7 +25,7 @@ func TestCreateEntry(t *testing.T) {
ignoreResponse bool ignoreResponse bool
contentType string contentType string
response gin.H response gin.H
requestBody URLUtil requestBody urlUtil
statusCode int statusCode int
}{ }{
{ {
@ -37,7 +37,7 @@ func TestCreateEntry(t *testing.T) {
}, },
{ {
name: "short URL generation", name: "short URL generation",
requestBody: URLUtil{ requestBody: urlUtil{
URL: "https://www.google.de/", URL: "https://www.google.de/",
}, },
statusCode: http.StatusOK, statusCode: http.StatusOK,
@ -45,7 +45,7 @@ func TestCreateEntry(t *testing.T) {
}, },
{ {
name: "no valid URL", name: "no valid URL",
requestBody: URLUtil{ requestBody: urlUtil{
URL: "this is really not a URL", URL: "this is really not a URL",
}, },
statusCode: http.StatusBadRequest, statusCode: http.StatusBadRequest,
@ -76,7 +76,7 @@ func TestCreateEntry(t *testing.T) {
if tc.ignoreResponse { if tc.ignoreResponse {
return return
} }
var parsed URLUtil var parsed urlUtil
if err := json.Unmarshal(respBody, &parsed); err != nil { if err := json.Unmarshal(respBody, &parsed); err != nil {
t.Fatalf("could not unmarshal data: %v", err) t.Fatalf("could not unmarshal data: %v", err)
} }
@ -96,7 +96,7 @@ func TestHandleInfo(t *testing.T) {
t.Fatalf("could not marshal json: %v", err) t.Fatalf("could not marshal json: %v", err)
} }
respBody := createEntryWithJSON(t, reqBody, "application/json; charset=utf-8", http.StatusOK) respBody := createEntryWithJSON(t, reqBody, "application/json; charset=utf-8", http.StatusOK)
var parsed URLUtil var parsed urlUtil
if err = json.Unmarshal(respBody, &parsed); err != nil { if err = json.Unmarshal(respBody, &parsed); err != nil {
t.Fatalf("could not unmarshal data: %v", err) t.Fatalf("could not unmarshal data: %v", err)
} }

15
handlers/utils.go

@ -1,15 +0,0 @@
package handlers
import (
"fmt"
"github.com/gin-gonic/gin"
)
func (h *Handler) getSchemaAndHost(c *gin.Context) string {
protocol := "http"
if c.Request.TLS != nil {
protocol = "https"
}
return fmt.Sprintf("%s://%s", protocol, c.Request.Host)
}

2
main.go

@ -41,7 +41,7 @@ func initShortener() (func(), error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not create store") return nil, errors.Wrap(err, "could not create store")
} }
handler, err := handlers.New(*store, false) handler, err := handlers.New(*store)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not create handlers") return nil, errors.Wrap(err, "could not create handlers")
} }

11
store/store.go

@ -29,6 +29,7 @@ type Entry struct {
Public EntryPublicData Public EntryPublicData
} }
// EntryPublicData is the public part of an entry
type EntryPublicData struct { type EntryPublicData struct {
CreatedOn, LastVisit time.Time CreatedOn, LastVisit time.Time
VisitCount int VisitCount int
@ -73,12 +74,12 @@ func (s *Store) GetEntryByID(id string) (*Entry, error) {
if id == "" { if id == "" {
return nil, ErrIDIsEmpty return nil, ErrIDIsEmpty
} }
raw, err := s.GetEntryByIDRaw(id) rawEntry, err := s.GetEntryByIDRaw(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var entry *Entry var entry *Entry
return entry, json.Unmarshal(raw, &entry) return entry, json.Unmarshal(rawEntry, &entry)
} }
// IncreaseVisitCounter increments the visit counter of an entry // IncreaseVisitCounter increments the visit counter of an entry
@ -93,26 +94,24 @@ func (s *Store) IncreaseVisitCounter(id string) error {
if err != nil { if err != nil {
return err return err
} }
err = s.db.Update(func(tx *bolt.Tx) error { return s.db.Update(func(tx *bolt.Tx) error {
if err := tx.Bucket(s.bucketName).Put([]byte(id), raw); err != nil { if err := tx.Bucket(s.bucketName).Put([]byte(id), raw); err != nil {
return errors.Wrap(err, "could not put updated visitor count JSON into the bucket") return errors.Wrap(err, "could not put updated visitor count JSON into the bucket")
} }
return nil return nil
}) })
return err
} }
// GetEntryByIDRaw returns the raw data (JSON) of a data set // GetEntryByIDRaw returns the raw data (JSON) of a data set
func (s *Store) GetEntryByIDRaw(id string) ([]byte, error) { func (s *Store) GetEntryByIDRaw(id string) ([]byte, error) {
var raw []byte var raw []byte
err := s.db.View(func(tx *bolt.Tx) error { return raw, s.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(s.bucketName).Get([]byte(id)) raw = tx.Bucket(s.bucketName).Get([]byte(id))
if raw == nil { if raw == nil {
return ErrNoEntryFound return ErrNoEntryFound
} }
return nil return nil
}) })
return raw, err
} }
// CreateEntry creates a new record and returns his short id // CreateEntry creates a new record and returns his short id

23
store/util.go

@ -13,10 +13,9 @@ import (
// createEntryRaw creates a entry with the given key value pair // createEntryRaw creates a entry with the given key value pair
func (s *Store) createEntryRaw(key, value []byte) error { func (s *Store) createEntryRaw(key, value []byte) error {
err := s.db.Update(func(tx *bolt.Tx) error { return s.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket(s.bucketName) bucket := tx.Bucket(s.bucketName)
raw := bucket.Get(key) if raw := bucket.Get(key); raw != nil {
if raw != nil {
return errors.New("entry already exists") return errors.New("entry already exists")
} }
if err := bucket.Put(key, value); err != nil { if err := bucket.Put(key, value); err != nil {
@ -24,27 +23,23 @@ func (s *Store) createEntryRaw(key, value []byte) error {
} }
return nil return nil
}) })
return err
} }
// createEntry creates a new entry // createEntry creates a new entry with a randomly generated id. If on is present
func (s *Store) createEntry(entry Entry, givenID string) (string, error) { // then the given ID is used
var id string func (s *Store) createEntry(entry Entry, entryID string) (string, error) {
var err error var err error
if givenID != "" { if entryID == "" {
id = givenID if entryID, err = generateRandomString(s.idLength); err != nil {
} else {
id, err = generateRandomString(s.idLength)
if err != nil {
return "", errors.Wrap(err, "could not generate random string") return "", errors.Wrap(err, "could not generate random string")
} }
} }
entry.Public.CreatedOn = time.Now() entry.Public.CreatedOn = time.Now()
raw, err := json.Marshal(entry) rawEntry, err := json.Marshal(entry)
if err != nil { if err != nil {
return "", err return "", err
} }
return id, s.createEntryRaw([]byte(id), raw) return entryID, s.createEntryRaw([]byte(entryID), rawEntry)
} }
// generateRandomString generates a random string with an predefined length // generateRandomString generates a random string with an predefined length

10
util/config.go

@ -11,8 +11,8 @@ import (
var ( var (
dataDirPath string dataDirPath string
// DoNotSetConfigName is used to predefine if the ConfigName should be set. // DoNotSetConfigName is used to predefine if the name of the config should be set.
// used for the unit testing reason // Used for the unit testing
DoNotSetConfigName = false DoNotSetConfigName = false
) )
@ -26,8 +26,7 @@ func ReadInConfig() error {
} }
viper.AddConfigPath(".") viper.AddConfigPath(".")
setConfigDefaults() setConfigDefaults()
err := viper.ReadInConfig() if err := viper.ReadInConfig(); err != nil {
if err != nil {
return errors.Wrap(err, "could not reload config file") return errors.Wrap(err, "could not reload config file")
} }
return checkForDatadir() return checkForDatadir()
@ -56,8 +55,7 @@ func checkForDatadir() error {
return errors.Wrap(err, "could not get relative data dir path") return errors.Wrap(err, "could not get relative data dir path")
} }
if _, err = os.Stat(dataDirPath); os.IsNotExist(err) { if _, err = os.Stat(dataDirPath); os.IsNotExist(err) {
err = os.MkdirAll(dataDirPath, 0755) if err = os.MkdirAll(dataDirPath, 0755); err != nil {
if err != nil {
return errors.Wrap(err, "could not create config directory") return errors.Wrap(err, "could not create config directory")
} }
} }

18
util/private.go

@ -11,28 +11,28 @@ import (
var privateKey []byte var privateKey []byte
// CheckForPrivateKey checks if already an private key exists, if not one will be randomly generated // CheckForPrivateKey checks if already an private key exists, if not it will be randomly generated
func CheckForPrivateKey() error { func CheckForPrivateKey() error {
privateDat := filepath.Join(GetDataDir(), "private.dat") privateDatPath := filepath.Join(GetDataDir(), "private.dat")
d, err := ioutil.ReadFile(privateDat) privateDatContent, err := ioutil.ReadFile(privateDatPath)
if err == nil { if err == nil {
privateKey = d privateKey = privateDatContent
} else if os.IsNotExist(err) { } else if os.IsNotExist(err) {
b := make([]byte, 256) randomGeneratedKey := make([]byte, 256)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(randomGeneratedKey); err != nil {
return errors.Wrap(err, "could not read random bytes") return errors.Wrap(err, "could not read random bytes")
} }
if err = ioutil.WriteFile(privateDat, b, 0644); err != nil { if err = ioutil.WriteFile(privateDatPath, randomGeneratedKey, 0644); err != nil {
return errors.Wrap(err, "could not write private key") return errors.Wrap(err, "could not write private key")
} }
privateKey = b privateKey = randomGeneratedKey
} else if err != nil { } else if err != nil {
return errors.Wrap(err, "could not read private key") return errors.Wrap(err, "could not read private key")
} }
return nil return nil
} }
// GetPrivateKey returns the private key from the memory // GetPrivateKey returns the private key from the loaded private key
func GetPrivateKey() []byte { func GetPrivateKey() []byte {
return privateKey return privateKey
} }

Loading…
Cancel
Save