File: //opt/imunify360/venv/lib64/python3.11/site-packages/im360/internals/core/__init__.py
"""Core module for rules and sets managing."""
import logging
import math
import time
from typing import Dict, Iterable, List, Optional, Set, Tuple
from defence360agent.utils import await_for, retry_on, timeit
from im360.contracts.config import NetworkInterface, UnifiedAccessLogger
from im360.internals.core.ipset.port_deny import (
InputPortBlockingDenyModeIPSet,
OutputPortBlockingDenyModeIPSet,
)
from im360.utils.validate import IPVersion
from . import ip_versions
from .firewall import (
FirewallRules,
RuleDef,
firewall_logging_enabled,
is_nat_available,
)
from .ipset import IP_SET_PREFIX, libipset
from .ipset.country import IPSetCountry
from .ipset.ip import IPSet
from .ipset.libipset import IPSetCmdBuilder, IPSetRestoreCmd
from .ipset.port import IPSetIgnoredByPort, IPSetPort
from .ipset.redirect import (
IPSetNoRedirectPort,
IPSetWebshieldPort,
)
from .ipset.sync import IPSetSyncIPListPurpose, IPSetSyncIPListRecords
logger = logging.getLogger(__name__)
class RuleSet:
"""Managing iptables rules and ipsets."""
_CHAINS = [
FirewallRules.COUNTRY_WHITELIST_CHAIN,
FirewallRules.COUNTRY_BLACKLIST_CHAIN,
FirewallRules.BP_INPUT_CHAIN,
FirewallRules.LOG_BLACKLIST_CHAIN,
FirewallRules.LOG_GRAYLIST_CHAIN,
FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN,
FirewallRules.WEBSHIELD_PORTS_INPUT_CHAIN,
FirewallRules.LOG_BLOCK_PORT_CHAIN,
]
# Since DB and ipset are updated at different times,
# check relative value instead of compare absolute values.
# Use a large enough relative number to avoid false positives,
# 20% difference looks reasonable for this.
_IPSET_COUNT_TO_RECREATE_THRESHOLD = 0.2
def __init__(self):
self.entities = (
InputPortBlockingDenyModeIPSet(),
OutputPortBlockingDenyModeIPSet(),
IPSetPort(),
IPSet(),
# Order is important here,
# Ensure IPSetSyncIPListRecords is created before IPSetSyncIPListPurpose
IPSetSyncIPListRecords(),
IPSetSyncIPListPurpose(),
IPSetCountry(),
IPSetIgnoredByPort(),
IPSetNoRedirectPort(),
IPSetWebshieldPort(),
)
@staticmethod
def targets(ip_version: IPVersion) -> List[Tuple]:
"""
Returns tables & chains that Imunify360 will use in firewall management
:param ip_version: IPv4 or IPv6
:return: List[Tuple]:
"""
return [
(FirewallRules.FILTER, "INPUT"),
(
(FirewallRules.NAT, "PREROUTING")
if is_nat_available(ip_version)
else (FirewallRules.MANGLE, "PREROUTING")
),
]
@staticmethod
def _apply_ignored_interfaces(action, interface_conf, *args, **kwargs):
"""
:param interface_conf: interface configuration
:param Callable action: action to perform with interface
"""
for interface in interface_conf[NetworkInterface.DEVICE_SKIP]:
yield action(
FirewallRules.compose_rule(
FirewallRules.interface(interface),
action=FirewallRules.compose_action(FirewallRules.ACCEPT),
),
chain=FirewallRules.IMUNIFY_INPUT_CHAIN,
priority=0, # max priority for firewalld
*args,
**kwargs,
)
@staticmethod
def _compose_rule(ip_version: IPVersion, interface_conf: dict) -> RuleDef:
"""Compose rule based on NetworkInterface config"""
target_interface = interface_conf[ip_version]
action = FirewallRules.compose_action(
FirewallRules.IMUNIFY_INPUT_CHAIN
)
if target_interface:
rule = FirewallRules.compose_rule(
FirewallRules.interface(target_interface), action=action
)
else:
rule = action
return rule
async def ipset_create_commands(self, ip_version: IPVersion) -> List[str]:
names = [] # type: List[str]
for entity in self.entities:
names.extend(entity.gen_ipset_create_ops(ip_version))
return names
async def ipset_flush_commands(
self, ip_version: IPVersion, existing: Optional[Set[str]] = None
) -> Iterable[IPSetRestoreCmd]:
"""Generate ipset restore commands to destroy *existing* ipsets."""
if existing is None:
existing = await self.existing_ipsets(ip_version)
# get entity specific flush commands
cmds = []
needed_entities = [
entity
for entity in self.entities
if hasattr(entity, "gen_ipset_flush_cmds")
]
for entity in needed_entities:
cmds += entity.gen_ipset_flush_cmds(ip_version, existing)
return cmds
async def ipset_destroy_commands(
self, ip_version: IPVersion, existing: Optional[Set[str]] = None
) -> Iterable[IPSetRestoreCmd]:
"""Generate ipset restore commands to destroy *existing* ipsets."""
if existing is None:
existing = await self.existing_ipsets(ip_version)
# get entity specific destroy commands
cmds = {} # type: Dict[str, IPSetRestoreCmd]
for entity in self.entities:
entity_cmds = entity.gen_ipset_destroy_ops(ip_version, existing)
cmds.update(entity_cmds)
# generic destroy
for ipset_name in existing:
if ipset_name not in cmds:
# ipset is not special, remove using a generic destroy command
cmds[ipset_name] = IPSetCmdBuilder.get_destroy_cmd(ipset_name)
return cmds.values()
async def create_commands(
self, firewall, interface_conf: dict, ip_version: IPVersion
) -> list:
"""Return a list of firewall commands to create all required rules."""
actions = []
# input chains
for table, chain in self.targets(ip_version):
actions.extend(
[
firewall.create_chain(
table=table, chain=FirewallRules.IMUNIFY_INPUT_CHAIN
),
firewall.insert_rule(
self._compose_rule(ip_version, interface_conf),
table=table,
chain=chain,
),
*self._apply_ignored_interfaces(
firewall.insert_rule, interface_conf, table=table
),
]
)
actions.extend(
[
firewall.create_chain(table=FirewallRules.FILTER, chain=chain)
for chain in self._CHAINS
]
)
actions.extend(self._log_block_rules(firewall.append_rule, ip_version))
# output chains
actions.extend(
[
firewall.create_chain(
table=FirewallRules.FILTER,
chain=FirewallRules.IMUNIFY_OUTPUT_CHAIN,
),
firewall.insert_rule(
FirewallRules.compose_action(
FirewallRules.IMUNIFY_OUTPUT_CHAIN
),
chain="OUTPUT",
),
]
)
actions.extend(
[
firewall.create_chain(table=FirewallRules.FILTER, chain=chain)
for chain in [FirewallRules.BP_OUTPUT_CHAIN]
]
)
actions.extend(
[
firewall.append_rule(**rule)
for rule in await self._collect_ipset_rules(ip_version)
]
)
# Add connection tracking rule.
actions.append(
firewall.insert_rule(
# fmt: off
(
"-m", "comment",
"--comment", '"Connection tracking for Imunify360."',
"-j", "CT",
),
# fmt: off
table="raw", chain="PREROUTING"
)
)
return actions
def destroy_commands(
self, firewall, interface_conf: dict, ip_version: IPVersion
) -> Iterable[list]:
"""Returns an iterable over list of commands to destroy firewall rules.
Each list should be executed as a separate firewall commit
operation."""
# input chains
for table, chain in self.targets(ip_version):
yield [
firewall.delete_rule(
self._compose_rule(ip_version, interface_conf),
table=table,
chain=chain,
)
]
yield [
firewall.flush_chain(
FirewallRules.IMUNIFY_INPUT_CHAIN, table=table
),
firewall.delete_chain(
FirewallRules.IMUNIFY_INPUT_CHAIN, table=table
),
]
for chain in self._CHAINS:
yield [
firewall.flush_chain(chain, table=FirewallRules.FILTER),
firewall.delete_chain(chain, table=FirewallRules.FILTER),
]
# output chains
yield [
firewall.delete_rule(
FirewallRules.compose_action(
FirewallRules.IMUNIFY_OUTPUT_CHAIN
),
chain="OUTPUT",
)
]
yield [
firewall.flush_chain(FirewallRules.IMUNIFY_OUTPUT_CHAIN),
firewall.delete_chain(FirewallRules.IMUNIFY_OUTPUT_CHAIN),
]
for chain in [FirewallRules.BP_OUTPUT_CHAIN]:
yield [firewall.flush_chain(chain), firewall.delete_chain(chain)]
# Delete connection tracking rule.
yield [
firewall.delete_rule(
# fmt: off
(
"-m", "comment",
"--comment", '"Connection tracking for Imunify360."',
"-j", "CT",
),
# fmt: off
table="raw",
chain="PREROUTING",
)
]
def required_ipsets(self, ip_version: IPVersion) -> Set[str]:
names = set() # type: Set[str]
for entity in self.entities:
names.update(entity.get_all_ipsets(ip_version))
return names
async def check_commands(
self, firewall, interface_conf, ip_version: IPVersion
) -> list:
"""Returns a list of firewall commands to check for firewall rules."""
actions = []
for table, chain in self.targets(ip_version):
actions.extend(
[
firewall.has_rule(
self._compose_rule(ip_version, interface_conf),
table=table,
chain=chain,
),
*self._apply_ignored_interfaces(
firewall.has_rule, interface_conf, table=table
),
]
)
actions.extend(self._log_block_rules(firewall.has_rule, ip_version))
actions.extend(
[
firewall.has_rule(
FirewallRules.compose_action(
FirewallRules.IMUNIFY_OUTPUT_CHAIN
),
table=FirewallRules.FILTER,
chain="OUTPUT",
),
]
)
actions.extend(
[
firewall.has_rule(**rule)
for rule in await self._collect_ipset_rules(ip_version)
]
)
actions.append(
firewall.has_rule(
# fmt: off
(
"-m", "comment",
"--comment", '"Connection tracking for Imunify360."',
"-j", "CT",
),
# fmt: off
table="raw", chain="PREROUTING"
)
)
return actions
def _log_block_rules(self, predicate, ip_version: IPVersion):
rules = []
for chain, prefix, action in (
(
FirewallRules.LOG_BLACKLIST_CHAIN,
UnifiedAccessLogger.BLACKLIST,
FirewallRules.compose_action(FirewallRules.DROP),
),
(
FirewallRules.LOG_GRAYLIST_CHAIN,
UnifiedAccessLogger.GRAYLIST,
FirewallRules.compose_action(FirewallRules.DROP),
),
(
FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN,
UnifiedAccessLogger.BLACKLIST_COUNTRY,
FirewallRules.compose_action(FirewallRules.DROP),
),
(
FirewallRules.LOG_BLOCK_PORT_CHAIN,
UnifiedAccessLogger.BLOCKED_BY_PORT,
FirewallRules.compose_action(FirewallRules.REJECT),
),
):
rules.extend(
predicate(rule, table=FirewallRules.FILTER, chain=chain)
for rule in self._log_drop_rules(ip_version, prefix, action)
)
return rules
async def _collect_ipset_rules(self, ip_version: IPVersion) -> List[dict]:
rules = [] # type: List[dict]
for entity in self.entities:
rules.extend(entity.get_rules(ip_version))
rules.sort(key=lambda r: (r["chain"], r["priority"]))
return rules
async def fill_ipsets(
self, ip_version: IPVersion, missing: Set[str]
) -> None:
"""Fills all ipsets with required elements."""
create_and_restore_cmds = []
for entity in self.entities:
for ip_set in entity.get_all_ipset_instances(ip_version):
if ip_set.gen_ipset_name_for_ip_version(ip_version) in missing:
create_and_restore_cmds.extend(
ip_set.gen_ipset_create_ops(ip_version)
)
create_and_restore_cmds.extend(
await ip_set.gen_ipset_restore_ops(ip_version)
)
await libipset.restore(create_and_restore_cmds)
logger.info("IP sets content restored from database")
@staticmethod
async def existing_ipsets(ip_version: IPVersion) -> Set[str]:
prefix = ".".join([IP_SET_PREFIX, ip_version])
return set(
s for s in await libipset.list_set() if s.startswith(prefix)
)
async def destroy_ipsets(
self, ip_version: IPVersion, existing: Optional[Set[str]] = None
) -> None:
"""Destroys ipsets with given names."""
if existing is None:
to_destroy = await self.existing_ipsets(ip_version)
else:
to_destroy = existing.copy()
max_tries = 3
attempt = 0
while to_destroy or attempt > max_tries:
# remove absent ipsets
to_destroy &= await self.existing_ipsets(ip_version)
try:
await libipset.restore(
await self.ipset_flush_commands(ip_version, to_destroy)
)
await libipset.restore(
await self.ipset_destroy_commands(ip_version, to_destroy)
)
return
except (
libipset.IPSetNotFoundError,
libipset.IPSetCannotBeDestroyedError,
):
attempt += 1
if to_destroy or attempt > max_tries:
logger.error("Failed to destroy ipsets: %s", ", ".join(to_destroy))
async def _recreate_ipsets(
self, ip_version: IPVersion, existing: Optional[Set[str]] = None
):
"""Reset all ipsets, create them again and fill with IPs
for given ip version."""
for entity in self.entities:
await entity.reset(ip_version, existing)
async def recreate_ipsets(
self, ip_version: IPVersion = None, existing: Optional[Set[str]] = None
):
"""Recreate existing ipsets (or given).
If *ip_version* is None, recreate ipsets for all enabled ip versions.
"""
if ip_version:
await self._recreate_ipsets(ip_version, existing)
else:
for ip_version in ip_versions.enabled():
await self._recreate_ipsets(ip_version, existing)
@staticmethod
def _log_drop_rules(ip_version: IPVersion, prefix, action):
rules = []
if firewall_logging_enabled():
rules.append(
FirewallRules.compose_rule(
action=FirewallRules.nflog_action(
group=FirewallRules.nflog_group(ip_version),
prefix=prefix,
)
)
)
rules.append(action)
return rules
async def get_outdated_ipsets(self, ip_version: IPVersion) -> list:
"""
Return list of ipsets the contents of which do not match the database
"""
outdated: list = []
for entity in self.entities:
all_ipsets = await entity.get_ipsets_count(ip_version)
outdated.extend(
ipset
for ipset in all_ipsets
if not math.isclose(
ipset.ipset_count,
ipset.db_count,
rel_tol=self._IPSET_COUNT_TO_RECREATE_THRESHOLD,
)
)
return outdated