diff --git a/.hgignore b/.hgignore --- a/.hgignore +++ b/.hgignore @@ -4,3 +4,4 @@ syntax: glob projtest/projtest.db projtest/south docs/_build +tmp/ diff --git a/conntrackt/templates/conntrackt/project_detail.html b/conntrackt/templates/conntrackt/project_detail.html --- a/conntrackt/templates/conntrackt/project_detail.html +++ b/conntrackt/templates/conntrackt/project_detail.html @@ -37,6 +37,12 @@ {% endfor %} +
+
+

Communications diagram

+ +
+
{% endif %} {% endblock %} diff --git a/conntrackt/tests/test_utils.py b/conntrackt/tests/test_utils.py --- a/conntrackt/tests/test_utils.py +++ b/conntrackt/tests/test_utils.py @@ -1,8 +1,12 @@ # Django imports. from django.test import TestCase +# Third-party Python library imports. +import palette +import pydot + # Application imports. -from conntrackt.models import Entity +from conntrackt.models import Entity, Project, Communication from conntrackt import utils @@ -47,3 +51,204 @@ COMMIT COMMIT """ self.assertEqual(generated, expected) + + +class GetDistinctColorsTest(TestCase): + """ + Tests covering the get_distinct_colors function. + """ + + def test_count(self): + """ + Tests if correct number of distinct colours are returned. + """ + + colors = utils.get_distinct_colors(13) + + self.assertEqual(len(colors), 13) + + colors = utils.get_distinct_colors(123) + + self.assertEqual(len(colors), 123) + + def test_start(self): + """ + Tests if the passed start colour is returned as part of generated + colours. + """ + + start = palette.Color("#AA3311") + + colors = utils.get_distinct_colors(10, start) + + self.assertEqual(start.hex, colors[0].hex) + + def test_color_distance(self): + """ + Tests if the generated colous all have proper distance between + each-other. + """ + + colors = utils.get_distinct_colors(13) + + # Set allowed margin of difference to 0.1% + delta = (1 / 13.) * 0.001 + + # Calculate diffs between colours. + diffs = [colors[i + 1].hsl["h"] - colors[i].hsl["h"] for i in range(12)] + + # Take first diff as reference point. + reference = diffs[0] + + # Create list that contains True/False for diffs depending on whether + # they're in delta-surrounding of reference point. + equal = [(abs(diff - reference) < delta) for diff in diffs] + + # There should be 12 True values. + self.assertEqual(equal.count(True), 12) + + # Check the difference between first and last colour. + equal = abs(colors[0].hsl["h"] + 1 - colors[12].hsl["h"] - reference) < delta + self.assertEqual(True, equal) + + +class GenerateProjectDiagramTest(TestCase): + """ + Tests the generate_project_diagram function. + """ + + fixtures = ["test-data.json"] + + def test_unique_entity_colors(self): + """ + Tests if each node/entity in the graph will have a unique colour. + """ + + # Get diagram for project + project = Project.objects.get(pk=1) + diagram = utils.generate_project_diagram(project) + + # Extract all nodes + clusters = diagram.get_subgraphs() + nodes = [] + for cluster in clusters: + nodes.extend(cluster.get_nodes()) + + # Get the node colours. + colors = [n.get_color() for n in nodes] + + # Verify they're all unique colours. + self.assertEqual(len(colors), len(set(colors))) + + def test_edge_colours(self): + """ + Tests if the edge colours match with source node/entity colour. + """ + + # Get diagram for project + project = Project.objects.get(pk=1) + diagram = utils.generate_project_diagram(project) + + # Extract all nodes and edges. + clusters = diagram.get_subgraphs() + nodes = {} + for cluster in clusters: + for node in cluster.get_nodes(): + nodes[node.get_name()] = node + edges = diagram.get_edges() + + # Validate that edges have same colour as the source nodes. + for edge in edges: + self.assertEqual(nodes[edge.get_source()].get_color(), edge.get_color()) + + def test_entities_present(self): + """ + Tests if all (and only) specific project entities are in the graph. + """ + + # Get diagram for project + project = Project.objects.get(pk=1) + diagram = utils.generate_project_diagram(project) + + # Set-up expected node names. + expected_node_names = ["Test Entity 1", "Test Entity 2", "Test Entity 3", "Test Subnet"] + + # Get all nodes from diagram. + clusters = diagram.get_subgraphs() + nodes = [] + for cluster in clusters: + nodes.extend(cluster.get_nodes()) + + # Get the node names, strip the quotes from them. + node_names = [n.get_name().replace('"', '') for n in nodes] + + # Validate that the two lists contain same elements. + self.assertEqual(sorted(expected_node_names), sorted(node_names)) + + def test_communications_present(self): + """ + Tests if all (and only) specific project communications are in the + graph. + """ + + # Get diagram for project + project = Project.objects.get(pk=1) + diagram = utils.generate_project_diagram(project) + + # Get all edges from the diagram. + edges = diagram.get_edges() + + # Create list of edge labels. + edge_labels = ["%s -> %s (%s)" % (e.get_source().replace('"', ''), + e.get_destination().replace('"', ''), + e.get_label().replace('"', '')) for e in edges] + + # Create list of expected edge labels + expected_edge_labels = ['Test Entity 1 -> Test Entity 2 (UDP:123)', 'Test Entity 1 -> Test Entity 3 (UDP:53)', + 'Test Entity 2 -> Test Entity 1 (ICMP:8)', 'Test Entity 2 -> Test Entity 1 (TCP:22)', + 'Test Entity 3 -> Test Entity 1 (TCP:3306)', 'Test Subnet -> Test Entity 1 (TCP:22)'] + + self.assertEqual(sorted(expected_edge_labels), sorted(edge_labels)) + + def test_locations_present(self): + """ + Tests if all (and only) specific project locations are in the graph (as + clusters). + """ + + # Get diagram for project. + project = Project.objects.get(pk=1) + diagram = utils.generate_project_diagram(project) + + # Set-up expected cluster names (based on locations). + expected_cluster_names = ["cluster_test_location_1", "cluster_test_location_2"] + + # Get cluster names. + cluster_names = [s.get_name() for s in diagram.get_subgraphs()] + + self.assertEqual(sorted(expected_cluster_names), sorted(cluster_names)) + + def test_return_type(self): + """ + Tests if a correct object type is returned. + """ + + # Get diagram for project. + project = Project.objects.get(pk=1) + diagram = utils.generate_project_diagram(project) + + self.assertEqual(type(diagram), pydot.Dot) + + def test_graph_properties(self): + """ + Tests if graph properties have been set-up properly. + """ + + # Get diagram for project. + project = Project.objects.get(pk=1) + diagram = utils.generate_project_diagram(project) + + self.assertEqual("digraph", diagram.get_graph_type()) + self.assertEqual("transparent", diagram.get_bgcolor()) + self.assertEqual("1.5", diagram.get_nodesep()) + self.assertEqual([{"shape": "record"}], diagram.get_node_defaults()) diff --git a/conntrackt/tests/test_views.py b/conntrackt/tests/test_views.py --- a/conntrackt/tests/test_views.py +++ b/conntrackt/tests/test_views.py @@ -15,7 +15,7 @@ from django.test import TestCase from conntrackt.models import Project, Location, Entity, Interface, Communication from conntrackt.views import IndexView -from conntrackt.views import entity_iptables, project_iptables +from conntrackt.views import entity_iptables, project_iptables, project_diagram from conntrackt.views import ProjectView, ProjectCreateView, ProjectUpdateView, ProjectDeleteView from conntrackt.views import LocationCreateView, LocationUpdateView, LocationDeleteView @@ -1332,3 +1332,53 @@ class CommunicationDeleteViewTest(Permis response = view(request, pk=1) self.assertEqual(response["Location"], reverse("entity", args=(communication.source.entity.pk,))) + + +class ProjectDiagramTest(PermissionTestMixin, TestCase): + + fixtures = ['test-data.json'] + + view_function = staticmethod(project_diagram) + sufficient_permissions = ("view",) + permission_test_view_kwargs = {"pk": "1"} + + def test_invalid_project(self): + """ + Tests if a 404 is returned if no project was found (invalid ID). + """ + + # Set-up a request. + request = create_get_request() + + # Get the view. + view = project_diagram + + # Validate the response. + self.assertRaises(Http404, view, request, pk=200) + + def test_content_type(self): + """ + Test if correct content type is being returned by the response. + """ + + # Get the view. + view = project_diagram + + # Get the response. + response = generate_get_response(view, pk=1) + + self.assertEqual(response['Content-Type'], "image/svg+xml") + + def test_content(self): + """ + Tests content produced by the view. + """ + + # Get the view. + view = project_diagram + + # Get the response. + response = generate_get_response(view, pk=1) + + self.assertContains(response, '"-//W3C//DTD SVG 1.1//EN"') + self.assertContains(response, "Test Project 1") diff --git a/conntrackt/urls.py b/conntrackt/urls.py --- a/conntrackt/urls.py +++ b/conntrackt/urls.py @@ -3,7 +3,7 @@ from django.conf.urls import patterns, u from django.contrib.auth.views import login, logout # Application imports. -from .views import IndexView, EntityView, entity_iptables, project_iptables +from .views import IndexView, EntityView, entity_iptables, project_iptables, project_diagram from .views import ProjectView, ProjectCreateView, ProjectUpdateView, ProjectDeleteView from .views import LocationCreateView, LocationUpdateView, LocationDeleteView from .views import EntityCreateView, EntityUpdateView, EntityDeleteView @@ -64,6 +64,9 @@ urlpatterns = patterns( # View for rendering zip file with iptables rules for all entities in a project for a specific location. url(r'^project/(?P\d+)/location/(?P\d+)/iptables/$', project_iptables, name="project_location_iptables"), + # View for showing project communications in a diagram. + url(r'^project/(?P\d+)/diagram/$', project_diagram, name="project_diagram"), + # Views for logging-in/out the users. url(r'^login/$', login, {'template_name': 'conntrackt/login.html'}, name="login"), url(r'^logout/$', logout, name="logout"), diff --git a/conntrackt/utils.py b/conntrackt/utils.py --- a/conntrackt/utils.py +++ b/conntrackt/utils.py @@ -1,11 +1,17 @@ # Standard library imports. import re +import itertools + +# Third-party Python library imports. +import palette +import pydot # Django imports. from django.template import Context, loader # Application imports. import iptables +from .models import Communication def generate_entity_iptables(entity): @@ -52,3 +58,125 @@ def generate_entity_iptables(entity): content = "%s%s" % (filter, nat) return content + + +def get_distinct_colors(count, start=palette.Color("#AE1111")): + """ + Generates a number of distinct colours, and returns them as a list. The + colours are generated using the HSL (hue, saturation, lightness) model, + where saturation and lightness is kept the same for all colours, with + differing hue. The hue difference between each subsequent color in the list + is kept the same. + + Arguments: + + count - Total number of colours that should be generated. + + start - First colour that should be taken as a start point. All colours + are generated relative to this colour by increasing the hue. Should be + an instance of palette.Color class. Defaults to RGB colour "#AE1111". + + Return: + + List of distinct palette.Color instances. + """ + + # Read the HSL from provided Color. + hue, sat, lum = start.hsl["h"], start.hsl["s"], start.hsl["l"] + + # Calculate the step increase. + step = 1 / float(count) + + # Initiate an empty list that will store the generated colours. + colors = [] + + # Generate new colour by increasing the hue as long as we haven't generated + # the requested number of colours. + while len(colors) < count: + colors.append(palette.Color(hsl=(hue, sat, lum))) + hue += step + + return colors + + +def generate_project_diagram(project): + """ + Generates communication diagram for provided project. + + Arguments: + + project - Project for which the diagram should be generated. Instance of + conntrackt.models.Project class. + + Returns: + + Dot diagram (digraph) representing all of the communications in a + project. + """ + + # Set-up the graph object. + graph = pydot.Dot(graph_name=project.name, graph_type="digraph", bgcolor="transparent", nodesep="1.5") + # Set-up defaults for the graph nodes. + graph.set_node_defaults(shape="record") + + # Obtain list of all entities in a project. + entities = project.entity_set.all() + + # Set-up dictinary that will contains clusters of entities belonging to same + # location. + clusters = {} + + # Dictinoary for storing mapping between nodes and colours. + node_colors = {} + + # Get distinct colours, one for each node/entity. + colors = get_distinct_colors(entities.count()) + + # Created nodes based on entities, and put them into correct cluster. + for entity in entities: + + # Try to get the existing cluster based on location name. + location = entity.location + cluster_name = location.name.replace(" ", "_").lower() + cluster = clusters.get(cluster_name, None) + + # Set-up a new cluster for location encountered for the first time. + if cluster is None: + cluster = pydot.Cluster(graph_name=cluster_name, label=location.name) + clusters[cluster_name] = cluster + + # Fetch a colour that will be associated with the node/entity. + node_color = colors.pop() + node_colors[entity.id] = node_color.hex + + # Determine whether the node label should be black or white based on brightness of node colour. + node_color_brightness = 1 - (node_color.rgb["r"] * 0.299 + node_color.rgb["g"] * 0.587 + node_color.rgb["b"] * 0.114) + + if node_color_brightness < 0.5: + font_color = "black" + else: + font_color = "white" + + # Finally create the node, and add it to location cluster. + node = pydot.Node(entity.name, style="filled", color=node_color.hex, fontcolor=font_color) + cluster.add_node(node) + + # Add clusters to the graph. + for cluster in clusters.values(): + graph.add_subgraph(cluster) + + # Get all project communications. + communications = Communication.objects.filter(source__entity__project=project) + + # Add the edges (lines) representing communications, drawing them with same + # colour as the source node/entity. + for comm in communications: + edge_color = node_colors[comm.source.entity.id] + + label = '"%s:%s"' % (comm.protocol, str(comm.port)) + + edge = pydot.Edge(comm.source.entity.name, comm.destination.entity.name, label=label, color=edge_color) + + graph.add_edge(edge) + + return graph diff --git a/conntrackt/views.py b/conntrackt/views.py --- a/conntrackt/views.py +++ b/conntrackt/views.py @@ -16,7 +16,7 @@ from braces.views import MultiplePermiss # Application imports. from .forms import ProjectForm, LocationForm, EntityForm, InterfaceForm, CommunicationForm from .models import Project, Entity, Location, Interface, Communication -from .utils import generate_entity_iptables +from .utils import generate_entity_iptables, generate_project_diagram class IndexView(MultiplePermissionsRequiredMixin, TemplateView): @@ -880,3 +880,38 @@ class CommunicationDeleteView(SetHeadlin """ return "Delete communication %s" % self.object + + +@permission_required("conntrackt.view", raise_exception=True) +def project_diagram(request, pk): + """ + Custom view that returns response containing diagram of project + communications. + + The diagram will include coloured entities, with directional lines + connecting the source and destination end entities. + + The output format is SVG. + + Arguments: + + request - Request object. + + pk - Project ID for which the diagram should be generated. + + Returns: + + Response object that contains the project diagram rendered as SVG. + """ + + # Fetch the project. + project = get_object_or_404(Project, pk=pk) + + # Generate the diagram. + content = generate_project_diagram(project).create_svg() + + # Set the mime type. + response = HttpResponse(content, mimetype='image/svg+xml') + + # Return the response object. + return response diff --git a/requirements/base.txt b/requirements/base.txt --- a/requirements/base.txt +++ b/requirements/base.txt @@ -2,3 +2,5 @@ django>=1.5 South django-braces django-crispy-forms +pyparsing==1.5.7 +pydot