File diff 9b889983cd7b → 49f356923784
conntrackt/tests/test_utils.py
Show inline comments
 
# Django imports.
 
from django.test import TestCase
 

	
 
# Third-party Python library imports.
 
import palette
 
import pydot
 

	
 
# Application imports.
 
from conntrackt.models import Entity
 
from conntrackt.models import Entity, Project, Communication
 
from conntrackt import utils
 

	
 

	
 
@@ -47,3 +51,204 @@ COMMIT
 
COMMIT
 
"""
 
        self.assertEqual(generated, expected)
 

	
 

	
 
class GetDistinctColorsTest(TestCase):
 
    """
 
    Tests covering the get_distinct_colors function.
 
    """
 

	
 
    def test_count(self):
 
        """
 
        Tests if correct number of distinct colours are returned.
 
        """
 

	
 
        colors = utils.get_distinct_colors(13)
 

	
 
        self.assertEqual(len(colors), 13)
 

	
 
        colors = utils.get_distinct_colors(123)
 

	
 
        self.assertEqual(len(colors), 123)
 

	
 
    def test_start(self):
 
        """
 
        Tests if the passed start colour is returned as part of generated
 
        colours.
 
        """
 

	
 
        start = palette.Color("#AA3311")
 

	
 
        colors = utils.get_distinct_colors(10, start)
 

	
 
        self.assertEqual(start.hex, colors[0].hex)
 

	
 
    def test_color_distance(self):
 
        """
 
        Tests if the generated colous all have proper distance between
 
        each-other.
 
        """
 

	
 
        colors = utils.get_distinct_colors(13)
 

	
 
        # Set allowed margin of difference to 0.1%
 
        delta = (1 / 13.) * 0.001
 

	
 
        # Calculate diffs between colours.
 
        diffs = [colors[i + 1].hsl["h"] - colors[i].hsl["h"] for i in range(12)]
 

	
 
        # Take first diff as reference point.
 
        reference = diffs[0]
 

	
 
        # Create list that contains True/False for diffs depending on whether
 
        # they're in delta-surrounding of reference point.
 
        equal = [(abs(diff - reference) < delta) for diff in diffs]
 

	
 
        # There should be 12 True values.
 
        self.assertEqual(equal.count(True), 12)
 

	
 
        # Check the difference between first and last colour.
 
        equal = abs(colors[0].hsl["h"] + 1 - colors[12].hsl["h"] - reference) < delta
 
        self.assertEqual(True, equal)
 

	
 

	
 
class GenerateProjectDiagramTest(TestCase):
 
    """
 
    Tests the generate_project_diagram function.
 
    """
 

	
 
    fixtures = ["test-data.json"]
 

	
 
    def test_unique_entity_colors(self):
 
        """
 
        Tests if each node/entity in the graph will have a unique colour.
 
        """
 

	
 
        # Get diagram for project
 
        project = Project.objects.get(pk=1)
 
        diagram = utils.generate_project_diagram(project)
 

	
 
        # Extract all nodes
 
        clusters = diagram.get_subgraphs()
 
        nodes = []
 
        for cluster in clusters:
 
            nodes.extend(cluster.get_nodes())
 

	
 
        # Get the node colours.
 
        colors = [n.get_color() for n in nodes]
 

	
 
        # Verify they're all unique colours.
 
        self.assertEqual(len(colors), len(set(colors)))
 

	
 
    def test_edge_colours(self):
 
        """
 
        Tests if the edge colours match with source node/entity colour.
 
        """
 

	
 
        # Get diagram for project
 
        project = Project.objects.get(pk=1)
 
        diagram = utils.generate_project_diagram(project)
 

	
 
        # Extract all nodes and edges.
 
        clusters = diagram.get_subgraphs()
 
        nodes = {}
 
        for cluster in clusters:
 
            for node in cluster.get_nodes():
 
                nodes[node.get_name()] = node
 
        edges = diagram.get_edges()
 

	
 
        # Validate that edges have same colour as the source nodes.
 
        for edge in edges:
 
            self.assertEqual(nodes[edge.get_source()].get_color(), edge.get_color())
 

	
 
    def test_entities_present(self):
 
        """
 
        Tests if all (and only) specific project entities are in the graph.
 
        """
 

	
 
        # Get diagram for project
 
        project = Project.objects.get(pk=1)
 
        diagram = utils.generate_project_diagram(project)
 

	
 
        # Set-up expected node names.
 
        expected_node_names = ["Test Entity 1", "Test Entity 2", "Test Entity 3", "Test Subnet"]
 

	
 
        # Get all nodes from diagram.
 
        clusters = diagram.get_subgraphs()
 
        nodes = []
 
        for cluster in clusters:
 
            nodes.extend(cluster.get_nodes())
 

	
 
        # Get the node names, strip the quotes from them.
 
        node_names = [n.get_name().replace('"', '') for n in nodes]
 

	
 
        # Validate that the two lists contain same elements.
 
        self.assertEqual(sorted(expected_node_names), sorted(node_names))
 

	
 
    def test_communications_present(self):
 
        """
 
        Tests if all (and only) specific project communications are in the
 
        graph.
 
        """
 

	
 
        # Get diagram for project
 
        project = Project.objects.get(pk=1)
 
        diagram = utils.generate_project_diagram(project)
 

	
 
        # Get all edges from the diagram.
 
        edges = diagram.get_edges()
 

	
 
        # Create list of edge labels.
 
        edge_labels = ["%s -> %s (%s)" % (e.get_source().replace('"', ''),
 
                                           e.get_destination().replace('"', ''),
 
                                           e.get_label().replace('"', '')) for e in edges]
 

	
 
        # Create list of expected edge labels
 
        expected_edge_labels = ['Test Entity 1 -> Test Entity 2 (UDP:123)', 'Test Entity 1 -> Test Entity 3 (UDP:53)',
 
                          'Test Entity 2 -> Test Entity 1 (ICMP:8)', 'Test Entity 2 -> Test Entity 1 (TCP:22)',
 
                          'Test Entity 3 -> Test Entity 1 (TCP:3306)', 'Test Subnet -> Test Entity 1 (TCP:22)']
 

	
 
        self.assertEqual(sorted(expected_edge_labels), sorted(edge_labels))
 

	
 
    def test_locations_present(self):
 
        """
 
        Tests if all (and only) specific project locations are in the graph (as
 
        clusters).
 
        """
 

	
 
        # Get diagram for project.
 
        project = Project.objects.get(pk=1)
 
        diagram = utils.generate_project_diagram(project)
 

	
 
        # Set-up expected cluster names (based on locations).
 
        expected_cluster_names = ["cluster_test_location_1", "cluster_test_location_2"]
 

	
 
        # Get cluster names.
 
        cluster_names = [s.get_name() for s in diagram.get_subgraphs()]
 

	
 
        self.assertEqual(sorted(expected_cluster_names), sorted(cluster_names))
 

	
 
    def test_return_type(self):
 
        """
 
        Tests if a correct object type is returned.
 
        """
 

	
 
        # Get diagram for project.
 
        project = Project.objects.get(pk=1)
 
        diagram = utils.generate_project_diagram(project)
 

	
 
        self.assertEqual(type(diagram), pydot.Dot)
 

	
 
    def test_graph_properties(self):
 
        """
 
        Tests if graph properties have been set-up properly.
 
        """
 

	
 
        # Get diagram for project.
 
        project = Project.objects.get(pk=1)
 
        diagram = utils.generate_project_diagram(project)
 

	
 
        self.assertEqual("digraph", diagram.get_graph_type())
 
        self.assertEqual("transparent", diagram.get_bgcolor())
 
        self.assertEqual("1.5", diagram.get_nodesep())
 
        self.assertEqual([{"shape": "record"}], diagram.get_node_defaults())