diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..6e43647
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..7ef8e8d
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/short_url.iml b/.idea/short_url.iml
new file mode 100644
index 0000000..5dca121
--- /dev/null
+++ b/.idea/short_url.iml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/client.py b/client.py
new file mode 100644
index 0000000..9ae5baa
--- /dev/null
+++ b/client.py
@@ -0,0 +1,28 @@
+import hashlib
+import base64
+import hmac
+import requests
+import time
+
+
+def gen_sign(timestamp, secret):
+ string_to_sign = '{}\n{}'.format(timestamp, secret)
+ hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest()
+ sign = base64.b64encode(hmac_code).decode('utf-8')
+ return sign
+
+
+if __name__ == '__main__':
+ secret = "test" # the same secret in the config.yaml
+ url = "http://localhost:8000/" # your server address
+ ts = round(time.time())
+ data = {
+ "ts": ts,
+ "sign": gen_sign(ts, secret),
+ "target": "",
+ "https": 1,
+ "expire": ""
+ }
+ r = requests.post(url, data=data)
+ print(r.text)
+
diff --git a/config.yaml b/config.yaml
new file mode 100644
index 0000000..a02ef6c
--- /dev/null
+++ b/config.yaml
@@ -0,0 +1,17 @@
+database:
+ host:
+ port:
+ user:
+ password:
+ database:
+
+sign:
+ secret:
+
+server:
+ host:
+ port:
+ protocol:
+ ssl:
+ cert:
+ key:
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..36a2d81
--- /dev/null
+++ b/main.py
@@ -0,0 +1,187 @@
+import aiomysql
+import tornado.ioloop
+import tornado.web
+import datetime
+import random
+import string
+import time
+import hashlib
+import base64
+import hmac
+from urllib.parse import urlparse, urlunparse
+import yaml
+
+with open("config.yaml") as f:
+ config = yaml.safe_load(f)
+host = config['database']['host']
+port = int(config['database']['port'])
+user = config['database']['user']
+password = config['database']['password']
+database = config['database']['database']
+secret = config['sign']['secret']
+server_port = config['server']['port']
+server_url = f"{config['server']['protocol']}://{config['server']['host']}:{server_port}"
+
+
+class UrlHandler(tornado.web.RequestHandler):
+ @staticmethod
+ async def get_sources():
+ conn = await aiomysql.connect(
+ host=host,
+ port=port,
+ user=user,
+ password=password,
+ db=database
+ )
+ async with conn.cursor() as cursor:
+ await cursor.execute('SELECT source FROM urlList')
+ ret = await cursor.fetchall()
+ conn.close()
+ return tuple([i[0] for i in ret])
+
+ @staticmethod
+ async def get_redirect_url(code):
+ conn = await aiomysql.connect(
+ host=host,
+ port=port,
+ user=user,
+ password=password,
+ db=database
+ )
+ sql = f"SELECT id, source, target, createTime, expireTime From urlList where source = '{code}';"
+ async with conn.cursor() as cursor:
+ await cursor.execute(sql)
+ ret = await cursor.fetchall()
+ conn.close()
+ if ret == ():
+ return {}
+ url_info = {
+ "id": ret[0][0],
+ "source": ret[0][1],
+ "target": ret[0][2],
+ "createTime": ret[0][3],
+ "expireTime": ret[0][4]
+ }
+ return url_info
+
+ @staticmethod
+ async def get_is_expired(url_info):
+ if url_info == {}:
+ return True
+ expire_time = url_info.get("expireTime")
+ if expire_time is None or expire_time >= datetime.datetime.now():
+ return False
+ sql = f"Delete From urlList where id = {url_info.get('id')};"
+ conn = await aiomysql.connect(
+ host=host,
+ port=port,
+ user=user,
+ password=password,
+ db=database
+ )
+ async with conn.cursor() as cursor:
+ await cursor.execute(sql)
+ await conn.commit()
+ conn.close()
+ return True
+
+ async def insert_url(self, url, sources, expire_time=None):
+ sql = f"SELECT id, source, target, createTime, expireTime From urlList where target = '{url}';"
+ conn = await aiomysql.connect(
+ host=host,
+ port=port,
+ user=user,
+ password=password,
+ db=database
+ )
+ async with conn.cursor() as cursor:
+ await cursor.execute(sql)
+ ret = await cursor.fetchall()
+ if ret != ():
+ url_info = {
+ "id": ret[0][0],
+ "source": ret[0][1],
+ "target": ret[0][2],
+ "createTime": ret[0][3],
+ "expireTime": ret[0][4]
+ }
+ if not await self.get_is_expired(url_info) and expire_time is not None\
+ and ret[0][4] == datetime.datetime.fromtimestamp(int(expire_time)):
+ return url_info.get("source")
+ new_source = ""
+ while new_source == "" or new_source in sources:
+ new_source = "".join(random.sample(string.ascii_letters, 7))
+ if expire_time is None:
+ sql = "INSERT INTO urlList (`source`, `target`, `createTime`) VALUES " \
+ f"('{new_source}', '{url}', '{datetime.datetime.now()}');"
+ async with conn.cursor() as cursor:
+ await cursor.execute(sql)
+ await conn.commit()
+ conn.close()
+ return new_source
+ expire = datetime.datetime.fromtimestamp(int(expire_time))
+ if ret != () and expire == ret[0][4]:
+ return ret[0][1]
+ sql = "INSERT INTO urlList (`source`, `target`, `createTime`, `expireTime`) VALUES " \
+ f"('{new_source}', '{url}', '{datetime.datetime.now()}', '{expire}');"
+ async with conn.cursor() as cursor:
+ await cursor.execute(sql)
+ await conn.commit()
+ conn.close()
+ return new_source
+
+ @staticmethod
+ async def gen_sign(timestamp):
+ string_to_sign = '{}\n{}'.format(timestamp, secret)
+ hmac_code = hmac.new(string_to_sign.encode("utf-8"), digestmod=hashlib.sha256).digest()
+ sign = base64.b64encode(hmac_code).decode('utf-8')
+ return sign
+
+ async def get(self):
+ source = self.request.path[1:]
+ url_info = await self.get_redirect_url(source)
+ if not await self.get_is_expired(url_info):
+ self.redirect(url_info["target"])
+ return
+ self.set_status(404)
+ self.write("")
+
+ async def post(self):
+ timestamp = self.get_argument("ts", None)
+ if abs(int(timestamp) - round(time.time())) > 600:
+ self.set_status(404)
+ self.write("")
+ sign = self.get_argument("sign", None)
+ if sign is None or sign != await self.gen_sign(timestamp):
+ self.set_status(403)
+ self.write("")
+ target = self.get_argument("target", None)
+ if target is None:
+ self.set_status(404)
+ self.write("")
+ parsed_url = urlparse(target)
+ https = self.get_argument("https", "0")
+ proto = "https" if https == "1" else "http"
+ if not parsed_url.scheme:
+ parsed_url = parsed_url._replace(scheme=proto)
+ final_url = urlunparse(parsed_url).replace("///", "//")
+ sources = await self.get_sources()
+ expire_timestamp = self.get_argument("expire", None)
+ url_info = await self.insert_url(final_url, sources, expire_timestamp)
+ self.write(host + url_info)
+
+
+class Application(tornado.web.Application):
+ def __init__(self):
+ handlers = [
+ ("/.*", UrlHandler),
+ (server_url + "/.*", UrlHandler)
+ ]
+ tornado.web.Application.__init__(self, handlers)
+
+
+if __name__ == '__main__':
+ app = Application()
+ app.listen(server_port)
+ tornado.ioloop.IOLoop.current().start()
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..0af264e
Binary files /dev/null and b/requirements.txt differ