File size: 6,053 Bytes
7630510
bf292d9
a3a1b05
ee5adc7
7630510
 
 
 
 
bf292d9
66c4f69
7630510
a9000ae
66c4f69
a9000ae
 
 
 
 
 
 
 
 
 
bbfbcdd
66c4f69
a3a1b05
 
 
7630510
a3a1b05
 
 
 
 
7630510
a3a1b05
 
 
bf292d9
 
 
 
 
a3a1b05
 
0bf22fe
bf292d9
7630510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bf22fe
66c4f69
 
05be9a1
a3a1b05
05be9a1
7630510
 
 
 
 
 
 
 
 
 
01785f3
05be9a1
01785f3
 
 
7630510
01785f3
7630510
 
 
 
 
 
 
 
 
66c4f69
 
 
7630510
 
66c4f69
7630510
 
66c4f69
7630510
 
 
 
 
bf292d9
 
 
 
66c4f69
 
 
 
 
7630510
66c4f69
a9000ae
7630510
 
 
 
 
 
 
 
 
bf292d9
0bf22fe
 
 
 
 
 
 
bf292d9
 
7630510
0bf22fe
bf292d9
 
7630510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# rabbit_base.py
from typing import Callable, Dict, List, Optional
from urllib.parse import urlsplit, unquote
import ssl
import json
import logging
import aio_pika

from config import settings

ExchangeResolver = Callable[[str], str]
logger = logging.getLogger(__name__)


def _normalize_exchange_type(val: str) -> aio_pika.ExchangeType:
    if isinstance(val, str):
        name = val.upper()
        if hasattr(aio_pika.ExchangeType, name):
            return getattr(aio_pika.ExchangeType, name)
        try:
            return aio_pika.ExchangeType(val.lower())
        except Exception:
            pass
    return aio_pika.ExchangeType.TOPIC


def _parse_amqp_url(url: str) -> dict:
    parts = urlsplit(url)
    return {
        "scheme": parts.scheme or "amqp",
        "host": parts.hostname or "localhost",
        "port": parts.port or (5671 if parts.scheme == "amqps" else 5672),
        "login": parts.username or "guest",
        "password": parts.password or "guest",
        "virtualhost": unquote(parts.path[1:] or "/"),
        "ssl": (parts.scheme == "amqps"),
    }


class RabbitBase:
    def __init__(self, exchange_type_resolver: Optional[ExchangeResolver] = None):
        self._conn: Optional[aio_pika.RobustConnection] = None
        self._chan: Optional[aio_pika.RobustChannel] = None
        self._exchanges: Dict[str, aio_pika.Exchange] = {}
        self._exchange_type_resolver = exchange_type_resolver or (
            lambda _: settings.RABBIT_EXCHANGE_TYPE
        )

    def is_connected(self) -> bool:
        return bool(
            self._conn and not self._conn.is_closed and
            self._chan and not self._chan.is_closed
        )

    async def close(self) -> None:
        try:
            if self._chan and not self._chan.is_closed:
                logger.info("Closing AMQP channel")
                await self._chan.close()
        finally:
            self._chan = None
        try:
            if self._conn and not self._conn.is_closed:
                logger.info("Closing AMQP connection")
                await self._conn.close()
        finally:
            self._conn = None
        logger.info("AMQP connection closed")

    async def connect(self) -> None:
        if self._conn and not self._conn.is_closed and self._chan and not self._chan.is_closed:
            return

        conn_kwargs = _parse_amqp_url(str(settings.AMQP_URL))

        safe_target = {
            "scheme": conn_kwargs["scheme"],
            "host": conn_kwargs["host"],
            "port": conn_kwargs["port"],
            "virtualhost": conn_kwargs["virtualhost"],
            "ssl": conn_kwargs["ssl"],
            "login": conn_kwargs["login"],
        }
        logger.info("AMQP connect -> %s", json.dumps(safe_target))

        ssl_ctx = None
        if conn_kwargs.get("ssl"):
            ssl_ctx = ssl.create_default_context()
            ssl_ctx.check_hostname = False
            ssl_ctx.verify_mode = ssl.CERT_NONE
            logger.warning("AMQP TLS verification is DISABLED (CERT_NONE)")

        try:
            self._conn = await aio_pika.connect_robust(
                host=conn_kwargs["host"],
                port=conn_kwargs["port"],
                login=conn_kwargs["login"],
                password=conn_kwargs["password"],
                virtualhost=conn_kwargs["virtualhost"],
                ssl=conn_kwargs["ssl"],
                ssl_context=ssl_ctx,
                heartbeat=60,              # keepalive during long CPU work
                timeout=30,
                client_properties={"connection_name": "hf_backend_publisher"},
            )
            logger.info("AMQP connection established")

            self._chan = await self._conn.channel()
            logger.info("AMQP channel created")

            await self._chan.set_qos(prefetch_count=settings.RABBIT_PREFETCH)
            logger.info("AMQP QoS set (prefetch=%s)", settings.RABBIT_PREFETCH)
        except Exception:
            logger.exception("AMQP connection/channel setup failed")
            raise

    async def ensure_exchange(self, name: str) -> aio_pika.Exchange:
        await self.connect()
        if name in self._exchanges:
            ex = self._exchanges[name]
            if ex.channel and not ex.channel.is_closed:
                return ex
            # drop stale cache and recreate
            self._exchanges.pop(name, None)

        ex_type_str = self._exchange_type_resolver(name)
        ex_type = _normalize_exchange_type(ex_type_str)

        try:
            ex = await self._chan.declare_exchange(name, ex_type, durable=True)
            self._exchanges[name] = ex
            logger.info("Exchange declared: name=%s type=%s durable=true", name, ex_type.value)
            return ex
        except Exception:
            logger.exception("Failed declaring exchange: %s (%s)", name, ex_type_str)
            raise

    async def declare_queue_bind(
        self,
        exchange: str,
        queue_name: str,
        routing_keys: List[str],
        ttl_ms: Optional[int],
    ):
        await self.connect()
        ex = await self.ensure_exchange(exchange)

        args: Dict[str, int] = {}
        if ttl_ms:
            args["x-message-ttl"] = ttl_ms

        try:
            q = await self._chan.declare_queue(
                queue_name,
                durable=True,
                exclusive=False,
                auto_delete=True,
                arguments=args,
            )
            logger.info(
                "Queue declared: name=%s durable=true auto_delete=true args=%s",
                queue_name, args or {}
            )
            for rk in routing_keys or [""]:
                await q.bind(ex, rk)
                logger.info("Queue bound: queue=%s exchange=%s rk='%s'", queue_name, exchange, rk)
            return q
        except Exception:
            logger.exception(
                "Failed declare/bind queue: queue=%s exchange=%s rks=%s args=%s",
                queue_name, exchange, routing_keys, args or {}
            )
            raise