Skip to content
Snippets Groups Projects
queries.go 10.91 KiB
package db

import (
	"database/sql"
	"encoding/hex"
	"fmt"
	"strings"

	"github.com/lib/pq"
	"github.com/pkg/errors"
)

// SelectUserByToken returns a user object from a token.
func SelectUserByToken(token string) (User, error) {
	var user User
	var bucketCapacity sql.NullInt64
	err := DB.QueryRow(selectUserByToken, token).
		Scan(&user.ID, &user.Username, &user.Email, &user.IsAdmin, &user.IsBlocked, &bucketCapacity)
	if err != nil {
		return user, err
	}
	if bucketCapacity.Valid {
		user.BucketCapacity = bucketCapacity.Int64
	}
	return user, err
}

// ObjectKeyExists returns a boolean specifying if the given key is already in
// use by another object.
func ObjectKeyExists(bucket, key string) (bool, error) {
	var count int
	err := DB.QueryRow(countOfObjectByBucketAndRandomKey, bucket, key).Scan(&count)
	if err != nil {
		return false, err
	}
	return count > 0, nil
}

// InsertShortURL inserts a short URL object into the database.
func InsertShortURL(bucket, key, destURL string, associatedUser *string) error {
	if !strings.HasPrefix(key, "/") {
		key = "/" + key
	}
	result, err := DB.Exec(insertShortURL,
		bucket+key,
		bucket,
		key,
		key[1:],
		destURL,
		associatedUser)
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// InsertFile inserts a file object into the database.
func InsertFile(bucket, key, ext, contentType string, contentLength int64, md5Hash, sha256Hash []byte, associatedUser *string) error {
	if !strings.HasPrefix(key, "/") {
		key = "/" + key
	}
	result, err := DB.Exec(insertFile,
		bucket+key+ext,
		bucket,
		key+ext,
		key[1:],
		contentType,
		contentLength,
		md5Hash,
		sha256Hash,
		associatedUser)
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// CheckUserExistsByUsernameOrEmail returns a boolean specifying if a user
// exists with the given username or email address.
func CheckUserExistsByUsernameOrEmail(username, email string) (bool, error) {
	var id string
	err := DB.QueryRow(selectUserByUsernameOrEmail, strings.ToLower(username), email).Scan(&id)
	if err == sql.ErrNoRows {
		return false, nil
	}
	if err != nil {
		return false, err
	}
	return true, nil
}

// InsertUser inserts a user into the database.
func InsertUser(id, username, email string) error {
	result, err := DB.Exec(insertUser, id, username, strings.ToLower(username), email)
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// InsertToken inserts a token into the database.
func InsertToken(userID, token string) error {
	result, err := DB.Exec(insertToken, userID, token)
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// scanner is a database row/rows that can be scanned.
type scanner interface {
	Scan(dest ...interface{}) error
}

// scanObject scans an object row into an Object.
func scanObject(scanner scanner) (Object, error) {
	object := Object{}
	var destURL sql.NullString
	var contentType sql.NullString
	var contentLength sql.NullInt64
	var deletedAt pq.NullTime
	var deleteReason sql.NullString
	var associatedUser sql.NullString
	var md5Hash []byte
	var sha256Hash []byte
	err := scanner.Scan(&object.Bucket,
		&object.Key,
		&object.Directory,
		&object.Type,
		&destURL,
		&contentType,
		&contentLength,
		&object.CreatedAt,
		&deletedAt,
		&deleteReason,
		&md5Hash,
		&sha256Hash,
		&associatedUser)
	if err != nil {
		return Object{}, errors.Wrap(err, "failed to Scan row from query")
	}
	if object.Type < 0 || object.Type > 2 {
		return Object{}, errors.Wrap(err, "invalid Object scanned from query")
	}
	if object.Type != 1 {
		if md5Hash != nil && len(md5Hash) == 16 {
			object.MD5HashBytes = md5Hash
			md5String := hex.EncodeToString(md5Hash)
			object.MD5Hash = &md5String
		}
		if sha256Hash != nil && len(sha256Hash) == 32 {
			object.SHA256HashBytes = sha256Hash
			sha256String := hex.EncodeToString(sha256Hash)
			object.SHA256Hash = &sha256String
		}
	}
	if object.Type == 0 {
		if !contentType.Valid || !contentLength.Valid || object.MD5HashBytes == nil || object.SHA256HashBytes == nil {
			return Object{}, errors.Wrap(err, "invalid Object scanned from query")
		}
		object.ContentType = &contentType.String
		object.ContentLength = &contentLength.Int64
	}
	if object.Type == 1 {
		if !destURL.Valid {
			return Object{}, errors.Wrap(err, "invalid Object scanned from query")
		}
		object.DestURL = &destURL.String
	}
	if object.Type == 2 {
		if deletedAt.Valid {
			object.DeletedAt = &deletedAt.Time
			if deleteReason.Valid {
				object.DeleteReason = &deleteReason.String
			}
		} else {
			if deleteReason.Valid {
				return Object{}, errors.Wrap(err, "invalid Object scanned from query")
			}
		}
	} else if deletedAt.Valid || deleteReason.Valid {
		return Object{}, errors.Wrap(err, "invalid Object scanned from query")
	}
	if associatedUser.Valid {
		object.AssociatedUser = &associatedUser.String
	}
	return object, nil
}

// ListObjectsByAssociatedUser returns all objects (paginated) associated with a
// user.
func ListObjectsByAssociatedUser(userID string, typ int, asc bool, offset, limit int) ([]Object, error) {
	objects := []Object{}
	order := "DESC"
	if asc {
		order = "ASC"
	}

	typeFilter := "!= 2"
	if typ != -1 {
		typeFilter = fmt.Sprintf("= %v", typ)
	}

	rows, err := DB.Query(fmt.Sprintf(listObjectsByAssociatedUser, typeFilter, order), userID, limit, offset)
	if err != nil {
		return objects, err
	}
	defer rows.Close()

	// Scan into Objects
	for rows.Next() {
		object, err := scanObject(rows)
		if err != nil {
			return []Object{}, err
		}
		objects = append(objects, object)
	}
	return objects, nil
}

// CountObjectsByAssociatedUser returns the count of all objects associated with a user.
func CountObjectsByAssociatedUser(userID string, typ int) (int, error) {
	typeFilter := "!= 2"
	if typ != -1 {
		typeFilter = fmt.Sprintf("= %v", typ)
	}

	var count int
	query := fmt.Sprintf(countObjectsByAssociatedUser, typeFilter)
	err := DB.QueryRow(query, userID).Scan(&count)
	if err != nil {
		return 0, err
	}
	return count, nil
}

// GetObject returns an object.
func GetObject(bucket, key string) (Object, error) {
	row := DB.QueryRow(getObjectByBucketKey, fmt.Sprintf("%s/%s", bucket, key))
	return scanObject(row)
}

// CheckIfObjectExists returns true if an object exists with the specified hash.
func CheckIfObjectExists(sha256 []byte) (bool, error) {
	var count int
	err := DB.QueryRow(countOfObjectsBySHA256, sha256).Scan(&count)
	if err != nil {
		return false, err
	}
	return count > 0, nil
}
// UpdateObjectToTombstoneByBucketKey sets an object to be a tombstone by bucket and key.
func UpdateObjectToTombstoneByBucketKey(bucket, key string, reason *string, retainAssociatedUser bool, retainHashes bool) error {
	sql := updateObjectToTombstoneByBucketKey
	if retainAssociatedUser {
		sql = updateObjectToTombstoneKeepAssociatedUserByBucketKey
		if retainHashes {
			sql = updateObjectToTombstoneKeepHashesAndAssociatedUserByBucketKey
		}
	}
	result, err := DB.Exec(sql, reason, fmt.Sprintf("%s/%s", bucket, key))
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// UpdateObjectToTombstoneBySHA256Hash sets all objects with a matching SHA256 hash to be a tombstone.
func UpdateObjectToTombstoneBySHA256Hash(sha256 []byte, reason *string, retainAssociatedUser bool, retainHashes bool) error {
	sql := updateObjectToTombstoneBySHA256Hash
	if retainAssociatedUser {
		sql = updateObjectToTombstoneKeepAssociatedUserBySHA256Hash
		if retainHashes {
			sql = updateObjectToTombstoneKeepHashesAndAssociatedUserBySHA256Hash
		}
	}
	_, err := DB.Exec(sql, reason, sha256)
	if err != nil {
		return err
	}
	return nil
}

// CheckIfFileBanExists returns a boolean specifying if the given key is already
// in use by another object.
func CheckIfFileBanExists(sha256 []byte) (bool, error) {
	var count int
	err := DB.QueryRow(countOfFileBanBySHA256, sha256).Scan(&count)
	if err != nil {
		return false, err
	}
	return count > 0, nil
}

// InsertFileBan inserts a file ban into the database.
func InsertFileBan(sha256 []byte, didQuarantine bool, reason int, description, malwareName *string) error {
	if reason < 0 || reason > 3 {
		return errors.New("invalid file ban reason")
	}
	if malwareName != nil && *malwareName == "" {
		malwareName = nil
	}
	result, err := DB.Exec(insertFileBan, sha256, didQuarantine, reason, description, malwareName)
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// scanFileBan scans an file_bans row into an FileBan.
func scanFileBan(scanner scanner) (FileBan, error) {
	fileBan := FileBan{}
	err := scanner.Scan(&fileBan.SHA256HashBytes,
		&fileBan.DidQuarantine,
		&fileBan.Reason,
		&fileBan.Description,
		&fileBan.MalwareName)
	if err != nil {
		return FileBan{}, errors.Wrap(err, "failed to Scan row from query")
	}
	if fileBan.Reason < 0 || fileBan.Reason > 3 {
		return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query")
	}
	if fileBan.Reason != 1 && fileBan.MalwareName != nil {
		return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query")
	}
	if fileBan.SHA256HashBytes != nil && len(fileBan.SHA256HashBytes) == 32 {
		sha256String := hex.EncodeToString(fileBan.SHA256HashBytes)
		fileBan.SHA256Hash = &sha256String
	}
	return fileBan, nil
}

// ListBannedFiles returns all banned files.
func ListBannedFiles() ([]FileBan, error) {
	bannedFiles := []FileBan{}
	rows, err := DB.Query(listBannedFiles)
	if err != nil {
		return bannedFiles, err
	}
	defer rows.Close()

	// Scan into bannedFiles
	for rows.Next() {
		bannedFile, err := scanFileBan(rows)
		if err != nil {
			return []FileBan{}, err
		}
		bannedFiles = append(bannedFiles, bannedFile)
	}
	return bannedFiles, nil
}

// GetBannedFile returns an FileBan.
func GetBannedFile(sha256 []byte) (FileBan, error) {
	row := DB.QueryRow(getBannedFileBySHA256Hash, sha256)
	return scanFileBan(row)
}

// DeleteBannedFile deletes a banned file from the database, effectively unbanning it.
func DeleteBannedFile(sha256 []byte) error {
	result, err := DB.Exec(deleteFileBan, sha256)
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}