diff --git a/go.mod b/go.mod index 7f061ce526832b1e17bdf6bdf9ffb319f7967fa8..497c678c3f5e3294e11478022a1bba1b3b923c16 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,17 @@ module owo.codes/whats-this/api require ( + github.com/BurntSushi/toml v0.3.1 // indirect github.com/akamensky/base58 v0.0.0-20170920141933-92b0f56f531a github.com/go-chi/chi v4.0.1+incompatible github.com/go-chi/render v1.0.1 - github.com/go-chi/valve v0.0.0-20170920024740-9e45288364f4 + github.com/go-redis/redis v6.15.1+incompatible github.com/gofrs/uuid v3.2.0+incompatible github.com/lib/pq v1.0.0 - github.com/o1egl/paseto v1.0.0 + github.com/onsi/ginkgo v1.7.0 // indirect + github.com/onsi/gomega v1.4.3 // indirect github.com/pkg/errors v0.8.1 github.com/rs/zerolog v1.11.0 - github.com/satori/go.uuid v1.2.0 github.com/spf13/pflag v1.0.3 github.com/spf13/viper v1.3.1 - golang.org/x/crypto v0.0.0-20190219172222-a4c6cb3142f2 ) diff --git a/go.sum b/go.sum index fec3130aee3cc74f37e40f21ce1b54632e251b09..4f436ae5ee93e48cb03cadf8ffb8ab2061a4c364 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,12 @@ -github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= -github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= -github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb h1:6Z/wqhPFZ7y5ksCEV/V5MXOazLaeu/EW97CU5rz8NWk= -github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb/go.mod h1:UzH9IX1MMqOcwhoNOIjmTQeAxrFgzs50j4golQtXXxU= -github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= -github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/akamensky/base58 v0.0.0-20170920141933-92b0f56f531a h1:ndZwlx4H28xTc2WeU9tsGyv+y8xBp2HH31x0xi41c7M= github.com/akamensky/base58 v0.0.0-20170920141933-92b0f56f531a/go.mod h1:dl0ldiQPIl8Io2dgdNbAtqeuq0Os0SxIGCn8IVGH3zo= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -17,30 +14,35 @@ github.com/go-chi/chi v4.0.1+incompatible h1:RSRC5qmFPtO90t7pTL0DBMNpZFsb/sHF3RX github.com/go-chi/chi v4.0.1+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns= -github.com/go-chi/valve v0.0.0-20170920024740-9e45288364f4 h1:JYZmrkBDj6LwUbsRysF9tpLnz59npoZSI3KG2XHqvHw= -github.com/go-chi/valve v0.0.0-20170920024740-9e45288364f4/go.mod h1:F4ZINQr5T71wO1JOmdQsGTBew+njUAXn65LLGjuagwY= +github.com/go-redis/redis v6.15.1+incompatible h1:BZ9s4/vHrIqwOb0OPtTQ5uABxETJ3NRuUNoSUurnkew= +github.com/go-redis/redis v6.15.1+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/magiconair/properties v1.8.0 h1:LLgXmsheXeRoUOBOjtwPQCWIYqM/LU1ayDtDePerRcY= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/o1egl/paseto v1.0.0 h1:bwpvPu2au176w4IBlhbyUv/S5VPptERIA99Oap5qUd0= -github.com/o1egl/paseto v1.0.0/go.mod h1:5HxsZPmw/3RI2pAwGo1HhOOwSdvBpcuVzO7uDkm+CLU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M= github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= -github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= @@ -51,18 +53,26 @@ github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/viper v1.3.1 h1:5+8j8FTpnFV4nEImW/ofkzEt8VoOiLXxdYIDsB73T38= github.com/spf13/viper v1.3.1/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -golang.org/x/crypto v0.0.0-20181025213731-e84da0312774/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190219172222-a4c6cb3142f2 h1:NwxKRvbkH5MsNkvOtPZi3/3kmI8CAzs3mtv+GLQMkNo= -golang.org/x/crypto v0.0.0-20190219172222-a4c6cb3142f2/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a h1:1n5lsVfiQW3yfsRGu98756EH1YthsFqr/5mxHduZW2A= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/lib/apierrors/errors.go b/lib/apierrors/errors.go index c826d540a017f9d1d186c5b2c7d7e9f6afd2b9ec..8b0b1c35f32fa4bba85e1239fc09f1ef0733449c 100644 --- a/lib/apierrors/errors.go +++ b/lib/apierrors/errors.go @@ -40,6 +40,9 @@ var ( // AlreadyDeleted is a 410 gone error. AlreadyDeleted = APIError{false, 410, "already deleted", false} + + // InsufficientTokens is a 429 too many requests error. + InsufficientTokens = APIError{false, 429, "too many requests", false} ) // Pomf errors diff --git a/lib/db/models.go b/lib/db/models.go index 73306d123df532b2e6ffaf42a0bb075849188123..3340a2da18c7fe1864477f17eae6d906ddf8bfbb 100644 --- a/lib/db/models.go +++ b/lib/db/models.go @@ -4,11 +4,12 @@ import "time" // User represents a user from the database. type User struct { - ID string `json:"id"` - Username string `json:"username"` - Email string `json:"email"` - IsAdmin bool `json:"is_admin"` - IsBlocked bool `json:"is_blocked"` + ID string `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + IsAdmin bool `json:"is_admin"` + IsBlocked bool `json:"is_blocked"` + BucketCapacity int64 `json:"-"` } // Object represents an object from the database. diff --git a/lib/db/queries.go b/lib/db/queries.go index 19757756b09562954bb6cbcfef49afa9e1be36b2..e48cd917855d07b3fc3e3f44a56f7d3feaa8cfdc 100644 --- a/lib/db/queries.go +++ b/lib/db/queries.go @@ -12,8 +12,15 @@ import ( // SelectUserByToken returns a user object from a token. func SelectUserByToken(token string) (User, error) { var user User + var bucketCapacity sql.NullInt64 err := DB.QueryRow(selectUserByToken, token). - Scan(&user.ID, &user.Username, &user.Email, &user.IsAdmin, &user.IsBlocked) + Scan(&user.ID, &user.Username, &user.Email, &user.IsAdmin, &user.IsBlocked, &bucketCapacity) + if err != nil { + return user, err + } + if bucketCapacity.Valid { + user.BucketCapacity = bucketCapacity.Int64 + } return user, err } diff --git a/lib/db/sql.go b/lib/db/sql.go index 8d9d2491f92cb5edf9a9f43c8e7d8daf74026a7f..52ee9681e9a3652a7437249308aaca57d4f29d12 100644 --- a/lib/db/sql.go +++ b/lib/db/sql.go @@ -6,7 +6,8 @@ SELECT username, email, is_admin, - is_blocked + is_blocked, + bucket_capacity FROM users u, tokens t diff --git a/lib/middleware/authentication.go b/lib/middleware/authentication.go index a7e8c0428d60583dd7b9cb9b4bd1efb42795ba1d..2903cdfd2c5af3b422d0362df673e39d2cebbff4 100644 --- a/lib/middleware/authentication.go +++ b/lib/middleware/authentication.go @@ -9,18 +9,23 @@ import ( "owo.codes/whats-this/api/lib/apierrors" "owo.codes/whats-this/api/lib/db" + "owo.codes/whats-this/api/lib/ratelimiter" "github.com/rs/zerolog/log" + "github.com/spf13/viper" ) // UUIDRegex is a UUID regex. See http://stackoverflow.com/a/13653180. var UUIDRegex = regexp.MustCompile("^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$") -type authKey struct{} +type authKey string // AuthorizedUserKey is the context value key for storing request authorization // information. -var AuthorizedUserKey authKey +var AuthorizedUserKey = authKey("AuthorizedUserKey") + +// BucketKey is the context value key for storing the request ratelimit bucket. +var BucketKey = authKey("BucketKey") // Authenticator creates middleware that authenticates incoming requests using // a token and checking it against the database. @@ -59,6 +64,16 @@ func Authenticator(next http.Handler) http.Handler { panic(apierrors.Unauthorized) } ctx = context.WithValue(ctx, AuthorizedUserKey, user) + + // Create a bucket and add to request + if viper.GetBool("ratelimiter.enable") { + bucketCapacity := user.BucketCapacity + if bucketCapacity < 0 { + bucketCapacity = 0 + } + bucket := ratelimiter.NewBucket(user.ID, bucketCapacity, 0) + ctx = context.WithValue(ctx, BucketKey, bucket) + } next.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(fn) @@ -69,3 +84,15 @@ func GetAuthorizedUser(r *http.Request) db.User { user, _ := r.Context().Value(AuthorizedUserKey).(db.User) return user } + +// GetBucket returns the current authorized user's ratelimit bucket. +func GetBucket(r *http.Request) ratelimiter.Bucket { + if !viper.GetBool("ratelimiter.enable") { + return ratelimiter.NoopBucket + } + bucket, ok := r.Context().Value(BucketKey).(ratelimiter.Bucket) + if !ok || bucket == nil { + return ratelimiter.EmptyBucket + } + return bucket +} diff --git a/lib/ratelimiter/bucket.go b/lib/ratelimiter/bucket.go new file mode 100644 index 0000000000000000000000000000000000000000..8a5a94008f630aba5e2b4a9d7fd29fa4e72460e4 --- /dev/null +++ b/lib/ratelimiter/bucket.go @@ -0,0 +1,210 @@ +package ratelimiter + +import ( + "fmt" + "net/http" + "strconv" + "sync" + "time" + + "github.com/go-redis/redis" + "github.com/rs/zerolog/log" + "github.com/spf13/viper" +) + +// Redis key format string for bucket ID. +const redisKeyFormat = "whats-this:api:bucket:%s" + +// onceNoCapacityWarning is used for creating a one-time warning in +// NewBucketWithRedis. +var onceNoCapacityWarning sync.Once + +// onceNoExpiryWarning is used for creating a one-time warning in +// NewBucketWithRedis. +var onceNoExpiryWarning sync.Once + +// Bucket is an interface that implements token bucket methods. +type Bucket interface { + Take(tokens int64) error + TakeWithHeaders(w http.ResponseWriter, tokens int64) error + Reset() error +} + +// noopBucket is a bucket that does nothing for use when ratelimiting is +// disabled (infinite capacity bucket). +type noopBucket struct{} + +// Take implements Bucket (noop). +func (b *noopBucket) Take(tokens int64) error { + return nil +} + +// TakeWithHeaders implements Bucket (noop). +func (b *noopBucket) TakeWithHeaders(w http.ResponseWriter, tokens int64) error { + return nil +} + +// Reset implements Bucket (noop). +func (b *noopBucket) Reset() error { + return nil +} + +// NoopBucket is a reusable noopBucket (Bucket that has infinite capacity). +var NoopBucket Bucket = &noopBucket{} + +// emptyBucket is a bucket that is always empty. +type emptyBucket struct{} + +// Take implements Bucket (always empty). +func (b *emptyBucket) Take(tokens int64) error { + return InsufficientTokens +} + +// TakeWithHeaders implements Bucket (always empty). +func (b *emptyBucket) TakeWithHeaders(w http.ResponseWriter, tokens int64) error { + return InsufficientTokens +} + +// Reset implements Bucket (always empty). +func (b *emptyBucket) Reset() error { + return nil +} + +// EmptyBucket is a reusable emptyBucket (Bucket that is always empty). +var EmptyBucket Bucket = &emptyBucket{} + +// RedisBucket represents a named bucket stored in Redis with a set number of +// tokens. +type RedisBucket struct { + r *redis.Client + ID string + Capacity int64 + Expiry time.Duration +} + +// NewBucket creates a Bucket with the default Redis client. If the capacity is +// set to 0, the limit will be set to the default limit in the configuration. +func NewBucket(id string, capacity int64, expiry time.Duration) Bucket { + return NewBucketWithRedis(rc, id, capacity, expiry) +} + +// NewBucketWithRedis creates a Bucket with a custom Redis client. If the +// capacity is set to 0, the limit will be set to the default limit in the +// configuration. +func NewBucketWithRedis(client *redis.Client, id string, capacity int64, expiry time.Duration) Bucket { + if client == nil { + panic("Invalid Redis client supplied") + } + if id == "" { + panic("id should be specified") + } + if capacity == 0 { + // Get default bucket capacity from Viper + if c := viper.GetInt64("ratelimiter.defaultBucketCapacity"); c > 0 { + capacity = c + } else { + onceNoCapacityWarning.Do(func() { + log.Warn().Msg("RedisBucket made with no capacity, returning a noopBucket " + + "(infinite capacity Bucket)") + }) + return NoopBucket + } + } + if capacity < 0 { + panic("RedisBucket limit is less than 0") + } + if expiry == 0 { + // Get default bucket expiry from Viper + if e := viper.GetDuration("ratelimiter.bucketExpiryDuration"); e > 0 { + expiry = e + } else { + onceNoExpiryWarning.Do(func() { + log.Warn().Msg("RedisBucket made with no expiry duration, returning a noopBucket " + + "(infinite capacity Bucket)") + }) + return NoopBucket + } + } + if expiry < 0 { + panic("RedisBucket expiry is less than 0") + } + return &RedisBucket{ + r: client, + ID: id, + Capacity: capacity, + Expiry: expiry, + } +} + +// key returns the key for this Bucket in Redis. +func (b *RedisBucket) key() string { + return fmt.Sprintf(redisKeyFormat, b.ID) +} + +// Take implements Bucket for Redis. +func (b *RedisBucket) Take(tokens int64) error { + _, _, err := b.take(tokens) + return err +} + +// TakeWithHeaders implements Bucket for Redis. +func (b *RedisBucket) TakeWithHeaders(w http.ResponseWriter, tokens int64) error { + current, expiry, err := b.take(tokens) + if err != nil && err != InsufficientTokens { + return err + } + if current < 0 { + // Don't set headers if there are negative tokens in the bucket + return nil + } + w.Header().Set("X-Ratelimiter-Capacity", strconv.Itoa(int(b.Capacity))) + w.Header().Set("X-Ratelimiter-Remaining", strconv.Itoa(int(current))) + w.Header().Set("X-Ratelimiter-Cost", strconv.Itoa(int(tokens))) + expirySeconds := time.Now().Add(expiry).UTC().Unix() + w.Header().Set("X-Ratelimiter-Expiry", strconv.Itoa(int(expirySeconds))) + return err +} + +// take is the internal take function used by TakeWithHeaders and Take. +func (b *RedisBucket) take(tokens int64) (int64, time.Duration, error) { + key := b.key() + currentStr, err := b.r.Get(key).Result() + var current int64 + if err != nil && err != redis.Nil { + return -1, 0, err + } + expiry := b.Expiry + if err == redis.Nil { + _, err = b.r.Set(key, b.Capacity, b.Expiry).Result() + if err != nil { + return -1, 0, err + } + current = b.Capacity + } else { + c, err := strconv.Atoi(currentStr) + if err != nil { + return -1, 0, err + } + current = int64(c) + + // Get key TTL + expiry, err = b.r.TTL(key).Result() + if err != nil { + return -1, 0, err + } + } + if current < tokens { + return current, expiry, InsufficientTokens + } + _, err = b.r.IncrBy(key, -tokens).Result() + if err != nil { + return current, expiry, err + } + return current - tokens, expiry, nil +} + +// Reset implements Bucket for Redis. +func (b *RedisBucket) Reset() error { + _, err := b.r.Set(b.key(), b.Capacity, b.Expiry).Result() + return err +} diff --git a/lib/ratelimiter/connect.go b/lib/ratelimiter/connect.go new file mode 100644 index 0000000000000000000000000000000000000000..d0a5ed4f33e75681c03467ec75fffebe3de939cd --- /dev/null +++ b/lib/ratelimiter/connect.go @@ -0,0 +1,24 @@ +package ratelimiter + +import ( + "github.com/go-redis/redis" +) + +var rc *redis.Client + +// RedisConnect connects to Redis and sets the default ratelimiter connection +// for RedisBuckets (use NewBucket to make a RedisBucket with the default +// connection). +func RedisConnect(url string) error { + opt, err := redis.ParseURL(url) + if err != nil { + return err + } + r := redis.NewClient(opt) + _, err = r.Ping().Result() + if err != nil { + return err + } + rc = r + return nil +} diff --git a/lib/ratelimiter/costs.go b/lib/ratelimiter/costs.go new file mode 100644 index 0000000000000000000000000000000000000000..85eefa8c498ee782513e02ee70c30140ed3c0fdd --- /dev/null +++ b/lib/ratelimiter/costs.go @@ -0,0 +1,13 @@ +package ratelimiter + +// Default costs for each route. +const ( + UploadPomfCost = 10 + ShortenPolrCost = 5 + + // CreateUserCost = 0 // unused + MeCost = 3 + ListObjectsCost = 10 + ObjectCost = 3 + DeleteObjectCost = 5 +) diff --git a/lib/ratelimiter/errors.go b/lib/ratelimiter/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..d8eba6f71228208644fe2b1808534a16da93eaf1 --- /dev/null +++ b/lib/ratelimiter/errors.go @@ -0,0 +1,14 @@ +package ratelimiter + +type bucketError struct { + Err string +} + +// Error implements error. +func (e *bucketError) Error() string { + return e.Err +} + +// InsufficientTokens means the requested amount of tokens can't be taken +// because there isn't enough tokens in the bucket at the moment. +var InsufficientTokens error = &bucketError{"insufficient tokens in Bucket for Take operation"} diff --git a/lib/routes/deleteobject.go b/lib/routes/deleteobject.go index 6ce06f4b45e96608a86647f991d4f0cd168cc911..93e2ae07250295fbda69e42c907fdcb9d9338c64 100644 --- a/lib/routes/deleteobject.go +++ b/lib/routes/deleteobject.go @@ -11,6 +11,7 @@ import ( "owo.codes/whats-this/api/lib/apierrors" "owo.codes/whats-this/api/lib/db" "owo.codes/whats-this/api/lib/middleware" + "owo.codes/whats-this/api/lib/ratelimiter" "github.com/go-chi/render" "github.com/pkg/errors" @@ -30,6 +31,16 @@ func DeleteObject(w http.ResponseWriter, r *http.Request) { panic(apierrors.Unauthorized) } + // Apply ratelimits + bucket := middleware.GetBucket(r) + err := bucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.deleteObjectCost")) + if err == ratelimiter.InsufficientTokens { + panic(apierrors.InsufficientTokens) + } + if err != nil { + panic(apierrors.InternalServerError) + } + // Get the key key := r.URL.Path if strings.HasPrefix(key, "/objects/") { diff --git a/lib/routes/listobjects.go b/lib/routes/listobjects.go index 08b0ce3fec3da290a1ed2aa7d0bf1139368c18cf..8647c39a3544abd2d8d64e690e46652c6b913a90 100644 --- a/lib/routes/listobjects.go +++ b/lib/routes/listobjects.go @@ -7,9 +7,11 @@ import ( "owo.codes/whats-this/api/lib/apierrors" "owo.codes/whats-this/api/lib/db" "owo.codes/whats-this/api/lib/middleware" + "owo.codes/whats-this/api/lib/ratelimiter" "github.com/go-chi/render" "github.com/rs/zerolog/log" + "github.com/spf13/viper" ) // Maximum objects per page @@ -29,6 +31,16 @@ func ListObjects(w http.ResponseWriter, r *http.Request) { panic(apierrors.Unauthorized) } + // Apply ratelimits + bucket := middleware.GetBucket(r) + err := bucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.listObjectsCost")) + if err == ratelimiter.InsufficientTokens { + panic(apierrors.InsufficientTokens) + } + if err != nil { + panic(apierrors.InternalServerError) + } + // Determine offset and limit information query := r.URL.Query() l := query.Get("limit") diff --git a/lib/routes/me.go b/lib/routes/me.go index 3d6f17645c0844db6eeca98b9350842785c2de45..8a2f77901eae41b7d8d7f2f1c9af4c75d7f2ba6f 100644 --- a/lib/routes/me.go +++ b/lib/routes/me.go @@ -5,8 +5,10 @@ import ( "owo.codes/whats-this/api/lib/apierrors" "owo.codes/whats-this/api/lib/middleware" + "owo.codes/whats-this/api/lib/ratelimiter" "github.com/go-chi/render" + "github.com/spf13/viper" ) // meResponse is the response type for /users/me. @@ -34,6 +36,16 @@ func Me(w http.ResponseWriter, r *http.Request) { panic(apierrors.Unauthorized) } + // Apply ratelimits + bucket := middleware.GetBucket(r) + err := bucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.meCost")) + if err == ratelimiter.InsufficientTokens { + panic(apierrors.InsufficientTokens) + } + if err != nil { + panic(apierrors.InternalServerError) + } + // Return response u := meUser{ UserID: user.ID, diff --git a/lib/routes/object.go b/lib/routes/object.go index 6c9e24022f79e1441fbbe816b60fc52205109e06..471ba282ab6488dda8f7f90815cbc0e54c9dbcd1 100644 --- a/lib/routes/object.go +++ b/lib/routes/object.go @@ -8,6 +8,7 @@ import ( "owo.codes/whats-this/api/lib/apierrors" "owo.codes/whats-this/api/lib/db" "owo.codes/whats-this/api/lib/middleware" + "owo.codes/whats-this/api/lib/ratelimiter" "github.com/go-chi/render" "github.com/pkg/errors" @@ -30,6 +31,16 @@ func Object(w http.ResponseWriter, r *http.Request) { panic(apierrors.Unauthorized) } + // Apply ratelimits + bucket := middleware.GetBucket(r) + err := bucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.objectCost")) + if err == ratelimiter.InsufficientTokens { + panic(apierrors.InsufficientTokens) + } + if err != nil { + panic(apierrors.InternalServerError) + } + // Get the key key := r.URL.Path if strings.HasPrefix(key, "/objects/") { diff --git a/lib/routes/shortenpolr.go b/lib/routes/shortenpolr.go index 5e6beacc43a5579650b526682a62c4d669c4f451..7ec2b096333a247951fa6cefa9635e45f7330318 100644 --- a/lib/routes/shortenpolr.go +++ b/lib/routes/shortenpolr.go @@ -11,6 +11,7 @@ import ( "owo.codes/whats-this/api/lib/apierrors" "owo.codes/whats-this/api/lib/db" "owo.codes/whats-this/api/lib/middleware" + "owo.codes/whats-this/api/lib/ratelimiter" "owo.codes/whats-this/api/lib/util" "github.com/rs/zerolog/log" @@ -55,6 +56,16 @@ func ShortenPolr(associateObjectsWithUser bool) func(http.ResponseWriter, *http. panic(apierrors.InvalidPolrAction) } + // Apply ratelimits + rBucket := middleware.GetBucket(r) + err := rBucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.shortenPolrCost")) + if err == ratelimiter.InsufficientTokens { + panic(apierrors.InsufficientTokens) + } + if err != nil { + panic(apierrors.InternalServerError) + } + // Get URL urlString := strings.TrimSpace(query.Get("url")) u, err := url.Parse(urlString) diff --git a/lib/routes/uploadpomf.go b/lib/routes/uploadpomf.go index cf8118d357dadcbd1f31930a96e4610b744a4734..bf411932f78c4bb3b15493309ecad845c1e2b154 100644 --- a/lib/routes/uploadpomf.go +++ b/lib/routes/uploadpomf.go @@ -13,6 +13,7 @@ import ( "owo.codes/whats-this/api/lib/apierrors" "owo.codes/whats-this/api/lib/db" "owo.codes/whats-this/api/lib/middleware" + "owo.codes/whats-this/api/lib/ratelimiter" "owo.codes/whats-this/api/lib/util" "github.com/go-chi/render" @@ -55,6 +56,16 @@ func UploadPomf(associateObjectsWithUser bool) func(http.ResponseWriter, *http.R panic(apierrors.Unauthorized) } + // Apply ratelimits + rBucket := middleware.GetBucket(r) + err := rBucket.TakeWithHeaders(w, viper.GetInt64("ratelimiter.uploadPomfCost")) + if err == ratelimiter.InsufficientTokens { + panic(apierrors.InsufficientTokens) + } + if err != nil { + panic(apierrors.InternalServerError) + } + // Check Content-Length if supplied contentLength := r.Header.Get("Content-Length") if contentLength == "" { diff --git a/main.go b/main.go index 1c7102c82aa087693e19f50418519621833445bb..79d59c766d4fad19730500a11f451f3545b08d50 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "owo.codes/whats-this/api/lib/db" "owo.codes/whats-this/api/lib/middleware" + "owo.codes/whats-this/api/lib/ratelimiter" "owo.codes/whats-this/api/lib/routes" "github.com/go-chi/chi" @@ -62,6 +63,15 @@ func init() { viper.SetDefault("database.objectBucket", "public") viper.SetDefault("http.listenAddress", ":49544") viper.BindPFlag("log.level", flags.Lookup("log-level")) // default is 1 (info) + viper.SetDefault("ratelimiter.enable", false) + viper.SetDefault("ratelimiter.defaultBucketCapacity", 50) + viper.SetDefault("ratelimiter.bucketExpiryDuration", time.Second*30) + viper.SetDefault("ratelimiter.uploadPomfCost", ratelimiter.UploadPomfCost) + viper.SetDefault("ratelimiter.shortenPolrCost", ratelimiter.ShortenPolrCost) + viper.SetDefault("ratelimiter.meCost", ratelimiter.MeCost) + viper.SetDefault("ratelimiter.listObjectsCost", ratelimiter.ListObjectsCost) + viper.SetDefault("ratelimiter.objectCost", ratelimiter.ObjectCost) + viper.SetDefault("ratelimiter.deleteObjectCost", ratelimiter.DeleteObjectCost) // Load configuration file viper.SetConfigType("toml") @@ -118,6 +128,12 @@ func init() { if viper.GetString("pomf.storageLocation") == "" { log.Fatal().Msg("Configuration: pomf.storageLocation is required") } + if viper.GetString("pomf.tempLocation") == "" { + log.Fatal().Msg("Configuration: pomf.tempLocation is required") + } + if viper.GetBool("ratelimiter.enable") && viper.GetString("ratelimiter.redisURL") == "" { + log.Fatal().Msg("Configuration: ratelimiter.redisURL is required when ratelimiter is enabled") + } } func main() { @@ -130,6 +146,14 @@ func main() { log.Fatal().Err(err).Msg("failed to connect to and ping the database") } + // Connect to Redis for ratelimiter + if viper.GetBool("ratelimiter.enable") { + err := ratelimiter.RedisConnect(viper.GetString("ratelimiter.redisURL")) + if err != nil { + log.Fatal().Err(err).Msg("failed to connect to and ping Redis for ratelimiting") + } + } + // Mount middleware r := chi.NewRouter() r.Use(middleware.Recoverer)