From 8b06b9b200d76d11454233defc10857c81d149f1 Mon Sep 17 00:00:00 2001
From: Dean Sheather <dean@deansheather.com>
Date: Tue, 5 Jan 2021 21:06:48 +1000
Subject: [PATCH] Add type filter to list objects endpoint

---
 lib/apierrors/errors.go   |  3 +++
 lib/db/queries.go         | 10 ++++++++--
 lib/db/sql.go             |  2 +-
 lib/routes/listobjects.go | 25 +++++++++++++++++++++----
 4 files changed, 33 insertions(+), 7 deletions(-)

diff --git a/lib/apierrors/errors.go b/lib/apierrors/errors.go
index c73cd96..078e014 100644
--- a/lib/apierrors/errors.go
+++ b/lib/apierrors/errors.go
@@ -61,6 +61,9 @@ var (
 
 	// FileIsNotBanned is a 404 not found error.
 	FileIsNotBanned = APIError{false, 404, "specified file is not banned", false}
+
+	// InvalidObjectFilter is a 400 bad request error.
+	InvalidObjectFilter = APIError{false, 400, `invalid filter, must be "", "files" or "links"`, false}
 )
 
 // Pomf errors
diff --git a/lib/db/queries.go b/lib/db/queries.go
index 815efff..2057b8c 100644
--- a/lib/db/queries.go
+++ b/lib/db/queries.go
@@ -217,13 +217,19 @@ func scanObject(scanner scanner) (Object, error) {
 
 // ListObjectsByAssociatedUser returns all objects (paginated) associated with a
 // user.
-func ListObjectsByAssociatedUser(userID string, asc bool, offset, limit int) ([]Object, error) {
+func ListObjectsByAssociatedUser(userID string, typ int, asc bool, offset, limit int) ([]Object, error) {
 	objects := []Object{}
 	order := "DESC"
 	if asc {
 		order = "ASC"
 	}
-	rows, err := DB.Query(fmt.Sprintf(listObjectsByAssociatedUser, order), userID, limit, offset)
+
+	typeFilter := "!= 0"
+	if typ != -1 {
+		typeFilter = fmt.Sprintf("= %v", typ)
+	}
+
+	rows, err := DB.Query(fmt.Sprintf(listObjectsByAssociatedUser, typeFilter, order), userID, limit, offset)
 	if err != nil {
 		return objects, err
 	}
diff --git a/lib/db/sql.go b/lib/db/sql.go
index 1964b67..94300ba 100644
--- a/lib/db/sql.go
+++ b/lib/db/sql.go
@@ -84,7 +84,7 @@ FROM
 	objects
 WHERE
 	associated_user = $1 AND
-	"type" != 2
+	"type" %v
 ORDER BY
 	created_at %s
 LIMIT $2
diff --git a/lib/routes/listobjects.go b/lib/routes/listobjects.go
index c4a9058..1fcb05b 100644
--- a/lib/routes/listobjects.go
+++ b/lib/routes/listobjects.go
@@ -14,8 +14,14 @@ import (
 	"github.com/spf13/viper"
 )
 
-// Maximum objects per page
-const maxLimit = 100
+const (
+	// Maximum objects per page
+	maxLimit = 100
+
+	// filter keys for file vs short link.
+	filterFiles = "file"
+	filterLinks = "link"
+)
 
 // listObjectsResponse is the response format for ListObjects.
 type listObjectsResponse struct {
@@ -41,7 +47,7 @@ func ListObjects(w http.ResponseWriter, r *http.Request) {
 		panic(apierrors.InternalServerError)
 	}
 
-	// Determine offset and limit information
+	// Determine offset, limit and filter params
 	query := r.URL.Query()
 	l := query.Get("limit")
 	limit, err := strconv.Atoi(l)
@@ -60,9 +66,20 @@ func ListObjects(w http.ResponseWriter, r *http.Request) {
 	if query.Get("order") == "asc" {
 		asc = true
 	}
+	filter := query.Get("type")
+	if filter != "" && filter != filterFiles && filter != filterLinks {
+		panic(apierrors.InvalidObjectFilter)
+	}
+
+	f := -1
+	if filter == filterFiles {
+		f = 0
+	} else if filter == filterLinks {
+		f = 1
+	}
 
 	// Get the data
-	objects, err := db.ListObjectsByAssociatedUser(user.ID, asc, offset, limit)
+	objects, err := db.ListObjectsByAssociatedUser(user.ID, f, asc, offset, limit)
 	if err != nil {
 		log.Error().Err(err).Msg("failed to list objects for user")
 		panic(apierrors.InternalServerError)
-- 
GitLab