Skip to content
Snippets Groups Projects
main.go 7.28 KiB
Newer Older
  • Learn to ignore specific revisions
  • Dean's avatar
    Dean committed
    package main
    
    import (
    	"context"
    	"fmt"
    	"io"
    	"net/http"
    	"os"
    	"os/signal"
    	"sort"
    	"time"
    
    	"owo.codes/whats-this/api/lib/db"
    	"owo.codes/whats-this/api/lib/middleware"
    
    Dean's avatar
    Dean committed
    	"owo.codes/whats-this/api/lib/ratelimiter"
    
    Dean's avatar
    Dean committed
    	"owo.codes/whats-this/api/lib/routes"
    
    	"github.com/go-chi/chi"
    	"github.com/rs/zerolog"
    	"github.com/rs/zerolog/log"
    	"github.com/spf13/pflag"
    	"github.com/spf13/viper"
    )
    
    // Build config
    const (
    	configLocationUnix = "/etc/whats-this/api/config.toml"
    	shutdownTimeout    = 10 * time.Second
    
    	version            = "1.6.5"
    
    Dean's avatar
    Dean committed
    )
    
    // printConfiguration iterates through a configuration map[string]interface{}
    // and prints out all of the values in alphabetical order. Configuration keys
    // are printed with dot notation.
    func printConfiguration(prefix string, config map[string]interface{}) {
    	keys := make([]string, len(config))
    	i := 0
    	for k := range config {
    		keys[i] = k
    		i++
    	}
    	sort.Strings(keys)
    
    	for _, k := range keys {
    		if v, ok := config[k].(map[string]interface{}); ok {
    			printConfiguration(fmt.Sprintf("%s%s.", prefix, k), v)
    		} else {
    			fmt.Printf("%s%s: %+v\n", prefix, k, config[k])
    		}
    	}
    }
    
    func init() {
    	// Flag configuration
    	flags := pflag.NewFlagSet("whats-this-api", pflag.ExitOnError)
    	flags.IntP("log-level", "l", 1, "Set zerolog logging level (5=panic, 4=fatal, 3=error, 2=warn, 1=info, 0=debug)")
    	configFile := flags.StringP("config-file", "c", configLocationUnix,
    		fmt.Sprintf("Path to configuration file, defaults to %s", configLocationUnix))
    	printConfig := flags.BoolP("print-config", "p", false, "Prints configuration and exits")
    	flags.Parse(os.Args)
    
    	// Configuration defaults
    	viper.SetDefault("database.objectBucket", "public")
    	viper.SetDefault("http.listenAddress", ":49544")
    	viper.BindPFlag("log.level", flags.Lookup("log-level")) // default is 1 (info)
    
    Dean's avatar
    Dean committed
    	viper.SetDefault("ratelimiter.enable", false)
    	viper.SetDefault("ratelimiter.defaultBucketCapacity", 50)
    	viper.SetDefault("ratelimiter.bucketExpiryDuration", time.Second*30)
    	viper.SetDefault("ratelimiter.uploadPomfCost", ratelimiter.UploadPomfCost)
    	viper.SetDefault("ratelimiter.shortenPolrCost", ratelimiter.ShortenPolrCost)
    	viper.SetDefault("ratelimiter.meCost", ratelimiter.MeCost)
    	viper.SetDefault("ratelimiter.listObjectsCost", ratelimiter.ListObjectsCost)
    	viper.SetDefault("ratelimiter.objectCost", ratelimiter.ObjectCost)
    	viper.SetDefault("ratelimiter.deleteObjectCost", ratelimiter.DeleteObjectCost)
    
    Dean's avatar
    Dean committed
    
    	// Load configuration file
    	viper.SetConfigType("toml")
    	file, err := os.Open(*configFile)
    	if err != nil {
    		fmt.Fprintf(os.Stderr, "Failed to open configuration file (%s) for reading: %s", *configFile, err.Error())
    		os.Exit(1)
    		return
    	}
    	err = viper.ReadConfig(file)
    	if err != nil {
    		fmt.Fprintf(os.Stderr, "Failed to parse configuration file (%s): %s", *configFile, err.Error())
    		os.Exit(1)
    		return
    	}
    	file.Close()
    
    	// Configure logger
    	zerolog.TimeFieldFormat = ""
    	if lvl := viper.GetInt("log.level"); 0 <= lvl && lvl <= 5 {
    		zerolog.SetGlobalLevel(zerolog.Level(lvl))
    	} else {
    		viper.Set("log.level", 1)
    		zerolog.SetGlobalLevel(zerolog.InfoLevel)
    		log.Warn().Int("log.level", lvl).Msg("Invalid log level, defaulting to 1 (info)")
    	}
    	log.Debug().Uint8("level", uint8(zerolog.GlobalLevel())).Msg("Set logger level")
    
    	// Print configuration variables in alphabetical order
    	if *printConfig {
    		log.Info().Msg("Printing configuration values to Stdout")
    		settings := viper.AllSettings()
    		printConfiguration("", settings)
    		os.Exit(0)
    		return
    	}
    
    	// Ensure required configuration variables are set
    	if viper.GetString("database.connectionURL") == "" {
    		log.Fatal().Msg("Configuration: database.connectionURL is required")
    	}
    	if viper.GetString("database.objectBucket") == "" {
    		log.Fatal().Msg("Configuration: database.objectBucket is required")
    	}
    	if viper.GetString("http.listenAddress") == "" {
    		log.Fatal().Msg("Configuration: http.listenAddress is required")
    	}
    	if viper.GetInt64("http.maximumRequestSize") == 0 {
    		log.Fatal().Msg("Configuration: http.maximumRequestSize is required")
    	}
    	if viper.GetString("polr.resultURL") == "" {
    		log.Fatal().Msg("Configuration: polr.resultURL is required")
    	}
    	if viper.GetString("pomf.storageLocation") == "" {
    		log.Fatal().Msg("Configuration: pomf.storageLocation is required")
    	}
    
    Dean's avatar
    Dean committed
    	if viper.GetString("pomf.tempLocation") == "" {
    		log.Fatal().Msg("Configuration: pomf.tempLocation is required")
    	}
    	if viper.GetBool("ratelimiter.enable") && viper.GetString("ratelimiter.redisURL") == "" {
    		log.Fatal().Msg("Configuration: ratelimiter.redisURL is required when ratelimiter is enabled")
    	}
    
    Dean's avatar
    Dean committed
    }
    
    func main() {
    
    	baseCtx, baseCtxCancel := context.WithCancel(context.Background())
    	defer baseCtxCancel()
    
    Dean's avatar
    Dean committed
    
    	// Connect to database
    	err := db.Connect("postgres", viper.GetString("database.connectionURL"))
    	if err != nil {
    		log.Fatal().Err(err).Msg("failed to connect to and ping the database")
    	}
    
    
    Dean's avatar
    Dean committed
    	// Connect to Redis for ratelimiter
    	if viper.GetBool("ratelimiter.enable") {
    		err := ratelimiter.RedisConnect(viper.GetString("ratelimiter.redisURL"))
    		if err != nil {
    			log.Fatal().Err(err).Msg("failed to connect to and ping Redis for ratelimiting")
    		}
    	}
    
    
    Dean's avatar
    Dean committed
    	// Mount middleware
    	r := chi.NewRouter()
    	r.Use(middleware.Recoverer)
    
    	// TODO: include request ID in error logs
    
    Dean's avatar
    Dean committed
    	r.Use(middleware.RequestID)
    	r.Use(middleware.CORSHeaders([]string{"*"}))
    	r.Use(middleware.StatusEndpoint("/health"))
    	r.Use(middleware.Authenticator)
    
    	// Route handlers
    	r.Get("/shorten/polr", routes.ShortenPolr(false))
    	r.Get("/shorten/polr/associated", routes.ShortenPolr(true))
    	r.Post("/upload/pomf", routes.UploadPomf(false))
    	r.Post("/upload/pomf/associated", routes.UploadPomf(true))
    	r.Post("/users", routes.CreateUser)
    	r.Get("/users/me", routes.Me)
    
    	r.Get("/objects", routes.ListObjects)
    	r.Get("/objects/*", routes.Object)
    	r.Delete("/objects/*", routes.DeleteObject)
    
    Dean's avatar
    Dean committed
    
    	// MethodNotAllowed handler
    	r.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
    		w.WriteHeader(http.StatusMethodNotAllowed)
    		io.WriteString(w, "405 method not allowed")
    	}))
    
    	// Create HTTP server on specified listening address
    	listenAddress := viper.GetString("http.listenAddress")
    	server := http.Server{
    		Addr:    listenAddress,
    		Handler: chi.ServerBaseContext(baseCtx, r),
    	}
    
    	// Listen for interrupts (^C) and exit gracefully
    	c := make(chan os.Signal, 1)
    	signal.Notify(c, os.Interrupt)
    	go func() {
    		<-c
    		go func() {
    
    			<-c
    			os.Exit(1)
    
    Dean's avatar
    Dean committed
    		}()
    
    		log.Info().Str("cause", "interrupt").Dur("timeout", shutdownTimeout).Msg("Shutting down API worker")
    		baseCtxCancel()
    
    		// Shutdown HTTP server
    		ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
    		defer cancel()
    		log.Info().Dur("timeout", shutdownTimeout).Msg("Shutting down HTTP server")
    		err := server.Shutdown(ctx)
    		if err != nil {
    			log.Warn().Err(err).Msg("Failed to shutdown HTTP server gracefully, closing forecefully")
    			server.Close()
    
    Dean's avatar
    Dean committed
    			log.Info().Msg("Finished shutting down with errors")
    			os.Exit(1)
    
    		} else {
    			log.Info().Msg("Successfully shutdown HTTP server gracefully")
    
    Dean's avatar
    Dean committed
    			log.Info().Msg("Finished shutting down")
    			os.Exit(0)
    		}
    	}()
    
    	// Start HTTP server
    	log.Info().Str("listenAddress", listenAddress).Msg("Starting HTTP server")
    	err = server.ListenAndServe()
    	if err != nil && err != http.ErrServerClosed {
    		log.Fatal().Err(err).Msg("Failed to start HTTP server")
    	}
    	<-c
    }