feature: Basic auth for web console

This commit is contained in:
zekiahmetbayar 2023-10-26 14:10:31 +03:00
parent d1b889456d
commit 7f20425625

View File

@ -2,6 +2,8 @@ package main
import ( import (
"bytes" "bytes"
"crypto/sha256"
"crypto/subtle"
"crypto/tls" "crypto/tls"
"flag" "flag"
"fmt" "fmt"
@ -25,6 +27,11 @@ type Config struct {
// Server port // Server port
Port int Port int
// Basic auth options
EnableBasicAuth bool
BasicAuthUsername string
BasicAuthPassword string
// Redis connection options // Redis connection options
RedisAddr string RedisAddr string
RedisDB int RedisDB int
@ -72,6 +79,9 @@ func parseFlags(progname string, args []string) (cfg *Config, output string, err
flags.BoolVar(&conf.EnableMetricsExporter, "enable-metrics-exporter", getEnvOrDefaultBool("ENABLE_METRICS_EXPORTER", false), "enable prometheus metrics exporter to expose queue metrics") flags.BoolVar(&conf.EnableMetricsExporter, "enable-metrics-exporter", getEnvOrDefaultBool("ENABLE_METRICS_EXPORTER", false), "enable prometheus metrics exporter to expose queue metrics")
flags.StringVar(&conf.PrometheusServerAddr, "prometheus-addr", getEnvDefaultString("PROMETHEUS_ADDR", ""), "address of prometheus server to query time series") flags.StringVar(&conf.PrometheusServerAddr, "prometheus-addr", getEnvDefaultString("PROMETHEUS_ADDR", ""), "address of prometheus server to query time series")
flags.BoolVar(&conf.ReadOnly, "read-only", getEnvOrDefaultBool("READ_ONLY", false), "restrict to read-only mode") flags.BoolVar(&conf.ReadOnly, "read-only", getEnvOrDefaultBool("READ_ONLY", false), "restrict to read-only mode")
flags.BoolVar(&conf.EnableBasicAuth, "enable-basic-auth", getEnvOrDefaultBool("ENABLE_BASIC_AUTH", false), "enable basic auth for web console")
flags.StringVar(&conf.BasicAuthUsername, "basic-auth-username", getEnvDefaultString("BASIC_AUTH_USERNAME", "administrator"), "username for web console's basic auth")
flags.StringVar(&conf.BasicAuthPassword, "basic-auth-password", getEnvDefaultString("BASIC_AUTH_PASSWORD", "Passw0rd!!!"), "password for web console's basic auth")
err = flags.Parse(args) err = flags.Parse(args)
if err != nil { if err != nil {
@ -159,8 +169,18 @@ func main() {
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowedMethods: []string{"GET", "POST", "DELETE"}, AllowedMethods: []string{"GET", "POST", "DELETE"},
}) })
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/", c.Handler(h)) // Is basic auth enabled for web console
switch cfg.EnableBasicAuth {
case true:
// Serve with basic auth
mux.Handle("/", basicAuth(c.Handler(h), cfg.BasicAuthUsername, cfg.BasicAuthPassword))
default:
// Serve without basic auth
mux.Handle("/", c.Handler(h))
}
if cfg.EnableMetricsExporter { if cfg.EnableMetricsExporter {
// Using NewPedanticRegistry here to test the implementation of Collectors and Metrics. // Using NewPedanticRegistry here to test the implementation of Collectors and Metrics.
reg := prometheus.NewPedanticRegistry() reg := prometheus.NewPedanticRegistry()
@ -187,6 +207,36 @@ func main() {
log.Fatal(srv.ListenAndServe()) log.Fatal(srv.ListenAndServe())
} }
// Basic auth
func basicAuth(next http.Handler, expectedUsername, expectedPassword string) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get username and password from user
username, password, ok := r.BasicAuth()
if ok {
// Get hash of username and password input
usernameHash := sha256.Sum256([]byte(username))
passwordHash := sha256.Sum256([]byte(password))
// Get hash of expected username and password
expectedUsernameHash := sha256.Sum256([]byte(expectedUsername))
expectedPasswordHash := sha256.Sum256([]byte(expectedPassword))
// Check username and passwords are match
usernameMatch := (subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) == 1)
passwordMatch := (subtle.ConstantTimeCompare(passwordHash[:], expectedPasswordHash[:]) == 1)
if usernameMatch && passwordMatch {
// Serve
next.ServeHTTP(w, r)
return
}
}
// Set error, unauthorized
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
}
func payloadFormatterFunc(cfg *Config) func(string, []byte) string { func payloadFormatterFunc(cfg *Config) func(string, []byte) string {
return func(taskType string, payload []byte) string { return func(taskType string, payload []byte) string {
payloadStr := asynqmon.DefaultPayloadFormatter.FormatPayload(taskType, payload) payloadStr := asynqmon.DefaultPayloadFormatter.FormatPayload(taskType, payload)