Newer
Older
package db
import (
"database/sql"
)
// SelectUserByToken returns a user object from a token.
func SelectUserByToken(token string) (User, error) {
var user User
err := DB.QueryRow(selectUserByToken, token).
Scan(&user.ID, &user.Username, &user.Email, &user.IsAdmin, &user.IsBlocked, &bucketCapacity)
if err != nil {
return user, err
}
if bucketCapacity.Valid {
user.BucketCapacity = bucketCapacity.Int64
}
return user, err
}
// ObjectKeyExists returns a boolean specifying if the given key is already in
// use by another object.
func ObjectKeyExists(bucket, key string) (bool, error) {
err := DB.QueryRow(countOfObjectByBucketAndRandomKey, bucket, key).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// InsertShortURL inserts a short URL object into the database.
func InsertShortURL(bucket, key, destURL string, associatedUser *string) error {
if !strings.HasPrefix(key, "/") {
key = "/" + key
}
result, err := DB.Exec(insertShortURL,
bucket+key,
bucket,
key,
key[1:],
destURL,
associatedUser)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// InsertFile inserts a file object into the database.
func InsertFile(bucket, key, ext, contentType string, contentLength int64, md5Hash, sha256Hash []byte, associatedUser *string) error {
if !strings.HasPrefix(key, "/") {
key = "/" + key
}
contentType,
contentLength,
md5Hash,
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// CheckUserExistsByUsernameOrEmail returns a boolean specifying if a user
// exists with the given username or email address.
func CheckUserExistsByUsernameOrEmail(username, email string) (bool, error) {
var id string
err := DB.QueryRow(selectUserByUsernameOrEmail, strings.ToLower(username), email).Scan(&id)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// InsertUser inserts a user into the database.
func InsertUser(id, username, email string) error {
result, err := DB.Exec(insertUser, id, username, strings.ToLower(username), email)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// InsertToken inserts a token into the database.
func InsertToken(userID, token string) error {
result, err := DB.Exec(insertToken, userID, token)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
// scanner is a database row/rows that can be scanned.
type scanner interface {
Scan(dest ...interface{}) error
}
// scanObject scans an object row into an Object.
func scanObject(scanner scanner) (Object, error) {
object := Object{}
var destURL sql.NullString
var contentType sql.NullString
var contentLength sql.NullInt64
var deletedAt pq.NullTime
var deleteReason sql.NullString
var associatedUser sql.NullString
err := scanner.Scan(&object.Bucket,
&object.Key,
&object.Directory,
&object.Type,
&destURL,
&contentType,
&contentLength,
&object.CreatedAt,
&deletedAt,
&deleteReason,
&md5Hash,
&associatedUser)
if err != nil {
return Object{}, errors.Wrap(err, "failed to Scan row from query")
}
if object.Type < 0 || object.Type > 2 {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
if object.Type != 1 {
if md5Hash != nil && len(md5Hash) == 16 {
object.MD5HashBytes = md5Hash
md5String := hex.EncodeToString(md5Hash)
object.MD5Hash = &md5String
}
if sha256Hash != nil && len(sha256Hash) == 32 {
object.SHA256HashBytes = sha256Hash
sha256String := hex.EncodeToString(sha256Hash)
object.SHA256Hash = &sha256String
}
}
if !contentType.Valid || !contentLength.Valid || object.MD5HashBytes == nil || object.SHA256HashBytes == nil {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
object.ContentType = &contentType.String
object.ContentLength = &contentLength.Int64
if !destURL.Valid {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
object.DestURL = &destURL.String
}
if object.Type == 2 {
if deletedAt.Valid {
object.DeletedAt = &deletedAt.Time
if deleteReason.Valid {
object.DeleteReason = &deleteReason.String
}
} else {
if deleteReason.Valid {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
}
} else if deletedAt.Valid || deleteReason.Valid {
return Object{}, errors.Wrap(err, "invalid Object scanned from query")
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
}
if associatedUser.Valid {
object.AssociatedUser = &associatedUser.String
}
return object, nil
}
// ListObjectsByAssociatedUser returns all objects (paginated) associated with a
// user.
func ListObjectsByAssociatedUser(userID string, asc bool, offset, limit int) ([]Object, error) {
objects := []Object{}
order := "DESC"
if asc {
order = "ASC"
}
rows, err := DB.Query(fmt.Sprintf(listObjectsByAssociatedUser, order), userID, limit, offset)
if err != nil {
return objects, err
}
defer rows.Close()
// Scan into Objects
for rows.Next() {
object, err := scanObject(rows)
if err != nil {
return []Object{}, err
}
objects = append(objects, object)
}
return objects, nil
}
// GetObject returns an object.
func GetObject(bucket, key string) (Object, error) {
row := DB.QueryRow(getObjectByBucketKey, fmt.Sprintf("%s/%s", bucket, key))
return scanObject(row)
}
// CheckIfObjectExists returns true if an object exists with the specified hash.
func CheckIfObjectExists(sha256 []byte) (bool, error) {
var count int
err := DB.QueryRow(countOfObjectsBySHA256, sha256).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
// UpdateObjectToTombstoneByBucketKey sets an object to be a tombstone by bucket and key.
func UpdateObjectToTombstoneByBucketKey(bucket, key string, reason *string, retainAssociatedUser bool, retainHashes bool) error {
sql := updateObjectToTombstoneByBucketKey
if retainAssociatedUser {
sql = updateObjectToTombstoneKeepAssociatedUserByBucketKey
if retainHashes {
sql = updateObjectToTombstoneKeepHashesAndAssociatedUserByBucketKey
}
}
result, err := DB.Exec(sql, reason, fmt.Sprintf("%s/%s", bucket, key))
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// UpdateObjectToTombstoneBySHA256Hash sets all objects with a matching SHA256 hash to be a tombstone.
func UpdateObjectToTombstoneBySHA256Hash(sha256 []byte, reason *string, retainAssociatedUser bool, retainHashes bool) error {
sql := updateObjectToTombstoneBySHA256Hash
if retainAssociatedUser {
sql = updateObjectToTombstoneKeepAssociatedUserBySHA256Hash
if retainHashes {
sql = updateObjectToTombstoneKeepHashesAndAssociatedUserBySHA256Hash
}
}
_, err := DB.Exec(sql, reason, sha256)
if err != nil {
return err
}
return nil
}
// CheckIfFileBanExists returns a boolean specifying if the given key is already
// in use by another object.
func CheckIfFileBanExists(sha256 []byte) (bool, error) {
var count int
err := DB.QueryRow(countOfFileBanBySHA256, sha256).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// InsertFileBan inserts a file ban into the database.
func InsertFileBan(sha256 []byte, didQuarantine bool, reason int, description, malwareName *string) error {
if reason < 0 || reason > 3 {
return errors.New("invalid file ban reason")
}
if malwareName != nil && *malwareName == "" {
malwareName = nil
}
result, err := DB.Exec(insertFileBan, sha256, didQuarantine, reason, description, malwareName)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}
// scanFileBan scans an file_bans row into an FileBan.
func scanFileBan(scanner scanner) (FileBan, error) {
fileBan := FileBan{}
err := scanner.Scan(&fileBan.SHA256HashBytes,
&fileBan.DidQuarantine,
&fileBan.Reason,
&fileBan.Description,
&fileBan.MalwareName)
if err != nil {
return FileBan{}, errors.Wrap(err, "failed to Scan row from query")
}
if fileBan.Reason < 0 || fileBan.Reason > 3 {
return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query")
}
if fileBan.Reason != 1 && fileBan.MalwareName != nil {
return FileBan{}, errors.Wrap(err, "invalid FileBan scanned from query")
}
if fileBan.SHA256HashBytes != nil && len(fileBan.SHA256HashBytes) == 32 {
sha256String := hex.EncodeToString(fileBan.SHA256HashBytes)
fileBan.SHA256Hash = &sha256String
}
return fileBan, nil
}
// ListBannedFiles returns all banned files.
func ListBannedFiles() ([]FileBan, error) {
bannedFiles := []FileBan{}
rows, err := DB.Query(listBannedFiles)
if err != nil {
return bannedFiles, err
}
defer rows.Close()
// Scan into bannedFiles
for rows.Next() {
bannedFile, err := scanFileBan(rows)
if err != nil {
return []FileBan{}, err
}
bannedFiles = append(bannedFiles, bannedFile)
}
return bannedFiles, nil
}
// GetBannedFile returns an FileBan.
func GetBannedFile(sha256 []byte) (FileBan, error) {
row := DB.QueryRow(getBannedFileBySHA256Hash, sha256)
return scanFileBan(row)
}
// DeleteBannedFile deletes a banned file from the database, effectively unbanning it.
func DeleteBannedFile(sha256 []byte) error {
result, err := DB.Exec(deleteFileBan, sha256)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return errors.Errorf("unexpected amount of rows affected: expected 1, got %v", rows)
}
return nil
}