short_url/main.py

188 lines
6.3 KiB
Python
Raw Normal View History

2023-04-16 21:47:04 +08:00
import aiomysql
import tornado.ioloop
import tornado.web
import datetime
import random
import string
import time
import hashlib
import base64
import hmac
from urllib.parse import urlparse, urlunparse
import yaml
with open("config.yaml") as f:
config = yaml.safe_load(f)
host = config['database']['host']
port = int(config['database']['port'])
user = config['database']['user']
password = config['database']['password']
database = config['database']['database']
secret = config['sign']['secret']
server_port = config['server']['port']
server_url = f"{config['server']['protocol']}://{config['server']['host']}:{server_port}"
class UrlHandler(tornado.web.RequestHandler):
@staticmethod
async def get_sources():
conn = await aiomysql.connect(
host=host,
port=port,
user=user,
password=password,
db=database
)
async with conn.cursor() as cursor:
await cursor.execute('SELECT source FROM urlList')
ret = await cursor.fetchall()
conn.close()
return tuple([i[0] for i in ret])
@staticmethod
async def get_redirect_url(code):
conn = await aiomysql.connect(
host=host,
port=port,
user=user,
password=password,
db=database
)
sql = f"SELECT id, source, target, createTime, expireTime From urlList where source = '{code}';"
async with conn.cursor() as cursor:
await cursor.execute(sql)
ret = await cursor.fetchall()
conn.close()
if ret == ():
return {}
url_info = {
"id": ret[0][0],
"source": ret[0][1],
"target": ret[0][2],
"createTime": ret[0][3],
"expireTime": ret[0][4]
}
return url_info
@staticmethod
async def get_is_expired(url_info):
if url_info == {}:
return True
expire_time = url_info.get("expireTime")
if expire_time is None or expire_time >= datetime.datetime.now():
return False
sql = f"Delete From urlList where id = {url_info.get('id')};"
conn = await aiomysql.connect(
host=host,
port=port,
user=user,
password=password,
db=database
)
async with conn.cursor() as cursor:
await cursor.execute(sql)
await conn.commit()
conn.close()
return True
async def insert_url(self, url, sources, expire_time=None):
sql = f"SELECT id, source, target, createTime, expireTime From urlList where target = '{url}';"
conn = await aiomysql.connect(
host=host,
port=port,
user=user,
password=password,
db=database
)
async with conn.cursor() as cursor:
await cursor.execute(sql)
ret = await cursor.fetchall()
if ret != ():
url_info = {
"id": ret[0][0],
"source": ret[0][1],
"target": ret[0][2],
"createTime": ret[0][3],
"expireTime": ret[0][4]
}
if not await self.get_is_expired(url_info) and expire_time is not None\
and ret[0][4] == datetime.datetime.fromtimestamp(int(expire_time)):
return url_info.get("source")
new_source = ""
while new_source == "" or new_source in sources:
new_source = "".join(random.sample(string.ascii_letters, 7))
if expire_time is None:
sql = "INSERT INTO urlList (`source`, `target`, `createTime`) VALUES " \
f"('{new_source}', '{url}', '{datetime.datetime.now()}');"
async with conn.cursor() as cursor:
await cursor.execute(sql)
await conn.commit()
conn.close()
return new_source
expire = datetime.datetime.fromtimestamp(int(expire_time))
if ret != () and expire == ret[0][4]:
return ret[0][1]
sql = "INSERT INTO urlList (`source`, `target`, `createTime`, `expireTime`) VALUES " \
f"('{new_source}', '{url}', '{datetime.datetime.now()}', '{expire}');"
async with conn.cursor() as cursor:
await cursor.execute(sql)
await conn.commit()
conn.close()
return new_source
@staticmethod
async def gen_sign(timestamp):
string_to_sign = '{}\n{}'.format(timestamp, secret)
hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest()
sign = base64.b64encode(hmac_code).decode('utf-8')
return sign
async def get(self):
source = self.request.path[1:]
url_info = await self.get_redirect_url(source)
if not await self.get_is_expired(url_info):
self.redirect(url_info["target"])
return
self.set_status(404)
self.write("")
async def post(self):
timestamp = self.get_argument("ts", None)
if abs(int(timestamp) - round(time.time())) > 600:
self.set_status(404)
self.write("")
sign = self.get_argument("sign", None)
if sign is None or sign != await self.gen_sign(timestamp):
self.set_status(403)
self.write("")
target = self.get_argument("target", None)
if target is None:
self.set_status(404)
self.write("")
parsed_url = urlparse(target)
https = self.get_argument("https", "0")
proto = "https" if https == "1" else "http"
if not parsed_url.scheme:
parsed_url = parsed_url._replace(scheme=proto)
final_url = urlunparse(parsed_url).replace("///", "//")
sources = await self.get_sources()
expire_timestamp = self.get_argument("expire", None)
url_info = await self.insert_url(final_url, sources, expire_timestamp)
self.write(host + url_info)
class Application(tornado.web.Application):
def __init__(self):
handlers = [
("/.*", UrlHandler),
(server_url + "/.*", UrlHandler)
]
tornado.web.Application.__init__(self, handlers)
if __name__ == '__main__':
app = Application()
app.listen(server_port)
tornado.ioloop.IOLoop.current().start()