From 2f9d2021c33258362c40dcca193622e0034b621a Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Mon, 2 May 2022 06:47:29 -0700 Subject: [PATCH] Refactor flag parsing --- cmd/asynqmon/main.go | 159 +++++++++++++++++++++++++++---------------- 1 file changed, 100 insertions(+), 59 deletions(-) diff --git a/cmd/asynqmon/main.go b/cmd/asynqmon/main.go index e759d18..d046fbc 100644 --- a/cmd/asynqmon/main.go +++ b/cmd/asynqmon/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "crypto/tls" "flag" "fmt" @@ -20,50 +21,78 @@ import ( "github.com/rs/cors" ) -// Command-line flags -var ( - flagPort int - flagRedisAddr string - flagRedisDB int - flagRedisPassword string - flagRedisTLS string - flagRedisURL string - flagRedisInsecureTLS bool - flagRedisClusterNodes string - flagMaxPayloadLength int - flagMaxResultLength int - flagEnableMetricsExporter bool - flagPrometheusServerAddr string - flagReadOnly bool -) +// Config holds configurations for the running program +// provided by the user via command line. +type Config struct { + // Server port + Port int -func init() { - flag.IntVar(&flagPort, "port", getEnvOrDefaultInt("PORT", 8080), "port number to use for web ui server") - flag.StringVar(&flagRedisAddr, "redis-addr", getEnvDefaultString("REDIS_ADDR", "127.0.0.1:6379"), "address of redis server to connect to") - flag.IntVar(&flagRedisDB, "redis-db", getEnvOrDefaultInt("REDIS_DB", 0), "redis database number") - flag.StringVar(&flagRedisPassword, "redis-password", getEnvDefaultString("REDIS_PASSWORD", ""), "password to use when connecting to redis server") - flag.StringVar(&flagRedisTLS, "redis-tls", getEnvDefaultString("REDIS_TLS", ""), "server name for TLS validation used when connecting to redis server") - flag.StringVar(&flagRedisURL, "redis-url", getEnvDefaultString("REDIS_URL", ""), "URL to redis server") - flag.BoolVar(&flagRedisInsecureTLS, "redis-insecure-tls", getEnvOrDefaultBool("REDIS_INSECURE_TLS", false), "disable TLS certificate host checks") - flag.StringVar(&flagRedisClusterNodes, "redis-cluster-nodes", getEnvDefaultString("REDIS_CLUSTER_NODES", ""), "comma separated list of host:port addresses of cluster nodes") - flag.IntVar(&flagMaxPayloadLength, "max-payload-length", getEnvOrDefaultInt("MAX_PAYLOAD_LENGTH", 200), "maximum number of utf8 characters printed in the payload cell in the Web UI") - flag.IntVar(&flagMaxResultLength, "max-result-length", getEnvOrDefaultInt("MAX_RESULT_LENGTH", 200), "maximum number of utf8 characters printed in the result cell in the Web UI") - flag.BoolVar(&flagEnableMetricsExporter, "enable-metrics-exporter", getEnvOrDefaultBool("ENABLE_METRICS_EXPORTER", false), "enable prometheus metrics exporter to expose queue metrics") - flag.StringVar(&flagPrometheusServerAddr, "prometheus-addr", getEnvDefaultString("PROMETHEUS_ADDR", ""), "address of prometheus server to query time series") - flag.BoolVar(&flagReadOnly, "read-only", getEnvOrDefaultBool("READ_ONLY", false), "restrict to read-only mode") + // Redis connection options + RedisAddr string + RedisDB int + RedisPassword string + RedisTLS string + RedisURL string + RedisInsecureTLS bool + RedisClusterNodes string + + // UI related configs + ReadOnly bool + MaxPayloadLength int + MaxResultLength int + + // Prometheus related configs + EnableMetricsExporter bool + PrometheusServerAddr string + + // Args are the positional (non-flag) command line arguments + Args []string +} + +// parseFlags parses the command-line arguments provided to the program. +// Typically os.Args[0] is provided as 'progname' and os.args[1:] as 'args'. +// Returns the Config in case parsing succeeded, or an error. In any case, the +// output of the flag.Parse is returned in output. +// +// Reference: https://eli.thegreenplace.net/2020/testing-flag-parsing-in-go-programs/ +func parseFlags(progname string, args []string) (cfg *Config, output string, err error) { + flags := flag.NewFlagSet(progname, flag.ContinueOnError) + var buf bytes.Buffer + flags.SetOutput(&buf) + + var conf Config + flags.IntVar(&conf.Port, "port", getEnvOrDefaultInt("PORT", 8080), "port number to use for web ui server") + flags.StringVar(&conf.RedisAddr, "redis-addr", getEnvDefaultString("REDIS_ADDR", "127.0.0.1:6379"), "address of redis server to connect to") + flags.IntVar(&conf.RedisDB, "redis-db", getEnvOrDefaultInt("REDIS_DB", 0), "redis database number") + flags.StringVar(&conf.RedisPassword, "redis-password", getEnvDefaultString("REDIS_PASSWORD", ""), "password to use when connecting to redis server") + flags.StringVar(&conf.RedisTLS, "redis-tls", getEnvDefaultString("REDIS_TLS", ""), "server name for TLS validation used when connecting to redis server") + flags.StringVar(&conf.RedisURL, "redis-url", getEnvDefaultString("REDIS_URL", ""), "URL to redis server") + flags.BoolVar(&conf.RedisInsecureTLS, "redis-insecure-tls", getEnvOrDefaultBool("REDIS_INSECURE_TLS", false), "disable TLS certificate host checks") + flags.StringVar(&conf.RedisClusterNodes, "redis-cluster-nodes", getEnvDefaultString("REDIS_CLUSTER_NODES", ""), "comma separated list of host:port addresses of cluster nodes") + flags.IntVar(&conf.MaxPayloadLength, "max-payload-length", getEnvOrDefaultInt("MAX_PAYLOAD_LENGTH", 200), "maximum number of utf8 characters printed in the payload cell in the Web UI") + flags.IntVar(&conf.MaxResultLength, "max-result-length", getEnvOrDefaultInt("MAX_RESULT_LENGTH", 200), "maximum number of utf8 characters printed in the result cell in the Web UI") + flags.BoolVar(&conf.EnableMetricsExporter, "enable-metrics-exporter", getEnvOrDefaultBool("ENABLE_METRICS_EXPORTER", false), "enable prometheus metrics exporter to expose queue metrics") + flags.StringVar(&conf.PrometheusServerAddr, "prometheus-addr", getEnvDefaultString("PROMETHEUS_ADDR", ""), "address of prometheus server to query time series") + flags.BoolVar(&conf.ReadOnly, "read-only", getEnvOrDefaultBool("READ_ONLY", false), "restrict to read-only mode") + + err = flags.Parse(args) + if err != nil { + return nil, buf.String(), err + } + conf.Args = flags.Args() + return &conf, buf.String(), nil } // TODO: Write test and refactor this code. -// IDEA: https://eli.thegreenplace.net/2020/testing-flag-parsing-in-go-programs/ -func getRedisOptionsFromFlags() (asynq.RedisConnOpt, error) { +func makeRedisConnOpt(cfg *Config) (asynq.RedisConnOpt, error) { var opts redis.UniversalOptions - if flagRedisClusterNodes != "" { - opts.Addrs = strings.Split(flagRedisClusterNodes, ",") - opts.Password = flagRedisPassword + if cfg.RedisClusterNodes != "" { + opts.Addrs = strings.Split(cfg.RedisClusterNodes, ",") + opts.Password = cfg.RedisPassword } else { - if flagRedisURL != "" { - res, err := redis.ParseURL(flagRedisURL) + if cfg.RedisURL != "" { + res, err := redis.ParseURL(cfg.RedisURL) if err != nil { return nil, err } @@ -72,23 +101,23 @@ func getRedisOptionsFromFlags() (asynq.RedisConnOpt, error) { opts.Password = res.Password } else { - opts.Addrs = []string{flagRedisAddr} - opts.DB = flagRedisDB - opts.Password = flagRedisPassword + opts.Addrs = []string{cfg.RedisAddr} + opts.DB = cfg.RedisDB + opts.Password = cfg.RedisPassword } } - if flagRedisTLS != "" { - opts.TLSConfig = &tls.Config{ServerName: flagRedisTLS} + if cfg.RedisTLS != "" { + opts.TLSConfig = &tls.Config{ServerName: cfg.RedisTLS} } - if flagRedisInsecureTLS { + if cfg.RedisInsecureTLS { if opts.TLSConfig == nil { opts.TLSConfig = &tls.Config{} } opts.TLSConfig.InsecureSkipVerify = true } - if flagRedisClusterNodes != "" { + if cfg.RedisClusterNodes != "" { return asynq.RedisClusterClientOpt{ Addrs: opts.Addrs, Password: opts.Password, @@ -104,19 +133,27 @@ func getRedisOptionsFromFlags() (asynq.RedisConnOpt, error) { } func main() { - flag.Parse() + cfg, output, err := parseFlags(os.Args[0], os.Args[1:]) + if err == flag.ErrHelp { + fmt.Println(output) + os.Exit(2) + } else if err != nil { + fmt.Printf("error: %v\n", err) + fmt.Println(output) + os.Exit(1) + } - redisConnOpt, err := getRedisOptionsFromFlags() + redisConnOpt, err := makeRedisConnOpt(cfg) if err != nil { log.Fatal(err) } h := asynqmon.New(asynqmon.Options{ RedisConnOpt: redisConnOpt, - PayloadFormatter: asynqmon.PayloadFormatterFunc(formatPayload), - ResultFormatter: asynqmon.ResultFormatterFunc(formatResult), - PrometheusAddress: flagPrometheusServerAddr, - ReadOnly: flagReadOnly, + PayloadFormatter: asynqmon.PayloadFormatterFunc(payloadFormatterFunc(cfg)), + ResultFormatter: asynqmon.ResultFormatterFunc(resultFormatterFunc(cfg)), + PrometheusAddress: cfg.PrometheusServerAddr, + ReadOnly: cfg.ReadOnly, }) defer h.Close() @@ -125,7 +162,7 @@ func main() { }) mux := http.NewServeMux() mux.Handle("/", c.Handler(h)) - if flagEnableMetricsExporter { + if cfg.EnableMetricsExporter { // Using NewPedanticRegistry here to test the implementation of Collectors and Metrics. reg := prometheus.NewPedanticRegistry() @@ -142,23 +179,27 @@ func main() { srv := &http.Server{ Handler: mux, - Addr: fmt.Sprintf(":%d", flagPort), + Addr: fmt.Sprintf(":%d", cfg.Port), WriteTimeout: 10 * time.Second, ReadTimeout: 10 * time.Second, } - fmt.Printf("Asynq Monitoring WebUI server is listening on port %d\n", flagPort) + fmt.Printf("Asynq Monitoring WebUI server is listening on port %d\n", cfg.Port) log.Fatal(srv.ListenAndServe()) } -func formatPayload(taskType string, payload []byte) string { - payloadStr := asynqmon.DefaultPayloadFormatter.FormatPayload(taskType, payload) - return truncate(payloadStr, flagMaxPayloadLength) +func payloadFormatterFunc(cfg *Config) func(string, []byte) string { + return func(taskType string, payload []byte) string { + payloadStr := asynqmon.DefaultPayloadFormatter.FormatPayload(taskType, payload) + return truncate(payloadStr, cfg.MaxPayloadLength) + } } -func formatResult(taskType string, result []byte) string { - resultStr := asynqmon.DefaultResultFormatter.FormatResult(taskType, result) - return truncate(resultStr, flagMaxResultLength) +func resultFormatterFunc(cfg *Config) func(string, []byte) string { + return func(taskType string, result []byte) string { + resultStr := asynqmon.DefaultResultFormatter.FormatResult(taskType, result) + return truncate(resultStr, cfg.MaxResultLength) + } } // truncates string s to limit length (in utf8).