Browse Source

Cleaned up code and integrated deletion handling: fix #10

dependabot/npm_and_yarn/web/prismjs-1.21.0
Max Schmitt 8 years ago
parent
commit
6a4eb7302f
  1. 2
      .gitignore
  2. 4
      README.md
  3. 4
      build/info.sh
  4. 18
      handlers/auth.go
  5. 29
      handlers/auth/auth.go
  6. 3
      handlers/handlers.go
  7. 51
      handlers/public.go
  8. 5
      static/package.json
  9. 7
      static/src/About/About.js
  10. 10
      static/src/Card/Card.js
  11. 5
      static/src/Home/Home.js
  12. 2
      static/src/Lookup/Lookup.js
  13. 3
      static/src/ShareX/ShareX.js
  14. 35
      static/src/index.js
  15. 62
      store/store.go
  16. 16
      store/store_test.go
  17. 16
      store/util.go
  18. 2
      util/config.go
  19. 5
      util/private.go
  20. 2
      util/util.go

2
.gitignore

@ -14,6 +14,8 @@
.glide/
debug
*.db
*.lock
/config.*
/handlers/static.go
/handlers/tmpls/tmpls.go

4
README.md

@ -14,7 +14,7 @@
- Visitor Counting
- Expirable Links
- URL deletion
- Authorization System via OAuth 2.0 (Google, GitHub and Micrsoft)
- Authorization System via OAuth 2.0 (Google, GitHub and Microsoft)
- High performance database with [bolt](https://github.com/boltdb/bolt)
- Easy [ShareX](https://github.com/ShareX/ShareX) integration
- Dockerizable
@ -36,7 +36,7 @@
## Why did you built this
Just only because I want to extend my current self hosted URL shorter (which was really messy code) with some more features and learn about new techniques like:
Only because I just want to extend my current self hosted URL shorter (which was really messy code) with some more features and learn about new techniques like:
- Golang unit testing
- React

4
build/info.sh

@ -1,10 +1,12 @@
cat > util/info.go <<EOL
package util
// VersionInfo contains the generated information which is
// done at build time and used for the frontend page About
var VersionInfo = map[string]string{
"nodeJS": "`node --version`",
"commit": "`git rev-parse HEAD`",
"compilationTime": "`date`",
"compilationTime": "`date --iso-8601=seconds`",
"yarn": "`yarn --version`",
}
EOL

18
handlers/auth.go

@ -51,11 +51,11 @@ func (h *Handler) authMiddleware(c *gin.Context) {
authError := func() error {
wt := c.GetHeader("Authorization")
if wt == "" {
return errors.New("'Authorization' header not set")
return errors.New("Authorization header not set")
}
claims, err := h.parseJWT(wt)
if err != nil {
return err
return errors.Wrap(err, "could not parse JWT")
}
c.Set("user", claims)
return nil
@ -64,7 +64,7 @@ func (h *Handler) authMiddleware(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "authentication failed",
})
logrus.Debugf("Authentication middleware failed: %v\n", authError)
logrus.Debugf("Authentication middleware check failed: %v\n", authError)
return
}
c.Next()
@ -75,12 +75,12 @@ func (h *Handler) handleAuthCheck(c *gin.Context) {
Token string `binding:"required"`
}
if err := c.ShouldBind(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
claims, err := h.parseJWT(data.Token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
@ -90,3 +90,11 @@ func (h *Handler) handleAuthCheck(c *gin.Context) {
"Provider": claims.OAuthProvider,
})
}
func (h *Handler) oAuthPropertiesEquals(c *gin.Context, oauthID, oauthProvider string) bool {
user := c.MustGet("user").(*auth.JWTClaims)
if oauthID == user.OAuthID && oauthProvider == user.OAuthProvider {
return true
}
return false
}

29
handlers/auth/auth.go

@ -12,6 +12,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/maxibanki/golang-url-shortener/util"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// Adapter will be implemented by each oAuth provider
@ -49,10 +50,18 @@ func WithAdapterWrapper(a Adapter, h *gin.RouterGroup) *AdapterWrapper {
// HandleLogin handles the incoming http request for the oAuth process
// and redirects to the generated URL of the provider
func (a *AdapterWrapper) HandleLogin(c *gin.Context) {
state := a.randToken()
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
http.Error(c.Writer, fmt.Sprintf("could not read random state: %v", err), http.StatusInternalServerError)
return
}
state := base64.RawURLEncoding.EncodeToString(b)
session := sessions.Default(c)
session.Set("state", state)
session.Save()
if err := session.Save(); err != nil {
http.Error(c.Writer, fmt.Sprintf("could not save state to session: %v", err), http.StatusInternalServerError)
return
}
c.Redirect(http.StatusTemporaryRedirect, a.GetRedirectURL(state))
}
@ -62,17 +71,21 @@ func (a *AdapterWrapper) HandleCallback(c *gin.Context) {
sessionState := session.Get("state")
receivedState := c.Query("state")
if sessionState != receivedState {
c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("invalid session state: %s", sessionState)})
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("invalid session state: %s", sessionState)})
return
}
user, err := a.GetUserData(receivedState, c.Query("code"))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
logrus.WithFields(logrus.Fields{
"Provider": a.GetOAuthProviderName(),
"Name": user.Name,
}).Info("New user logged in via oAuth")
token, err := a.newJWT(user, a.GetOAuthProviderName())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.HTML(http.StatusOK, "token.tmpl", gin.H{
@ -96,9 +109,3 @@ func (a *AdapterWrapper) newJWT(user *user, provider string) (string, error) {
}
return tokenString, nil
}
func (a *AdapterWrapper) randToken() string {
b := make([]byte, 32)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}

3
handlers/handlers.go

@ -42,7 +42,7 @@ func New(store store.Store) (*Handler, error) {
}
if !DoNotPrivateKeyChecking {
if err := util.CheckForPrivateKey(); err != nil {
return nil, errors.Wrap(err, "could not check for privat key")
return nil, errors.Wrap(err, "could not check for private key")
}
}
h.initOAuth()
@ -73,6 +73,7 @@ func (h *Handler) setHandlers() error {
protected.POST("/lookup", h.handleLookup)
h.engine.GET("/api/v1/info", h.handleInfo)
h.engine.GET("/d/:id/:hash", h.handleDelete)
// Handling the shorted URLs, if no one exists, it checks
// in the filesystem and sets headers for caching

51
handlers/public.go

@ -3,6 +3,7 @@ package handlers
import (
"fmt"
"net/http"
"net/url"
"runtime"
"time"
@ -16,8 +17,8 @@ import (
// un- and marshalling
type urlUtil struct {
URL string `binding:"required"`
ID string
Expiration time.Time
ID, DeletionURL string
Expiration *time.Time `json:",omitempty"`
}
// handleLookup is the http handler for getting the infos
@ -26,16 +27,15 @@ func (h *Handler) handleLookup(c *gin.Context) {
ID string `binding:"required"`
}
if err := c.ShouldBind(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
entry, err := h.store.GetEntryByID(data.ID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
user := c.MustGet("user").(*auth.JWTClaims)
if entry.OAuthID != user.OAuthID || entry.OAuthProvider != user.OAuthProvider {
if !h.oAuthPropertiesEquals(c, entry.OAuthID, entry.OAuthProvider) {
c.JSON(http.StatusOK, store.Entry{
Public: store.EntryPublicData{
URL: entry.Public.URL,
@ -48,25 +48,14 @@ func (h *Handler) handleLookup(c *gin.Context) {
// handleAccess handles the access for incoming requests
func (h *Handler) handleAccess(c *gin.Context) {
var id string
if len(c.Request.URL.Path) > 1 {
id = c.Request.URL.Path[1:]
}
entry, err := h.store.GetEntryByID(id)
if err == store.ErrIDIsEmpty || err == store.ErrNoEntryFound {
url, err := h.store.GetURLAndIncrease(c.Request.URL.Path[1:])
if err == store.ErrNoEntryFound {
return
} else if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
if time.Now().After(entry.Public.Expiration) && !entry.Public.Expiration.IsZero() {
return
}
if err := h.store.IncreaseVisitCounter(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
http.Error(c.Writer, fmt.Sprintf("could not get and crease visitor counter: %v, ", err), http.StatusInternalServerError)
return
}
c.Redirect(http.StatusTemporaryRedirect, entry.Public.URL)
c.Redirect(http.StatusTemporaryRedirect, url)
}
// handleCreate handles requests to create an entry
@ -77,7 +66,7 @@ func (h *Handler) handleCreate(c *gin.Context) {
return
}
user := c.MustGet("user").(*auth.JWTClaims)
id, err := h.store.CreateEntry(store.Entry{
id, delID, err := h.store.CreateEntry(store.Entry{
Public: store.EntryPublicData{
URL: data.URL,
Expiration: data.Expiration,
@ -90,8 +79,11 @@ func (h *Handler) handleCreate(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
data.URL = h.getSchemaAndHost(c) + "/" + id
c.JSON(http.StatusOK, data)
originURL := h.getURLOrigin(c)
c.JSON(http.StatusOK, urlUtil{
URL: fmt.Sprintf("%s/%s", originURL, id),
DeletionURL: fmt.Sprintf("%s/d/%s/%s", originURL, id, url.QueryEscape(delID)),
})
}
func (h *Handler) handleInfo(c *gin.Context) {
@ -104,8 +96,17 @@ func (h *Handler) handleInfo(c *gin.Context) {
}
c.JSON(http.StatusOK, info)
}
func (h *Handler) handleDelete(c *gin.Context) {
if err := h.store.DeleteEntry(c.Param("id"), c.Param("hash")); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
})
}
func (h *Handler) getSchemaAndHost(c *gin.Context) string {
func (h *Handler) getURLOrigin(c *gin.Context) string {
protocol := "http"
if c.Request.TLS != nil {
protocol = "https"

5
static/package.json

@ -16,12 +16,13 @@
"react-dom": "^16.1.1",
"react-prism": "^4.3.1",
"react-qr-svg": "^2.1.0",
"react-responsive": "^4.0.0",
"react-responsive": "^4.0.2",
"react-router": "^4.2.0",
"react-router-dom": "^4.2.2",
"react-scripts": "1.0.17",
"semantic-ui-css": "^2.2.12",
"semantic-ui-react": "^0.76.0"
"semantic-ui-react": "^0.76.0",
"toastr": "^2.1.2"
},
"scripts": {
"start": "react-scripts start",

7
static/src/About/About.js

@ -1,12 +1,13 @@
import React, { Component } from 'react'
import { Container, Table } from 'semantic-ui-react'
import moment from 'moment'
export default class AboutComponent extends Component {
state = {
info: null
}
componentWillMount() {
fetch("/api/v1/info").then(res => res.json()).then(d => this.setState({ info: d }))
componentWillReceiveProps = () => {
this.setState({ info: this.props.info })
}
render() {
const { info } = this.state
@ -31,7 +32,7 @@ export default class AboutComponent extends Component {
</Table.Row>
<Table.Row>
<Table.Cell>Compilation Time</Table.Cell>
<Table.Cell>{info.compilationTime}</Table.Cell>
<Table.Cell>{moment(info.compilationTime).fromNow()} ({info.compilationTime})</Table.Cell>
</Table.Row>
<Table.Row>
<Table.Cell>Commit Hash</Table.Cell>

10
static/src/Card/Card.js

@ -2,6 +2,7 @@ import React, { Component } from 'react'
import { Card, Icon, Button, Modal } from 'semantic-ui-react'
import { QRCode } from 'react-qr-svg';
import Clipboard from 'react-clipboard.js';
import toastr from 'toastr'
export default class CardComponent extends Component {
state = {
@ -15,6 +16,12 @@ export default class CardComponent extends Component {
}, 500)
}
}
onDeletonLinkCopy() {
toastr.info('Copied the deletion URL to the Clipboard')
}
onShortedURLSuccess() {
toastr.info('Copied the shorted URL to the Clipboard')
}
render() {
const { expireDate } = this.state
return (<Card key={this.key}>
@ -30,6 +37,7 @@ export default class CardComponent extends Component {
</Card.Meta>
<Card.Description>
{this.props.description}
{this.props.deletionURL && <Clipboard component="i" className="trash link icon" style={{ float: "right" }} button-title="Copy the deletion URL to the clipboard" data-clipboard-text={this.props.deletionURL} onSuccess={this.onDeletonLinkCopy} />}
</Card.Description>
</Card.Content>
<Card.Content extra>
@ -40,7 +48,7 @@ export default class CardComponent extends Component {
<QRCode style={{ width: '75%' }} value={this.props.description} />
</Modal.Content>
</Modal>
<Clipboard component='button' className='ui button' data-clipboard-text={this.props.description} button-title='Copy the Shortened URL to the Clipboard'>
<Clipboard component='button' className='ui button' data-clipboard-text={this.props.description} onSuccess={this.onShortedURLSuccess} button-title='Copy the Shortened URL to the Clipboard'>
<div>
<Icon name='clipboard' />
Copy to Clipboard

5
static/src/Home/Home.js

@ -61,7 +61,8 @@ export default class HomeComponent extends Component {
links: [...this.state.links, [
r.URL,
this.url,
this.state.setOptions.indexOf("expire") > -1 ? this.state.expiration : undefined
this.state.setOptions.indexOf("expire") > -1 ? this.state.expiration : undefined,
r.DeletionURL
]]
}))
}
@ -102,7 +103,7 @@ export default class HomeComponent extends Component {
</Form>
</Segment>
<Card.Group itemsPerRow="2" stackable style={{ marginTop: "1rem" }}>
{links.map((link, i) => <CustomCard key={i} header={new URL(link[1]).hostname} expireDate={link[2]} metaHeader={link[1]} description={link[0]} />)}
{links.map((link, i) => <CustomCard key={i} header={new URL(link[1]).hostname} expireDate={link[2]} metaHeader={link[1]} description={link[0]} deletionURL={link[3]}/>)}
</Card.Group>
</div >
)

2
static/src/Lookup/Lookup.js

@ -28,7 +28,7 @@ export default class LookupComponent extends Component {
this.VisitCount,
res.CratedOn,
res.LastVisit,
moment(res.Expiration)
res.Expiration ? moment(res.Expiration) : null
]]
}))
}

3
static/src/ShareX/ShareX.js

@ -22,7 +22,8 @@ export default class ShareXComponent extends Component {
Authorization: window.localStorage.getItem('token')
},
ResponseType: "Text",
URL: "$json:URL$"
URL: "$json:URL$",
DeletionURL: "$json:DeletionURL$"
}, null, 4),
currentStep: 0,
availableSteps: [

35
static/src/index.js

@ -2,7 +2,9 @@ import React, { Component } from 'react'
import ReactDOM from 'react-dom';
import { HashRouter, Route, Link } from 'react-router-dom'
import { Menu, Container, Modal, Button, Image, Icon } from 'semantic-ui-react'
import toastr from 'toastr'
import 'semantic-ui-css/semantic.min.css';
import 'toastr/build/toastr.css';
import About from './About/About'
import Home from './Home/Home'
@ -15,7 +17,7 @@ export default class BaseComponent extends Component {
userData: {},
authorized: false,
activeItem: "",
providers: []
info: null
}
onOAuthClose() {
@ -30,7 +32,10 @@ export default class BaseComponent extends Component {
}
loadInfo = () => {
fetch('/api/v1/info').then(d => d.json()).then(d => this.setState({ providers: d.providers }))
fetch('/api/v1/info')
.then(d => d.json())
.then(info => this.setState({ info }))
.catch(e => toastr.error(e))
}
checkAuth = () => {
@ -45,12 +50,14 @@ export default class BaseComponent extends Component {
headers: {
'Content-Type': 'application/json'
}
}).then(res => res.ok ? res.json() : Promise.reject(res.json())) // Check if the request was StatusOK, otherwise reject Promise
})
.then(res => res.ok ? res.json() : Promise.reject(`incorrect response status code: ${res.status}; text: ${res.statusText}`))
.then(d => {
that.setState({ userData: d })
that.setState({ authorized: true })
})
.catch(e => {
toastr.error(`Could not fetch info: ${e}`)
window.localStorage.removeItem('token');
that.setState({ authorized: false })
})
@ -69,7 +76,7 @@ export default class BaseComponent extends Component {
onOAuthClick = provider => {
window.addEventListener('message', this.onOAuthCallback, false);
var url = `${window.location.origin}/api/v1/auth/${provider}/login`;
if (!this._oAuthPopup) {
if (!this._oAuthPopup || this._oAuthPopup.closed) {
// Open the oAuth window that is it centered in the middle of the screen
var wwidth = 400,
wHeight = 500;
@ -87,7 +94,7 @@ export default class BaseComponent extends Component {
}
render() {
const { open, authorized, activeItem, userData, providers } = this.state
const { open, authorized, activeItem, userData, info } = this.state
if (!authorized) {
return (
<Modal size='tiny' open={open} onClose={this.onOAuthClose}>
@ -96,26 +103,26 @@ export default class BaseComponent extends Component {
</Modal.Header>
<Modal.Content>
<p>The following authentication services are currently available:</p>
<div className='ui center aligned segment'>
{providers.length === 0 && <p>There are currently no correct oAuth credentials maintained.</p>}
{providers.indexOf("google") !== -1 && <div>
{info && <div className='ui center aligned segment'>
{info.providers.length === 0 && <p>There are currently no correct oAuth credentials maintained.</p>}
{info.providers.indexOf("google") !== -1 && <div>
<Button className='ui google plus button' onClick={this.onOAuthClick.bind(this, "google")}>
<Icon name='google' /> Login with Google
</Button>
{providers.indexOf("github") !== -1 && <div className="ui divider"></div>}
{info.providers.indexOf("github") !== -1 && <div className="ui divider"></div>}
</div>}
{providers.indexOf("github") !== -1 && <div>
{info.providers.indexOf("github") !== -1 && <div>
<Button style={{ backgroundColor: "#333", color: "white" }} onClick={this.onOAuthClick.bind(this, "github")}>
<Icon name='github' /> Login with GitHub
</Button>
</div>}
{providers.indexOf("microsoft") !== -1 && <div>
{info.providers.indexOf("microsoft") !== -1 && <div>
<div className="ui divider"></div>
<Button style={{ backgroundColor: "#0067b8", color: "white" }} onClick={this.onOAuthClick.bind(this, "microsoft")}>
<Icon name='windows' /> Login with Microsoft
</Button>
</div>}
</div>
</div>}
</Modal.Content>
</Modal >
)
@ -147,7 +154,9 @@ export default class BaseComponent extends Component {
</Menu.Menu>
</Menu>
<Route exact path="/" component={Home} />
<Route path="/about" component={About} />
<Route path="/about" render={(props) => (
<About info={info} />
)} />
<Route path="/ShareX" component={ShareX} />
<Route path="/Lookup" component={Lookup} />
</Container>

62
store/store.go

@ -2,6 +2,9 @@
package store
import (
"crypto/hmac"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"path/filepath"
"time"
@ -31,7 +34,8 @@ type Entry struct {
// EntryPublicData is the public part of an entry
type EntryPublicData struct {
CreatedOn, LastVisit, Expiration time.Time
CreatedOn, LastVisit time.Time
Expiration *time.Time `json:",omitempty"`
VisitCount int
URL string
}
@ -45,8 +49,8 @@ var ErrNoValidURL = errors.New("the given URL is no valid URL")
// ErrGeneratingIDFailed is returned when the 10 tries to generate an id failed
var ErrGeneratingIDFailed = errors.New("could not generate unique id, all ten tries failed")
// ErrIDIsEmpty is returned when the given ID is empty
var ErrIDIsEmpty = errors.New("the given ID is empty")
// ErrEntryIsExpired is returned when the entry is expired
var ErrEntryIsExpired = errors.New("entry is expired")
// New initializes the store with the db
func New() (*Store, error) {
@ -72,7 +76,7 @@ func New() (*Store, error) {
// GetEntryByID returns a unmarshalled entry of the db by a given ID
func (s *Store) GetEntryByID(id string) (*Entry, error) {
if id == "" {
return nil, ErrIDIsEmpty
return nil, ErrNoEntryFound
}
rawEntry, err := s.GetEntryByIDRaw(id)
if err != nil {
@ -102,6 +106,22 @@ func (s *Store) IncreaseVisitCounter(id string) error {
})
}
// GetURLAndIncrease Increases the visitor count, checks
// if the URL is expired and returns the origin URL
func (s *Store) GetURLAndIncrease(id string) (string, error) {
entry, err := s.GetEntryByID(id)
if err != nil {
return "", err
}
if entry.Public.Expiration != nil && time.Now().After(*entry.Public.Expiration) {
return "", ErrEntryIsExpired
}
if err := s.IncreaseVisitCounter(id); err != nil {
return "", errors.Wrap(err, "could not increase visitor counter")
}
return entry.Public.URL, nil
}
// GetEntryByIDRaw returns the raw data (JSON) of a data set
func (s *Store) GetEntryByIDRaw(id string) ([]byte, error) {
var raw []byte
@ -115,22 +135,44 @@ func (s *Store) GetEntryByIDRaw(id string) ([]byte, error) {
}
// CreateEntry creates a new record and returns his short id
func (s *Store) CreateEntry(entry Entry, givenID string) (string, error) {
func (s *Store) CreateEntry(entry Entry, givenID string) (string, string, error) {
if !govalidator.IsURL(entry.Public.URL) {
return "", ErrNoValidURL
return "", "", ErrNoValidURL
}
// try it 10 times to make a short URL
for i := 1; i <= 10; i++ {
id, err := s.createEntry(entry, givenID)
id, delID, err := s.createEntry(entry, givenID)
if err != nil && givenID != "" {
return "", err
return "", "", err
} else if err != nil {
logrus.Debugf("Could not create entry: %v", err)
continue
}
return id, nil
return id, delID, nil
}
return "", ErrGeneratingIDFailed
return "", "", ErrGeneratingIDFailed
}
// DeleteEntry deletes an Entry fully from the DB
func (s *Store) DeleteEntry(id, hash string) error {
mac := hmac.New(sha512.New, util.GetPrivateKey())
if _, err := mac.Write([]byte(id)); err != nil {
return errors.Wrap(err, "could not write hmac")
}
givenHmac, err := base64.RawURLEncoding.DecodeString(hash)
if err != nil {
return errors.Wrap(err, "could not decode base64")
}
if !hmac.Equal(mac.Sum(nil), givenHmac) {
return errors.New("hmac verification failed")
}
return s.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket(s.bucketName)
if bucket.Get([]byte(id)) == nil {
return errors.New("entry already deleted")
}
return bucket.Delete([]byte(id))
})
}
// Close closes the bolt db database

16
store/store_test.go

@ -2,6 +2,7 @@ package store
import (
"os"
"strings"
"testing"
"github.com/spf13/viper"
@ -54,12 +55,12 @@ func TestCreateEntry(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
defer cleanup(store)
_, err = store.CreateEntry(Entry{}, "")
_, _, err = store.CreateEntry(Entry{}, "")
if err != ErrNoValidURL {
t.Fatalf("unexpected error: %v", err)
}
for i := 1; i <= 100; i++ {
_, err := store.CreateEntry(Entry{
_, _, err := store.CreateEntry(Entry{
Public: EntryPublicData{
URL: "https://golang.org/",
},
@ -81,8 +82,8 @@ func TestGetEntryByID(t *testing.T) {
t.Fatalf("could not get expected '%v' error: %v", ErrNoEntryFound, err)
}
_, err = store.GetEntryByID("")
if err != ErrIDIsEmpty {
t.Fatalf("could not get expected '%v' error: %v", ErrIDIsEmpty, err)
if err != ErrNoEntryFound {
t.Fatalf("could not get expected '%v' error: %v", ErrNoEntryFound, err)
}
}
@ -92,7 +93,7 @@ func TestIncreaseVisitCounter(t *testing.T) {
t.Fatalf("could not create store: %v", err)
}
defer cleanup(store)
id, err := store.CreateEntry(Entry{
id, _, err := store.CreateEntry(Entry{
Public: EntryPublicData{
URL: "https://golang.org/",
},
@ -114,9 +115,8 @@ func TestIncreaseVisitCounter(t *testing.T) {
if entryBeforeInc.Public.VisitCount+1 != entryAfterInc.Public.VisitCount {
t.Fatalf("the increasement was not successful, the visit count is not correct")
}
errIDIsEmpty := "could not get entry by ID: the given ID is empty"
if err = store.IncreaseVisitCounter(""); err.Error() != errIDIsEmpty {
t.Fatalf("could not get expected '%v'; got: %v", errIDIsEmpty, err)
if err = store.IncreaseVisitCounter(""); !strings.Contains(err.Error(), ErrNoEntryFound.Error()) {
t.Fatalf("could not get expected '%v'; got: %v", ErrNoEntryFound, err)
}
}

16
store/util.go

@ -1,13 +1,17 @@
package store
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"math/big"
"time"
"unicode"
"github.com/boltdb/bolt"
"github.com/maxibanki/golang-url-shortener/util"
"github.com/pkg/errors"
)
@ -27,19 +31,23 @@ func (s *Store) createEntryRaw(key, value []byte) error {
// createEntry creates a new entry with a randomly generated id. If on is present
// then the given ID is used
func (s *Store) createEntry(entry Entry, entryID string) (string, error) {
func (s *Store) createEntry(entry Entry, entryID string) (string, string, error) {
var err error
if entryID == "" {
if entryID, err = generateRandomString(s.idLength); err != nil {
return "", errors.Wrap(err, "could not generate random string")
return "", "", errors.Wrap(err, "could not generate random string")
}
}
entry.Public.CreatedOn = time.Now()
rawEntry, err := json.Marshal(entry)
if err != nil {
return "", err
return "", "", err
}
mac := hmac.New(sha512.New, util.GetPrivateKey())
if _, err := mac.Write([]byte(entryID)); err != nil {
return "", "", errors.Wrap(err, "could not write hmac")
}
return entryID, s.createEntryRaw([]byte(entryID), rawEntry)
return entryID, base64.RawURLEncoding.EncodeToString(mac.Sum(nil)), s.createEntryRaw([]byte(entryID), rawEntry)
}
// generateRandomString generates a random string with an predefined length

2
util/config.go

@ -14,7 +14,7 @@ import (
var (
dataDirPath string
// DoNotSetConfigName is used to predefine if the name of the config should be set.
// Used for the unit testing
// Used for unit testing
DoNotSetConfigName = false
)

5
util/private.go

@ -11,7 +11,8 @@ import (
var privateKey []byte
// CheckForPrivateKey checks if already an private key exists, if not it will be randomly generated
// CheckForPrivateKey checks if already an private key exists, if not it will
// be randomly generated and saved as a private.dat file in the data directory
func CheckForPrivateKey() error {
privateDatPath := filepath.Join(GetDataDir(), "private.dat")
privateDatContent, err := ioutil.ReadFile(privateDatPath)
@ -32,7 +33,7 @@ func CheckForPrivateKey() error {
return nil
}
// GetPrivateKey returns the private key from the loaded private key
// GetPrivateKey returns the private key
func GetPrivateKey() []byte {
return privateKey
}

2
util/util.go

@ -1,2 +0,0 @@
// Package util implements helper functions for the complete Golang URL Shortener app
package util
Loading…
Cancel
Save