Newer
Older
package routes
import (
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"owo.codes/whats-this/api/lib/apierrors"
"owo.codes/whats-this/api/lib/db"
"owo.codes/whats-this/api/lib/middleware"
"owo.codes/whats-this/api/lib/util"
"github.com/go-chi/render"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
)
// Maximum memory per upload before using temporary files.
const maxMemory = 1000 * 1000 * 50 // 50 MB
// File field name for multipart/form-data.
const fieldName = "files[]"
// File extension regex.
var fileExtRegex = regexp.MustCompile(`(\.[a-z0-9_-]+)$`)
// .*.gz extension regex.
var gzExtRegex = regexp.MustCompile(`(\.[a-z0-9_-]+\.gz)$`)
// determineExtension returns a file extension (including leading dot) for a
// filename.
func determineExtension(filename string) string {
ext := ""
if strings.HasSuffix(filename, ".gz") {
ext = gzExtRegex.FindString(filename)
}
if ext == "" && filename != "" {
ext = fileExtRegex.FindString(filename)
}
if len(ext) > 20 {
ext = ext[0:20] // limit ext length to 20 chars
}
return ext
}
// fileResponse represents a file response in an upload request.
type fileResponse struct {
Success bool `json:"success"`
StatusCode int `json:"errorcode,omitempty"`
Description string `json:"description,omitempty"`
Hash string `json:"hash,omitempty"` // MD5 hash, not SHA256 hash
Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"`
Size int64 `json:"size,omitempty"`
}
// fullResponse represents a full response from the server.
type fullResponse struct {
Success bool `json:"success"`
Files []fileResponse `json:"files"`
}
// fileWebhookRequest represents the data submitted in a file webhook request.
type fileWebhookRequest struct {
SHA256Hash string `json:"sha256_hash"`
}
// UploadPomf handles Pomf multipart/form-data upload requests.
func UploadPomf(associateObjectsWithUser bool, simpleResponse bool) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
// Only authorized users can use this route
user := middleware.GetAuthorizedUser(r)
if user.ID == "" || user.IsBlocked {
panic(apierrors.Unauthorized)
}
// Apply ratelimits
rBucket := middleware.GetBucket(r)
err := rBucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.uploadPomfCost"))
if err == ratelimiter.InsufficientTokens {
panic(apierrors.InsufficientTokens)
}
if err != nil {
panic(apierrors.InternalServerError)
}
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Check Content-Length if supplied
contentLength := r.Header.Get("Content-Length")
if contentLength == "" {
panic(apierrors.ContentLengthRequired)
}
length, err := strconv.ParseInt(contentLength, 10, 64)
if err != nil {
panic(apierrors.InvalidContentLength)
}
if length > viper.GetInt64("http.maximumRequestSize") {
panic(apierrors.BodyTooLarge)
}
// Parse form body
err = r.ParseMultipartForm(maxMemory)
if err != nil {
log.Warn().Err(err).Msg("failed to parse multipart/form-data body")
panic(apierrors.InternalServerError)
}
defer func() {
err := r.MultipartForm.RemoveAll()
if err != nil {
log.Warn().Err(err).Msg("failed to remove temporary files associated with form")
}
}()
// Loop over each file and copy them to the destination
files, ok := r.MultipartForm.File[fieldName]
if !ok || len(files) == 0 {
panic(apierrors.NoFilesInRequest)
}
if viper.GetInt("pomf.maxFilesPerUpload") > 0 && len(files) > viper.GetInt("pomf.maxFilesPerUpload") {
panic(apierrors.TooManyFilesInRequest)
}
fileResponses := []fileResponse{}
for _, file := range files {
if file.Size > viper.GetInt64("http.maximumRequestSize") {
panic(apierrors.BodyTooLarge)
}
// Generate a file key
bucket := viper.GetString("database.objectBucket")
key, err := util.GenerateAvailableKey(bucket, 5)
if err != nil {
log.Error().Err(err).Msg("failed to generate available key within 5 attempts for file object")
if len(files) == 1 {
panic(apierrors.InternalServerError)
}
fileResponses = append(fileResponses, fileResponse{
Success: false,
StatusCode: 500,
Description: "internal server error",
Name: file.Filename,
})
continue
}
// Determine object extension for response.
ext := determineExtension(file.Filename)
// Determine Content-Type
contentType := "application/octet-stream"
if ct := file.Header.Get("content-type"); ct != "" {
contentType = ct
}
// Open file for reading
f, err := file.Open()
if err != nil {
log.Error().Err(err).Msg("failed to open multipart file")
if len(files) == 1 {
panic(apierrors.InternalServerError)
}
fileResponses = append(fileResponses, fileResponse{
Success: false,
StatusCode: 500,
Description: "internal server error",
Name: file.Filename,
})
continue
}
// Write file to MD5 and SHA256 hashers and to temp file
md5Hash := md5.New()
sha256Hash := sha256.New()
tempPath := filepath.Join(viper.GetString("files.tempLocation"), key+ext)
log.Error().Err(err).Msg("failed to create temporary destination file")
if len(files) == 1 {
panic(apierrors.InternalServerError)
}
fileResponses = append(fileResponses, fileResponse{
Success: false,
StatusCode: 500,
Description: "internal server error",
Name: file.Filename,
})
continue
}
writer := io.MultiWriter(md5Hash, sha256Hash, tempFile)
_, err = io.Copy(writer, f)
log.Error().Err(err).Msg("failed to write to hashers and temporary path")
log.Error().Err(err).Msg("failed to delete temporary file after error")
}
if len(files) == 1 {
panic(apierrors.InternalServerError)
}
fileResponses = append(fileResponses, fileResponse{
Success: false,
StatusCode: 500,
Description: "internal server error",
Name: file.Filename,
})
continue
}
// Get checksums
md5Bytes := md5Hash.Sum(nil)
sha256Bytes := sha256Hash.Sum(nil)
// Check if file is banned
banned, err := db.CheckIfFileBanExists(sha256Bytes)
if err != nil {
log.Error().Err(err).Msg("failed to check if file is banned in UploadPomf")
panic(apierrors.InternalServerError)
}
if banned {
if len(files) == 1 {
panic(apierrors.FileIsBanned)
}
fileResponses = append(fileResponses, fileResponse{
})
continue
}
// Check if destination file exists, if not move temporary file to destination
destPath := filepath.Join(viper.GetString("files.storageLocation"), hex.EncodeToString(sha256Bytes))
_, err = os.Stat(destPath)
if !os.IsNotExist(err) {
log.Error().Err(err).Msg("failed to check if destination file exists")
err = os.Remove(tempPath)
if err != nil {
log.Error().Err(err).Msg("failed to delete temporary file after error")
}
if len(files) == 1 {
panic(apierrors.InternalServerError)
}
fileResponses = append(fileResponses, fileResponse{
})
}
// Move file to the destination
err = os.Rename(tempPath, destPath)
if err != nil {
log.Error().Err(err).Msg("failed to move file to the destination")
err = os.Remove(tempPath)
if err != nil {
log.Error().Err(err).Msg("failed to delete temporary file after error")
}
if len(files) == 1 {
panic(apierrors.InternalServerError)
}
fileResponses = append(fileResponses, fileResponse{
// Fire fileWebhook in goroutine
if viper.GetBool("fileWebhook.enable") {
go func() {
reqData := fileWebhookRequest{hex.EncodeToString(sha256Bytes)}
m, err := json.Marshal(reqData)
if err != nil {
log.Warn().Err(err).Msg("failed to marshal reqData in fileWebhook goroutine")
return
}
buf := bytes.NewBuffer(m)
resp, err := http.Post(viper.GetString("fileWebhook.url"), "application/json", buf)
if err != nil {
log.Warn().Err(err).Msg("failed to send request in fileWebhook goroutine")
return
}
if resp.StatusCode < 200 || resp.StatusCode > 399 {
log.Warn().Msgf("got unexpected status code from fileWebhook url: %v", resp.StatusCode)
}
}()
}
err = os.Remove(tempPath)
if err != nil {
log.Warn().Err(err).Msg("failed to delete temporary file")
// Insert object into database
var associatedUser *string
if associateObjectsWithUser {
associatedUser = &user.ID
}
err = db.InsertFile(bucket, key, ext, contentType, file.Size, md5Bytes, sha256Bytes, associatedUser)
if err != nil {
log.Error().Err(err).Msg("failed to create DB object for file upload")
if len(files) == 1 {
panic(apierrors.InternalServerError)
}
fileResponses = append(fileResponses, fileResponse{
Success: false,
StatusCode: 500,
Description: "internal server error",
Name: file.Filename,
})
continue
}
fileResponses = append(fileResponses, fileResponse{
Success: true,
Hash: hex.EncodeToString(md5Bytes),
Name: file.Filename,
URL: key + ext,
Size: file.Size,
})
}
// Determine suitable response code (500 if all files failed or 207 if ANY files failed)
statusCode := http.StatusOK
failedFiles := 0
for _, fResponse := range fileResponses {
if !fResponse.Success {
failedFiles++
statusCode = http.StatusMultiStatus // we don't conform to 207 spec but whatever
}
}
if failedFiles == len(fileResponses) {
statusCode = http.StatusInternalServerError
}
// Return response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
// Simple response
if simpleResponse {
for _, fResponse := range fileResponses {
w.Write([]byte(fmt.Sprintf("%s\n", fResponse.URL)))
}
} else {
render.JSON(w, r, fullResponse{Success: true, Files: fileResponses})
}