package main

import (
	"context"
	"fmt"
	"io"
	"net/http"
	"os"
	"os/signal"
	"sort"
	"time"

	"owo.codes/whats-this/api/lib/db"
	"owo.codes/whats-this/api/lib/middleware"
	"owo.codes/whats-this/api/lib/ratelimiter"
	"owo.codes/whats-this/api/lib/routes"

	"github.com/go-chi/chi"
	"github.com/rs/zerolog"
	"github.com/rs/zerolog/log"
	"github.com/spf13/pflag"
	"github.com/spf13/viper"
)

// Build config
const (
	configLocationUnix = "/etc/whats-this/api/config.toml"
	shutdownTimeout    = 10 * time.Second
	version            = "1.7.1"
)

// printConfiguration iterates through a configuration map[string]interface{}
// and prints out all of the values in alphabetical order. Configuration keys
// are printed with dot notation.
func printConfiguration(prefix string, config map[string]interface{}) {
	keys := make([]string, len(config))
	i := 0
	for k := range config {
		keys[i] = k
		i++
	}
	sort.Strings(keys)

	for _, k := range keys {
		if v, ok := config[k].(map[string]interface{}); ok {
			printConfiguration(fmt.Sprintf("%s%s.", prefix, k), v)
		} else {
			fmt.Printf("%s%s: %+v\n", prefix, k, config[k])
		}
	}
}

func init() {
	// Flag configuration
	flags := pflag.NewFlagSet("whats-this-api", pflag.ExitOnError)
	flags.IntP("log-level", "l", 1, "Set zerolog logging level (5=panic, 4=fatal, 3=error, 2=warn, 1=info, 0=debug)")
	configFile := flags.StringP("config-file", "c", configLocationUnix,
		fmt.Sprintf("Path to configuration file, defaults to %s", configLocationUnix))
	printConfig := flags.BoolP("print-config", "p", false, "Prints configuration and exits")
	flags.Parse(os.Args)

	// Configuration defaults
	viper.SetDefault("database.objectBucket", "public")
	viper.SetDefault("http.listenAddress", ":49544")
	viper.BindPFlag("log.level", flags.Lookup("log-level")) // default is 1 (info)
	viper.SetDefault("pomf.maxFilesPerUpload", 3)
	viper.SetDefault("ratelimiter.enable", false)
	viper.SetDefault("ratelimiter.defaultBucketCapacity", 50)
	viper.SetDefault("ratelimiter.bucketExpiryDuration", time.Second*30)
	viper.SetDefault("ratelimiter.uploadPomfCost", ratelimiter.UploadPomfCost)
	viper.SetDefault("ratelimiter.shortenPolrCost", ratelimiter.ShortenPolrCost)
	viper.SetDefault("ratelimiter.meCost", ratelimiter.MeCost)
	viper.SetDefault("ratelimiter.listObjectsCost", ratelimiter.ListObjectsCost)
	viper.SetDefault("ratelimiter.objectCost", ratelimiter.ObjectCost)
	viper.SetDefault("ratelimiter.deleteObjectCost", ratelimiter.DeleteObjectCost)
	viper.SetDefault("fileWebhook.enable", false)

	// Load configuration file
	viper.SetConfigType("toml")
	file, err := os.Open(*configFile)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to open configuration file (%s) for reading: %s", *configFile, err.Error())
		os.Exit(1)
		return
	}
	err = viper.ReadConfig(file)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to parse configuration file (%s): %s", *configFile, err.Error())
		os.Exit(1)
		return
	}
	file.Close()

	// Configure logger
	zerolog.TimeFieldFormat = ""
	if lvl := viper.GetInt("log.level"); 0 <= lvl && lvl <= 5 {
		zerolog.SetGlobalLevel(zerolog.Level(lvl))
	} else {
		viper.Set("log.level", 1)
		zerolog.SetGlobalLevel(zerolog.InfoLevel)
		log.Warn().Int("log.level", lvl).Msg("Invalid log level, defaulting to 1 (info)")
	}
	log.Debug().Uint8("level", uint8(zerolog.GlobalLevel())).Msg("Set logger level")

	// Print configuration variables in alphabetical order
	if *printConfig {
		log.Info().Msg("Printing configuration values to Stdout")
		settings := viper.AllSettings()
		printConfiguration("", settings)
		os.Exit(0)
		return
	}

	// Ensure required configuration variables are set
	if viper.GetString("database.connectionURL") == "" {
		log.Fatal().Msg("Configuration: database.connectionURL is required")
	}
	if viper.GetString("database.objectBucket") == "" {
		log.Fatal().Msg("Configuration: database.objectBucket is required")
	}
	if viper.GetString("http.listenAddress") == "" {
		log.Fatal().Msg("Configuration: http.listenAddress is required")
	}
	if viper.GetInt64("http.maximumRequestSize") == 0 {
		log.Fatal().Msg("Configuration: http.maximumRequestSize is required")
	}
	if viper.GetString("polr.resultURL") == "" {
		log.Fatal().Msg("Configuration: polr.resultURL is required")
	}
	if viper.GetString("files.quarantineLocation") == "" {
		log.Fatal().Msg("Configuration: files.quarantineLocation is required")
	}
	if viper.GetString("files.storageLocation") == "" {
		log.Fatal().Msg("Configuration: files.storageLocation is required")
	}
	if viper.GetString("files.tempLocation") == "" {
		log.Fatal().Msg("Configuration: files.tempLocation is required")
	}
	if viper.GetBool("ratelimiter.enable") && viper.GetString("ratelimiter.redisURL") == "" {
		log.Fatal().Msg("Configuration: ratelimiter.redisURL is required when ratelimiter is enabled")
	}
	if viper.GetBool("fileWebhook.enable") && viper.GetString("fileWebhook.url") == "" {
		log.Fatal().Msg("Configuration: fileWebhook.url is required when fileWebhook is enabled")
	}
}

