import ipaddress
import os
import re

import testinfra.utils.ansible_runner

import pytest


testinfra_hosts = testinfra.utils.ansible_runner.AnsibleRunner(
    os.environ['MOLECULE_INVENTORY_FILE']).get_hosts('parameters-*')


def test_pam_umask(host):
    """
    Tests configuration of PAM umask module.
    """

    pam_auth_update_config = host.file('/usr/share/pam-configs/umask')
    assert pam_auth_update_config.exists
    assert pam_auth_update_config.user == 'root'
    assert pam_auth_update_config.group == 'root'
    assert pam_auth_update_config.mode == 0o644

    assert host.file('/etc/pam.d/common-session').contains(r'session[[:blank:]]\+required[[:blank:]]\+pam_umask.so')
    assert host.file('/etc/pam.d/common-session-noninteractive').contains(r'session[[:blank:]]\+required[[:blank:]]\+pam_umask.so')


def test_login_umask(host):
    """
    Tests set-up of default UMASK via /etc/login.defs.
    """

    assert host.file('/etc/login.defs').contains(r'UMASK[[:blank:]]\+027')


def test_adduser_umask(host):
    """
    Tests UMASK configuration used for creating user home directory.
    """

    assert host.file('/etc/adduser.conf').contains('DIR_MODE=0750')


def test_bash_prompt(host):
    """
    Tests file permissions on custom bash prompt configuration.
    """

    bash_prompt = host.file('/etc/profile.d/bash_prompt.sh')

    assert bash_prompt.exists
    assert bash_prompt.user == 'root'
    assert bash_prompt.group == 'root'
    assert bash_prompt.mode == 0o644


def test_home_profile_d(host):
    """
    Tests deployment of special profile file used for enabling profile.d-like
    capability in user's home directory.
    """

    home_profile_d = host.file('/etc/profile.d/z99-user_profile_d.sh')

    assert home_profile_d.is_file
    assert home_profile_d.user == 'root'
    assert home_profile_d.group == 'root'
    assert home_profile_d.mode == 0o644


def test_home_skeleton_bashrc(host):
    """
    Tests deployment of home directory skeleton bashrc.
    """

    bashrc = host.file('/etc/skel/.bashrc')

    assert bashrc.is_file
    assert bashrc.user == 'root'
    assert bashrc.group == 'root'
    assert bashrc.mode == 0o644
    assert bashrc.sha256sum == '4f946fb387a413c8d7633787d8e8a7785c256d77f7c6a692822ffdb439c78277'


def test_default_bashrc(host):
    """
    Tests deployment of default bashrc file.
    """

    bashrc = host.file('/etc/bash.bashrc')

    assert bashrc.is_file
    assert bashrc.user == 'root'
    assert bashrc.group == 'root'
    assert bashrc.mode == 0o644


def test_root_bashrc(host):
    """
    Tests overwriting of root's bashrc configuration with default one.
    """

    with host.sudo():
        bashrc = host.file('/root/.bashrc')

        assert bashrc.is_file
        assert bashrc.user == 'root'
        assert bashrc.group == 'root'
        assert bashrc.mode == 0o640
        assert bashrc.sha256sum == '4f946fb387a413c8d7633787d8e8a7785c256d77f7c6a692822ffdb439c78277'


def test_installed_packages(host):
    """
    Tests installation of required packages.
    """

    assert host.package('sudo').is_installed
    assert host.package('ssl-cert').is_installed
    assert host.package('ferm').is_installed
    assert host.package('apticron').is_installed
    assert host.package('python3-setuptools').is_installed
    assert host.package('virtualenv').is_installed


def test_root_remote_login_disabled(host):
    """
    Tests if SSH server has been configured to prevent remote root logins.
    """

    assert 'PermitRootLogin no' in host.file('/etc/ssh/sshd_config').content_string


def test_remote_login_via_password_disabled(host):
    """
    Tests if SSH server has been configured to disable password-based
    authentication.
    """

    assert 'PasswordAuthentication no' in host.file('/etc/ssh/sshd_config').content_string


