import aiomysql import tornado.ioloop import tornado.web import datetime import random import string import time import hashlib import base64 import hmac from urllib.parse import urlparse, urlunparse import yaml with open("config.yaml") as f: config = yaml.safe_load(f) host = config['database']['host'] port = int(config['database']['port']) user = config['database']['user'] password = config['database']['password'] database = config['database']['database'] secret = config['sign']['secret'] server_port = config['server']['port'] server_url = f"{config['server']['protocol']}://{config['server']['host']}:{server_port}" class UrlHandler(tornado.web.RequestHandler): @staticmethod async def get_sources(): conn = await aiomysql.connect( host=host, port=port, user=user, password=password, db=database ) async with conn.cursor() as cursor: await cursor.execute('SELECT source FROM urlList') ret = await cursor.fetchall() conn.close() return tuple([i[0] for i in ret]) @staticmethod async def get_redirect_url(code): conn = await aiomysql.connect( host=host, port=port, user=user, password=password, db=database ) sql = f"SELECT id, source, target, createTime, expireTime From urlList where source = '{code}';" async with conn.cursor() as cursor: await cursor.execute(sql) ret = await cursor.fetchall() conn.close() if ret == (): return {} url_info = { "id": ret[0][0], "source": ret[0][1], "target": ret[0][2], "createTime": ret[0][3], "expireTime": ret[0][4] } return url_info @staticmethod async def get_is_expired(url_info): if url_info == {}: return True expire_time = url_info.get("expireTime") if expire_time is None or expire_time >= datetime.datetime.now(): return False sql = f"Delete From urlList where id = {url_info.get('id')};" conn = await aiomysql.connect( host=host, port=port, user=user, password=password, db=database ) async with conn.cursor() as cursor: await cursor.execute(sql) await conn.commit() conn.close() return True async def insert_url(self, url, sources, expire_time=None): sql = f"SELECT id, source, target, createTime, expireTime From urlList where target = '{url}';" conn = await aiomysql.connect( host=host, port=port, user=user, password=password, db=database ) async with conn.cursor() as cursor: await cursor.execute(sql) ret = await cursor.fetchall() if ret != (): url_info = { "id": ret[0][0], "source": ret[0][1], "target": ret[0][2], "createTime": ret[0][3], "expireTime": ret[0][4] } if not await self.get_is_expired(url_info) and expire_time is not None\ and ret[0][4] == datetime.datetime.fromtimestamp(int(expire_time)): return url_info.get("source") new_source = "" while new_source == "" or new_source in sources: new_source = "".join(random.sample(string.ascii_letters, 7)) if expire_time is None: sql = "INSERT INTO urlList (`source`, `target`, `createTime`) VALUES " \ f"('{new_source}', '{url}', '{datetime.datetime.now()}');" async with conn.cursor() as cursor: await cursor.execute(sql) await conn.commit() conn.close() return new_source expire = datetime.datetime.fromtimestamp(int(expire_time)) if ret != () and expire == ret[0][4]: return ret[0][1] sql = "INSERT INTO urlList (`source`, `target`, `createTime`, `expireTime`) VALUES " \ f"('{new_source}', '{url}', '{datetime.datetime.now()}', '{expire}');" async with conn.cursor() as cursor: await cursor.execute(sql) await conn.commit() conn.close() return new_source @staticmethod async def gen_sign(timestamp): string_to_sign = '{}\n{}'.format(timestamp, secret) hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest() sign = base64.b64encode(hmac_code).decode('utf-8') return sign async def get(self): source = self.request.path[1:] url_info = await self.get_redirect_url(source) if not await self.get_is_expired(url_info): self.redirect(url_info["target"]) return self.set_status(404) self.write("") async def post(self): timestamp = self.get_argument("ts", None) if abs(int(timestamp) - round(time.time())) > 600: self.set_status(404) self.write("") sign = self.get_argument("sign", None) if sign is None or sign != await self.gen_sign(timestamp): self.set_status(403) self.write("") target = self.get_argument("target", None) if target is None: self.set_status(404) self.write("") parsed_url = urlparse(target) https = self.get_argument("https", "0") proto = "https" if https == "1" else "http" if not parsed_url.scheme: parsed_url = parsed_url._replace(scheme=proto) final_url = urlunparse(parsed_url).replace("///", "//") sources = await self.get_sources() expire_timestamp = self.get_argument("expire", None) url_info = await self.insert_url(final_url, sources, expire_timestamp) self.write(host + url_info) class Application(tornado.web.Application): def __init__(self): handlers = [ ("/.*", UrlHandler), (server_url + "/.*", UrlHandler) ] tornado.web.Application.__init__(self, handlers) if __name__ == '__main__': app = Application() app.listen(server_port) tornado.ioloop.IOLoop.current().start()