func main() {
	baseCtx, baseCtxCancel := context.WithCancel(context.Background())
	defer baseCtxCancel()

	// Connect to database
	err := db.Connect("postgres", viper.GetString("database.connectionURL"))
	if err != nil {
		log.Fatal().Err(err).Msg("failed to connect to and ping the database")
	}

	// Connect to Redis for ratelimiter
	if viper.GetBool("ratelimiter.enable") {
		err := ratelimiter.RedisConnect(viper.GetString("ratelimiter.redisURL"))
		if err != nil {
			log.Fatal().Err(err).Msg("failed to connect to and ping Redis for ratelimiting")
		}
	}

	// Mount middleware
	r := chi.NewRouter()
	r.Use(middleware.Recoverer)
	// TODO: include request ID in error logs
	r.Use(middleware.RequestID)
	r.Use(middleware.CORSHeaders([]string{"*"}))
	r.Use(middleware.StatusEndpoint("/health"))
	r.Use(middleware.Authenticator)

	// Route handlers
	r.Get("/shorten/polr", routes.ShortenPolr(false))
	r.Get("/shorten/polr/associated", routes.ShortenPolr(true))
	r.Post("/upload/pomf", routes.UploadPomf(false, false))
	r.Post("/upload/pomf/associated", routes.UploadPomf(true, false))
	r.Post("/upload/simple", routes.UploadPomf(false, true))
	r.Post("/upload/simple/associated", routes.UploadPomf(true, true))
	r.Post("/users", routes.CreateUser)
	r.Get("/users/me", routes.Me)
	r.Get("/objects", routes.ListObjects)
	r.Get("/objects/*", routes.Object)
	r.Delete("/objects/*", routes.DeleteObject)
	r.Post("/bannedfiles", routes.BanFile)
	r.Get("/bannedfiles", routes.ListBannedFiles)
	r.Get("/bannedfiles/*", routes.BannedFile)
	r.Delete("/bannedfiles/*", routes.DeleteBannedFile)

	// MethodNotAllowed handler
	r.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
		w.WriteHeader(http.StatusMethodNotAllowed)
		io.WriteString(w, "405 method not allowed")
	}))

	// Create HTTP server on specified listening address
	listenAddress := viper.GetString("http.listenAddress")
	server := http.Server{
		Addr:    listenAddress,
		Handler: chi.ServerBaseContext(baseCtx, r),
	}

	// Listen for interrupts (^C) and exit gracefully
	c := make(chan os.Signal, 1)
	signal.Notify(c, os.Interrupt)
	go func() {
		<-c
		go func() {
			<-c
			os.Exit(1)
		}()
		log.Info().Str("cause", "interrupt").Dur("timeout", shutdownTimeout).Msg("Shutting down API worker")
		baseCtxCancel()

		// Shutdown HTTP server
		ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
		defer cancel()
		log.Info().Dur("timeout", shutdownTimeout).Msg("Shutting down HTTP server")
		err := server.Shutdown(ctx)
		if err != nil {
			log.Warn().Err(err).Msg("Failed to shutdown HTTP server gracefully, closing forecefully")
			server.Close()
			log.Info().Msg("Finished shutting down with errors")
			os.Exit(1)
		} else {
			log.Info().Msg("Successfully shutdown HTTP server gracefully")
			log.Info().Msg("Finished shutting down")
			os.Exit(0)
		}
	}()

	// Start HTTP server
	log.Info().Str("listenAddress", listenAddress).Msg("Starting HTTP server")
	err = server.ListenAndServe()
	if err != nil && err != http.ErrServerClosed {
		log.Fatal().Err(err).Msg("Failed to start HTTP server")
	}
	<-c
}