diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..6e43647 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..7ef8e8d --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/short_url.iml b/.idea/short_url.iml new file mode 100644 index 0000000..5dca121 --- /dev/null +++ b/.idea/short_url.iml @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/client.py b/client.py new file mode 100644 index 0000000..9ae5baa --- /dev/null +++ b/client.py @@ -0,0 +1,28 @@ +import hashlib +import base64 +import hmac +import requests +import time + + +def gen_sign(timestamp, secret): + string_to_sign = '{}\n{}'.format(timestamp, secret) + hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() + sign = base64.b64encode(hmac_code).decode('utf-8') + return sign + + +if __name__ == '__main__': + secret = "test" # the same secret in the config.yaml + url = "http://localhost:8000/" # your server address + ts = round(time.time()) + data = { + "ts": ts, + "sign": gen_sign(ts, secret), + "target": "", + "https": 1, + "expire": "" + } + r = requests.post(url, data=data) + print(r.text) + diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..a02ef6c --- /dev/null +++ b/config.yaml @@ -0,0 +1,17 @@ +database: + host: + port: + user: + password: + database: + +sign: + secret: + +server: + host: + port: + protocol: + ssl: + cert: + key: diff --git a/main.py b/main.py new file mode 100644 index 0000000..36a2d81 --- /dev/null +++ b/main.py @@ -0,0 +1,187 @@ +import aiomysql +import tornado.ioloop +import tornado.web +import datetime +import random +import string +import time +import hashlib +import base64 +import hmac +from urllib.parse import urlparse, urlunparse +import yaml + +with open("config.yaml") as f: + config = yaml.safe_load(f) +host = config['database']['host'] +port = int(config['database']['port']) +user = config['database']['user'] +password = config['database']['password'] +database = config['database']['database'] +secret = config['sign']['secret'] +server_port = config['server']['port'] +server_url = f"{config['server']['protocol']}://{config['server']['host']}:{server_port}" + + +class UrlHandler(tornado.web.RequestHandler): + @staticmethod + async def get_sources(): + conn = await aiomysql.connect( + host=host, + port=port, + user=user, + password=password, + db=database + ) + async with conn.cursor() as cursor: + await cursor.execute('SELECT source FROM urlList') + ret = await cursor.fetchall() + conn.close() + return tuple([i[0] for i in ret]) + + @staticmethod + async def get_redirect_url(code): + conn = await aiomysql.connect( + host=host, + port=port, + user=user, + password=password, + db=database + ) + sql = f"SELECT id, source, target, createTime, expireTime From urlList where source = '{code}';" + async with conn.cursor() as cursor: + await cursor.execute(sql) + ret = await cursor.fetchall() + conn.close() + if ret == (): + return {} + url_info = { + "id": ret[0][0], + "source": ret[0][1], + "target": ret[0][2], + "createTime": ret[0][3], + "expireTime": ret[0][4] + } + return url_info + + @staticmethod + async def get_is_expired(url_info): + if url_info == {}: + return True + expire_time = url_info.get("expireTime") + if expire_time is None or expire_time >= datetime.datetime.now(): + return False + sql = f"Delete From urlList where id = {url_info.get('id')};" + conn = await aiomysql.connect( + host=host, + port=port, + user=user, + password=password, + db=database + ) + async with conn.cursor() as cursor: + await cursor.execute(sql) + await conn.commit() + conn.close() + return True + + async def insert_url(self, url, sources, expire_time=None): + sql = f"SELECT id, source, target, createTime, expireTime From urlList where target = '{url}';" + conn = await aiomysql.connect( + host=host, + port=port, + user=user, + password=password, + db=database + ) + async with conn.cursor() as cursor: + await cursor.execute(sql) + ret = await cursor.fetchall() + if ret != (): + url_info = { + "id": ret[0][0], + "source": ret[0][1], + "target": ret[0][2], + "createTime": ret[0][3], + "expireTime": ret[0][4] + } + if not await self.get_is_expired(url_info) and expire_time is not None\ + and ret[0][4] == datetime.datetime.fromtimestamp(int(expire_time)): + return url_info.get("source") + new_source = "" + while new_source == "" or new_source in sources: + new_source = "".join(random.sample(string.ascii_letters, 7)) + if expire_time is None: + sql = "INSERT INTO urlList (`source`, `target`, `createTime`) VALUES " \ + f"('{new_source}', '{url}', '{datetime.datetime.now()}');" + async with conn.cursor() as cursor: + await cursor.execute(sql) + await conn.commit() + conn.close() + return new_source + expire = datetime.datetime.fromtimestamp(int(expire_time)) + if ret != () and expire == ret[0][4]: + return ret[0][1] + sql = "INSERT INTO urlList (`source`, `target`, `createTime`, `expireTime`) VALUES " \ + f"('{new_source}', '{url}', '{datetime.datetime.now()}', '{expire}');" + async with conn.cursor() as cursor: + await cursor.execute(sql) + await conn.commit() + conn.close() + return new_source + + @staticmethod + async def gen_sign(timestamp): + string_to_sign = '{}\n{}'.format(timestamp, secret) + hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() + sign = base64.b64encode(hmac_code).decode('utf-8') + return sign + + async def get(self): + source = self.request.path[1:] + url_info = await self.get_redirect_url(source) + if not await self.get_is_expired(url_info): + self.redirect(url_info["target"]) + return + self.set_status(404) + self.write("") + + async def post(self): + timestamp = self.get_argument("ts", None) + if abs(int(timestamp) - round(time.time())) > 600: + self.set_status(404) + self.write("") + sign = self.get_argument("sign", None) + if sign is None or sign != await self.gen_sign(timestamp): + self.set_status(403) + self.write("") + target = self.get_argument("target", None) + if target is None: + self.set_status(404) + self.write("") + parsed_url = urlparse(target) + https = self.get_argument("https", "0") + proto = "https" if https == "1" else "http" + if not parsed_url.scheme: + parsed_url = parsed_url._replace(scheme=proto) + final_url = urlunparse(parsed_url).replace("///", "//") + sources = await self.get_sources() + expire_timestamp = self.get_argument("expire", None) + url_info = await self.insert_url(final_url, sources, expire_timestamp) + self.write(host + url_info) + + +class Application(tornado.web.Application): + def __init__(self): + handlers = [ + ("/.*", UrlHandler), + (server_url + "/.*", UrlHandler) + ] + tornado.web.Application.__init__(self, handlers) + + +if __name__ == '__main__': + app = Application() + app.listen(server_port) + tornado.ioloop.IOLoop.current().start() + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0af264e Binary files /dev/null and b/requirements.txt differ