diff --git a/conntrackt/models.py b/conntrackt/models.py --- a/conntrackt/models.py +++ b/conntrackt/models.py @@ -25,8 +25,9 @@ from django.core.exceptions import Valid from django.core.urlresolvers import reverse from django.db import models from django.db.models.query_utils import Q -from django.utils.html import format_html -from django.utils.text import capfirst + +# Application imports. +from .utils import list_formatter_callback class SearchManager(models.Manager): @@ -94,13 +95,8 @@ class RelatedCollectorMixin(object): The resulting nested list can be shown to the user for warning/notification purposes using the unordered_list template tag. - Each non-list element will be a string of format: - - MODEL_NAME: OBJECT_REPRESENTATION - - If object has a callable get_absolute_url method, the object - representation will be surrouned by HTML anchor tag () where - target (href) is set to the value of get_absolute_url() method call. + Each non-list element will be a string generated using the + conntrackt.utils.list_formatter_callback function. Returns: Nested list of representations of model objects that depend @@ -111,30 +107,7 @@ class RelatedCollectorMixin(object): collector.collect([self]) - def formatter_callback(obj): - """ - Creates model object representation in format: - - MODEL_NAME: OBJECT_REPRESENTATION - - If passed object has a callable get_absolute_url method, the - instance representation will be surrouned by an HTML anchor - () where target is set to value of the get_absolute_url() - method call. - - Arguments: - obj - Model object whose representation should be returned. - - Returns: - String represenation of passed model object. - """ - - try: - return format_html('{0}: {2}', capfirst(obj._meta.verbose_name), obj.get_absolute_url(), str(obj)) - except AttributeError: - return format_html('{0}: {1}', capfirst(obj._meta.verbose_name), str(obj)) - - return collector.nested(formatter_callback) + return collector.nested(list_formatter_callback) class Project(RelatedCollectorMixin, models.Model): @@ -152,7 +125,7 @@ class Project(RelatedCollectorMixin, mod name = models.CharField(max_length=100, unique=True) description = models.TextField(blank=True) objects = SearchManager() - deletion_collect_models = ["Entity", "Interface"] + deletion_collect_models = ["Entity", "Interface"] class Meta: permissions = (("view", "Can view information"),) diff --git a/conntrackt/tests/test_models.py b/conntrackt/tests/test_models.py --- a/conntrackt/tests/test_models.py +++ b/conntrackt/tests/test_models.py @@ -18,15 +18,23 @@ # Django Conntrackt. If not, see . # +# Standard Python library import. +import collections + +# Python third-party library imports. +import mock # Django imports. from django.core.exceptions import ValidationError from django.db import IntegrityError +from django.db.models import Model from django.test import TestCase # Application imports. from conntrackt.models import Project, Location, Entity, Interface, Communication from conntrackt.models import SearchManager +from conntrackt.models import NestedObjects +from conntrackt.utils import list_formatter_callback # Test imports. from .factories import ProjectFactory, LocationFactory @@ -36,6 +44,105 @@ from .factories import CommunicationFact from .factories import setup_test_data +class RelatedCollectorMixinTest(TestCase): + + @mock.patch.object(NestedObjects, "collect") + @mock.patch.object(NestedObjects, "nested") + def test_get_dependant_objects_method_calls(self, nested_mock, collect_mock): + """ + Tests if correct methods are being called with correct arguments during + the invocation of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + + # Call the method. + project.get_dependant_objects() + + # Check if correct collector methods were called. + collect_mock.assert_called_with([project]) + nested_mock.assert_called_with() + + def test_get_dependant_objects_return_value(self): + """ + Tests the return value of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + location = LocationFactory() + entity1 = ServerEntityFactory(pk=1, project=project, location=location) + entity2 = ServerEntityFactory(pk=2, project=project, location=location) + communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22") + + # Get the dependant objects. + dependant_objects = project.get_dependant_objects() + + # Create a small local function for traversing the recursive list. + def traverse(data): + # If data is iterable, verify it is a list, and process its members + # as well. If data is not iterable, make sure it is descendant of + # Django Model class. + if isinstance(data, collections.Iterable): + self.assertIsInstance(data, list) + for element in data: + traverse(element) + else: + self.assertIsInstance(data, Model) + + # Traverse the obtained dependant objects. + traverse(dependant_objects) + + @mock.patch.object(NestedObjects, "collect") + @mock.patch.object(NestedObjects, "nested") + def test_get_dependant_objects_representation_method_calls(self, nested_mock, collect_mock): + """ + Tests if correct methods are being called with correct arguments during + the invocation of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + + # Call the method. + project.get_dependant_objects_representation() + + # Check if correct collector methods were called. + collect_mock.assert_called_with([project]) + nested_mock.assert_called_with(list_formatter_callback) + + def test_get_dependant_objects_representation_return_value(self): + """ + Tests the return value of get_dependant_objects_representation method. + """ + + # Set-up some test data. + project = ProjectFactory() + location = LocationFactory() + entity1 = ServerEntityFactory(pk=1, project=project, location=location) + entity2 = ServerEntityFactory(pk=2, project=project, location=location) + communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22") + + # Get the dependant objects. + dependant_objects = project.get_dependant_objects_representation() + + # Create a small local function for traversing the recursive list. + def traverse(data): + # If data is iterable, verify it is a list, and process its members + # as well. If data is not iterable, make sure it is descendant of + # Django Model class. + if isinstance(data, collections.Iterable) and not isinstance(data, str): + self.assertIsInstance(data, list) + for element in data: + traverse(element) + else: + self.assertIsInstance(data, str) + + # Traverse the obtained dependant objects. + traverse(dependant_objects) + + class ProjectTest(TestCase): def test_unique_name(self): diff --git a/conntrackt/tests/test_utils.py b/conntrackt/tests/test_utils.py --- a/conntrackt/tests/test_utils.py +++ b/conntrackt/tests/test_utils.py @@ -25,6 +25,7 @@ from django.test import TestCase # Third-party Python library imports. import palette import pydot +import mock # Application imports. from conntrackt.models import Entity, Project, Communication @@ -287,3 +288,40 @@ class GenerateProjectDiagramTest(TestCas self.assertEqual("transparent", diagram.get_bgcolor()) self.assertEqual("1.5", diagram.get_nodesep()) self.assertEqual([{"shape": "record"}], diagram.get_node_defaults()) + + +class ListFormatterCallbackTest(TestCase): + """ + Tests the list_formatter_callback function. + """ + + def test_get_absolute_url(self): + """ + Test the return result in case the get_absolute_url is available on + passed object instance. + """ + + # Set-up a minimal object mock. + obj = mock.Mock(spec=["_meta", "__repr__", "get_absolute_url"]) + obj._meta = mock.Mock() + obj._meta.verbose_name = "name" + obj.__repr__ = mock.Mock() + obj.__repr__.return_value = "representation" + obj.get_absolute_url.return_value = "url" + + self.assertEqual(utils.list_formatter_callback(obj), 'Name: representation') + + def test_no_get_absolute_url(self): + """ + Test the return result in case the get_absolute_url is not available on + passed object instance. + """ + + # Set-up a minimal object mock. + obj = mock.Mock(spec=["_meta", "__repr__"]) + obj._meta = mock.Mock() + obj._meta.verbose_name = "name" + obj.__repr__ = mock.Mock() + obj.__repr__.return_value = "representation" + + self.assertEqual(utils.list_formatter_callback(obj), "Name: representation") diff --git a/conntrackt/tests/test_views.py b/conntrackt/tests/test_views.py --- a/conntrackt/tests/test_views.py +++ b/conntrackt/tests/test_views.py @@ -1758,4 +1758,3 @@ class APISearchViewTest(PermissionTestMi response = generate_get_response(view, search_term="test") self.assertEqual(response['Content-Type'], "application/json") - diff --git a/conntrackt/utils.py b/conntrackt/utils.py --- a/conntrackt/utils.py +++ b/conntrackt/utils.py @@ -29,10 +29,12 @@ import pydot # Django imports. from django.template import Context, loader +from django.utils.html import format_html +from django.utils.text import capfirst # Application imports. import iptables -from .models import Communication +import models def generate_entity_iptables(entity): @@ -187,7 +189,7 @@ def generate_project_diagram(project): graph.add_subgraph(cluster) # Get all project communications. - communications = Communication.objects.filter(source__entity__project=project) + communications = models.Communication.objects.filter(source__entity__project=project) # Add the edges (lines) representing communications, drawing them with same # colour as the source node/entity. @@ -201,3 +203,27 @@ def generate_project_diagram(project): graph.add_edge(edge) return graph + + +def list_formatter_callback(obj): + """ + Creates model object representation in format: + + MODEL_NAME: OBJECT_REPRESENTATION + + If passed object has a callable get_absolute_url method, the + instance representation will be surrouned by an HTML anchor + () where target is set to value of the get_absolute_url() + method call. + + Arguments: + obj - Model object whose representation should be returned. + + Returns: + String represenation of passed model object. + """ + + try: + return format_html('{0}: {2}', capfirst(obj._meta.verbose_name), obj.get_absolute_url(), str(obj)) + except AttributeError: + return format_html('{0}: {1}', capfirst(obj._meta.verbose_name), str(obj)) diff --git a/conntrackt/views.py b/conntrackt/views.py --- a/conntrackt/views.py +++ b/conntrackt/views.py @@ -1115,14 +1115,14 @@ class APISearchView(MultiplePermissionsR items.append({"name": entity.name, "project": entity.project.name, "type": "entity", - "url": entity.get_absolute_url(),}) + "url": entity.get_absolute_url()}) # Add found projects. for project in projects: items.append({"name": project.name, "project": project.name, "type": "project", - "url": project.get_absolute_url(),}) + "url": project.get_absolute_url()}) # Generate the JSON response. content = json.dumps(items)