package db import ( "database/sql" "encoding/hex" "fmt" "strings" "github.com/lib/pq" "github.com/pkg/errors" ) // SelectUserByToken returns a user object from a token. func SelectUserByToken(token string) (User, error) { var user User var bucketCapacity sql.NullInt64 err := DB.QueryRow(selectUserByToken, token). Scan(&user.ID, &user.Username, &user.Email, &user.IsAdmin, &user.IsBlocked, &bucketCapacity) if err != nil { return user, err } if bucketCapacity.Valid { user.BucketCapacity = bucketCapacity.Int64 } return user, err } // ObjectKeyExists returns a boolean specifying if the given key is already in // use by another object. func ObjectKeyExists(bucket, key string) (bool, error) { var count int err := DB.QueryRow(countOfObjectByBucketAndRandomKey, bucket, key).Scan(&count) if err != nil { return false, err } return count > 0, nil } // InsertShortURL inserts a short URL object into the database. func InsertShortURL(bucket, key, destURL string, associatedUser *string) error { if !strings.HasPrefix(key, "/") { key = "/" + key } result, err := DB.Exec(insertShortURL, bucket+key, bucket, key, key[1:], destURL, associatedUser) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows) } return nil } // InsertFile inserts a file object into the database. func InsertFile(bucket, key, ext, contentType string, contentLength int64, md5Hash, sha256Hash []byte, associatedUser *string) error { if !strings.HasPrefix(key, "/") { key = "/" + key } result, err := DB.Exec(insertFile, bucket+key+ext, bucket, key+ext, key[1:], contentType, contentLength, md5Hash, sha256Hash, associatedUser) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows) } return nil } // CheckUserExistsByUsernameOrEmail returns a boolean specifying if a user // exists with the given username or email address. func CheckUserExistsByUsernameOrEmail(username, email string) (bool, error) { var id string err := DB.QueryRow(selectUserByUsernameOrEmail, strings.ToLower(username), email).Scan(&id) if err == sql.ErrNoRows { return false, nil } if err != nil { return false, err } return true, nil } // InsertUser inserts a user into the database. func InsertUser(id, username, email string) error { result, err := DB.Exec(insertUser, id, username, strings.ToLower(username), email) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows) } return nil } // InsertToken inserts a token into the database. func InsertToken(userID, token string) error { result, err := DB.Exec(insertToken, userID, token) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows) } return nil } // scanner is a database row/rows that can be scanned. type scanner interface { Scan(dest ...interface{}) error } // scanObject scans an object row into an Object. func scanObject(scanner scanner) (Object, error) { object := Object{} var destURL sql.NullString var contentType sql.NullString var contentLength sql.NullInt64 var deletedAt pq.NullTime var deleteReason sql.NullString var associatedUser sql.NullString var md5Hash []byte var sha256Hash []byte err := scanner.Scan(&object.Bucket, &object.Key, &object.Directory, &object.Type, &destURL, &contentType, &contentLength, &object.CreatedAt, &deletedAt, &deleteReason, &md5Hash, &sha256Hash, &associatedUser) if err != nil { return Object{}, errors.Wrap(err, "failed to Scan row from query") } if object.Type < 0 || object.Type > 2 { return Object{}, errors.Wrap(err, "invalid Object scanned from query") } if object.Type != 1 { if md5Hash != nil && len(md5Hash) == 16 { object.MD5HashBytes = md5Hash md5String := hex.EncodeToString(md5Hash) object.MD5Hash = &md5String } if sha256Hash != nil && len(sha256Hash) == 32 { object.SHA256HashBytes = sha256Hash sha256String := hex.EncodeToString(sha256Hash) object.SHA256Hash = &sha256String } } if object.Type == 0 { if !contentType.Valid || !contentLength.Valid || object.MD5HashBytes == nil || object.SHA256HashBytes == nil { return Object{}, errors.Wrap(err, "invalid Object scanned from query") } object.ContentType = &contentType.String object.ContentLength = &contentLength.Int64 } if object.Type == 1 { if !destURL.Valid { return Object{}, errors.Wrap(err, "invalid Object scanned from query") } object.DestURL = &destURL.String } if object.Type == 2 { if deletedAt.Valid { object.DeletedAt = &deletedAt.Time if deleteReason.Valid { object.DeleteReason = &deleteReason.String } } else { if deleteReason.Valid { return Object{}, errors.Wrap(err, "invalid Object scanned from query") } } } else if deletedAt.Valid || deleteReason.Valid { return Object{}, errors.Wrap(err, "invalid Object scanned from query") } if associatedUser.Valid { object.AssociatedUser = &associatedUser.String } return object, nil } // ListObjectsByAssociatedUser returns all objects (paginated) associated with a // user. func ListObjectsByAssociatedUser(userID string, typ int, asc bool, offset, limit int) ([]Object, error) { objects := []Object{} order := "DESC" if asc { order = "ASC" } typeFilter := "!= 2" if typ != -1 { typeFilter = fmt.Sprintf("= %v", typ) } rows, err := DB.Query(fmt.Sprintf(listObjectsByAssociatedUser, typeFilter, order), userID, limit, offset) if err != nil { return objects, err } defer rows.Close() // Scan into Objects for rows.Next() { object, err := scanObject(rows) if err != nil { return []Object{}, err } objects = append(objects, object) } return objects, nil } // CountObjectsByAssociatedUser returns the count of all objects associated with a user. func CountObjectsByAssociatedUser(userID string, typ int) (int, error) { typeFilter := "!= 2" if typ != -1 { typeFilter = fmt.Sprintf("= %v", typ) } var count int query := fmt.Sprintf(countObjectsByAssociatedUser, typeFilter) err := DB.QueryRow(query, userID).Scan(&count) if err != nil { return 0, err } return count, nil } // GetObject returns an object. func GetObject(bucket, key string) (Object, error) { row := DB.QueryRow(getObjectByBucketKey, fmt.Sprintf("%s/%s", bucket, key)) return scanObject(row) } // CheckIfObjectExists returns true if an object exists with the specified hash. func CheckIfObjectExists(sha256 []byte) (bool, error) { var count int err := DB.QueryRow(countOfObjectsBySHA256, sha256).Scan(&count) if err != nil { return false, err } return count > 0, nil } // UpdateObjectToTombstoneByBucketKey sets an object to be a tombstone by bucket and key. func UpdateObjectToTombstoneByBucketKey(bucket, key string, reason *string, retainAssociatedUser bool, retainHashes bool) error { sql := updateObjectToTombstoneByBucketKey if retainAssociatedUser { sql = updateObjectToTombstoneKeepAssociatedUserByBucketKey if retainHashes { sql = updateObjectToTombstoneKeepHashesAndAssociatedUserByBucketKey } } result, err := DB.Exec(sql, reason, fmt.Sprintf("%s/%s", bucket, key)) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows) } return nil } // UpdateObjectToTombstoneBySHA256Hash sets all objects with a matching SHA256 hash to be a tombstone. func UpdateObjectToTombstoneBySHA256Hash(sha256 []byte, reason *string, retainAssociatedUser bool, retainHashes bool) error { sql := updateObjectToTombstoneBySHA256Hash if retainAssociatedUser { sql = updateObjectToTombstoneKeepAssociatedUserBySHA256Hash if retainHashes { sql = updateObjectToTombstoneKeepHashesAndAssociatedUserBySHA256Hash } } _, err := DB.Exec(sql, reason, sha256) if err != nil { return err } return nil } // CheckIfFileBanExists returns a boolean specifying if the given key is already // in use by another object. func CheckIfFileBanExists(sha256 []byte) (bool, error) { var count int err := DB.QueryRow(countOfFileBanBySHA256, sha256).Scan(&count) if err != nil { return false, err } return count > 0, nil } // InsertFileBan inserts a file ban into the database. func InsertFileBan(sha256 []byte, didQuarantine bool, reason int, description, malwareName *string) error { if reason < 0 || reason > 3 { return errors.New("invalid file ban reason") } if malwareName != nil && *malwareName == "" { malwareName = nil } result, err := DB.Exec(insertFileBan, sha256, didQuarantine, reason, description, malwareName) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows) } return nil } // scanFileBan scans an file_bans row into an FileBan. func scanFileBan(scanner scanner) (FileBan, error) { fileBan := FileBan{} err := scanner.Scan(&fileBan.SHA256HashBytes, &fileBan.DidQuarantine, &fileBan.Reason, &fileBan.Description, &fileBan.MalwareName) if err != nil { return FileBan{}, errors.Wrap(err, "failed to Scan row from query") } if fileBan.Reason < 0 || fileBan.Reason > 3 { return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query") } if fileBan.Reason != 1 && fileBan.MalwareName != nil { return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query") } if fileBan.SHA256HashBytes != nil && len(fileBan.SHA256HashBytes) == 32 { sha256String := hex.EncodeToString(fileBan.SHA256HashBytes) fileBan.SHA256Hash = &sha256String } return fileBan, nil } // ListBannedFiles returns all banned files. func ListBannedFiles() ([]FileBan, error) { bannedFiles := []FileBan{} rows, err := DB.Query(listBannedFiles) if err != nil { return bannedFiles, err } defer rows.Close() // Scan into bannedFiles for rows.Next() { bannedFile, err := scanFileBan(rows) if err != nil { return []FileBan{}, err } bannedFiles = append(bannedFiles, bannedFile) } return bannedFiles, nil } // GetBannedFile returns an FileBan. func GetBannedFile(sha256 []byte) (FileBan, error) { row := DB.QueryRow(getBannedFileBySHA256Hash, sha256) return scanFileBan(row) } // DeleteBannedFile deletes a banned file from the database, effectively unbanning it. func DeleteBannedFile(sha256 []byte) error { result, err := DB.Exec(deleteFileBan, sha256) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows) } return nil }