15 changed files with 145 additions and 95 deletions
@ -1,8 +1,14 @@ |
|||
listen_addr: ':8080' # Consists of 'IP:Port', e.g. ':8080' listens on any IP and on Port 8080 |
|||
base_url: 'http://localhost:3000' # Origin URL, required for the authentication via OAuth |
|||
data_dir: ./data # Contains: the database and the private key |
|||
enable_debug_mode: true # Activates more detailed logging |
|||
shorted_id_length: 4 # Length of the random generated ID which is used for new shortened URLs |
|||
ListenAddr: ':8080' # Consists of 'IP:Port', e.g. ':8080' listens on any IP and on Port 8080 |
|||
BaseURL: 'http://localhost:3000' # Origin URL, required for the authentication via OAuth |
|||
DataDir: ./data # Contains: the database and the private key |
|||
EnableDebugMode: true # Activates more detailed logging |
|||
ShortedIDLength: 10 # Length of the random generated ID which is used for new shortened URLs |
|||
Google: |
|||
ClientID: replace me # ClientID which you get from console.cloud.google.com |
|||
ClientSecret: replace me # ClientSecret which get from console.cloud.google.com |
|||
ClientID: replace me |
|||
ClientSecret: replace me |
|||
GitHub: |
|||
ClientID: replace me |
|||
ClientSecret: replace me |
|||
Microsoft: |
|||
ClientID: replace me |
|||
ClientSecret: 'replace me' |
|||
|
|||
@ -1,71 +1,119 @@ |
|||
package util |
|||
|
|||
import ( |
|||
"io/ioutil" |
|||
"os" |
|||
"path/filepath" |
|||
"strings" |
|||
"reflect" |
|||
"strconv" |
|||
|
|||
"github.com/sirupsen/logrus" |
|||
|
|||
"github.com/pkg/errors" |
|||
"github.com/spf13/viper" |
|||
"gopkg.in/yaml.v2" |
|||
) |
|||
|
|||
type Configuration struct { |
|||
ListenAddr string `yaml:"ListenAddr" env:"LISTEN_ADDR"` |
|||
BaseURL string `yaml:"BaseURL" env:"BASE_URL"` |
|||
DataDir string `yaml:"DataDir" env:"DATA_DIR"` |
|||
EnableDebugMode bool `yaml:"EnableDebugMode" env:"ENABLE_DEBUG_MODE"` |
|||
ShortedIDLength int `yaml:"ShortedIDLength" env:"SHORTED_ID_LENGTH"` |
|||
Google oAuthConf `yaml:"Google" env:"GOOGLE"` |
|||
GitHub oAuthConf `yaml:"GitHub" env:"GITHUB"` |
|||
Microsoft oAuthConf `yaml:"Microsoft" env:"MICROSOFT"` |
|||
} |
|||
|
|||
type oAuthConf struct { |
|||
ClientID string `yaml:"ClientID" env:"CLIENT_ID"` |
|||
ClientSecret string `yaml:"ClientSecret" env:"CLIENT_SECRET"` |
|||
} |
|||
|
|||
var ( |
|||
dataDirPath string |
|||
// DoNotSetConfigName is used to predefine if the name of the config should be set.
|
|||
// Used for unit testing
|
|||
DoNotSetConfigName = false |
|||
config = Configuration{ |
|||
ListenAddr: ":8080", |
|||
BaseURL: "http://localhost:3000", |
|||
DataDir: "data", |
|||
EnableDebugMode: false, |
|||
ShortedIDLength: 4, |
|||
} |
|||
) |
|||
|
|||
// ReadInConfig loads the configuration and other needed folders for further usage
|
|||
// ReadInConfig loads the Configuration and other needed folders for further usage
|
|||
func ReadInConfig() error { |
|||
viper.AutomaticEnv() |
|||
viper.SetEnvPrefix("gus") |
|||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) |
|||
if !DoNotSetConfigName { |
|||
viper.SetConfigName("config") |
|||
} |
|||
viper.AddConfigPath(".") |
|||
setConfigDefaults() |
|||
switch err := viper.ReadInConfig(); err.(type) { |
|||
case viper.ConfigFileNotFoundError: |
|||
logrus.Info("No configuration file found, using defaults and environment overrides.") |
|||
break |
|||
case nil: |
|||
break |
|||
default: |
|||
file, err := ioutil.ReadFile("config.yaml") |
|||
if err == nil { |
|||
if err := yaml.Unmarshal(file, &config); err != nil { |
|||
return errors.Wrap(err, "could not unmarshal yaml file") |
|||
} |
|||
} else if !os.IsNotExist(err) { |
|||
return errors.Wrap(err, "could not read config file") |
|||
} |
|||
return checkForDatadir() |
|||
if err := config.ApplyEnvironmentConfig(); err != nil { |
|||
return errors.Wrap(err, "could not apply environment configuration") |
|||
} |
|||
|
|||
// setConfigDefaults sets the default values for the configuration
|
|||
func setConfigDefaults() { |
|||
viper.SetDefault("listen_addr", ":8080") |
|||
viper.SetDefault("base_url", "http://localhost:3000") |
|||
|
|||
viper.SetDefault("data_dir", "data") |
|||
viper.SetDefault("enable_debug_mode", true) |
|||
viper.SetDefault("shorted_id_length", 4) |
|||
config.DataDir, err = filepath.Abs(config.DataDir) |
|||
if err != nil { |
|||
return errors.Wrap(err, "could not get relative data dir path") |
|||
} |
|||
if _, err = os.Stat(config.DataDir); os.IsNotExist(err) { |
|||
if err = os.MkdirAll(config.DataDir, 0755); err != nil { |
|||
return errors.Wrap(err, "could not create config directory") |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// GetDataDir returns the absolute path of the data directory
|
|||
func GetDataDir() string { |
|||
return dataDirPath |
|||
func (c *Configuration) ApplyEnvironmentConfig() error { |
|||
return c.setDefaultValue(reflect.ValueOf(c), reflect.TypeOf(*c), reflect.StructField{}, "GUS") |
|||
} |
|||
|
|||
// checkForDatadir checks for the data dir and creates it if it not exists
|
|||
func checkForDatadir() error { |
|||
var err error |
|||
dataDirPath, err = filepath.Abs(viper.GetString("data_dir")) |
|||
func (c *Configuration) setDefaultValue(v reflect.Value, t reflect.Type, f reflect.StructField, prefix string) error { |
|||
if v.Kind() != reflect.Ptr { |
|||
return errors.New("Not a pointer value") |
|||
} |
|||
v = reflect.Indirect(v) |
|||
fieldEnv, exists := f.Tag.Lookup("env") |
|||
env := os.Getenv(prefix + fieldEnv) |
|||
if exists && env != "" { |
|||
switch v.Kind() { |
|||
case reflect.Int: |
|||
envI, err := strconv.Atoi(env) |
|||
if err != nil { |
|||
return errors.Wrap(err, "could not get relative data dir path") |
|||
logrus.Warningf("could not parse to int: %v", err) |
|||
break |
|||
} |
|||
v.SetInt(int64(envI)) |
|||
case reflect.String: |
|||
v.SetString(env) |
|||
case reflect.Bool: |
|||
envB, err := strconv.ParseBool(env) |
|||
if err != nil { |
|||
logrus.Warningf("could not parse to bool: %v", err) |
|||
break |
|||
} |
|||
v.SetBool(envB) |
|||
} |
|||
} |
|||
if v.Kind() == reflect.Struct { |
|||
// Iterate over the struct fields
|
|||
for i := 0; i < v.NumField(); i++ { |
|||
if err := c.setDefaultValue(v.Field(i).Addr(), t, t.Field(i), prefix+fieldEnv+"_"); err != nil { |
|||
return err |
|||
} |
|||
if _, err = os.Stat(dataDirPath); os.IsNotExist(err) { |
|||
if err = os.MkdirAll(dataDirPath, 0755); err != nil { |
|||
return errors.Wrap(err, "could not create config directory") |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (o oAuthConf) Enabled() bool { |
|||
return o.ClientSecret != "" |
|||
} |
|||
|
|||
func GetConfig() Configuration { |
|||
return config |
|||
} |
|||
|
|||
func SetConfig(c Configuration) { |
|||
config = c |
|||
} |
|||
|
|||
Loading…
Reference in new issue