File: //opt/imunify360/venv/lib64/python3.11/site-packages/im360/internals/core/ipset/country.py
import asyncio
import logging
from typing import FrozenSet, Iterable, List
from defence360agent.contracts.config import CountryInfo
from defence360agent.model.simplification import run_in_executor
from defence360agent.utils import timeit
from im360.contracts.config import UnifiedAccessLogger
from im360.model.country import CountryList
from im360.utils.validate import IP, IPVersion
from .. import ip_versions
from ..firewall import FirewallRules, firewall_logging_enabled, get_firewall
from . import (
IP_SET_PREFIX,
AbstractIPSet,
IPSetAtomicRestoreBase,
IPSetCount,
get_ipset_family,
libipset,
)
from .libipset import IPSetCmdBuilder
logger = logging.getLogger(__name__)
def ips_for_country(country_code):
subnets_file = CountryInfo.country_subnets_file(country_code)
try:
with open(subnets_file, encoding="utf-8") as f:
for line in f:
yield line.strip()
except FileNotFoundError:
logger.error("Can't find subnets file %s", subnets_file)
return
class IPSetCountryBlack:
CHAIN = FirewallRules.COUNTRY_BLACKLIST_CHAIN
PRIORITY = FirewallRules.BLACKLIST_PRIORITY
def single_entry_rules(self, set_name, _):
return [
FirewallRules.ipset_rule(
set_name, FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN
)
]
class IPSetCountryWhite:
CHAIN = FirewallRules.COUNTRY_WHITELIST_CHAIN
PRIORITY = FirewallRules.WHITELIST_PRIORITY
def single_entry_rules(self, set_name, ip_version: IPVersion):
result = []
if firewall_logging_enabled():
result.append(
FirewallRules.compose_rule(
FirewallRules.ipset(set_name),
action=FirewallRules.nflog_action(
group=FirewallRules.nflog_group(ip_version),
prefix=UnifiedAccessLogger.WHITELIST_COUNTRY,
),
)
)
result.append(FirewallRules.ipset_rule(set_name, FirewallRules.ACCEPT))
return result
class SingleIpSetCountry(IPSetAtomicRestoreBase):
_NAME = "{prefix}.{ip_version}.country-{country_code}"
MAX_ELEM = 524288
def __init__(self, country_code: str):
super().__init__(country_code)
self.country_code = country_code
def gen_ipset_name_for_ip_version(self, ip_version: IPVersion) -> str:
return self.custom_ipset_name or self._NAME.format(
prefix=IP_SET_PREFIX,
ip_version=ip_version,
country_code=self.country_code.lower(),
)
def gen_ipset_create_ops(self, ip_version: IPVersion) -> List[str]:
ipset_options = self._get_ipset_create_options(ip_version)
return [
IPSetCmdBuilder.get_create_cmd(
self.gen_ipset_name_for_ip_version(ip_version), **ipset_options
)
]
def gen_ipset_destroy_ops(self, ip_version: IPVersion) -> List[str]:
ipset_name = self.gen_ipset_name_for_ip_version(ip_version)
return [IPSetCmdBuilder.get_destroy_cmd(ipset_name)]
def gen_ipset_flush_ops(self, ip_version: IPVersion) -> List[str]:
return [
IPSetCmdBuilder.get_flush_cmd(
self.gen_ipset_name_for_ip_version(ip_version)
)
]
async def gen_ipset_restore_ops(self, ip_version: IPVersion) -> List[str]:
commands = []
for ip in ips_for_country(self.country_code):
try:
version = IP.type_of(ip)
except ValueError:
logger.error(
"{} is neither IPv4 nor IPv6 valid address".format(ip)
)
continue
if version != ip_version:
continue
set_name = self.gen_ipset_name_for_ip_version(ip_version=version)
# get ips specific lines for ipset
add_template = "add {set_name} {ip_net} -exist"
commands.append(add_template.format(set_name=set_name, ip_net=ip))
return commands
def _get_ipset_create_options(self, ip_version: IPVersion):
return dict(
family=get_ipset_family(ip_version),
maxelem=self.MAX_ELEM,
)
class IPSetCountry(AbstractIPSet):
_LISTNAME = _CHAIN = _PRIORITY = None
_IP_SETS = {
CountryList.BLACK: IPSetCountryBlack(),
CountryList.WHITE: IPSetCountryWhite(),
}
async def block(self, country_code, *args, **kwargs):
"""
Create ip set + rule
:param country_code: ISO 3166-1 alpha-2 code
:return:
"""
ipset = self._IP_SETS[kwargs["listname"]]
commands = []
for ip_version in ip_versions.enabled():
ip_set = SingleIpSetCountry(country_code)
async with await get_firewall(ip_version) as fw:
set_name = ip_set.gen_ipset_name_for_ip_version(ip_version)
await libipset.create_hash_set(
set_name, **ip_set._get_ipset_create_options(ip_version)
)
await fw.commit(
[
fw.append_rule(r, chain=ipset.CHAIN)
for r in ipset.single_entry_rules(set_name, ip_version)
]
)
commands.extend(await ip_set.gen_ipset_restore_ops(ip_version))
await libipset.restore(commands)
async def unblock(self, country_code, *args, **kwargs):
"""
Drop rule + ip set
:param country_code: ISO 3166-1 alpha-2 code
:return:
"""
ipset = self._IP_SETS[kwargs["listname"]]
for ip_version in ip_versions.enabled():
ip_set = SingleIpSetCountry(country_code)
async with await get_firewall(ip_version) as fw:
set_name = ip_set.gen_ipset_name_for_ip_version(ip_version)
await fw.commit(
[
fw.delete_rule(
rule, chain=ipset.CHAIN, ip_version=ip_version
)
for rule in ipset.single_entry_rules(
set_name, ip_version
)
]
)
await libipset.delete_set(set_name)
def gen_ipset_create_ops(self, ip_version: IPVersion) -> List[str]:
"""
Generate list of commands to create all ip sets
:return: list of ipset commands to use with ipset restore
"""
result = []
for ip_set in self.get_all_ipset_instances(ip_version):
result.extend(ip_set.gen_ipset_create_ops(ip_version))
return result
async def gen_ipset_restore_ops(self, ip_version: IPVersion) -> List[str]:
"""
Generate list of commands to fill all ip sets
:return: list of ipset commands to use with ipset restore
"""
commands = [] # type: List[str]
for ipset in self.get_all_ipset_instances(ip_version):
commands.append(
IPSetCmdBuilder.get_flush_cmd(
ipset.gen_ipset_name_for_ip_version(ip_version)
)
)
commands.extend(await ipset.gen_ipset_restore_ops(ip_version))
return commands
def _fetch(self):
return [
(row["country"]["code"], row["listname"])
for row in CountryList.fetch()
]
def get_all_ipsets(self, ip_version: IPVersion) -> FrozenSet[str]:
return frozenset(
ipset.gen_ipset_name_for_ip_version(ip_version)
for ipset in self.get_all_ipset_instances(ip_version)
)
def get_all_ipset_instances(
self, ip_version: IPVersion
) -> List[IPSetAtomicRestoreBase]:
return [
SingleIpSetCountry(country_code)
for country_code, _ in self._fetch()
]
def get_rules(self, ip_version: IPVersion, **kwargs) -> Iterable[dict]:
result = [
dict(
rule=FirewallRules.compose_action(ipset.CHAIN),
chain=FirewallRules.IMUNIFY_INPUT_CHAIN,
table=FirewallRules.FILTER,
priority=ipset.PRIORITY,
)
for ipset in self._IP_SETS.values()
]
i = 0
for country, listname in self._fetch():
ip_set = SingleIpSetCountry(country)
for rule in self._IP_SETS[listname].single_entry_rules(
ip_set.gen_ipset_name_for_ip_version(ip_version),
ip_version,
):
result.append(
dict(
rule=rule,
chain=self._IP_SETS[listname].CHAIN,
table=FirewallRules.FILTER,
priority=i,
)
)
i += 1
return result
async def restore(self, ip_version: IPVersion) -> None:
with timeit("ipset_restore", logger):
await libipset.restore(
await self.gen_ipset_restore_ops(ip_version)
)
async def get_ipsets_count(
self, ip_version: IPVersion
) -> List[IPSetCount]:
ipsets = []
for country, _ in self._fetch():
expected_count = sum(
IP.type_of(ip) == ip_version for ip in ips_for_country(country)
)
ip_set = SingleIpSetCountry(country)
set_name = ip_set.gen_ipset_name_for_ip_version(ip_version)
ipset_count = await libipset.get_ipset_count(set_name)
ipsets.append(
IPSetCount(
name=set_name,
db_count=expected_count,
ipset_count=ipset_count,
)
)
return ipsets