【通用消息通知服务】0x3 - 发送我们第一条消息
项目地址: A generic message notification system[Github]
实现接收/发送Websocket消息
Websocket Connection Pool
import asyncio
from asyncio.queues import Queue
from asyncio.queues import QueueEmpty
from contextlib import suppress
from typing import Anyimport async_timeout
import orjson
from sanic.log import logger
from ulid import ULIDfrom common.depend import DependencyPING = "#ping"
PONG = "#pong"class WebsocketConnectionPoolDependency(Dependency, dependency_name="WebsocketPool", dependency_alias="ws_pool"
):def __init__(self, app) -> None:super().__init__(app)self.lock = asyncio.Lock()self.connections = {} # 存储websocket connectionsself.send_queues = {} # 各websocket发送队列self.recv_queues = {} # 各websocket接收消息队列self.close_callbacks = {} # websocket销毁回调self.listeners = {} # 连接监听函数def _gen_id(self) -> str:return str(ULID())async def add_connection(self, connection) -> str:async with self.lock:id = self._gen_id()self.connections[id] = connectionself.send_queues[id] = Queue()self.app.add_task(self.send_task(self.send_queues[id], connection),name=f"websocket_{id}_send_task",)self.recv_queues[id] = Queue()self.app.add_task(self.recv_task(self.recv_queues[id], connection),name=f"websocket_{id}_recv_task",)self.app.add_task(self.notify_task(id), name=f"websocket_{id}_notify_task")self.app.add_task(self.is_alive_task(id), name=f"websocket_{id}_is_alive_task")setattr(connection, "_id", id)return connection._iddef get_connection(self, connection_id: str):return self.connections.get(connection_id)async def add_listener(self, connection_id, handler) -> str:async with self.lock:id = self._gen_id()self.listeners.setdefault(connection_id, {}).update({id: handler})return idasync def remove_listener(self, connection_id, listener_id):async with self.lock:self.listeners.get(connection_id, {}).pop(listener_id, None)async def add_close_callback(self, connection_id, callback):async with self.lock:self.close_callbacks.setdefault(connection_id, []).append(callback)def is_alive(self, connection_id: str):if hasattr(connection_id, "_id"):connection_id = connection_id._idreturn connection_id in self.connectionsasync def remove_connection(self, connection: Any):if hasattr(connection, "_id"):connection_id = connection._idelse:connection_id = connectionif connection_id not in self.connections:# removed alreadyreturnasync with self.lock:logger.info(f"remove connection: {connection_id}")with suppress(Exception):await self.app.cancel_task(f"websocket_{connection_id}_send_task")with suppress(Exception):await self.app.cancel_task(f"websocket_{connection_id}_recv_task")with suppress(Exception):await self.app.cancel_task(f"websocket_{connection_id}_notify_task")with suppress(Exception):await self.app.cancel_task(f"websocket_{connection_id}_is_alive_task")if connection_id in self.send_queues:del self.send_queues[connection_id]if connection_id in self.recv_queues:del self.recv_queues[connection_id]if connection_id in self.listeners:del self.listeners[connection_id]if connection_id in self.close_callbacks:await self.do_close_callbacks(connection_id)del self.close_callbacks[connection_id]if connection_id in self.connections:del self.connections[connection_id]async def do_close_callbacks(self, connection_id):for cb in self.close_callbacks.get(connection_id, []):self.app.add_task(cb(connection_id))async def prepare(self):self.is_prepared = Truelogger.info("dependency:WebsocketPool is prepared")return self.is_preparedasync def check(self):return Trueasync def send_task(self, queue, connection):while self.is_alive(connection):try:data = queue.get_nowait()except QueueEmpty:await asyncio.sleep(0)continuetry:if isinstance(data, (bytes, str, int)):await connection.send(data)else:await connection.send(orjson.dumps(data).decode())queue.task_done()except Exception as err:breakasync def recv_task(self, queue, connection):while self.is_alive(connection):try:data = await connection.recv()await queue.put(data)logger.info(f"recv message: {data} from connection: {connection._id}")except Exception as err:breakasync def notify_task(self, connection_id):while self.is_alive(connection_id):try:logger.info(f"notify connection: {connection_id}'s listeners")data = await self.recv_queues[connection_id].get()for listener in self.listeners.get(connection_id, {}).values():await listener(connection_id, data)except Exception as err:passasync def is_alive_task(self, connection_id: str):if hasattr(connection_id, "_id"):connection_id = connection_id._idget_pong = asyncio.Event()async def wait_pong(connection_id, data):if data != PONG:returnget_pong.set()while True:get_pong.clear()await self.send(connection_id, PING)listener_id = await self.add_listener(connection_id, wait_pong)with suppress(asyncio.TimeoutError):async with async_timeout.timeout(self.app.config.WEBSOCKET_PING_TIMEOUT):await get_pong.wait()await self.remove_listener(connection_id, listener_id)if get_pong.is_set():# this connection is closedawait asyncio.sleep(self.app.config.WEBSOCKET_PING_INTERVAL)else:await self.remove_connection(connection_id)async def wait_closed(self, connection_id: str):"""if negative=True, only release when client close this connection."""while self.is_alive(connection_id):await asyncio.sleep(0)return Falseasync def send(self, connection_id: str, data: Any) -> bool:if not self.is_alive(connection_id):return Falseif connection_id not in self.send_queues:return Falseawait self.send_queues[connection_id].put(data)return True
Websocket Provider
from typing import Dict
from typing import List
from typing import Unionfrom pydantic import BaseModel
from pydantic import field_serializer
from sanic.log import loggerfrom apps.message.common.constants import MessageProviderType
from apps.message.common.constants import MessageStatus
from apps.message.common.interfaces import SendResult
from apps.message.providers.base import MessageProviderModel
from apps.message.validators.types import EndpointExID
from apps.message.validators.types import EndpointTag
from apps.message.validators.types import ETag
from apps.message.validators.types import ExID
from utils import get_appclass WebsocketMessageProviderModel(MessageProviderModel):class Info:name = "websocket"description = "Bio-Channel Communication"type = MessageProviderType.WEBSOCKETclass Capability:is_enabled = Truecan_send = Trueclass Message(BaseModel):connections: List[Union[EndpointTag, EndpointExID, str]]action: strpayload: Union[List, Dict, str, bytes]@field_serializer("connections")def serialize_connections(self, connections):return list(set(map(str, connections)))async def send(self, provider_id, message: Message) -> SendResult:app = get_app()websocket_pool = app.ctx.ws_poolsent_list = set()connections = []for connection in message.connections:if isinstance(connection, ETag):connections.extend([wfor c in await connection.decode()for w in c.get("websockets", [])])elif isinstance(connection, ExID):endpoint = await connection.decode()if endpoint:connections.extend(endpoint.get("websockets", []))else:connections.append(connection)connections = list(set(filter(lambda x: app.ctx.ws_pool.is_alive(connection), connections)))# logger.info(f"sending websocket message to {connections}")for connection in connections:if await websocket_pool.send(connection, data=message.model_dump_json(exclude=["connections"])):sent_list.add(connection)if sent_list:return SendResult(provider_id=provider_id, message=message, status=MessageStatus.SUCCEEDED)else:return SendResult(provider_id=provider_id, message=message, status=MessageStatus.FAILED)
websocket接口
@app.websocket("/websocket")
async def handle_websocket(request, ws):from apps.endpoint.listeners import register_websocket_endpointfrom apps.endpoint.listeners import unregister_websocket_endpointcon_id = Nonetry:ctx = request.app.ctxcon_id = await ctx.ws_pool.add_connection(ws)logger.info(f"new connection connected -> {con_id}")await ctx.ws_pool.add_listener(con_id, register_websocket_endpoint)await ctx.ws_pool.add_close_callback(con_id, unregister_websocket_endpoint)await ctx.ws_pool.send(con_id, data={"action": "on.connect", "payload": {"connection_id": con_id}})await ctx.ws_pool.wait_closed(con_id) # 等待连接断开finally:# 如果连接被客户端断开, handle_websocket将会被直接销毁, 所以销毁处理需要放在finally。request.app.add_task(request.app.ctx.ws_pool.remove_connection(con_id))