package main import ( "context" "fmt" "io" "net/http" "os" "os/signal" "sort" "time" "owo.codes/whats-this/api/lib/db" "owo.codes/whats-this/api/lib/middleware" "owo.codes/whats-this/api/lib/ratelimiter" "owo.codes/whats-this/api/lib/routes" "github.com/go-chi/chi" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/pflag" "github.com/spf13/viper" ) // Build config const ( configLocationUnix = "/etc/whats-this/api/config.toml" shutdownTimeout = 10 * time.Second version = "1.7.1" ) // printConfiguration iterates through a configuration map[string]interface{} // and prints out all of the values in alphabetical order. Configuration keys // are printed with dot notation. func printConfiguration(prefix string, config map[string]interface{}) { keys := make([]string, len(config)) i := 0 for k := range config { keys[i] = k i++ } sort.Strings(keys) for _, k := range keys { if v, ok := config[k].(map[string]interface{}); ok { printConfiguration(fmt.Sprintf("%s%s.", prefix, k), v) } else { fmt.Printf("%s%s: %+v\n", prefix, k, config[k]) } } } func init() { // Flag configuration flags := pflag.NewFlagSet("whats-this-api", pflag.ExitOnError) flags.IntP("log-level", "l", 1, "Set zerolog logging level (5=panic, 4=fatal, 3=error, 2=warn, 1=info, 0=debug)") configFile := flags.StringP("config-file", "c", configLocationUnix, fmt.Sprintf("Path to configuration file, defaults to %s", configLocationUnix)) printConfig := flags.BoolP("print-config", "p", false, "Prints configuration and exits") flags.Parse(os.Args) // Configuration defaults viper.SetDefault("database.objectBucket", "public") viper.SetDefault("http.listenAddress", ":49544") viper.BindPFlag("log.level", flags.Lookup("log-level")) // default is 1 (info) viper.SetDefault("pomf.maxFilesPerUpload", 3) viper.SetDefault("ratelimiter.enable", false) viper.SetDefault("ratelimiter.defaultBucketCapacity", 50) viper.SetDefault("ratelimiter.bucketExpiryDuration", time.Second*30) viper.SetDefault("ratelimiter.uploadPomfCost", ratelimiter.UploadPomfCost) viper.SetDefault("ratelimiter.shortenPolrCost", ratelimiter.ShortenPolrCost) viper.SetDefault("ratelimiter.meCost", ratelimiter.MeCost) viper.SetDefault("ratelimiter.listObjectsCost", ratelimiter.ListObjectsCost) viper.SetDefault("ratelimiter.objectCost", ratelimiter.ObjectCost) viper.SetDefault("ratelimiter.deleteObjectCost", ratelimiter.DeleteObjectCost) viper.SetDefault("fileWebhook.enable", false) // Load configuration file viper.SetConfigType("toml") file, err := os.Open(*configFile) if err != nil { fmt.Fprintf(os.Stderr, "Failed to open configuration file (%s) for reading: %s", *configFile, err.Error()) os.Exit(1) return } err = viper.ReadConfig(file) if err != nil { fmt.Fprintf(os.Stderr, "Failed to parse configuration file (%s): %s", *configFile, err.Error()) os.Exit(1) return } file.Close() // Configure logger zerolog.TimeFieldFormat = "" if lvl := viper.GetInt("log.level"); 0 <= lvl && lvl <= 5 { zerolog.SetGlobalLevel(zerolog.Level(lvl)) } else { viper.Set("log.level", 1) zerolog.SetGlobalLevel(zerolog.InfoLevel) log.Warn().Int("log.level", lvl).Msg("Invalid log level, defaulting to 1 (info)") } log.Debug().Uint8("level", uint8(zerolog.GlobalLevel())).Msg("Set logger level") // Print configuration variables in alphabetical order if *printConfig { log.Info().Msg("Printing configuration values to Stdout") settings := viper.AllSettings() printConfiguration("", settings) os.Exit(0) return } // Ensure required configuration variables are set if viper.GetString("database.connectionURL") == "" { log.Fatal().Msg("Configuration: database.connectionURL is required") } if viper.GetString("database.objectBucket") == "" { log.Fatal().Msg("Configuration: database.objectBucket is required") } if viper.GetString("http.listenAddress") == "" { log.Fatal().Msg("Configuration: http.listenAddress is required") } if viper.GetInt64("http.maximumRequestSize") == 0 { log.Fatal().Msg("Configuration: http.maximumRequestSize is required") } if viper.GetString("polr.resultURL") == "" { log.Fatal().Msg("Configuration: polr.resultURL is required") } if viper.GetString("files.quarantineLocation") == "" { log.Fatal().Msg("Configuration: files.quarantineLocation is required") } if viper.GetString("files.storageLocation") == "" { log.Fatal().Msg("Configuration: files.storageLocation is required") } if viper.GetString("files.tempLocation") == "" { log.Fatal().Msg("Configuration: files.tempLocation is required") } if viper.GetBool("ratelimiter.enable") && viper.GetString("ratelimiter.redisURL") == "" { log.Fatal().Msg("Configuration: ratelimiter.redisURL is required when ratelimiter is enabled") } if viper.GetBool("fileWebhook.enable") && viper.GetString("fileWebhook.url") == "" { log.Fatal().Msg("Configuration: fileWebhook.url is required when fileWebhook is enabled") } } func main() { baseCtx, baseCtxCancel := context.WithCancel(context.Background()) defer baseCtxCancel() // Connect to database err := db.Connect("postgres", viper.GetString("database.connectionURL")) if err != nil { log.Fatal().Err(err).Msg("failed to connect to and ping the database") } // Connect to Redis for ratelimiter if viper.GetBool("ratelimiter.enable") { err := ratelimiter.RedisConnect(viper.GetString("ratelimiter.redisURL")) if err != nil { log.Fatal().Err(err).Msg("failed to connect to and ping Redis for ratelimiting") } } // Mount middleware r := chi.NewRouter() r.Use(middleware.Recoverer) // TODO: include request ID in error logs r.Use(middleware.RequestID) r.Use(middleware.CORSHeaders([]string{"*"})) r.Use(middleware.StatusEndpoint("/health")) r.Use(middleware.Authenticator) // Route handlers r.Get("/shorten/polr", routes.ShortenPolr(false)) r.Get("/shorten/polr/associated", routes.ShortenPolr(true)) r.Post("/upload/pomf", routes.UploadPomf(false, false)) r.Post("/upload/pomf/associated", routes.UploadPomf(true, false)) r.Post("/upload/simple", routes.UploadPomf(false, true)) r.Post("/upload/simple/associated", routes.UploadPomf(true, true)) r.Post("/users", routes.CreateUser) r.Get("/users/me", routes.Me) r.Get("/objects", routes.ListObjects) r.Get("/objects/*", routes.Object) r.Delete("/objects/*", routes.DeleteObject) r.Post("/bannedfiles", routes.BanFile) r.Get("/bannedfiles", routes.ListBannedFiles) r.Get("/bannedfiles/*", routes.BannedFile) r.Delete("/bannedfiles/*", routes.DeleteBannedFile) // MethodNotAllowed handler r.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusMethodNotAllowed) io.WriteString(w, "405 method not allowed") })) // Create HTTP server on specified listening address listenAddress := viper.GetString("http.listenAddress") server := http.Server{ Addr: listenAddress, Handler: chi.ServerBaseContext(baseCtx, r), } // Listen for interrupts (^C) and exit gracefully c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { <-c go func() { <-c os.Exit(1) }() log.Info().Str("cause", "interrupt").Dur("timeout", shutdownTimeout).Msg("Shutting down API worker") baseCtxCancel() // Shutdown HTTP server ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() log.Info().Dur("timeout", shutdownTimeout).Msg("Shutting down HTTP server") err := server.Shutdown(ctx) if err != nil { log.Warn().Err(err).Msg("Failed to shutdown HTTP server gracefully, closing forecefully") server.Close() log.Info().Msg("Finished shutting down with errors") os.Exit(1) } else { log.Info().Msg("Successfully shutdown HTTP server gracefully") log.Info().Msg("Finished shutting down") os.Exit(0) } }() // Start HTTP server log.Info().Str("listenAddress", listenAddress).Msg("Starting HTTP server") err = server.ListenAndServe() if err != nil && err != http.ErrServerClosed { log.Fatal().Err(err).Msg("Failed to start HTTP server") } <-c }