Skip to content
Snippets Groups Projects
Commit 6ee39e58 authored by lordjbs's avatar lordjbs
Browse files

Added basic ratelimiting

parent d260b12b
No related branches found
No related tags found
No related merge requests found
{ {
"port": 5000, "port": 5000,
"database": "database.db", "database": "database.db",
"url": "http://localhost" "url": "http://localhost",
"ratelimit": {
"enabled": true,
"timeUntilClear": 60,
"maxRequests": 30
}
} }
\ No newline at end of file
...@@ -18,16 +18,33 @@ import utils ...@@ -18,16 +18,33 @@ import utils
from database import Database from database import Database
import json import json
import time import time
from ratelimit import Ratelimit
VERSION = "v2.0" VERSION = "v2.0"
print("shortnex " + VERSION + "\nmade by jbs") print("shortnex " + VERSION + "\nmade by jbs")
print("shortnex | Loading config and flask") print("shortnex | Loading config")
with open('config.json') as _config: with open('config.json') as _config:
data = json.load(_config) data = json.load(_config)
config = {"port": data["port"], "database": data["database"], "url": data["url"]} config = {"port": data["port"], "database": data["database"], "url": data["url"], "rEnabled": data["ratelimit"]["enabled"]}
print("shortnex | Done...")
db = Database(config.get("database")) db = Database(config.get("database"))
print("shortnex | Loading ratelimit service...")
ratelimits = Ratelimit()
if config["rEnabled"]:
try:
ratelimits.loop()
print("shortnex | Done.")
except Exception:
print("shortnex | Failed loading ratelimit service, you could report this issue on git... Exiting..")
exit(0)
else:
print("Ratelimit service is disabled...")
print("shortnex | Loading flask...")
app = Flask(__name__, static_url_path='/static/') app = Flask(__name__, static_url_path='/static/')
...@@ -39,9 +56,12 @@ def index(): ...@@ -39,9 +56,12 @@ def index():
# curl --header "Content-Type: application/json, charset=utf-8" --request POST --data '{"url":"https://example.org"}' http://localhost:5000/shorten # curl --header "Content-Type: application/json, charset=utf-8" --request POST --data '{"url":"https://example.org"}' http://localhost:5000/shorten
@app.route("/shorten", methods=['POST']) @app.route("/shorten", methods=['POST'])
def shorten(): def shorten():
if not ratelimits.check(request.remote_addr):
return {"success": "false", "error": "You are being ratelimited."}
if request.method != "POST": if request.method != "POST":
return {"success": "false", "error": "This route is POST only."} return {"success": "false", "error": "This route is POST only."}
content = request.get_json() content = request.get_json()
if not "url" in content: if not "url" in content:
return {"success": False, "error": "The parameter 'url' does not exist.", "code": 1} return {"success": False, "error": "The parameter 'url' does not exist.", "code": 1}
......
# shortnex v2.0
# made by jbs (https://github.com/lordjbs/)
# Copyright (C) 2018-2020 jbs
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from utils import set_interval
import json
import threading
with open('config.json') as _config:
data = json.load(_config)
config = {"timeUntilClear": data["ratelimit"]["timeUntilClear"], "maxRequests": data["ratelimit"]["maxRequests"], "enabled": data["ratelimit"]["enabled"]}
class Ratelimit:
def __init__(self):
self.currentIPs = {}
def loop(self):
self.set_interval(config["timeUntilClear"])
def check(self, ip):
if not config["enabled"]:
return True
if ip in self.currentIPs:
self.addOneRequestToIp(ip)
if self.currentIPs[ip] > config["maxRequests"]:
return False
else:
return True
else:
self.addIpToCurrentIps(ip)
return True
def addIp(self, ip):
if ip in self.currentIPs:
try:
self.addOneRequestToIp(ip)
except KeyError:
self.addIpToCurrentIps(ip)
else:
self.addIpToCurrentIps(ip)
def addIpToCurrentIps(self, ip):
self.currentIPs[ip] = 1
def addOneRequestToIp(self, ip):
self.currentIPs[ip] += 1
# TODO: add that if the number of requests exceed a given number above the maxrequests value to add the ip to the array again
def clearIps(self):
self.currentIPs = {}
# https://stackoverflow.com/questions/2697039/python-equivalent-of-setinterval/14035296#14035296
def set_interval(self, sec):
a = sec
def func_wrapper():
self.set_interval(a)
self.clearIps()
t = threading.Timer(sec, func_wrapper)
t.start()
return t
\ No newline at end of file
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import random import random
import re import re
import string import string
import threading
regex = re.compile( regex = re.compile(
r'^(?:http|ftp)s?://' r'^(?:http|ftp)s?://'
...@@ -40,3 +41,13 @@ def returnProperURL(url): ...@@ -40,3 +41,13 @@ def returnProperURL(url):
def createID(): def createID():
return ''.join([random.choice(string.ascii_letters + string.digits) for n in range(6)]) return ''.join([random.choice(string.ascii_letters + string.digits) for n in range(6)])
# https://stackoverflow.com/questions/2697039/python-equivalent-of-setinterval/14035296#14035296
def set_interval(func, sec):
def func_wrapper():
set_interval(func, sec)
func()
t = threading.Timer(sec, func_wrapper)
t.start()
return t
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment