Skip to content
Snippets Groups Projects
uploadpomf.go 10.3 KiB
Newer Older
	"bytes"
	"crypto/sha256"
	"encoding/json"
	"io"
	"net/http"
	"os"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"

	"owo.codes/whats-this/api/lib/apierrors"
	"owo.codes/whats-this/api/lib/db"
	"owo.codes/whats-this/api/lib/middleware"
Dean's avatar
Dean committed
	"owo.codes/whats-this/api/lib/ratelimiter"
	"owo.codes/whats-this/api/lib/util"

	"github.com/go-chi/render"
	"github.com/rs/zerolog/log"
	"github.com/spf13/viper"
)

// Maximum memory per upload before using temporary files.
const maxMemory = 1000 * 1000 * 50 // 50 MB

// File field name for multipart/form-data.
const fieldName = "files[]"

// File extension regex.
var fileExtRegex = regexp.MustCompile(`(\.[a-z0-9_-]+)$`)

// .*.gz extension regex.
var gzExtRegex = regexp.MustCompile(`(\.[a-z0-9_-]+\.gz)$`)

// determineExtension returns a file extension (including leading dot) for a
// filename.
func determineExtension(filename string) string {
	ext := ""
	if strings.HasSuffix(filename, ".gz") {
		ext = gzExtRegex.FindString(filename)
	}
	if ext == "" && filename != "" {
		ext = fileExtRegex.FindString(filename)
	}
	if len(ext) > 20 {
		ext = ext[0:20] // limit ext length to 20 chars
	}
	return ext
}

// fileResponse represents a file response in an upload request.
type fileResponse struct {
	Success     bool   `json:"success"`
	StatusCode  int    `json:"errorcode,omitempty"`
	Description string `json:"description,omitempty"`
	Hash        string `json:"hash,omitempty"` // MD5 hash, not SHA256 hash
	Name        string `json:"name,omitempty"`
	URL         string `json:"url,omitempty"`
	Size        int64  `json:"size,omitempty"`
}

// fullResponse represents a full response from the server.
type fullResponse struct {
	Success bool           `json:"success"`
	Files   []fileResponse `json:"files"`
}

// fileWebhookRequest represents the data submitted in a file webhook request.
type fileWebhookRequest struct {
	SHA256Hash string `json:"sha256_hash"`
}

// UploadPomf handles Pomf multipart/form-data upload requests.
Dean's avatar
Dean committed
func UploadPomf(associateObjectsWithUser bool) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		// Only authorized users can use this route
		user := middleware.GetAuthorizedUser(r)
		if user.ID == "" || user.IsBlocked {
			panic(apierrors.Unauthorized)
		}

Dean's avatar
Dean committed
		// Apply ratelimits
		rBucket := middleware.GetBucket(r)
		err := rBucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.uploadPomfCost"))
		if err == ratelimiter.InsufficientTokens {
			panic(apierrors.InsufficientTokens)
		}
		if err != nil {
			panic(apierrors.InternalServerError)
		}

		// Check Content-Length if supplied
		contentLength := r.Header.Get("Content-Length")
		if contentLength == "" {
			panic(apierrors.ContentLengthRequired)
		}
		length, err := strconv.ParseInt(contentLength, 10, 64)
		if err != nil {
			panic(apierrors.InvalidContentLength)
		}
		if length > viper.GetInt64("http.maximumRequestSize") {
			panic(apierrors.BodyTooLarge)
		}

		// Parse form body
		err = r.ParseMultipartForm(maxMemory)
		if err != nil {
			log.Warn().Err(err).Msg("failed to parse multipart/form-data body")
			panic(apierrors.InternalServerError)
		}
		defer func() {
			err := r.MultipartForm.RemoveAll()
			if err != nil {
				log.Warn().Err(err).Msg("failed to remove temporary files associated with form")
			}
		}()

		// Loop over each file and copy them to the destination
		files, ok := r.MultipartForm.File[fieldName]
		if !ok || len(files) == 0 {
			panic(apierrors.NoFilesInRequest)
		}
		if viper.GetInt("pomf.maxFilesPerUpload") > 0 && len(files) > viper.GetInt("pomf.maxFilesPerUpload") {
			panic(apierrors.TooManyFilesInRequest)
		}
		fileResponses := []fileResponse{}
		for _, file := range files {
			if file.Size > viper.GetInt64("http.maximumRequestSize") {
				panic(apierrors.BodyTooLarge)
			}

			// Generate a file key
			bucket := viper.GetString("database.objectBucket")
			key, err := util.GenerateAvailableKey(bucket, 5)
			if err != nil {
				log.Error().Err(err).Msg("failed to generate available key within 5 attempts for file object")
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}

			// Determine object extension for response.
			ext := determineExtension(file.Filename)

			// Determine Content-Type
			contentType := "application/octet-stream"
			if ct := file.Header.Get("content-type"); ct != "" {
				contentType = ct
			}

			// Open file for reading
			f, err := file.Open()
			if err != nil {
				log.Error().Err(err).Msg("failed to open multipart file")
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}

			// Write file to MD5 and SHA256 hashers and to temp file
			md5Hash := md5.New()
			sha256Hash := sha256.New()
Dean's avatar
Dean committed
			tempPath := filepath.Join(viper.GetString("files.tempLocation"), key+ext)
			tempFile, err := os.Create(tempPath)
				log.Error().Err(err).Msg("failed to create temporary destination file")
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}
			writer := io.MultiWriter(md5Hash, sha256Hash, tempFile)
			_, err = io.Copy(writer, f)
			tempFile.Close()
				log.Error().Err(err).Msg("failed to write to hashers and temporary path")
				err = os.Remove(tempPath)
					log.Error().Err(err).Msg("failed to delete temporary file after error")
				}
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}