def test_ferm_service_configuration(host):

    ferm_service_config = host.file('/etc/default/ferm')

    assert ferm_service_config.is_file
    assert ferm_service_config.user == 'root'
    assert ferm_service_config.group == 'root'
    assert ferm_service_config.mode == 0o644
    assert 'FAST=yes' in ferm_service_config.content_string
    assert 'CACHE=no' in ferm_service_config.content_string
    assert 'ENABLED="yes"' in ferm_service_config.content_string


def test_ferm_configuration_directory(host):
    """
    Tests creation of ferm configuration directory.
    """

    with host.sudo():
        ferm_dir = host.file('/etc/ferm/conf.d')

        assert ferm_dir.is_directory
        assert ferm_dir.user == 'root'
        assert ferm_dir.group == 'root'
        assert ferm_dir.mode == 0o750


def test_ferm_configuration(host):
    """
    Tests deployment of basic ferm configuration files.
    """

    with host.sudo():

        ferm_configuration = host.file('/etc/ferm/ferm.conf')
        assert ferm_configuration.is_file
        assert ferm_configuration.user == 'root'
        assert ferm_configuration.group == 'root'
        assert ferm_configuration.mode == 0o640
        assert "@include '/etc/ferm/conf.d/';" in ferm_configuration.content_string

        ferm_base = host.file('/etc/ferm/conf.d/00-base.conf')
        assert ferm_base.is_file
        assert ferm_base.user == 'root'
        assert ferm_base.group == 'root'
        assert ferm_base.mode == 0o640


def test_ferm_service(host):
    """
    Tests if ferm is started and enabled to start automatically on boot.
    """

    ferm = host.service('ferm')

    assert ferm.is_running
    assert ferm.is_enabled


def test_check_certificate_script(host):

    check_certificate = host.file('/usr/local/bin/check_certificate.sh')

    assert check_certificate.is_file
    assert check_certificate.user == 'root'
    assert check_certificate.group == 'root'
    assert check_certificate.mode == 0o755


def test_check_certificate_directory(host):

    check_certificate_dir = host.file('/etc/check_certificate')

    assert check_certificate_dir.is_directory
    assert check_certificate_dir.user == 'root'
    assert check_certificate_dir.group == 'root'
    assert check_certificate_dir.mode == 0o755


def test_check_certificate_crontab(host):
    """
    Tests deployment of cron job for checking certificates.
    """

    check_certificate_crontab = host.file('/etc/cron.d/check_certificate')

    assert check_certificate_crontab.is_file
    assert check_certificate_crontab.user == 'root'
    assert check_certificate_crontab.group == 'root'
    assert check_certificate_crontab.mode == 0o644
    assert "0 0 * * * nobody /usr/local/bin/check_certificate.sh -q expiration" in check_certificate_crontab.content_string


def test_pipreqcheck_virtualenv(host):
    """
    Tests creation of Python virtual environment used for performing pip
    requirements upgrade checks.
    """

    virtualenv_activate_path = '/var/lib/pipreqcheck/virtualenv/bin/activate'

    with host.sudo():
        virtualenv_activate = host.file(virtualenv_activate_path)

        assert virtualenv_activate.is_file
        assert virtualenv_activate.user == 'pipreqcheck'
        assert virtualenv_activate.group == 'pipreqcheck'
        assert virtualenv_activate.mode == 0o640


def test_pipreqcheck_virtualenv_prompt(host):
    """
    Tests if Python virtual environment prompt has been set-up
    correctly.
    """

    with host.sudo("pipreqcheck"):
        prompt = host.run('bash -c "source /var/lib/pipreqcheck/virtualenv/bin/activate; printenv PS1"')

        # Chop off trailing newline if present (this is from the
        # host.run itself).
        if prompt.stdout.endswith("\n"):
            prompt_stdout = prompt.stdout[:-1]
        else:
            prompt_stdout = prompt.stdout

        assert prompt_stdout == "(pipreqcheck) "


def test_pipreqcheck_directories(host):
    """
    Tests creation of directories used for storing configuration used by script
    that performs pip requirements upgrade checks.
    """

    config_dir = '/etc/pip_check_requirements_upgrades'

    with host.sudo():
        pipreqcheck_config_directory = host.file(config_dir)
        assert pipreqcheck_config_directory.is_directory
        assert pipreqcheck_config_directory.user == 'root'
        assert pipreqcheck_config_directory.group == 'pipreqcheck'
        assert pipreqcheck_config_directory.mode == 0o750

        pipreqcheck_config_directory_pipreqcheck = host.file(os.path.join(config_dir, 'pipreqcheck'))
        assert pipreqcheck_config_directory_pipreqcheck.is_directory
        assert pipreqcheck_config_directory_pipreqcheck.user == 'root'
        assert pipreqcheck_config_directory_pipreqcheck.group == 'pipreqcheck'
        assert pipreqcheck_config_directory_pipreqcheck.mode == 0o750


