Skip to content
Snippets Groups Projects
queries.go 5.87 KiB
Newer Older
package db

import (
	"database/sql"
	"fmt"
	"strings"
Dean's avatar
Dean committed

	"github.com/lib/pq"
Dean's avatar
Dean committed
	"github.com/pkg/errors"
)

// SelectUserByToken returns a user object from a token.
func SelectUserByToken(token string) (User, error) {
	var user User
Dean's avatar
Dean committed
	var bucketCapacity sql.NullInt64
	err := DB.QueryRow(selectUserByToken, token).
Dean's avatar
Dean committed
		Scan(&user.ID, &user.Username, &user.Email, &user.IsAdmin, &user.IsBlocked, &bucketCapacity)
	if err != nil {
		return user, err
	}
	if bucketCapacity.Valid {
		user.BucketCapacity = bucketCapacity.Int64
	}
	return user, err
}

// ObjectKeyExists returns a boolean specifying if the given key is already in
// use by another object.
func ObjectKeyExists(bucket, key string) (bool, error) {
	err := DB.QueryRow(countOfObjectByBucketAndRandomKey, bucket, key).Scan(&count)
	if err != nil {
		return false, err
	}
	return count > 0, nil
}

// InsertShortURL inserts a short URL object into the database.
func InsertShortURL(bucket, key, destURL string, associatedUser *string) error {
	if !strings.HasPrefix(key, "/") {
		key = "/" + key
	}
	result, err := DB.Exec(insertShortURL,
		bucket+key,
		bucket,
		key,
		key[1:],
		destURL,
		associatedUser)
Dean's avatar
Dean committed
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// InsertFile inserts a file object into the database.
func InsertFile(bucket, key, ext, contentType string, contentLength int64, md5Hash string, associatedUser *string) error {
	if !strings.HasPrefix(key, "/") {
		key = "/" + key
	}
Dean's avatar
Dean committed
	result, err := DB.Exec(insertFile,
		bucket+key,
		bucket,
Dean's avatar
Dean committed
		key+ext,
		contentType,
		contentLength,
		md5Hash,
		associatedUser)
Dean's avatar
Dean committed
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// CheckUserExistsByUsernameOrEmail returns a boolean specifying if a user
// exists with the given username or email address.
func CheckUserExistsByUsernameOrEmail(username, email string) (bool, error) {
	var id string
	err := DB.QueryRow(selectUserByUsernameOrEmail, strings.ToLower(username), email).Scan(&id)
	if err == sql.ErrNoRows {
		return false, nil
	}
	if err != nil {
		return false, err
	}
	return true, nil
}

// InsertUser inserts a user into the database.
func InsertUser(id, username, email string) error {
Dean's avatar
Dean committed
	result, err := DB.Exec(insertUser, id, username, strings.ToLower(username), email)
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}

// InsertToken inserts a token into the database.
func InsertToken(userID, token string) error {
Dean's avatar
Dean committed
	result, err := DB.Exec(insertToken, userID, token)
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil

// scanner is a database row/rows that can be scanned.
type scanner interface {
	Scan(dest ...interface{}) error
}

// scanObject scans an object row into an Object.
func scanObject(scanner scanner) (Object, error) {
	object := Object{}
	var destURL sql.NullString
	var contentType sql.NullString
	var contentLength sql.NullInt64
	var deletedAt pq.NullTime
	var deleteReason sql.NullString
	var associatedUser sql.NullString
	var md5Hash sql.NullString
	err := scanner.Scan(&object.Bucket,
		&object.Key,
		&object.Directory,
		&object.Type,
		&destURL,
		&contentType,
		&contentLength,
		&object.CreatedAt,
		&deletedAt,
		&deleteReason,
		&md5Hash,
		&associatedUser)
	if err != nil {
		return Object{}, errors.Wrap(err, "failed to Scan row from query")
	}
	if object.Type == 0 {
		if !contentType.Valid || !contentLength.Valid || !md5Hash.Valid {
			return Object{}, errors.Wrap(err, "invalid Object scanned from query")
		}
		object.ContentType = &contentType.String
		object.ContentLength = &contentLength.Int64
		object.MD5Hash = &md5Hash.String
	} else if object.Type == 1 {
		if !destURL.Valid {
			return Object{}, errors.Wrap(err, "invalid Object scanned from query")
		}
		object.DestURL = &destURL.String
	}
	if deletedAt.Valid {
		object.DeletedAt = &deletedAt.Time
		if deleteReason.Valid {
			object.DeleteReason = &deleteReason.String
		}
	} else {
		if deleteReason.Valid {
			return Object{}, errors.Wrap(err, "invalid Object scanned from query")
		}
	}
	if associatedUser.Valid {
		object.AssociatedUser = &associatedUser.String
	}
	return object, nil
}

// ListObjectsByAssociatedUser returns all objects (paginated) associated with a
// user.
func ListObjectsByAssociatedUser(userID string, asc bool, offset, limit int) ([]Object, error) {
	objects := []Object{}
	order := "DESC"
	if asc {
		order = "ASC"
	}
	rows, err := DB.Query(fmt.Sprintf(listObjectsByAssociatedUser, order), userID, limit, offset)
	if err != nil {
		return objects, err
	}
	defer rows.Close()

	// Scan into Objects
	for rows.Next() {
		object, err := scanObject(rows)
		if err != nil {
			return []Object{}, err
		}
		objects = append(objects, object)
	}
	return objects, nil
}

// GetObject returns an object.
func GetObject(bucket, key string) (Object, error) {
	row := DB.QueryRow(getObjectByBucketKey, fmt.Sprintf("%s/%s", bucket, key))
	return scanObject(row)
}

// UpdateObjectToTombstone sets an object to be a tombstone.
func UpdateObjectToTombstone(bucket, key string, reason *string) error {
	result, err := DB.Exec(updateObjectToTombstone, reason, fmt.Sprintf("%s/%s", bucket, key))
	if err != nil {
		return err
	}
	rows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if rows != 1 {
		return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
	}
	return nil
}