diff --git a/conntrackt/models.py b/conntrackt/models.py
--- a/conntrackt/models.py
+++ b/conntrackt/models.py
@@ -25,8 +25,9 @@ from django.core.exceptions import Valid
from django.core.urlresolvers import reverse
from django.db import models
from django.db.models.query_utils import Q
-from django.utils.html import format_html
-from django.utils.text import capfirst
+
+# Application imports.
+from .utils import list_formatter_callback
class SearchManager(models.Manager):
@@ -94,13 +95,8 @@ class RelatedCollectorMixin(object):
The resulting nested list can be shown to the user for
warning/notification purposes using the unordered_list template tag.
- Each non-list element will be a string of format:
-
- MODEL_NAME: OBJECT_REPRESENTATION
-
- If object has a callable get_absolute_url method, the object
- representation will be surrouned by HTML anchor tag () where
- target (href) is set to the value of get_absolute_url() method call.
+ Each non-list element will be a string generated using the
+ conntrackt.utils.list_formatter_callback function.
Returns:
Nested list of representations of model objects that depend
@@ -111,30 +107,7 @@ class RelatedCollectorMixin(object):
collector.collect([self])
- def formatter_callback(obj):
- """
- Creates model object representation in format:
-
- MODEL_NAME: OBJECT_REPRESENTATION
-
- If passed object has a callable get_absolute_url method, the
- instance representation will be surrouned by an HTML anchor
- () where target is set to value of the get_absolute_url()
- method call.
-
- Arguments:
- obj - Model object whose representation should be returned.
-
- Returns:
- String represenation of passed model object.
- """
-
- try:
- return format_html('{0}: {2}', capfirst(obj._meta.verbose_name), obj.get_absolute_url(), str(obj))
- except AttributeError:
- return format_html('{0}: {1}', capfirst(obj._meta.verbose_name), str(obj))
-
- return collector.nested(formatter_callback)
+ return collector.nested(list_formatter_callback)
class Project(RelatedCollectorMixin, models.Model):
@@ -152,7 +125,7 @@ class Project(RelatedCollectorMixin, mod
name = models.CharField(max_length=100, unique=True)
description = models.TextField(blank=True)
objects = SearchManager()
- deletion_collect_models = ["Entity", "Interface"]
+ deletion_collect_models = ["Entity", "Interface"]
class Meta:
permissions = (("view", "Can view information"),)
diff --git a/conntrackt/tests/test_models.py b/conntrackt/tests/test_models.py
--- a/conntrackt/tests/test_models.py
+++ b/conntrackt/tests/test_models.py
@@ -18,15 +18,23 @@
# Django Conntrackt. If not, see .
#
+# Standard Python library import.
+import collections
+
+# Python third-party library imports.
+import mock
# Django imports.
from django.core.exceptions import ValidationError
from django.db import IntegrityError
+from django.db.models import Model
from django.test import TestCase
# Application imports.
from conntrackt.models import Project, Location, Entity, Interface, Communication
from conntrackt.models import SearchManager
+from conntrackt.models import NestedObjects
+from conntrackt.utils import list_formatter_callback
# Test imports.
from .factories import ProjectFactory, LocationFactory
@@ -36,6 +44,105 @@ from .factories import CommunicationFact
from .factories import setup_test_data
+class RelatedCollectorMixinTest(TestCase):
+
+ @mock.patch.object(NestedObjects, "collect")
+ @mock.patch.object(NestedObjects, "nested")
+ def test_get_dependant_objects_method_calls(self, nested_mock, collect_mock):
+ """
+ Tests if correct methods are being called with correct arguments during
+ the invocation of get_dependant_objects method.
+ """
+
+ # Set-up some test data.
+ project = ProjectFactory()
+
+ # Call the method.
+ project.get_dependant_objects()
+
+ # Check if correct collector methods were called.
+ collect_mock.assert_called_with([project])
+ nested_mock.assert_called_with()
+
+ def test_get_dependant_objects_return_value(self):
+ """
+ Tests the return value of get_dependant_objects method.
+ """
+
+ # Set-up some test data.
+ project = ProjectFactory()
+ location = LocationFactory()
+ entity1 = ServerEntityFactory(pk=1, project=project, location=location)
+ entity2 = ServerEntityFactory(pk=2, project=project, location=location)
+ communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22")
+
+ # Get the dependant objects.
+ dependant_objects = project.get_dependant_objects()
+
+ # Create a small local function for traversing the recursive list.
+ def traverse(data):
+ # If data is iterable, verify it is a list, and process its members
+ # as well. If data is not iterable, make sure it is descendant of
+ # Django Model class.
+ if isinstance(data, collections.Iterable):
+ self.assertIsInstance(data, list)
+ for element in data:
+ traverse(element)
+ else:
+ self.assertIsInstance(data, Model)
+
+ # Traverse the obtained dependant objects.
+ traverse(dependant_objects)
+
+ @mock.patch.object(NestedObjects, "collect")
+ @mock.patch.object(NestedObjects, "nested")
+ def test_get_dependant_objects_representation_method_calls(self, nested_mock, collect_mock):
+ """
+ Tests if correct methods are being called with correct arguments during
+ the invocation of get_dependant_objects method.
+ """
+
+ # Set-up some test data.
+ project = ProjectFactory()
+
+ # Call the method.
+ project.get_dependant_objects_representation()
+
+ # Check if correct collector methods were called.
+ collect_mock.assert_called_with([project])
+ nested_mock.assert_called_with(list_formatter_callback)
+
+ def test_get_dependant_objects_representation_return_value(self):
+ """
+ Tests the return value of get_dependant_objects_representation method.
+ """
+
+ # Set-up some test data.
+ project = ProjectFactory()
+ location = LocationFactory()
+ entity1 = ServerEntityFactory(pk=1, project=project, location=location)
+ entity2 = ServerEntityFactory(pk=2, project=project, location=location)
+ communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22")
+
+ # Get the dependant objects.
+ dependant_objects = project.get_dependant_objects_representation()
+
+ # Create a small local function for traversing the recursive list.
+ def traverse(data):
+ # If data is iterable, verify it is a list, and process its members
+ # as well. If data is not iterable, make sure it is descendant of
+ # Django Model class.
+ if isinstance(data, collections.Iterable) and not isinstance(data, str):
+ self.assertIsInstance(data, list)
+ for element in data:
+ traverse(element)
+ else:
+ self.assertIsInstance(data, str)
+
+ # Traverse the obtained dependant objects.
+ traverse(dependant_objects)
+
+
class ProjectTest(TestCase):
def test_unique_name(self):
diff --git a/conntrackt/tests/test_utils.py b/conntrackt/tests/test_utils.py
--- a/conntrackt/tests/test_utils.py
+++ b/conntrackt/tests/test_utils.py
@@ -25,6 +25,7 @@ from django.test import TestCase
# Third-party Python library imports.
import palette
import pydot
+import mock
# Application imports.
from conntrackt.models import Entity, Project, Communication
@@ -287,3 +288,40 @@ class GenerateProjectDiagramTest(TestCas
self.assertEqual("transparent", diagram.get_bgcolor())
self.assertEqual("1.5", diagram.get_nodesep())
self.assertEqual([{"shape": "record"}], diagram.get_node_defaults())
+
+
+class ListFormatterCallbackTest(TestCase):
+ """
+ Tests the list_formatter_callback function.
+ """
+
+ def test_get_absolute_url(self):
+ """
+ Test the return result in case the get_absolute_url is available on
+ passed object instance.
+ """
+
+ # Set-up a minimal object mock.
+ obj = mock.Mock(spec=["_meta", "__repr__", "get_absolute_url"])
+ obj._meta = mock.Mock()
+ obj._meta.verbose_name = "name"
+ obj.__repr__ = mock.Mock()
+ obj.__repr__.return_value = "representation"
+ obj.get_absolute_url.return_value = "url"
+
+ self.assertEqual(utils.list_formatter_callback(obj), 'Name: representation')
+
+ def test_no_get_absolute_url(self):
+ """
+ Test the return result in case the get_absolute_url is not available on
+ passed object instance.
+ """
+
+ # Set-up a minimal object mock.
+ obj = mock.Mock(spec=["_meta", "__repr__"])
+ obj._meta = mock.Mock()
+ obj._meta.verbose_name = "name"
+ obj.__repr__ = mock.Mock()
+ obj.__repr__.return_value = "representation"
+
+ self.assertEqual(utils.list_formatter_callback(obj), "Name: representation")
diff --git a/conntrackt/tests/test_views.py b/conntrackt/tests/test_views.py
--- a/conntrackt/tests/test_views.py
+++ b/conntrackt/tests/test_views.py
@@ -1758,4 +1758,3 @@ class APISearchViewTest(PermissionTestMi
response = generate_get_response(view, search_term="test")
self.assertEqual(response['Content-Type'], "application/json")
-
diff --git a/conntrackt/utils.py b/conntrackt/utils.py
--- a/conntrackt/utils.py
+++ b/conntrackt/utils.py
@@ -29,10 +29,12 @@ import pydot
# Django imports.
from django.template import Context, loader
+from django.utils.html import format_html
+from django.utils.text import capfirst
# Application imports.
import iptables
-from .models import Communication
+import models
def generate_entity_iptables(entity):
@@ -187,7 +189,7 @@ def generate_project_diagram(project):
graph.add_subgraph(cluster)
# Get all project communications.
- communications = Communication.objects.filter(source__entity__project=project)
+ communications = models.Communication.objects.filter(source__entity__project=project)
# Add the edges (lines) representing communications, drawing them with same
# colour as the source node/entity.
@@ -201,3 +203,27 @@ def generate_project_diagram(project):
graph.add_edge(edge)
return graph
+
+
+def list_formatter_callback(obj):
+ """
+ Creates model object representation in format:
+
+ MODEL_NAME: OBJECT_REPRESENTATION
+
+ If passed object has a callable get_absolute_url method, the
+ instance representation will be surrouned by an HTML anchor
+ () where target is set to value of the get_absolute_url()
+ method call.
+
+ Arguments:
+ obj - Model object whose representation should be returned.
+
+ Returns:
+ String represenation of passed model object.
+ """
+
+ try:
+ return format_html('{0}: {2}', capfirst(obj._meta.verbose_name), obj.get_absolute_url(), str(obj))
+ except AttributeError:
+ return format_html('{0}: {1}', capfirst(obj._meta.verbose_name), str(obj))
diff --git a/conntrackt/views.py b/conntrackt/views.py
--- a/conntrackt/views.py
+++ b/conntrackt/views.py
@@ -1115,14 +1115,14 @@ class APISearchView(MultiplePermissionsR
items.append({"name": entity.name,
"project": entity.project.name,
"type": "entity",
- "url": entity.get_absolute_url(),})
+ "url": entity.get_absolute_url()})
# Add found projects.
for project in projects:
items.append({"name": project.name,
"project": project.name,
"type": "project",
- "url": project.get_absolute_url(),})
+ "url": project.get_absolute_url()})
# Generate the JSON response.
content = json.dumps(items)