Refactor flag parsing

This commit is contained in:
Ken Hibino 2022-05-02 06:47:29 -07:00
parent 9796da746b
commit 2f9d2021c3

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"flag" "flag"
"fmt" "fmt"
@ -20,50 +21,78 @@ import (
"github.com/rs/cors" "github.com/rs/cors"
) )
// Command-line flags // Config holds configurations for the running program
var ( // provided by the user via command line.
flagPort int type Config struct {
flagRedisAddr string // Server port
flagRedisDB int Port int
flagRedisPassword string
flagRedisTLS string
flagRedisURL string
flagRedisInsecureTLS bool
flagRedisClusterNodes string
flagMaxPayloadLength int
flagMaxResultLength int
flagEnableMetricsExporter bool
flagPrometheusServerAddr string
flagReadOnly bool
)
func init() { // Redis connection options
flag.IntVar(&flagPort, "port", getEnvOrDefaultInt("PORT", 8080), "port number to use for web ui server") RedisAddr string
flag.StringVar(&flagRedisAddr, "redis-addr", getEnvDefaultString("REDIS_ADDR", "127.0.0.1:6379"), "address of redis server to connect to") RedisDB int
flag.IntVar(&flagRedisDB, "redis-db", getEnvOrDefaultInt("REDIS_DB", 0), "redis database number") RedisPassword string
flag.StringVar(&flagRedisPassword, "redis-password", getEnvDefaultString("REDIS_PASSWORD", ""), "password to use when connecting to redis server") RedisTLS string
flag.StringVar(&flagRedisTLS, "redis-tls", getEnvDefaultString("REDIS_TLS", ""), "server name for TLS validation used when connecting to redis server") RedisURL string
flag.StringVar(&flagRedisURL, "redis-url", getEnvDefaultString("REDIS_URL", ""), "URL to redis server") RedisInsecureTLS bool
flag.BoolVar(&flagRedisInsecureTLS, "redis-insecure-tls", getEnvOrDefaultBool("REDIS_INSECURE_TLS", false), "disable TLS certificate host checks") RedisClusterNodes string
flag.StringVar(&flagRedisClusterNodes, "redis-cluster-nodes", getEnvDefaultString("REDIS_CLUSTER_NODES", ""), "comma separated list of host:port addresses of cluster nodes")
flag.IntVar(&flagMaxPayloadLength, "max-payload-length", getEnvOrDefaultInt("MAX_PAYLOAD_LENGTH", 200), "maximum number of utf8 characters printed in the payload cell in the Web UI") // UI related configs
flag.IntVar(&flagMaxResultLength, "max-result-length", getEnvOrDefaultInt("MAX_RESULT_LENGTH", 200), "maximum number of utf8 characters printed in the result cell in the Web UI") ReadOnly bool
flag.BoolVar(&flagEnableMetricsExporter, "enable-metrics-exporter", getEnvOrDefaultBool("ENABLE_METRICS_EXPORTER", false), "enable prometheus metrics exporter to expose queue metrics") MaxPayloadLength int
flag.StringVar(&flagPrometheusServerAddr, "prometheus-addr", getEnvDefaultString("PROMETHEUS_ADDR", ""), "address of prometheus server to query time series") MaxResultLength int
flag.BoolVar(&flagReadOnly, "read-only", getEnvOrDefaultBool("READ_ONLY", false), "restrict to read-only mode")
// Prometheus related configs
EnableMetricsExporter bool
PrometheusServerAddr string
// Args are the positional (non-flag) command line arguments
Args []string
}
// parseFlags parses the command-line arguments provided to the program.
// Typically os.Args[0] is provided as 'progname' and os.args[1:] as 'args'.
// Returns the Config in case parsing succeeded, or an error. In any case, the
// output of the flag.Parse is returned in output.
//
// Reference: https://eli.thegreenplace.net/2020/testing-flag-parsing-in-go-programs/
func parseFlags(progname string, args []string) (cfg *Config, output string, err error) {
flags := flag.NewFlagSet(progname, flag.ContinueOnError)
var buf bytes.Buffer
flags.SetOutput(&buf)
var conf Config
flags.IntVar(&conf.Port, "port", getEnvOrDefaultInt("PORT", 8080), "port number to use for web ui server")
flags.StringVar(&conf.RedisAddr, "redis-addr", getEnvDefaultString("REDIS_ADDR", "127.0.0.1:6379"), "address of redis server to connect to")
flags.IntVar(&conf.RedisDB, "redis-db", getEnvOrDefaultInt("REDIS_DB", 0), "redis database number")
flags.StringVar(&conf.RedisPassword, "redis-password", getEnvDefaultString("REDIS_PASSWORD", ""), "password to use when connecting to redis server")
flags.StringVar(&conf.RedisTLS, "redis-tls", getEnvDefaultString("REDIS_TLS", ""), "server name for TLS validation used when connecting to redis server")
flags.StringVar(&conf.RedisURL, "redis-url", getEnvDefaultString("REDIS_URL", ""), "URL to redis server")
flags.BoolVar(&conf.RedisInsecureTLS, "redis-insecure-tls", getEnvOrDefaultBool("REDIS_INSECURE_TLS", false), "disable TLS certificate host checks")
flags.StringVar(&conf.RedisClusterNodes, "redis-cluster-nodes", getEnvDefaultString("REDIS_CLUSTER_NODES", ""), "comma separated list of host:port addresses of cluster nodes")
flags.IntVar(&conf.MaxPayloadLength, "max-payload-length", getEnvOrDefaultInt("MAX_PAYLOAD_LENGTH", 200), "maximum number of utf8 characters printed in the payload cell in the Web UI")
flags.IntVar(&conf.MaxResultLength, "max-result-length", getEnvOrDefaultInt("MAX_RESULT_LENGTH", 200), "maximum number of utf8 characters printed in the result cell in the Web UI")
flags.BoolVar(&conf.EnableMetricsExporter, "enable-metrics-exporter", getEnvOrDefaultBool("ENABLE_METRICS_EXPORTER", false), "enable prometheus metrics exporter to expose queue metrics")
flags.StringVar(&conf.PrometheusServerAddr, "prometheus-addr", getEnvDefaultString("PROMETHEUS_ADDR", ""), "address of prometheus server to query time series")
flags.BoolVar(&conf.ReadOnly, "read-only", getEnvOrDefaultBool("READ_ONLY", false), "restrict to read-only mode")
err = flags.Parse(args)
if err != nil {
return nil, buf.String(), err
}
conf.Args = flags.Args()
return &conf, buf.String(), nil
} }
// TODO: Write test and refactor this code. // TODO: Write test and refactor this code.
// IDEA: https://eli.thegreenplace.net/2020/testing-flag-parsing-in-go-programs/ func makeRedisConnOpt(cfg *Config) (asynq.RedisConnOpt, error) {
func getRedisOptionsFromFlags() (asynq.RedisConnOpt, error) {
var opts redis.UniversalOptions var opts redis.UniversalOptions
if flagRedisClusterNodes != "" { if cfg.RedisClusterNodes != "" {
opts.Addrs = strings.Split(flagRedisClusterNodes, ",") opts.Addrs = strings.Split(cfg.RedisClusterNodes, ",")
opts.Password = flagRedisPassword opts.Password = cfg.RedisPassword
} else { } else {
if flagRedisURL != "" { if cfg.RedisURL != "" {
res, err := redis.ParseURL(flagRedisURL) res, err := redis.ParseURL(cfg.RedisURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,23 +101,23 @@ func getRedisOptionsFromFlags() (asynq.RedisConnOpt, error) {
opts.Password = res.Password opts.Password = res.Password
} else { } else {
opts.Addrs = []string{flagRedisAddr} opts.Addrs = []string{cfg.RedisAddr}
opts.DB = flagRedisDB opts.DB = cfg.RedisDB
opts.Password = flagRedisPassword opts.Password = cfg.RedisPassword
} }
} }
if flagRedisTLS != "" { if cfg.RedisTLS != "" {
opts.TLSConfig = &tls.Config{ServerName: flagRedisTLS} opts.TLSConfig = &tls.Config{ServerName: cfg.RedisTLS}
} }
if flagRedisInsecureTLS { if cfg.RedisInsecureTLS {
if opts.TLSConfig == nil { if opts.TLSConfig == nil {
opts.TLSConfig = &tls.Config{} opts.TLSConfig = &tls.Config{}
} }
opts.TLSConfig.InsecureSkipVerify = true opts.TLSConfig.InsecureSkipVerify = true
} }
if flagRedisClusterNodes != "" { if cfg.RedisClusterNodes != "" {
return asynq.RedisClusterClientOpt{ return asynq.RedisClusterClientOpt{
Addrs: opts.Addrs, Addrs: opts.Addrs,
Password: opts.Password, Password: opts.Password,
@ -104,19 +133,27 @@ func getRedisOptionsFromFlags() (asynq.RedisConnOpt, error) {
} }
func main() { func main() {
flag.Parse() cfg, output, err := parseFlags(os.Args[0], os.Args[1:])
if err == flag.ErrHelp {
fmt.Println(output)
os.Exit(2)
} else if err != nil {
fmt.Printf("error: %v\n", err)
fmt.Println(output)
os.Exit(1)
}
redisConnOpt, err := getRedisOptionsFromFlags() redisConnOpt, err := makeRedisConnOpt(cfg)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
h := asynqmon.New(asynqmon.Options{ h := asynqmon.New(asynqmon.Options{
RedisConnOpt: redisConnOpt, RedisConnOpt: redisConnOpt,
PayloadFormatter: asynqmon.PayloadFormatterFunc(formatPayload), PayloadFormatter: asynqmon.PayloadFormatterFunc(payloadFormatterFunc(cfg)),
ResultFormatter: asynqmon.ResultFormatterFunc(formatResult), ResultFormatter: asynqmon.ResultFormatterFunc(resultFormatterFunc(cfg)),
PrometheusAddress: flagPrometheusServerAddr, PrometheusAddress: cfg.PrometheusServerAddr,
ReadOnly: flagReadOnly, ReadOnly: cfg.ReadOnly,
}) })
defer h.Close() defer h.Close()
@ -125,7 +162,7 @@ func main() {
}) })
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/", c.Handler(h)) mux.Handle("/", c.Handler(h))
if flagEnableMetricsExporter { if cfg.EnableMetricsExporter {
// Using NewPedanticRegistry here to test the implementation of Collectors and Metrics. // Using NewPedanticRegistry here to test the implementation of Collectors and Metrics.
reg := prometheus.NewPedanticRegistry() reg := prometheus.NewPedanticRegistry()
@ -142,23 +179,27 @@ func main() {
srv := &http.Server{ srv := &http.Server{
Handler: mux, Handler: mux,
Addr: fmt.Sprintf(":%d", flagPort), Addr: fmt.Sprintf(":%d", cfg.Port),
WriteTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second,
ReadTimeout: 10 * time.Second, ReadTimeout: 10 * time.Second,
} }
fmt.Printf("Asynq Monitoring WebUI server is listening on port %d\n", flagPort) fmt.Printf("Asynq Monitoring WebUI server is listening on port %d\n", cfg.Port)
log.Fatal(srv.ListenAndServe()) log.Fatal(srv.ListenAndServe())
} }
func formatPayload(taskType string, payload []byte) string { func payloadFormatterFunc(cfg *Config) func(string, []byte) string {
return func(taskType string, payload []byte) string {
payloadStr := asynqmon.DefaultPayloadFormatter.FormatPayload(taskType, payload) payloadStr := asynqmon.DefaultPayloadFormatter.FormatPayload(taskType, payload)
return truncate(payloadStr, flagMaxPayloadLength) return truncate(payloadStr, cfg.MaxPayloadLength)
}
} }
func formatResult(taskType string, result []byte) string { func resultFormatterFunc(cfg *Config) func(string, []byte) string {
return func(taskType string, result []byte) string {
resultStr := asynqmon.DefaultResultFormatter.FormatResult(taskType, result) resultStr := asynqmon.DefaultResultFormatter.FormatResult(taskType, result)
return truncate(resultStr, flagMaxResultLength) return truncate(resultStr, cfg.MaxResultLength)
}
} }
// truncates string s to limit length (in utf8). // truncates string s to limit length (in utf8).