Skip to content
Snippets Groups Projects
queries.go 10.9 KiB
Newer Older
  • Learn to ignore specific revisions
  • package db
    
    import (
    	"database/sql"
    
    Dean's avatar
    Dean committed
    	"encoding/hex"
    
    Dean's avatar
    Dean committed
    
    
    	"github.com/lib/pq"
    
    Dean's avatar
    Dean committed
    	"github.com/pkg/errors"
    
    )
    
    // SelectUserByToken returns a user object from a token.
    func SelectUserByToken(token string) (User, error) {
    	var user User
    
    Dean's avatar
    Dean committed
    	var bucketCapacity sql.NullInt64
    
    	err := DB.QueryRow(selectUserByToken, token).
    
    Dean's avatar
    Dean committed
    		Scan(&user.ID, &user.Username, &user.Email, &user.IsAdmin, &user.IsBlocked, &bucketCapacity)
    	if err != nil {
    		return user, err
    	}
    	if bucketCapacity.Valid {
    		user.BucketCapacity = bucketCapacity.Int64
    	}
    
    	return user, err
    }
    
    // ObjectKeyExists returns a boolean specifying if the given key is already in
    // use by another object.
    
    func ObjectKeyExists(bucket, key string) (bool, error) {
    
    	err := DB.QueryRow(countOfObjectByBucketAndRandomKey, bucket, key).Scan(&count)
    
    	if err != nil {
    		return false, err
    	}
    	return count > 0, nil
    }
    
    // InsertShortURL inserts a short URL object into the database.
    func InsertShortURL(bucket, key, destURL string, associatedUser *string) error {
    
    	if !strings.HasPrefix(key, "/") {
    		key = "/" + key
    	}
    	result, err := DB.Exec(insertShortURL,
    		bucket+key,
    		bucket,
    		key,
    		key[1:],
    		destURL,
    		associatedUser)
    
    Dean's avatar
    Dean committed
    	if err != nil {
    		return err
    	}
    	rows, err := result.RowsAffected()
    	if err != nil {
    		return err
    	}
    	if rows != 1 {
    		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
    	}
    	return nil
    
    }
    
    // InsertFile inserts a file object into the database.
    
    func InsertFile(bucket, key, ext, contentType string, contentLength int64, md5Hash, sha256Hash []byte, associatedUser *string) error {
    
    	if !strings.HasPrefix(key, "/") {
    		key = "/" + key
    	}
    
    Dean's avatar
    Dean committed
    	result, err := DB.Exec(insertFile,
    
    Dean's avatar
    Dean committed
    		bucket+key+ext,
    
    Dean's avatar
    Dean committed
    		key+ext,
    
    		contentType,
    		contentLength,
    		md5Hash,
    
    Dean's avatar
    Dean committed
    	if err != nil {
    		return err
    	}
    	rows, err := result.RowsAffected()
    	if err != nil {
    		return err
    	}
    	if rows != 1 {
    		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
    	}
    	return nil
    
    }
    
    // CheckUserExistsByUsernameOrEmail returns a boolean specifying if a user
    // exists with the given username or email address.
    func CheckUserExistsByUsernameOrEmail(username, email string) (bool, error) {
    	var id string
    	err := DB.QueryRow(selectUserByUsernameOrEmail, strings.ToLower(username), email).Scan(&id)
    	if err == sql.ErrNoRows {
    		return false, nil
    	}
    	if err != nil {
    		return false, err
    	}
    	return true, nil
    }
    
    // InsertUser inserts a user into the database.
    func InsertUser(id, username, email string) error {
    
    Dean's avatar
    Dean committed
    	result, err := DB.Exec(insertUser, id, username, strings.ToLower(username), email)
    	if err != nil {
    		return err
    	}
    	rows, err := result.RowsAffected()
    	if err != nil {
    		return err
    	}
    	if rows != 1 {
    		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
    	}
    	return nil
    
    }
    
    // InsertToken inserts a token into the database.
    func InsertToken(userID, token string) error {
    
    Dean's avatar
    Dean committed
    	result, err := DB.Exec(insertToken, userID, token)
    	if err != nil {
    		return err
    	}
    	rows, err := result.RowsAffected()
    	if err != nil {
    		return err
    	}
    	if rows != 1 {
    		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
    	}
    	return nil
    
    
    // scanner is a database row/rows that can be scanned.
    type scanner interface {
    	Scan(dest ...interface{}) error
    }
    
    // scanObject scans an object row into an Object.
    func scanObject(scanner scanner) (Object, error) {
    	object := Object{}
    	var destURL sql.NullString
    	var contentType sql.NullString
    	var contentLength sql.NullInt64
    	var deletedAt pq.NullTime
    	var deleteReason sql.NullString
    	var associatedUser sql.NullString
    
    Dean's avatar
    Dean committed
    	var md5Hash []byte
    	var sha256Hash []byte
    
    	err := scanner.Scan(&object.Bucket,
    		&object.Key,
    		&object.Directory,
    		&object.Type,
    		&destURL,
    		&contentType,
    		&contentLength,
    		&object.CreatedAt,
    		&deletedAt,
    		&deleteReason,
    		&md5Hash,
    
    Dean's avatar
    Dean committed
    		&sha256Hash,
    
    		&associatedUser)
    	if err != nil {
    		return Object{}, errors.Wrap(err, "failed to Scan row from query")
    	}
    
    Dean's avatar
    Dean committed
    	if object.Type < 0 || object.Type > 2 {
    		return Object{}, errors.Wrap(err, "invalid Object scanned from query")
    	}
    	if object.Type != 1 {
    		if md5Hash != nil && len(md5Hash) == 16 {
    			object.MD5HashBytes = md5Hash
    			md5String := hex.EncodeToString(md5Hash)
    			object.MD5Hash = &md5String
    		}
    		if sha256Hash != nil && len(sha256Hash) == 32 {
    			object.SHA256HashBytes = sha256Hash
    			sha256String := hex.EncodeToString(sha256Hash)
    			object.SHA256Hash = &sha256String
    		}
    	}
    
    	if object.Type == 0 {
    
    Dean's avatar
    Dean committed
    		if !contentType.Valid || !contentLength.Valid || object.MD5HashBytes == nil || object.SHA256HashBytes == nil {
    
    			return Object{}, errors.Wrap(err, "invalid Object scanned from query")
    		}
    		object.ContentType = &contentType.String
    		object.ContentLength = &contentLength.Int64
    
    Dean's avatar
    Dean committed
    	}
    	if object.Type == 1 {
    
    		if !destURL.Valid {
    			return Object{}, errors.Wrap(err, "invalid Object scanned from query")
    		}
    		object.DestURL = &destURL.String
    	}
    
    Dean's avatar
    Dean committed
    	if object.Type == 2 {
    		if deletedAt.Valid {
    			object.DeletedAt = &deletedAt.Time
    			if deleteReason.Valid {
    				object.DeleteReason = &deleteReason.String
    			}
    		} else {
    			if deleteReason.Valid {
    				return Object{}, errors.Wrap(err, "invalid Object scanned from query")
    			}
    
    Dean's avatar
    Dean committed
    	} else if deletedAt.Valid || deleteReason.Valid {
    		return Object{}, errors.Wrap(err, "invalid Object scanned from query")
    
    	}
    	if associatedUser.Valid {
    		object.AssociatedUser = &associatedUser.String
    	}
    	return object, nil
    }
    
    // ListObjectsByAssociatedUser returns all objects (paginated) associated with a
    // user.
    
    func ListObjectsByAssociatedUser(userID string, typ int, asc bool, offset, limit int) ([]Object, error) {
    
    	objects := []Object{}
    	order := "DESC"
    	if asc {
    		order = "ASC"
    	}
    
    Spotlight Deveaux's avatar
    Spotlight Deveaux committed
    	typeFilter := "!= 2"
    
    	if typ != -1 {
    		typeFilter = fmt.Sprintf("= %v", typ)
    	}
    
    	rows, err := DB.Query(fmt.Sprintf(listObjectsByAssociatedUser, typeFilter, order), userID, limit, offset)
    
    	if err != nil {
    		return objects, err
    	}
    	defer rows.Close()
    
    	// Scan into Objects
    	for rows.Next() {
    		object, err := scanObject(rows)
    		if err != nil {
    			return []Object{}, err
    		}
    		objects = append(objects, object)
    	}
    	return objects, nil
    }
    
    
    Spotlight Deveaux's avatar
    Spotlight Deveaux committed
    // CountObjectsByAssociatedUser returns the count of all objects associated with a user.
    func CountObjectsByAssociatedUser(userID string, typ int) (int, error) {
    	typeFilter := "!= 2"
    	if typ != -1 {
    		typeFilter = fmt.Sprintf("= %v", typ)
    	}
    
    	var count int
    	query := fmt.Sprintf(countObjectsByAssociatedUser, typeFilter)
    	fmt.Printf(query)
    	err := DB.QueryRow(query, userID).Scan(&count)
    	if err != nil {
    		return 0, err
    	}
    	return count, nil
    }
    
    
    // GetObject returns an object.
    func GetObject(bucket, key string) (Object, error) {
    	row := DB.QueryRow(getObjectByBucketKey, fmt.Sprintf("%s/%s", bucket, key))
    	return scanObject(row)
    }
    
    
    Dean's avatar
    Dean committed
    // CheckIfObjectExists returns true if an object exists with the specified hash.
    func CheckIfObjectExists(sha256 []byte) (bool, error) {
    	var count int
    	err := DB.QueryRow(countOfObjectsBySHA256, sha256).Scan(&count)
    	if err != nil {
    		return false, err
    	}
    	return count > 0, nil
    }
    
    
    Dean's avatar
    Dean committed
    // UpdateObjectToTombstoneByBucketKey sets an object to be a tombstone by bucket and key.
    func UpdateObjectToTombstoneByBucketKey(bucket, key string, reason *string, retainAssociatedUser bool, retainHashes bool) error {
    	sql := updateObjectToTombstoneByBucketKey
    	if retainAssociatedUser {
    		sql = updateObjectToTombstoneKeepAssociatedUserByBucketKey
    		if retainHashes {
    			sql = updateObjectToTombstoneKeepHashesAndAssociatedUserByBucketKey
    		}
    	}
    	result, err := DB.Exec(sql, reason, fmt.Sprintf("%s/%s", bucket, key))
    	if err != nil {
    		return err
    	}
    	rows, err := result.RowsAffected()
    	if err != nil {
    		return err
    	}
    	if rows != 1 {
    		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
    	}
    	return nil
    }
    
    // UpdateObjectToTombstoneBySHA256Hash sets all objects with a matching SHA256 hash to be a tombstone.
    func UpdateObjectToTombstoneBySHA256Hash(sha256 []byte, reason *string, retainAssociatedUser bool, retainHashes bool) error {
    	sql := updateObjectToTombstoneBySHA256Hash
    	if retainAssociatedUser {
    		sql = updateObjectToTombstoneKeepAssociatedUserBySHA256Hash
    		if retainHashes {
    			sql = updateObjectToTombstoneKeepHashesAndAssociatedUserBySHA256Hash
    		}
    	}
    	_, err := DB.Exec(sql, reason, sha256)
    	if err != nil {
    		return err
    	}
    	return nil
    }
    
    // CheckIfFileBanExists returns a boolean specifying if the given key is already
    // in use by another object.
    func CheckIfFileBanExists(sha256 []byte) (bool, error) {
    	var count int
    	err := DB.QueryRow(countOfFileBanBySHA256, sha256).Scan(&count)
    	if err != nil {
    		return false, err
    	}
    	return count > 0, nil
    }
    
    // InsertFileBan inserts a file ban into the database.
    func InsertFileBan(sha256 []byte, didQuarantine bool, reason int, description, malwareName *string) error {
    	if reason < 0 || reason > 3 {
    		return errors.New("invalid file ban reason")
    	}
    	if malwareName != nil && *malwareName == "" {
    		malwareName = nil
    	}
    	result, err := DB.Exec(insertFileBan, sha256, didQuarantine, reason, description, malwareName)
    	if err != nil {
    		return err
    	}
    	rows, err := result.RowsAffected()
    	if err != nil {
    		return err
    	}
    	if rows != 1 {
    		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
    	}
    	return nil
    }
    
    // scanFileBan scans an file_bans row into an FileBan.
    func scanFileBan(scanner scanner) (FileBan, error) {
    	fileBan := FileBan{}
    	err := scanner.Scan(&fileBan.SHA256HashBytes,
    		&fileBan.DidQuarantine,
    		&fileBan.Reason,
    		&fileBan.Description,
    		&fileBan.MalwareName)
    	if err != nil {
    		return FileBan{}, errors.Wrap(err, "failed to Scan row from query")
    	}
    	if fileBan.Reason < 0 || fileBan.Reason > 3 {
    		return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query")
    	}
    	if fileBan.Reason != 1 && fileBan.MalwareName != nil {
    		return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query")
    	}
    	if fileBan.SHA256HashBytes != nil && len(fileBan.SHA256HashBytes) == 32 {
    		sha256String := hex.EncodeToString(fileBan.SHA256HashBytes)
    		fileBan.SHA256Hash = &sha256String
    	}
    	return fileBan, nil
    }
    
    // ListBannedFiles returns all banned files.
    func ListBannedFiles() ([]FileBan, error) {
    	bannedFiles := []FileBan{}
    	rows, err := DB.Query(listBannedFiles)
    	if err != nil {
    		return bannedFiles, err
    	}
    	defer rows.Close()
    
    	// Scan into bannedFiles
    	for rows.Next() {
    		bannedFile, err := scanFileBan(rows)
    		if err != nil {
    			return []FileBan{}, err
    		}
    		bannedFiles = append(bannedFiles, bannedFile)
    	}
    	return bannedFiles, nil
    }
    
    // GetBannedFile returns an FileBan.
    func GetBannedFile(sha256 []byte) (FileBan, error) {
    	row := DB.QueryRow(getBannedFileBySHA256Hash, sha256)
    	return scanFileBan(row)
    }
    
    // DeleteBannedFile deletes a banned file from the database, effectively unbanning it.
    func DeleteBannedFile(sha256 []byte) error {
    	result, err := DB.Exec(deleteFileBan, sha256)
    
    	if err != nil {
    		return err
    	}
    	rows, err := result.RowsAffected()
    	if err != nil {
    		return err
    	}
    	if rows != 1 {
    		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
    	}
    	return nil
    }