-
Spotlight Deveaux authoredSpotlight Deveaux authored
queries.go 10.91 KiB
package db
import (
"database/sql"
"encoding/hex"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/pkg/errors"
)
// SelectUserByToken returns a user object from a token.
func SelectUserByToken(token string) (User, error) {
var user User
var bucketCapacity sql.NullInt64
err := DB.QueryRow(selectUserByToken, token).
Scan(&user.ID, &user.Username, &user.Email, &user.IsAdmin, &user.IsBlocked, &bucketCapacity)
if err != nil {
return user, err
}
if bucketCapacity.Valid {
user.BucketCapacity = bucketCapacity.Int64
}
return user, err
}
// ObjectKeyExists returns a boolean specifying if the given key is already in
// use by another object.
func ObjectKeyExists(bucket, key string) (bool, error) {
var count int
err := DB.QueryRow(countOfObjectByBucketAndRandomKey, bucket, key).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// InsertShortURL inserts a short URL object into the database.
func InsertShortURL(bucket, key, destURL string, associatedUser *string) error {
if !strings.HasPrefix(key, "/") {
key = "/" + key
}
result, err := DB.Exec(insertShortURL,
bucket+key,
bucket,
key,
key[1:],
destURL,
associatedUser)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// InsertFile inserts a file object into the database.
func InsertFile(bucket, key, ext, contentType string, contentLength int64, md5Hash, sha256Hash []byte, associatedUser *string) error {
if !strings.HasPrefix(key, "/") {
key = "/" + key
}
result, err := DB.Exec(insertFile,
bucket+key+ext,
bucket,
key+ext,
key[1:],
contentType,
contentLength,
md5Hash,
sha256Hash,
associatedUser)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// CheckUserExistsByUsernameOrEmail returns a boolean specifying if a user
// exists with the given username or email address.
func CheckUserExistsByUsernameOrEmail(username, email string) (bool, error) {
var id string
err := DB.QueryRow(selectUserByUsernameOrEmail, strings.ToLower(username), email).Scan(&id)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// InsertUser inserts a user into the database.
func InsertUser(id, username, email string) error {
result, err := DB.Exec(insertUser, id, username, strings.ToLower(username), email)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// InsertToken inserts a token into the database.
func InsertToken(userID, token string) error {
result, err := DB.Exec(insertToken, userID, token)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// scanner is a database row/rows that can be scanned.
type scanner interface {
Scan(dest ...interface{}) error
}
// scanObject scans an object row into an Object.
func scanObject(scanner scanner) (Object, error) {
object := Object{}
var destURL sql.NullString
var contentType sql.NullString
var contentLength sql.NullInt64
var deletedAt pq.NullTime
var deleteReason sql.NullString
var associatedUser sql.NullString
var md5Hash []byte
var sha256Hash []byte
err := scanner.Scan(&object.Bucket,
&object.Key,
&object.Directory,
&object.Type,
&destURL,
&contentType,
&contentLength,
&object.CreatedAt,
&deletedAt,
&deleteReason,
&md5Hash,
&sha256Hash,
&associatedUser)
if err != nil {
return Object{}, errors.Wrap(err, "failed to Scan row from query")
}
if object.Type < 0 || object.Type > 2 {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
if object.Type != 1 {
if md5Hash != nil && len(md5Hash) == 16 {
object.MD5HashBytes = md5Hash
md5String := hex.EncodeToString(md5Hash)
object.MD5Hash = &md5String
}
if sha256Hash != nil && len(sha256Hash) == 32 {
object.SHA256HashBytes = sha256Hash
sha256String := hex.EncodeToString(sha256Hash)
object.SHA256Hash = &sha256String
}
}
if object.Type == 0 {
if !contentType.Valid || !contentLength.Valid || object.MD5HashBytes == nil || object.SHA256HashBytes == nil {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
object.ContentType = &contentType.String
object.ContentLength = &contentLength.Int64
}
if object.Type == 1 {
if !destURL.Valid {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
object.DestURL = &destURL.String
}
if object.Type == 2 {
if deletedAt.Valid {
object.DeletedAt = &deletedAt.Time
if deleteReason.Valid {
object.DeleteReason = &deleteReason.String
}
} else {
if deleteReason.Valid {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
}
} else if deletedAt.Valid || deleteReason.Valid {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
if associatedUser.Valid {
object.AssociatedUser = &associatedUser.String
}
return object, nil
}
// ListObjectsByAssociatedUser returns all objects (paginated) associated with a
// user.
func ListObjectsByAssociatedUser(userID string, typ int, asc bool, offset, limit int) ([]Object, error) {
objects := []Object{}
order := "DESC"
if asc {
order = "ASC"
}
typeFilter := "!= 2"
if typ != -1 {
typeFilter = fmt.Sprintf("= %v", typ)
}
rows, err := DB.Query(fmt.Sprintf(listObjectsByAssociatedUser, typeFilter, order), userID, limit, offset)
if err != nil {
return objects, err
}
defer rows.Close()
// Scan into Objects
for rows.Next() {
object, err := scanObject(rows)
if err != nil {
return []Object{}, err
}
objects = append(objects, object)
}
return objects, nil
}
// CountObjectsByAssociatedUser returns the count of all objects associated with a user.
func CountObjectsByAssociatedUser(userID string, typ int) (int, error) {
typeFilter := "!= 2"
if typ != -1 {
typeFilter = fmt.Sprintf("= %v", typ)
}
var count int
query := fmt.Sprintf(countObjectsByAssociatedUser, typeFilter)
err := DB.QueryRow(query, userID).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
// GetObject returns an object.
func GetObject(bucket, key string) (Object, error) {
row := DB.QueryRow(getObjectByBucketKey, fmt.Sprintf("%s/%s", bucket, key))
return scanObject(row)
}
// CheckIfObjectExists returns true if an object exists with the specified hash.
func CheckIfObjectExists(sha256 []byte) (bool, error) {
var count int
err := DB.QueryRow(countOfObjectsBySHA256, sha256).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// UpdateObjectToTombstoneByBucketKey sets an object to be a tombstone by bucket and key.
func UpdateObjectToTombstoneByBucketKey(bucket, key string, reason *string, retainAssociatedUser bool, retainHashes bool) error {
sql := updateObjectToTombstoneByBucketKey
if retainAssociatedUser {
sql = updateObjectToTombstoneKeepAssociatedUserByBucketKey
if retainHashes {
sql = updateObjectToTombstoneKeepHashesAndAssociatedUserByBucketKey
}
}
result, err := DB.Exec(sql, reason, fmt.Sprintf("%s/%s", bucket, key))
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// UpdateObjectToTombstoneBySHA256Hash sets all objects with a matching SHA256 hash to be a tombstone.
func UpdateObjectToTombstoneBySHA256Hash(sha256 []byte, reason *string, retainAssociatedUser bool, retainHashes bool) error {
sql := updateObjectToTombstoneBySHA256Hash
if retainAssociatedUser {
sql = updateObjectToTombstoneKeepAssociatedUserBySHA256Hash
if retainHashes {
sql = updateObjectToTombstoneKeepHashesAndAssociatedUserBySHA256Hash
}
}
_, err := DB.Exec(sql, reason, sha256)
if err != nil {
return err
}
return nil
}
// CheckIfFileBanExists returns a boolean specifying if the given key is already
// in use by another object.
func CheckIfFileBanExists(sha256 []byte) (bool, error) {
var count int
err := DB.QueryRow(countOfFileBanBySHA256, sha256).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// InsertFileBan inserts a file ban into the database.
func InsertFileBan(sha256 []byte, didQuarantine bool, reason int, description, malwareName *string) error {
if reason < 0 || reason > 3 {
return errors.New("invalid file ban reason")
}
if malwareName != nil && *malwareName == "" {
malwareName = nil
}
result, err := DB.Exec(insertFileBan, sha256, didQuarantine, reason, description, malwareName)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// scanFileBan scans an file_bans row into an FileBan.
func scanFileBan(scanner scanner) (FileBan, error) {
fileBan := FileBan{}
err := scanner.Scan(&fileBan.SHA256HashBytes,
&fileBan.DidQuarantine,
&fileBan.Reason,
&fileBan.Description,
&fileBan.MalwareName)
if err != nil {
return FileBan{}, errors.Wrap(err, "failed to Scan row from query")
}
if fileBan.Reason < 0 || fileBan.Reason > 3 {
return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query")
}
if fileBan.Reason != 1 && fileBan.MalwareName != nil {
return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query")
}
if fileBan.SHA256HashBytes != nil && len(fileBan.SHA256HashBytes) == 32 {
sha256String := hex.EncodeToString(fileBan.SHA256HashBytes)
fileBan.SHA256Hash = &sha256String
}
return fileBan, nil
}
// ListBannedFiles returns all banned files.
func ListBannedFiles() ([]FileBan, error) {
bannedFiles := []FileBan{}
rows, err := DB.Query(listBannedFiles)
if err != nil {
return bannedFiles, err
}
defer rows.Close()
// Scan into bannedFiles
for rows.Next() {
bannedFile, err := scanFileBan(rows)
if err != nil {
return []FileBan{}, err
}
bannedFiles = append(bannedFiles, bannedFile)
}
return bannedFiles, nil
}
// GetBannedFile returns an FileBan.
func GetBannedFile(sha256 []byte) (FileBan, error) {
row := DB.QueryRow(getBannedFileBySHA256Hash, sha256)
return scanFileBan(row)
}
// DeleteBannedFile deletes a banned file from the database, effectively unbanning it.
func DeleteBannedFile(sha256 []byte) error {
result, err := DB.Exec(deleteFileBan, sha256)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}