File: //opt/imunify360/venv/lib64/python3.11/site-packages/im360/plugins/resident/client360.py
import asyncio
import datetime
import json
import os
import random
import ssl
import time
from contextlib import suppress
from logging import getLogger
from typing import Optional
from uuid import uuid4
import aiohttp
from humanize import naturaldelta
from defence360agent.api import health
from defence360agent.api.server import APIError, APIErrorTooManyRequests
from defence360agent.api.server.send_message import BaseSendMessageAPI
from defence360agent.contracts.config import Core as CoreConfig
from defence360agent.contracts.config import bool_from_envvar, int_from_envvar
from defence360agent.contracts.license import LicenseCLN
from defence360agent.contracts.messages import (
Message,
MessageType,
MessageNotFoundError,
ReportTarget,
)
from defence360agent.contracts.plugins import (
MessageSink,
MessageSource,
expect,
)
from defence360agent.internals.persistent_message import (
PersistentMessagesQueue,
)
from defence360agent.internals.logger import getNetworkLogger
from defence360agent.internals.logging_protocol import LoggingProtocol
from defence360agent.model.instance import db
from defence360agent.utils import recurring_check
from defence360agent.utils.buffer import LineBuffer
from defence360agent.utils.common import DAY, rate_limit
from defence360agent.utils.json import ServerJSONEncoder
from im360.model.messages_to_send import MessageToSend
logger, network_logger = getLogger(__name__), getNetworkLogger(__name__)
throttled_log_error = rate_limit(period=DAY)(logger.error)
class _ServerErrorBacklog:
def __init__(self):
self._backlog_length = 0
def log_server_reponse(self, decoded_json):
if decoded_json.get("method") == "ERROR":
self._backlog_length += 1
# check license and if failed, we should't reconnect immediately
# only after 3 hour
elif (
decoded_json.get("method") == "LICENSE_STATUS"
and decoded_json.get("success") is False
):
self._backlog_length += 20 # 3 hour
else:
self._backlog_length = 0
def log_protocol_error(self):
self._backlog_length += 1
def log_connection_error(self):
self._backlog_length += 1
def log_connection_lost(self):
self._backlog_length += 1
def length(self):
return self._backlog_length
def _suppress_cancelled_error(fun):
async def wrapper(*args, **kwargs):
with suppress(asyncio.CancelledError):
return await fun(*args, **kwargs)
return wrapper
class _Config:
HOST = os.environ.get(
"IMUNIFY360_CLIENT360_HOST", "imunify360.cloudlinux.com"
)
PORT = int_from_envvar("IMUNIFY360_CLIENT360_PORT", 443)
USE_SSL = bool_from_envvar("IMUNIFY360_CLIENT360_USE_SSL", True)
PROTOCOL_VERSION = "2.3"
PROTOCOL_SEND_DEBUGINFO = True
RECONNECT_WITH_TIMEOUT = 5
# exponentially grow RECONNECT_WITH_TIMEOUT
# and randomize it via multuplying by random.uniform(1-0.5, 1+0.5):
RECONNECT_FLASHMOB_PROTECTION_COEF = 0.5
MAX_RECONNECT_TIMEOUT = int_from_envvar(
"IMUNIFY360_CLIENT360_MAX_RECONNECT_TIMEOUT",
datetime.timedelta(minutes=5).total_seconds(),
)
class SendMessageAPI(BaseSendMessageAPI):
async def _send_request(self, message_method, headers, post_data):
url = self._BASE_URL + self.URL.format(method=message_method)
try:
timeout = aiohttp.ClientTimeout(
total=CoreConfig.DEFAULT_SOCKET_TIMEOUT
)
async with aiohttp.ClientSession() as session:
async with session.post(
url, data=post_data, headers=headers, timeout=timeout
) as response:
if response.status != 200:
raise APIError(f"status code is {response.status}")
return await response.json()
except (
UnicodeDecodeError,
json.JSONDecodeError,
aiohttp.ClientError,
asyncio.TimeoutError,
) as e:
status_code = getattr(e, "status", None)
if status_code == 429:
raise APIErrorTooManyRequests(
"request failed, reason: %s" % (e,), status_code
) from e
raise APIError(f"request failed {e}") from e
class Client360(MessageSink, MessageSource):
#
# Intentionally derive from MessageSink for @expect magic to work
#
class _ServerProtocol(asyncio.Protocol):
def __init__(self, loop, host: str, port: int):
self._loop = loop
self._host = host
self._port = port
self._sink_future = asyncio.Future(loop=loop)
self._transport = None
self._line_buffer = LineBuffer()
self._server_error_backlog = _ServerErrorBacklog()
self._connection_lost_event = asyncio.Event()
self._queue = PersistentMessagesQueue(model=MessageToSend)
self._last_processed_seq_number = None
self._api = SendMessageAPI()
async def shutdown(self):
if self.is_connected():
self._transport.close()
self._queue.push_buffer_to_storage()
def set_sink_to(self, sink):
self._sink_future.set_result(sink)
@staticmethod
def _ssl_config():
if not _Config.USE_SSL:
return False
return ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
@staticmethod
def _reconnect_timeout(num_failed_attempts):
if num_failed_attempts == 0:
return 0
next_reconnect_timeout = min(
_Config.MAX_RECONNECT_TIMEOUT,
_Config.RECONNECT_WITH_TIMEOUT
* 2 ** (num_failed_attempts - 1),
)
next_reconnect_timeout *= random.uniform(
1.0 - _Config.RECONNECT_FLASHMOB_PROTECTION_COEF,
1.0 + _Config.RECONNECT_FLASHMOB_PROTECTION_COEF,
)
return next_reconnect_timeout
@_suppress_cancelled_error
async def connect_and_keep_connected(self):
while True:
try:
self._connection_lost_event.clear()
# putting debug comment before each await...
logger.info(
"Connecting the Server "
"[await loop.create_connection...]"
)
await self._loop.create_connection(
lambda: LoggingProtocol(logger, network_logger, self),
self._host,
self._port,
ssl=self._ssl_config(),
)
logger.info(
"Connected the Server "
"[loop.create_connection() succeeded]"
)
except OSError as e:
logger.error(
"Cannot connect the Server (%s) [%s]", self._host, e
)
self._server_error_backlog.log_connection_error()
else:
await self._notify_server_connected()
logger.info("await _connection_lost_event...")
await self._connection_lost_event.wait()
logger.warning(
"Lost connection to the Server (%s).", self._host
)
self._server_error_backlog.log_connection_lost()
timeout = self._reconnect_timeout(
self._server_error_backlog.length()
)
logger.warning(
"Waiting %s before retry...", naturaldelta(timeout)
)
await asyncio.sleep(timeout)
async def _notify_server_connected(self):
# putting debug comment before each await...
logger.info("await _sink_future...")
sink = await self._sink_future
logger.info("put ServerConnected() to the bus...")
await sink.process_message(MessageType.ServerConnected())
def connection_made(self, transport):
logger.info("Connected the Server [connection_made]")
self._transport = transport
# clean buffer before new data receiving
self._line_buffer.clean()
def connection_lost(self, exc):
self._transport = None
if exc is not None:
if isinstance(
exc, ssl.SSLError
) and "APPLICATION_DATA_AFTER_CLOSE_NOTIFY" in str(exc):
# https://bugs.python.org/issue39951
log_error = logger.warning
else:
log_error = logger.error
log_error("Lost connection to server", exc_info=exc)
else:
logger.info("Server connection closed")
self._connection_lost_event.set()
def _transport_write(self, message: bytes):
self._transport.write(message)
health.sensor.server_data_sent(time.time())
network_logger.debug("transport.write: %r", message)
def is_connected(self):
# When a connection is closed the callback *connection_lost*
# is not called immediately
return (
self._transport is not None
and not self._transport.is_closing()
)
async def send_messages_from_queue(self):
"""
Sends messages stored in the persistent queue.
"""
messages_to_send = self._queue.pop_all()
failed_to_send_count = 0
failed_to_send_msgs = []
stop_attempt_to_send_API = False
try:
while messages_to_send:
timestamp, message_bstr = messages_to_send.pop()
data2send = json.loads(message_bstr)
if not data2send: # pragma: no cover
continue
message_class = (
MessageType.Reportable.get_subclass_with_method(
data2send["method"]
)
)
target = getattr(message_class, "TARGET", None)
if target is ReportTarget.API:
if stop_attempt_to_send_API:
sent = False
else:
sent = await self._send_to_api(data2send)
if sent is None:
stop_attempt_to_send_API = True
elif target is ReportTarget.PERSISTENT_CONNECTION:
sent = self._send_to_conn(data2send)
await asyncio.sleep(0)
else: # should never happen
# send to Sentry and ignore
logger.error(
"Unexpected target to send data: %s for method %s",
target,
data2send["method"],
)
continue
if not sent: # try later
failed_to_send_count += 1
self._queue.put(message_bstr, timestamp=timestamp)
failed_to_send_msgs.append(
{
"method": data2send.get("method"),
"message_id": data2send.get("message_id"),
}
)
else:
logger.info(
"message sent %s",
{
"method": data2send.get("method"),
"message_id": data2send.get("message_id"),
},
)
except Exception as exc:
failed_to_send_count += len(messages_to_send)
logger.error("Error occurs while sending message: %s", exc)
if failed_to_send_count:
logger.info(
"%s messages failed to send. Messages will send later: %s",
failed_to_send_count,
failed_to_send_msgs,
)
# don't lose messages, add them back to the queue
for timestamp, message_bstr in messages_to_send:
self._queue.put(message_bstr, timestamp=timestamp)
if not failed_to_send_count:
logger.info("All stored messages are sent.")
def _get_data_to_send(self, message):
server_id = LicenseCLN.get_server_id()
if not server_id:
logger.warning(
"message with server_id=%r "
"will not be sent to the Server.",
server_id,
)
return
# add message handling time if it does not exist, so that
# the server does not depend on the time it was received
if "timestamp" not in message:
message["timestamp"] = time.time()
# json.dumps: keys must be a string
data2send = {
"method": message.method,
"payload": message.payload,
"ver": _Config.PROTOCOL_VERSION,
"rpm_ver": CoreConfig.VERSION,
"message_id": uuid4().hex,
"server_id": server_id,
"name": LicenseCLN.get_product_name(),
"license": LicenseCLN.get_token(),
}
if _Config.PROTOCOL_SEND_DEBUGINFO:
data2send["_debug"] = {
"messageclass": message.__class__.__module__
+ "."
+ message.__class__.__name__
}
return data2send
def _encode_data_to_send(self, data: dict) -> bytes:
msg = json.dumps(data, cls=ServerJSONEncoder) + "\n"
return msg.encode()
def put_message_to_queue(self, message: Message):
data2send = self._get_data_to_send(message)
self._queue.put(self._encode_data_to_send(data2send))
def send_to_persist_connection(self, message):
if data2send := self._get_data_to_send(message):
if not self._send_to_conn(data2send):
self._queue.put(self._encode_data_to_send(data2send))
def _send_to_conn(self, data2send) -> bool:
if not self.is_connected():
return False
bstr = self._encode_data_to_send(data2send)
self._transport_write(bstr)
return True
async def _send_to_api(self, data2send) -> Optional[bool]:
method = data2send.pop("method")
post_data = self._encode_data_to_send(data2send)
try:
await self._api.send_data(method, post_data)
except APIErrorTooManyRequests as exc:
logger.error(
"Failed to send message %s to the correlation server: %s",
method,
exc,
)
return
except APIError as exc:
logger.error(
"Failed to send message %s to the correlation server: %s",
method,
exc,
)
return False
health.sensor.server_data_sent(time.time())
return True
@staticmethod
def _on_server_ack(decoded_json):
logger.debug(
"The Server method=%r is not implemented on agent",
decoded_json["method"],
)
@staticmethod
def _on_server_error(decoded_json):
log_error = (
throttled_log_error
if isinstance(decoded_json, dict)
and decoded_json.get("error") == "Invalid license or signature"
else logger.error
)
log_error(
"The Server responded with error: %s",
decoded_json.get("error", repr(decoded_json)),
)
def _handle_event(self, action, decoded_json):
"""
:raise ValueError: for malformed SynclistResponse
"""
def send(message):
return self._loop.create_task(
self._sink_future.result().process_message(message)
)
def process_message_v1():
message_class = MessageType.Received.get_subclass_with_action(
action
)
message = message_class(decoded_json)
return send(message)
# note: no changes in ACK, ERROR processing between v1/v2 protocol
# versions
if action == "ACK":
return self._on_server_ack(decoded_json)
elif action == "ERROR":
return self._on_server_error(decoded_json)
# else: detect protocol version
try:
seq_number = decoded_json["_meta"]["per_seq"]
except KeyError: # v1 old message processing
return process_message_v1()
# else: # v2 persistent message processing
# is it a duplicate?
if (
self._last_processed_seq_number
is not None # has received at least one message already
and seq_number <= self._last_processed_seq_number
): # already processed
# resend ACK
# note: assume Server can handle duplicate ACKs
return send(MessageType.Ack(seq_number))
# else: message seen first time
process_message_v1()
# acknowledge successfull processing
try:
return send(MessageType.Ack(seq_number))
finally:
# note: it is ok if the server loses the ACK
self._last_processed_seq_number = seq_number
def data_received(self, data):
health.sensor.server_data_received(time.time())
network_logger.debug("data_received: %r", data)
if not self._sink_future.done():
logger.error(
"some data received but MessageSink has not been "
+ "associated with the protocol yet"
)
else:
self._line_buffer.append(data.decode())
for token in self._line_buffer:
try:
decoded = json.loads(token)
except json.decoder.JSONDecodeError as e:
logger.error("JSON decode error [%s]", e)
self._server_error_backlog.log_protocol_error()
return
else:
if not isinstance(decoded, dict):
logger.error(
"JSON decode error: expecting dict, not %s",
type(decoded).__name__,
)
self._server_error_backlog.log_protocol_error()
return
if decoded.get("payload"):
# FIXME: potential weakness.
# https://cloudlinux.atlassian.net/browse/DEF-318
decoded.update(decoded.pop("payload"))
try:
self._handle_event(decoded["method"], decoded)
except MessageNotFoundError as e:
logger.error("%s in malformed %s", repr(e), decoded)
self._server_error_backlog.log_protocol_error()
except asyncio.CancelledError:
raise
except: # noqa
logger.exception("Something went wrong")
else:
self._server_error_backlog.log_server_reponse(decoded)
async def create_sink(self, loop):
logger.info(
"imunify360 connection server: <%s:%d>, ssl=%r",
_Config.HOST,
_Config.PORT,
_Config.USE_SSL,
)
self._protocol = self._ServerProtocol(loop, _Config.HOST, _Config.PORT)
self.try_send_messages_from_queue = asyncio.Event()
self._tasks = [
loop.create_task(self._protocol.connect_and_keep_connected()),
loop.create_task(self.send_queue_messages_to_server()),
]
@expect(MessageType.Reportable)
async def send_to_server(self, message):
if message.TARGET is ReportTarget.PERSISTENT_CONNECTION:
self._protocol.send_to_persist_connection(message)
else: # via API
# send messages via API in a separate task
# to avoid slowing down the message processing queue
# due to possible network issues
self._protocol.put_message_to_queue(message)
self.try_send_messages_from_queue.set()
@expect(MessageType.ServerConnected)
async def send_queue_messages_on_reconnect(self, message):
self.try_send_messages_from_queue.set()
@recurring_check(0)
async def send_queue_messages_to_server(self):
await self.try_send_messages_from_queue.wait()
self.try_send_messages_from_queue.clear()
await self._protocol.send_messages_from_queue()
async def create_source(self, loop, sink):
self._protocol.set_sink_to(sink)
async def shutdown(self):
for task in self._tasks:
task.cancel()
logger.info("Waiting for tasks to finish...")
await asyncio.gather(*self._tasks, return_exceptions=True)
logger.info("Shutdown connection to the Server")
await self._protocol.shutdown()
logger.info("Connection to the Server is closed")