From 70d8ddb5dc1f3c4e7123508cf103c86c35b2b70d Mon Sep 17 00:00:00 2001 From: Max Schmitt Date: Sun, 26 Nov 2017 03:10:08 +0100 Subject: [PATCH] Switched from viper to something selfmade: fix #44 --- build/config.yaml | 20 ++++-- handlers/auth.go | 16 +++-- handlers/auth/github.go | 5 +- handlers/auth/google.go | 5 +- handlers/auth/microsoft.go | 5 +- handlers/auth_test.go | 3 - handlers/handlers.go | 6 +- main.go | 4 +- main_test.go | 4 +- store/store.go | 5 +- store/store_test.go | 19 +++-- util/config.go | 138 +++++++++++++++++++++++++------------ util/config_test.go | 6 +- util/private.go | 2 +- util/private_test.go | 2 +- 15 files changed, 145 insertions(+), 95 deletions(-) diff --git a/build/config.yaml b/build/config.yaml index 63b35a9..f0af7fb 100644 --- a/build/config.yaml +++ b/build/config.yaml @@ -1,8 +1,14 @@ -listen_addr: ':8080' # Consists of 'IP:Port', e.g. ':8080' listens on any IP and on Port 8080 -base_url: 'http://localhost:3000' # Origin URL, required for the authentication via OAuth -data_dir: ./data # Contains: the database and the private key -enable_debug_mode: true # Activates more detailed logging -shorted_id_length: 4 # Length of the random generated ID which is used for new shortened URLs +ListenAddr: ':8080' # Consists of 'IP:Port', e.g. ':8080' listens on any IP and on Port 8080 +BaseURL: 'http://localhost:3000' # Origin URL, required for the authentication via OAuth +DataDir: ./data # Contains: the database and the private key +EnableDebugMode: true # Activates more detailed logging +ShortedIDLength: 10 # Length of the random generated ID which is used for new shortened URLs Google: - ClientID: replace me # ClientID which you get from console.cloud.google.com - ClientSecret: replace me # ClientSecret which get from console.cloud.google.com + ClientID: replace me + ClientSecret: replace me +GitHub: + ClientID: replace me + ClientSecret: replace me +Microsoft: + ClientID: replace me + ClientSecret: 'replace me' diff --git a/handlers/auth.go b/handlers/auth.go index 5794d5a..18f2d0b 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -6,7 +6,6 @@ import ( "github.com/maxibanki/golang-url-shortener/handlers/auth" "github.com/maxibanki/golang-url-shortener/util" "github.com/sirupsen/logrus" - "github.com/spf13/viper" jwt "github.com/dgrijalva/jwt-go" "github.com/gin-gonic/contrib/sessions" @@ -18,16 +17,19 @@ func (h *Handler) initOAuth() { h.engine.Use(sessions.Sessions("backend", sessions.NewCookieStore(util.GetPrivateKey()))) h.providers = []string{} - if viper.GetString("Google.ClientSecret") != "" { - auth.WithAdapterWrapper(auth.NewGoogleAdapter(viper.GetString("Google.ClientID"), viper.GetString("Google.ClientSecret"), viper.GetString("base_url")), h.engine.Group("/api/v1/auth/google")) + google := util.GetConfig().Google + if google.Enabled() { + auth.WithAdapterWrapper(auth.NewGoogleAdapter(google.ClientID, google.ClientSecret), h.engine.Group("/api/v1/auth/google")) h.providers = append(h.providers, "google") } - if viper.GetString("GitHub.ClientSecret") != "" { - auth.WithAdapterWrapper(auth.NewGithubAdapter(viper.GetString("GitHub.ClientID"), viper.GetString("GitHub.ClientSecret"), viper.GetString("base_url")), h.engine.Group("/api/v1/auth/github")) + github := util.GetConfig().GitHub + if github.Enabled() { + auth.WithAdapterWrapper(auth.NewGithubAdapter(github.ClientID, github.ClientSecret), h.engine.Group("/api/v1/auth/github")) h.providers = append(h.providers, "github") } - if viper.GetString("Microsoft.ClientSecret") != "" { - auth.WithAdapterWrapper(auth.NewMicrosoftAdapter(viper.GetString("Microsoft.ClientID"), viper.GetString("Microsoft.ClientSecret"), viper.GetString("base_url")), h.engine.Group("/api/v1/auth/microsoft")) + microsoft := util.GetConfig().Microsoft + if microsoft.Enabled() { + auth.WithAdapterWrapper(auth.NewMicrosoftAdapter(microsoft.ClientID, microsoft.ClientSecret), h.engine.Group("/api/v1/auth/microsoft")) h.providers = append(h.providers, "microsoft") } diff --git a/handlers/auth/github.go b/handlers/auth/github.go index 34ac2bd..c6e06c1 100644 --- a/handlers/auth/github.go +++ b/handlers/auth/github.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" + "github.com/maxibanki/golang-url-shortener/util" "github.com/sirupsen/logrus" "golang.org/x/oauth2/github" @@ -17,11 +18,11 @@ type githubAdapter struct { } // NewGithubAdapter creates an oAuth adapter out of the credentials and the baseURL -func NewGithubAdapter(clientID, clientSecret, baseURL string) Adapter { +func NewGithubAdapter(clientID, clientSecret string) Adapter { return &githubAdapter{&oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - RedirectURL: baseURL + "/api/v1/auth/github/callback", + RedirectURL: util.GetConfig().BaseURL + "/api/v1/auth/github/callback", Scopes: []string{ "(no scope)", }, diff --git a/handlers/auth/google.go b/handlers/auth/google.go index b2a4a19..ab81b34 100644 --- a/handlers/auth/google.go +++ b/handlers/auth/google.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" + "github.com/maxibanki/golang-url-shortener/util" "github.com/pkg/errors" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -14,11 +15,11 @@ type googleAdapter struct { } // NewGoogleAdapter creates an oAuth adapter out of the credentials and the baseURL -func NewGoogleAdapter(clientID, clientSecret, baseURL string) Adapter { +func NewGoogleAdapter(clientID, clientSecret string) Adapter { return &googleAdapter{&oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - RedirectURL: baseURL + "/api/v1/auth/google/callback", + RedirectURL: util.GetConfig().BaseURL + "/api/v1/auth/google/callback", Scopes: []string{ "https://www.googleapis.com/auth/userinfo.email", }, diff --git a/handlers/auth/microsoft.go b/handlers/auth/microsoft.go index 538f4ee..b9cfa99 100644 --- a/handlers/auth/microsoft.go +++ b/handlers/auth/microsoft.go @@ -7,6 +7,7 @@ import ( "golang.org/x/oauth2/microsoft" + "github.com/maxibanki/golang-url-shortener/util" "github.com/sirupsen/logrus" "github.com/pkg/errors" @@ -18,11 +19,11 @@ type microsoftAdapter struct { } // NewMicrosoftAdapter creates an oAuth adapter out of the credentials and the baseURL -func NewMicrosoftAdapter(clientID, clientSecret, baseURL string) Adapter { +func NewMicrosoftAdapter(clientID, clientSecret string) Adapter { return µsoftAdapter{&oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - RedirectURL: baseURL + "/api/v1/auth/microsoft/callback", + RedirectURL: util.GetConfig().BaseURL + "/api/v1/auth/microsoft/callback", Scopes: []string{ "wl.basic", }, diff --git a/handlers/auth_test.go b/handlers/auth_test.go index 56baac4..b56cb24 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -15,7 +15,6 @@ import ( "github.com/maxibanki/golang-url-shortener/store" "github.com/maxibanki/golang-url-shortener/util" "github.com/pkg/errors" - "github.com/spf13/viper" ) var ( @@ -36,8 +35,6 @@ var ( func TestCreateBackend(t *testing.T) { secret = util.GetPrivateKey() - viper.SetConfigName("test") - util.DoNotSetConfigName = true if err := util.ReadInConfig(); err != nil { t.Fatalf("could not reload config file: %v", err) } diff --git a/handlers/handlers.go b/handlers/handlers.go index 8c6cb22..fddda81 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -8,7 +8,6 @@ import ( "github.com/gin-gonic/contrib/ginrus" "github.com/sirupsen/logrus" - "github.com/spf13/viper" "github.com/gin-gonic/gin" "github.com/maxibanki/golang-url-shortener/handlers/tmpls" @@ -30,7 +29,7 @@ var DoNotPrivateKeyChecking = false // New initializes the http handlers func New(store store.Store) (*Handler, error) { - if !viper.GetBool("enable_debug_mode") { + if !util.GetConfig().EnableDebugMode { gin.SetMode(gin.ReleaseMode) } h := &Handler{ @@ -82,13 +81,14 @@ func (h *Handler) setHandlers() error { h.engine.NoRoute(h.handleAccess, func(c *gin.Context) { c.Header("Vary", "Accept-Encoding") c.Header("Cache-Control", "public, max-age=2592000") + c.Header("ETag", util.VersionInfo["commit"]) }, gin.WrapH(http.FileServer(FS(false)))) return nil } // Listen starts the http server func (h *Handler) Listen() error { - return h.engine.Run(viper.GetString("listen_addr")) + return h.engine.Run(util.GetConfig().ListenAddr) } // CloseStore stops the http server and the closes the db gracefully diff --git a/main.go b/main.go index 9b41318..93847c6 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,6 @@ import ( "github.com/shiena/ansicolor" "github.com/sirupsen/logrus" - "github.com/spf13/viper" "github.com/maxibanki/golang-url-shortener/handlers" "github.com/maxibanki/golang-url-shortener/store" @@ -15,6 +14,7 @@ import ( ) func main() { + os.Setenv("GUS_SHORTED_ID_LENGTH", "4") stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt) logrus.SetFormatter(&logrus.TextFormatter{ @@ -34,7 +34,7 @@ func initShortener() (func(), error) { if err := util.ReadInConfig(); err != nil { return nil, errors.Wrap(err, "could not reload config file") } - if viper.GetBool("enable_debug_mode") { + if util.GetConfig().EnableDebugMode { logrus.SetLevel(logrus.DebugLevel) } store, err := store.New() diff --git a/main_test.go b/main_test.go index 92aa338..0dfc81e 100644 --- a/main_test.go +++ b/main_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/spf13/viper" + "github.com/maxibanki/golang-url-shortener/util" ) func TestInitShortener(t *testing.T) { @@ -15,7 +15,7 @@ func TestInitShortener(t *testing.T) { } time.Sleep(time.Millisecond * 200) // Give the http server a second to boot up // We expect there a port is in use error - if _, err := net.Listen("tcp", viper.GetString("listen_addr")); err == nil { + if _, err := net.Listen("tcp", util.GetConfig().ListenAddr); err == nil { t.Fatalf("port is not in use: %v", err) } close() diff --git a/store/store.go b/store/store.go index 119f019..12ba8f0 100644 --- a/store/store.go +++ b/store/store.go @@ -12,7 +12,6 @@ import ( "github.com/maxibanki/golang-url-shortener/util" "github.com/pborman/uuid" "github.com/sirupsen/logrus" - "github.com/spf13/viper" "github.com/asaskevich/govalidator" "github.com/boltdb/bolt" @@ -66,7 +65,7 @@ var ( // New initializes the store with the db func New() (*Store, error) { - db, err := bolt.Open(filepath.Join(util.GetDataDir(), "main.db"), 0644, &bolt.Options{Timeout: 1 * time.Second}) + db, err := bolt.Open(filepath.Join(util.GetConfig().DataDir, "main.db"), 0644, &bolt.Options{Timeout: 1 * time.Second}) if err != nil { return nil, errors.Wrap(err, "could not open bolt DB database") } @@ -79,7 +78,7 @@ func New() (*Store, error) { } return &Store{ db: db, - idLength: viper.GetInt("shorted_id_length"), + idLength: util.GetConfig().ShortedIDLength, }, nil } diff --git a/store/store_test.go b/store/store_test.go index bd9e7f2..fa9d519 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -1,20 +1,17 @@ package store import ( - "os" "strings" "testing" - "github.com/spf13/viper" -) - -const ( - testingDBName = "test.db" + "github.com/maxibanki/golang-url-shortener/util" ) func TestGenerateRandomString(t *testing.T) { - viper.SetDefault("data_dir", "./data") - viper.SetDefault("shorted_id_length", 4) + util.SetConfig(util.Configuration{ + DataDir: "./data", + ShortedIDLength: 4, + }) tt := []struct { name string length int @@ -41,6 +38,9 @@ func TestGenerateRandomString(t *testing.T) { func TestNewStore(t *testing.T) { t.Run("create store with correct arguments", func(r *testing.T) { + if err := util.ReadInConfig(); err != nil { + t.Fatalf("could not read in config: %v", err) + } store, err := New() if err != nil { t.Fatalf("unexpected error: %v", err) @@ -121,7 +121,6 @@ func TestIncreaseVisitCounter(t *testing.T) { } func TestDelete(t *testing.T) { - viper.Set("shorted_id_length", 4) store, err := New() if err != nil { t.Fatalf("could not create store: %v", err) @@ -144,7 +143,6 @@ func TestDelete(t *testing.T) { } func TestGetURLAndIncrease(t *testing.T) { - viper.Set("shorted_id_length", 4) store, err := New() if err != nil { t.Fatalf("could not create store: %v", err) @@ -181,5 +179,4 @@ func TestGetURLAndIncrease(t *testing.T) { func cleanup(s *Store) { s.Close() - os.Remove(testingDBName) } diff --git a/util/config.go b/util/config.go index 75c82b9..8930792 100644 --- a/util/config.go +++ b/util/config.go @@ -1,71 +1,119 @@ package util import ( + "io/ioutil" "os" "path/filepath" - "strings" + "reflect" + "strconv" "github.com/sirupsen/logrus" "github.com/pkg/errors" - "github.com/spf13/viper" + "gopkg.in/yaml.v2" ) +type Configuration struct { + ListenAddr string `yaml:"ListenAddr" env:"LISTEN_ADDR"` + BaseURL string `yaml:"BaseURL" env:"BASE_URL"` + DataDir string `yaml:"DataDir" env:"DATA_DIR"` + EnableDebugMode bool `yaml:"EnableDebugMode" env:"ENABLE_DEBUG_MODE"` + ShortedIDLength int `yaml:"ShortedIDLength" env:"SHORTED_ID_LENGTH"` + Google oAuthConf `yaml:"Google" env:"GOOGLE"` + GitHub oAuthConf `yaml:"GitHub" env:"GITHUB"` + Microsoft oAuthConf `yaml:"Microsoft" env:"MICROSOFT"` +} + +type oAuthConf struct { + ClientID string `yaml:"ClientID" env:"CLIENT_ID"` + ClientSecret string `yaml:"ClientSecret" env:"CLIENT_SECRET"` +} + var ( - dataDirPath string - // DoNotSetConfigName is used to predefine if the name of the config should be set. - // Used for unit testing - DoNotSetConfigName = false + config = Configuration{ + ListenAddr: ":8080", + BaseURL: "http://localhost:3000", + DataDir: "data", + EnableDebugMode: false, + ShortedIDLength: 4, + } ) -// ReadInConfig loads the configuration and other needed folders for further usage +// ReadInConfig loads the Configuration and other needed folders for further usage func ReadInConfig() error { - viper.AutomaticEnv() - viper.SetEnvPrefix("gus") - viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - if !DoNotSetConfigName { - viper.SetConfigName("config") - } - viper.AddConfigPath(".") - setConfigDefaults() - switch err := viper.ReadInConfig(); err.(type) { - case viper.ConfigFileNotFoundError: - logrus.Info("No configuration file found, using defaults and environment overrides.") - break - case nil: - break - default: + file, err := ioutil.ReadFile("config.yaml") + if err == nil { + if err := yaml.Unmarshal(file, &config); err != nil { + return errors.Wrap(err, "could not unmarshal yaml file") + } + } else if !os.IsNotExist(err) { return errors.Wrap(err, "could not read config file") } - return checkForDatadir() -} - -// setConfigDefaults sets the default values for the configuration -func setConfigDefaults() { - viper.SetDefault("listen_addr", ":8080") - viper.SetDefault("base_url", "http://localhost:3000") - - viper.SetDefault("data_dir", "data") - viper.SetDefault("enable_debug_mode", true) - viper.SetDefault("shorted_id_length", 4) + if err := config.ApplyEnvironmentConfig(); err != nil { + return errors.Wrap(err, "could not apply environment configuration") + } + config.DataDir, err = filepath.Abs(config.DataDir) + if err != nil { + return errors.Wrap(err, "could not get relative data dir path") + } + if _, err = os.Stat(config.DataDir); os.IsNotExist(err) { + if err = os.MkdirAll(config.DataDir, 0755); err != nil { + return errors.Wrap(err, "could not create config directory") + } + } + return nil } -// GetDataDir returns the absolute path of the data directory -func GetDataDir() string { - return dataDirPath +func (c *Configuration) ApplyEnvironmentConfig() error { + return c.setDefaultValue(reflect.ValueOf(c), reflect.TypeOf(*c), reflect.StructField{}, "GUS") } -// checkForDatadir checks for the data dir and creates it if it not exists -func checkForDatadir() error { - var err error - dataDirPath, err = filepath.Abs(viper.GetString("data_dir")) - if err != nil { - return errors.Wrap(err, "could not get relative data dir path") +func (c *Configuration) setDefaultValue(v reflect.Value, t reflect.Type, f reflect.StructField, prefix string) error { + if v.Kind() != reflect.Ptr { + return errors.New("Not a pointer value") } - if _, err = os.Stat(dataDirPath); os.IsNotExist(err) { - if err = os.MkdirAll(dataDirPath, 0755); err != nil { - return errors.Wrap(err, "could not create config directory") + v = reflect.Indirect(v) + fieldEnv, exists := f.Tag.Lookup("env") + env := os.Getenv(prefix + fieldEnv) + if exists && env != "" { + switch v.Kind() { + case reflect.Int: + envI, err := strconv.Atoi(env) + if err != nil { + logrus.Warningf("could not parse to int: %v", err) + break + } + v.SetInt(int64(envI)) + case reflect.String: + v.SetString(env) + case reflect.Bool: + envB, err := strconv.ParseBool(env) + if err != nil { + logrus.Warningf("could not parse to bool: %v", err) + break + } + v.SetBool(envB) + } + } + if v.Kind() == reflect.Struct { + // Iterate over the struct fields + for i := 0; i < v.NumField(); i++ { + if err := c.setDefaultValue(v.Field(i).Addr(), t, t.Field(i), prefix+fieldEnv+"_"); err != nil { + return err + } } } return nil } + +func (o oAuthConf) Enabled() bool { + return o.ClientSecret != "" +} + +func GetConfig() Configuration { + return config +} + +func SetConfig(c Configuration) { + config = c +} diff --git a/util/config_test.go b/util/config_test.go index 83e607f..17c6d19 100644 --- a/util/config_test.go +++ b/util/config_test.go @@ -2,14 +2,12 @@ package util import ( "testing" - - "github.com/spf13/viper" ) func TestReadInConfig(t *testing.T) { - DoNotSetConfigName = true - viper.SetConfigFile("test.yaml") if err := ReadInConfig(); err != nil { t.Fatalf("could not read in config file: %v", err) } + config := config + config.DataDir = "./test" } diff --git a/util/private.go b/util/private.go index d459498..1f9c4c5 100644 --- a/util/private.go +++ b/util/private.go @@ -14,7 +14,7 @@ var privateKey []byte // CheckForPrivateKey checks if already an private key exists, if not it will // be randomly generated and saved as a private.dat file in the data directory func CheckForPrivateKey() error { - privateDatPath := filepath.Join(GetDataDir(), "private.dat") + privateDatPath := filepath.Join(config.DataDir, "private.dat") privateDatContent, err := ioutil.ReadFile(privateDatPath) if err == nil { privateKey = privateDatContent diff --git a/util/private_test.go b/util/private_test.go index 8ad134f..96033b2 100644 --- a/util/private_test.go +++ b/util/private_test.go @@ -14,7 +14,7 @@ func TestCheckforPrivateKey(t *testing.T) { if GetPrivateKey() == nil { t.Fatalf("private key is nil") } - if err := os.RemoveAll(GetDataDir()); err != nil { + if err := os.RemoveAll(config.DataDir); err != nil { t.Fatalf("could not remove data dir: %v", err) } }