Dean's avatar
Dean committed
			// Get checksums
			md5Bytes := md5Hash.Sum(nil)
			sha256Bytes := sha256Hash.Sum(nil)

			// Check if file is banned
			banned, err := db.CheckIfFileBanExists(sha256Bytes)
			if err != nil {
				log.Error().Err(err).Msg("failed to check if file is banned in UploadPomf")
				panic(apierrors.InternalServerError)
			}
			if banned {
				if len(files) == 1 {
					panic(apierrors.FileIsBanned)
				}
				fileResponses = append(fileResponses, fileResponse{
Dean's avatar
Dean committed
					Success:     false,
					StatusCode:  409,
Dean's avatar
Dean committed
					Description: "file is banned",
Dean's avatar
Dean committed
					Name:        file.Filename,
Dean's avatar
Dean committed
				})
				continue
			}

			// Check if destination file exists, if not move temporary file to destination
			destPath := filepath.Join(viper.GetString("files.storageLocation"), hex.EncodeToString(sha256Bytes))
			_, err = os.Stat(destPath)
Dean's avatar
Dean committed
				if !os.IsNotExist(err) {
					log.Error().Err(err).Msg("failed to check if destination file exists")
					err = os.Remove(tempPath)
					if err != nil {
						log.Error().Err(err).Msg("failed to delete temporary file after error")
					}
					if len(files) == 1 {
						panic(apierrors.InternalServerError)
					}
					fileResponses = append(fileResponses, fileResponse{
Dean's avatar
Dean committed
						Success:     false,
						StatusCode:  500,
Dean's avatar
Dean committed
						Description: "internal server error",
Dean's avatar
Dean committed
						Name:        file.Filename,
Dean's avatar
Dean committed
					})
				}

				// Move file to the destination
				err = os.Rename(tempPath, destPath)
				if err != nil {
					log.Error().Err(err).Msg("failed to move file to the destination")
					err = os.Remove(tempPath)
					if err != nil {
						log.Error().Err(err).Msg("failed to delete temporary file after error")
					}
					if len(files) == 1 {
						panic(apierrors.InternalServerError)
					}
					fileResponses = append(fileResponses, fileResponse{
Dean's avatar
Dean committed
						Success:     false,
						StatusCode:  500,
Dean's avatar
Dean committed
						Description: "internal server error",
Dean's avatar
Dean committed
						Name:        file.Filename,
Dean's avatar
Dean committed
					})
				}

				// Fire fileWebhook in goroutine
				if viper.GetBool("fileWebhook.enable") {
					go func() {
						reqData := fileWebhookRequest{hex.EncodeToString(sha256Bytes)}
						m, err := json.Marshal(reqData)
						if err != nil {
							log.Warn().Err(err).Msg("failed to marshal reqData in fileWebhook goroutine")
							return
						}
						buf := bytes.NewBuffer(m)
						resp, err := http.Post(viper.GetString("fileWebhook.url"), "application/json", buf)
						if err != nil {
							log.Warn().Err(err).Msg("failed to send request in fileWebhook goroutine")
							return
						}
						if resp.StatusCode < 200 || resp.StatusCode > 399 {
							log.Warn().Msgf("got unexpected status code from fileWebhook url: %v", resp.StatusCode)
						}
					}()
				}
Dean's avatar
Dean committed
			} else {
				// Delete temporary file
				err = os.Remove(tempPath)
				if err != nil {
Dean's avatar
Dean committed
					log.Warn().Err(err).Msg("failed to delete temporary file")
			// Insert object into database
			var associatedUser *string
			if associateObjectsWithUser {
				associatedUser = &user.ID
			}
			err = db.InsertFile(bucket, key, ext, contentType, file.Size, md5Bytes, sha256Bytes, associatedUser)
			if err != nil {
				log.Error().Err(err).Msg("failed to create DB object for file upload")
				if len(files) == 1 {
					panic(apierrors.InternalServerError)
				}
				fileResponses = append(fileResponses, fileResponse{
					Success:     false,
					StatusCode:  500,
					Description: "internal server error",
					Name:        file.Filename,
				})
				continue
			}

			fileResponses = append(fileResponses, fileResponse{
				Success: true,
				Hash:    hex.EncodeToString(md5Bytes),
				Name:    file.Filename,
				URL:     key + ext,
				Size:    file.Size,
			})
		}

		// Determine suitable response code (500 if all files failed or 207 if ANY files failed)
		statusCode := http.StatusOK
		failedFiles := 0
		for _, fResponse := range fileResponses {
			if !fResponse.Success {
				failedFiles++
				statusCode = http.StatusMultiStatus // we don't conform to 207 spec but whatever
			}
		}
		if failedFiles == len(fileResponses) {
			statusCode = http.StatusInternalServerError
		}

		// Return response
		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(statusCode)
		render.JSON(w, r, fullResponse{Success: true, Files: fileResponses})
	}
}