def test_pipreqcheck_requirements(host):
    """
    Tests deployment of requirements input and text file used for virtual
    environment utilised by script that perform pip requirements upgrade checks.
    """

    requirements_in_path = '/etc/pip_check_requirements_upgrades/pipreqcheck/requirements.in'
    requirements_txt_path = '/etc/pip_check_requirements_upgrades/pipreqcheck/requirements.txt'

    with host.sudo():
        requirements_in = host.file(requirements_in_path)
        assert requirements_in.is_file
        assert requirements_in.user == 'root'
        assert requirements_in.group == 'pipreqcheck'
        assert requirements_in.mode == 0o640

        requirements_txt = host.file(requirements_txt_path)
        requirements_txt.is_file
        assert requirements_txt.user == 'root'
        assert requirements_txt.group == 'pipreqcheck'
        assert requirements_txt.mode == 0o640


def test_pipreqcheck_virtualenv_packages(host):
    """
    Tests if correct packages are installed in virtualenv used for pip
    requirements checks..
    """

    pip_path = '/var/lib/pipreqcheck/virtualenv/bin/pip'

    expected_packages = [
        "build==1.0.3",
        "click==8.1.7",
        "importlib-metadata==6.7.0",
        "packaging==23.2",
        "pip-tools==6.14.0",
        "pip==23.1.2",
        "pyproject_hooks==1.0.0",
        "setuptools==68.0.0",
        "tomli==2.0.1",
        "typing_extensions==4.7.1",
        "wheel==0.41.3",
        "zipp==3.15.0",
    ]

    packages = host.run("sudo -u pipreqcheck %s freeze --all", pip_path)

    # Normalise package names and order.
    expected_packages = sorted([p.lower() for p in expected_packages])
    actual_packages = sorted(packages.stdout.lower().strip().split("\n"))

    # This is a dummy distro-provided package ignored by the
    # pip-tools. Two variants of this name have been previously known.
    if "pkg-resources==0.0.0" in actual_packages:
        actual_packages.remove("pkg-resources==0.0.0")
    if "pkg_resources==0.0.0" in actual_packages:
        actual_packages.remove("pkg_resources==0.0.0")

    assert actual_packages == expected_packages


def test_pipreqcheck_script(host):
    """
    Tests script used for performing pip requirements upgrade checks.
    """

    pipreqcheck_script = host.file('/usr/local/bin/pip_check_requirements_upgrades.sh')

    assert pipreqcheck_script.is_file
    assert pipreqcheck_script.user == 'root'
    assert pipreqcheck_script.group == 'root'
    assert pipreqcheck_script.mode == 0o755


def test_pipreqcheck_crontab(host):
    """
    Tests if crontab entry is set-up correctly for running the pip requirements
    upgrade checks.
    """

    crontab_path = '/etc/cron.d/check_pip_requirements'
    virtualenv_path = '/var/lib/pipreqcheck/virtualenv'
    crontab = host.file(crontab_path)

    assert crontab.is_file
    assert crontab.user == 'root'
    assert crontab.group == 'root'
    assert crontab.mode == 0o644
    assert "MAILTO=root" in crontab.content_string
    assert virtualenv_path in crontab.content_string.split(" ")


def test_pipreqcheck_virtualenv_python_version(host):
    """
    Tests if Python virtual environment for pipreqcheck has been
    set-up correctly.
    """

    python_path = '/var/lib/pipreqcheck/virtualenv/bin/python'
    expected_major_version = '3'

    with host.sudo('pipreqcheck'):
        major_version = host.run("%s -c %s", python_path, "import sys; print(sys.version_info.major)")

    assert major_version.rc == 0
    assert major_version.stdout.strip() == expected_major_version


