1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
| from quart import Quart as Flask, websocket, Websocket
import quart as flask
import asyncio
app = Flask(__name__)
class Manager(object):
map_lock = asyncio.Lock()
conn_map: dict[Websocket, asyncio.Queue] = {}
async def broadcast(self, data):
async with self.map_lock:
for queue in self.conn_map.values():
try:
await queue.put(data)
except Exception as e:
print(f"广播数据到队列时出错: {e}")
async def join(self, key: Websocket):
buf = asyncio.Queue()
async with self.map_lock:
self.conn_map[key] = buf
async def leave(self, key: Websocket):
async with self.map_lock:
self.conn_map.pop(key, None)
async def handle_recv(self, websocket: Websocket):
try:
while True:
data = await websocket.receive()
await self.broadcast(data)
except asyncio.CancelledError:
raise
except BaseException as e:
print(f"handle_recv err: {e}")
async def handle_send(self, websocket: Websocket):
queue = self.conn_map.get(websocket)
try:
while True:
data = await queue.get()
await websocket.send(data)
except asyncio.CancelledError:
raise
except BaseException as e:
print(f"handle_send 出错: {e}")
manager = Manager()
@app.websocket('/ws')
async def ws():
key = websocket
await manager.join(key)
tasks = [
asyncio.create_task(manager.handle_recv(key)),
asyncio.create_task(manager.handle_send(key)),
]
try:
# 这里必须使用gather, 而不是wait
await asyncio.gather(*tasks, )
# asyncio.wait(tasks, return_when=asyncio)
except asyncio.exceptions.CancelledError:
print('disconnect')
except BaseException as e:
print(f"WebSocket 处理出错: {e}", type(e))
finally:
await manager.leave(key)
if __name__ == '__main__':
app.run(debug=False)
|