# HG changeset patch # User Branko Majic # Date 2013-11-08 23:55:00 # Node ID 2d83b9633ce75206ae26a2a0b36b0c5b9b9552b2 # Parent 5193ae7fc0e165998028b4e506b8717dbb915b70 CONNT-20: Moved out the formatter function to utils. Updated docs. Styling fixes. Added tests for the new mixin. Added tests for the formatter function. Switched to different import of models in utils in order to avoid circular imports. diff --git a/conntrackt/models.py b/conntrackt/models.py --- a/conntrackt/models.py +++ b/conntrackt/models.py @@ -25,8 +25,9 @@ from django.core.exceptions import Valid from django.core.urlresolvers import reverse from django.db import models from django.db.models.query_utils import Q -from django.utils.html import format_html -from django.utils.text import capfirst + +# Application imports. +from .utils import list_formatter_callback class SearchManager(models.Manager): @@ -94,13 +95,8 @@ class RelatedCollectorMixin(object): The resulting nested list can be shown to the user for warning/notification purposes using the unordered_list template tag. - Each non-list element will be a string of format: - - MODEL_NAME: OBJECT_REPRESENTATION - - If object has a callable get_absolute_url method, the object - representation will be surrouned by HTML anchor tag () where - target (href) is set to the value of get_absolute_url() method call. + Each non-list element will be a string generated using the + conntrackt.utils.list_formatter_callback function. Returns: Nested list of representations of model objects that depend @@ -111,30 +107,7 @@ class RelatedCollectorMixin(object): collector.collect([self]) - def formatter_callback(obj): - """ - Creates model object representation in format: - - MODEL_NAME: OBJECT_REPRESENTATION - - If passed object has a callable get_absolute_url method, the - instance representation will be surrouned by an HTML anchor - () where target is set to value of the get_absolute_url() - method call. - - Arguments: - obj - Model object whose representation should be returned. - - Returns: - String represenation of passed model object. - """ - - try: - return format_html('{0}: {2}', capfirst(obj._meta.verbose_name), obj.get_absolute_url(), str(obj)) - except AttributeError: - return format_html('{0}: {1}', capfirst(obj._meta.verbose_name), str(obj)) - - return collector.nested(formatter_callback) + return collector.nested(list_formatter_callback) class Project(RelatedCollectorMixin, models.Model): @@ -152,7 +125,7 @@ class Project(RelatedCollectorMixin, mod name = models.CharField(max_length=100, unique=True) description = models.TextField(blank=True) objects = SearchManager() - deletion_collect_models = ["Entity", "Interface"] + deletion_collect_models = ["Entity", "Interface"] class Meta: permissions = (("view", "Can view information"),) diff --git a/conntrackt/tests/test_models.py b/conntrackt/tests/test_models.py --- a/conntrackt/tests/test_models.py +++ b/conntrackt/tests/test_models.py @@ -18,15 +18,23 @@ # Django Conntrackt. If not, see . # +# Standard Python library import. +import collections + +# Python third-party library imports. +import mock # Django imports. from django.core.exceptions import ValidationError from django.db import IntegrityError +from django.db.models import Model from django.test import TestCase # Application imports. from conntrackt.models import Project, Location, Entity, Interface, Communication from conntrackt.models import SearchManager +from conntrackt.models import NestedObjects +from conntrackt.utils import list_formatter_callback # Test imports. from .factories import ProjectFactory, LocationFactory @@ -36,6 +44,105 @@ from .factories import CommunicationFact from .factories import setup_test_data +class RelatedCollectorMixinTest(TestCase): + + @mock.patch.object(NestedObjects, "collect") + @mock.patch.object(NestedObjects, "nested") + def test_get_dependant_objects_method_calls(self, nested_mock, collect_mock): + """ + Tests if correct methods are being called with correct arguments during + the invocation of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + + # Call the method. + project.get_dependant_objects() + + # Check if correct collector methods were called. + collect_mock.assert_called_with([project]) + nested_mock.assert_called_with() + + def test_get_dependant_objects_return_value(self): + """ + Tests the return value of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + location = LocationFactory() + entity1 = ServerEntityFactory(pk=1, project=project, location=location) + entity2 = ServerEntityFactory(pk=2, project=project, location=location) + communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22") + + # Get the dependant objects. + dependant_objects = project.get_dependant_objects() + + # Create a small local function for traversing the recursive list. + def traverse(data): + # If data is iterable, verify it is a list, and process its members + # as well. If data is not iterable, make sure it is descendant of + # Django Model class. + if isinstance(data, collections.Iterable): + self.assertIsInstance(data, list) + for element in data: + traverse(element) + else: + self.assertIsInstance(data, Model) + + # Traverse the obtained dependant objects. + traverse(dependant_objects) + + @mock.patch.object(NestedObjects, "collect") + @mock.patch.object(NestedObjects, "nested") + def test_get_dependant_objects_representation_method_calls(self, nested_mock, collect_mock): + """ + Tests if correct methods are being called with correct arguments during + the invocation of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + + # Call the method. + project.get_dependant_objects_representation() + + # Check if correct collector methods were called. + collect_mock.assert_called_with([project]) + nested_mock.assert_called_with(list_formatter_callback) + + def test_get_dependant_objects_representation_return_value(self): + """ + Tests the return value of get_dependant_objects_representation method. + """ + + # Set-up some test data. + project = ProjectFactory() + location = LocationFactory() + entity1 = ServerEntityFactory(pk=1, project=project, location=location) + entity2 = ServerEntityFactory(pk=2, project=project, location=location) + communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22") + + # Get the dependant objects. + dependant_objects = project.get_dependant_objects_representation() + + # Create a small local function for traversing the recursive list. + def traverse(data): + # If data is iterable, verify it is a list, and process its members + # as well. If data is not iterable, make sure it is descendant of + # Django Model class. + if isinstance(data, collections.Iterable) and not isinstance(data, str): + self.assertIsInstance(data, list) + for element in data: + traverse(element) + else: + self.assertIsInstance(data, str) + + # Traverse the obtained dependant objects. + traverse(dependant_objects) + + class ProjectTest(TestCase): def test_unique_name(self): diff --git a/conntrackt/tests/test_utils.py b/conntrackt/tests/test_utils.py --- a/conntrackt/tests/test_utils.py +++ b/conntrackt/tests/test_utils.py @@ -25,6 +25,7 @@ from django.test import TestCase # Third-party Python library imports. import palette import pydot +import mock # Application imports. from conntrackt.models import Entity, Project, Communication @@ -287,3 +288,40 @@ class GenerateProjectDiagramTest(TestCas self.assertEqual("transparent", diagram.get_bgcolor()) self.assertEqual("1.5", diagram.get_nodesep()) self.assertEqual([{"shape": "record"}], diagram.get_node_defaults()) + + +class ListFormatterCallbackTest(TestCase): + """ + Tests the list_formatter_callback function. + """ + + def test_get_absolute_url(self): + """ + Test the return result in case the get_absolute_url is available on + passed object instance. + """ + + # Set-up a minimal object mock. + obj = mock.Mock(spec=["_meta", "__repr__", "get_absolute_url"]) + obj._meta = mock.Mock() + obj._meta.verbose_name = "name" + obj.__repr__ = mock.Mock() + obj.__repr__.return_value = "representation" + obj.get_absolute_url.return_value = "url" + + self.assertEqual(utils.list_formatter_callback(obj), 'Name: representation') + + def test_no_get_absolute_url(self): + """ + Test the return result in case the get_absolute_url is not available on + passed object instance. + """ + + # Set-up a minimal object mock. + obj = mock.Mock(spec=["_meta", "__repr__"]) + obj._meta = mock.Mock() + obj._meta.verbose_name = "name" + obj.__repr__ = mock.Mock() + obj.__repr__.return_value = "representation" + + self.assertEqual(utils.list_formatter_callback(obj), "Name: representation") diff --git a/conntrackt/tests/test_views.py b/conntrackt/tests/test_views.py --- a/conntrackt/tests/test_views.py +++ b/conntrackt/tests/test_views.py @@ -1758,4 +1758,3 @@ class APISearchViewTest(PermissionTestMi response = generate_get_response(view, search_term="test") self.assertEqual(response['Content-Type'], "application/json") - diff --git a/conntrackt/utils.py b/conntrackt/utils.py --- a/conntrackt/utils.py +++ b/conntrackt/utils.py @@ -29,10 +29,12 @@ import pydot # Django imports. from django.template import Context, loader +from django.utils.html import format_html +from django.utils.text import capfirst # Application imports. import iptables -from .models import Communication +import models def generate_entity_iptables(entity): @@ -187,7 +189,7 @@ def generate_project_diagram(project): graph.add_subgraph(cluster) # Get all project communications. - communications = Communication.objects.filter(source__entity__project=project) + communications = models.Communication.objects.filter(source__entity__project=project) # Add the edges (lines) representing communications, drawing them with same # colour as the source node/entity. @@ -201,3 +203,27 @@ def generate_project_diagram(project): graph.add_edge(edge) return graph + + +def list_formatter_callback(obj): + """ + Creates model object representation in format: + + MODEL_NAME: OBJECT_REPRESENTATION + + If passed object has a callable get_absolute_url method, the + instance representation will be surrouned by an HTML anchor + () where target is set to value of the get_absolute_url() + method call. + + Arguments: + obj - Model object whose representation should be returned. + + Returns: + String represenation of passed model object. + """ + + try: + return format_html('{0}: {2}', capfirst(obj._meta.verbose_name), obj.get_absolute_url(), str(obj)) + except AttributeError: + return format_html('{0}: {1}', capfirst(obj._meta.verbose_name), str(obj)) diff --git a/conntrackt/views.py b/conntrackt/views.py --- a/conntrackt/views.py +++ b/conntrackt/views.py @@ -1115,14 +1115,14 @@ class APISearchView(MultiplePermissionsR items.append({"name": entity.name, "project": entity.project.name, "type": "entity", - "url": entity.get_absolute_url(),}) + "url": entity.get_absolute_url()}) # Add found projects. for project in projects: items.append({"name": project.name, "project": project.name, "type": "project", - "url": project.get_absolute_url(),}) + "url": project.get_absolute_url()}) # Generate the JSON response. content = json.dumps(items)