def test_pipreqcheck_script_output(host):
    """
    Tests if the pip_check_requirements_upgrades.sh script properly
    reports available updates or not.
    """

    virtualenv = '/var/lib/pipreqcheck/virtualenv'
    config_directory = '/tmp/pip_check_requirements_upgrades'

    expected_line_count = 9
    expected_warning_message = "[WARN]  Upgrades available for: %s/with_updates/requirements.txt" % config_directory
    expected_package_diff = "@@ -1 +1 @@\n-urllib3==1.24.2\n+urllib3==1.24.3"

    with host.sudo("pipreqcheck"):
        report = host.run("/usr/local/bin/pip_check_requirements_upgrades.sh -q -V %s %s", virtualenv, config_directory)

        # Clean-up the SSH warning from the beginning of stderr if
        # present.
        stderr = re.sub("^Warning: Permanently added.*?\r\n", "", report.stderr)

    assert stderr == ""
    assert len(report.stdout.split("\n")) == expected_line_count
    assert expected_warning_message in report.stdout
    assert expected_package_diff in report.stdout


@pytest.mark.parametrize('default_path', [
    '/usr/sbin/ferm',
])
def test_dpkg_diversions(host, default_path):
    """
    Tests if dpkg diversions have been put in place.
    """

    default = host.file(default_path)
    diversion = host.file(default_path + '.original')
    dpkg_divert = host.run('dpkg-divert --list %s', default_path)

    assert dpkg_divert.rc == 0
    assert default_path in dpkg_divert.stdout

    assert default.exists
    assert diversion.exists


@pytest.mark.parametrize('path,owner,group,mode,checksum', [
    ('/usr/sbin/ferm', 'root', 'root', 0o755, "13765317d7068005dac18757abe03762f79b6285ce7d078d33826d53801ee6b3"),
])
def test_file_overrides(host, path, owner, group, mode, checksum):
    """
    Tests if file overrides (that replace package-provided defaults)
    have been deployed correctly.
    """

    with host.sudo():
        file_override = host.file(path)

        assert file_override.is_file
        assert file_override.user == owner
        assert file_override.group == group
        assert file_override.mode == mode
        assert file_override.sha256sum == checksum


@pytest.mark.parametrize('iptables_family', [
    'ip',
    'ip6',
])
def test_legacy_iptables_not_present(host, iptables_family):
    """
    Tests if the legacy iptables are present (shouldn't be the case if
    ferm binary was patched/replaced).
    """

    iptables_save = host.run("sudo /usr/sbin/%stables-save", iptables_family)
    warning_message = "Warning: %stables-legacy tables present" % iptables_family

    assert warning_message not in iptables_save.stderr


def test_legacy_iptables_removal_script(host):
    """
    Tests if the script for dropping legacy iptables rules has been
    deployed correctly.
    """

    script = host.file("/usr/local/sbin/drop_legacy_iptables_rules.sh")

    assert script.is_file
    assert script.user == "root"
    assert script.group == "root"
    assert script.mode == 0o755


@pytest.mark.parametrize('ip_protocol', [4, 6])
def test_tcp_rate_limit_dropped_packet_message_tagging(host, ip_protocol):
    """
    Tests log message tagging for dropped incoming TCP packets
    (due to rate-limiting).
    """

    ansible_runner = testinfra.utils.ansible_runner.AnsibleRunner(os.environ['MOLECULE_INVENTORY_FILE'])
    client = ansible_runner.get_host(ansible_runner.get_hosts('client-allowed')[0])

    hostname = host.run("hostname").stdout.strip()
    timestamp = host.run("date '+%b %d %H:%M:%S'").stdout.strip()
    ip_address = client.run("getent ahostsv%s %s", str(ip_protocol), hostname).stdout.strip().split("\n")[-1].split()[0]
    ip_address_exploded = ipaddress.ip_address(ip_address).exploded

    expected_message = re.compile(r"%s kernel: RATELIMIT .* DST=%s .* PROTO=TCP .* DPT=22" % (re.escape(hostname), re.escape(ip_address_exploded)))

    with host.sudo():

        # This should trigger the firewall rules and produce log entries.
        client.run("for i in $(seq 12); do nc.openbsd -%s -w 1 -z %s 22; done", str(ip_protocol), hostname)

        log = host.run("journalctl --dmesg --since %s", timestamp)

        assert log.rc == 0
        assert expected_message.search(log.stdout) is not None
