Skip to content
Snippets Groups Projects
uploadpomf.go 7.41 KiB
Newer Older
package routes

import (
	"crypto/md5"
	"encoding/hex"
	"io"
	"net/http"
	"os"
	"path/filepath"
	"regexp"
	"strconv"

	"owo.codes/whats-this/api/lib/apierrors"
	"owo.codes/whats-this/api/lib/db"
	"owo.codes/whats-this/api/lib/middleware"
Dean's avatar
Dean committed
	"owo.codes/whats-this/api/lib/ratelimiter"
	"owo.codes/whats-this/api/lib/util"

	"github.com/go-chi/render"
	"github.com/rs/zerolog/log"
	"github.com/spf13/viper"
)

// Maximum memory per upload.
const maxMemory = 1000 * 1000 * 50 // 50 MB

// File field name for multipart/form-data.
const fieldName = "files[]"

// File extension regex.
var fileExtRegex = regexp.MustCompile(`(\.[a-z0-9_.-]+)$`)

// fileResponse represents a file response in an upload request.
type fileResponse struct {
	Success     bool   `json:"success"`
	StatusCode  int    `json:"errorcode,omitempty"`
	Description string `json:"description,omitempty"`
	Hash        string `json:"hash,omitempty"`
	Name        string `json:"name,omitempty"`
	URL         string `json:"url,omitempty"`
	Size        int64  `json:"size,omitempty"`
}

// fullResponse represents a full response from the server.
type fullResponse struct {
	Success bool           `json:"success"`
	Files   []fileResponse `json:"files"`
}

// UploadPomf handles Pomf multipart/form-data upload requests.
Dean's avatar
Dean committed
func UploadPomf(associateObjectsWithUser bool) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		// Only authorized users can use this route
		user := middleware.GetAuthorizedUser(r)
		if user.ID == "" || user.IsBlocked {
			panic(apierrors.Unauthorized)
		}

Dean's avatar
Dean committed
		// Apply ratelimits
		rBucket := middleware.GetBucket(r)
		err := rBucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.uploadPomfCost"))
		if err == ratelimiter.InsufficientTokens {
			panic(apierrors.InsufficientTokens)
		}
		if err != nil {
			panic(apierrors.InternalServerError)
		}

		// Check Content-Length if supplied
		contentLength := r.Header.Get("Content-Length")
		if contentLength == "" {
			panic(apierrors.ContentLengthRequired)
		}
		length, err := strconv.ParseInt(contentLength, 10, 64)
		if err != nil {
			panic(apierrors.InvalidContentLength)
		}
		if length > viper.GetInt64("http.maximumRequestSize") {
			panic(apierrors.BodyTooLarge)
		}

		// Parse form body
		err = r.ParseMultipartForm(maxMemory)
		if err != nil {
			log.Warn().Err(err).Msg("failed to parse multipart/form-data body")
			panic(apierrors.InternalServerError)
		}
		defer func() {
			err := r.MultipartForm.RemoveAll()
			if err != nil {
				log.Warn().Err(err).Msg("failed to remove temporary files associated with form")
			}
		}()

		// Loop over each file and copy them to the destination
		files, ok := r.MultipartForm.File[fieldName]
		if !ok || len(files) == 0 {
			panic(apierrors.NoFilesInRequest)
		}
		if viper.GetInt("pomf.maxFilesPerUpload") > 0 && len(files) > viper.GetInt("pomf.maxFilesPerUpload") {
			panic(apierrors.TooManyFilesInRequest)
		}
		fileResponses := []fileResponse{}
		for _, file := range files {
			if file.Size > viper.GetInt64("http.maximumRequestSize") {
				panic(apierrors.BodyTooLarge)
			}

			// Generate a file key
			key, err := util.GenerateAvailableKey(5)
			if err != nil {
				log.Error().Err(err).Msg("failed to generate available key within 5 attempts for file object")
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}

			// Determine object extension for response.
			ext := ""
			if file.Filename != "" {
				ext = fileExtRegex.FindString(file.Filename)
			}
			if len(ext) > 20 {
				ext = ext[0:20] // limit ext length to 20 chars
			}

			// Determine Content-Type
			contentType := "application/octet-stream"
			if ct := file.Header.Get("content-type"); ct != "" {
				contentType = ct
			}

			// Open file for reading
			f, err := file.Open()
			if err != nil {
				log.Error().Err(err).Msg("failed to open multipart file")
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}

			// Write file to MD5 and to temp file
			tempPath := filepath.Join(viper.GetString("pomf.tempLocation"), key+ext)
			tempFile, err := os.Create(tempPath)
			if err != nil {
				log.Error().Err(err).Msg("failed to create destination file")
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}
			writer := io.MultiWriter(hash, tempFile)
			_, err = io.Copy(writer, f)
			tempFile.Close()
				log.Error().Err(err).Msg("failed to write to MD5 hasher and temporary path")
				err = os.Remove(tempPath)
					log.Error().Err(err).Msg("failed to delete temporary file after error")
				}
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}

			// Move file to the destination
			destPath := filepath.Join(viper.GetString("pomf.storageLocation"), key+ext)
			err = os.Rename(tempPath, destPath)
			if err != nil {
				log.Error().Err(err).Msg("failed to move file to the destination")
				err = os.Remove(tempPath)
				if err != nil {
					log.Error().Err(err).Msg("failed to delete temporary file after error")
				}
				panic(apierrors.InternalServerError)
			}
			// Get MD5 hash digest
			md5Hash := hex.EncodeToString(hash.Sum(nil))

			// Insert object into database
			var associatedUser *string
			if associateObjectsWithUser {
				associatedUser = &user.ID
			}
			bucket := viper.GetString("database.objectBucket")
			err = db.InsertFile(bucket, key, ext, contentType, file.Size, md5Hash, associatedUser)
			if err != nil {
				log.Error().Err(err).Msg("failed to create DB object for file upload")
				err = os.Remove(destPath)
				if err != nil {
					log.Error().Err(err).Msg("failed to delete destination file after error")
				}
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}

			fileResponses = append(fileResponses, fileResponse{
				Success: true,
				Hash:    md5Hash,
				Name:    file.Filename,
				URL:     key + ext,
				Size:    file.Size,
			})
		}

		// Determine suitable response code (500 if all files failed or 207 if ANY files failed)
		statusCode := http.StatusOK
		failedFiles := 0
		for _, fResponse := range fileResponses {
			if !fResponse.Success {
				failedFiles++
				statusCode = http.StatusMultiStatus // we don't conform to 207 spec but whatever
			}
		}
		if failedFiles == len(fileResponses) {
			statusCode = http.StatusInternalServerError
		}

		// Return response
		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(statusCode)
		render.JSON(w, r, fullResponse{Success: true, Files: fileResponses})
	}
}