diff --git a/conntrackt/tests/test_models.py b/conntrackt/tests/test_models.py --- a/conntrackt/tests/test_models.py +++ b/conntrackt/tests/test_models.py @@ -18,15 +18,23 @@ # Django Conntrackt. If not, see . # +# Standard Python library import. +import collections + +# Python third-party library imports. +import mock # Django imports. from django.core.exceptions import ValidationError from django.db import IntegrityError +from django.db.models import Model from django.test import TestCase # Application imports. from conntrackt.models import Project, Location, Entity, Interface, Communication from conntrackt.models import SearchManager +from conntrackt.models import NestedObjects +from conntrackt.utils import list_formatter_callback # Test imports. from .factories import ProjectFactory, LocationFactory @@ -36,6 +44,105 @@ from .factories import CommunicationFact from .factories import setup_test_data +class RelatedCollectorMixinTest(TestCase): + + @mock.patch.object(NestedObjects, "collect") + @mock.patch.object(NestedObjects, "nested") + def test_get_dependant_objects_method_calls(self, nested_mock, collect_mock): + """ + Tests if correct methods are being called with correct arguments during + the invocation of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + + # Call the method. + project.get_dependant_objects() + + # Check if correct collector methods were called. + collect_mock.assert_called_with([project]) + nested_mock.assert_called_with() + + def test_get_dependant_objects_return_value(self): + """ + Tests the return value of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + location = LocationFactory() + entity1 = ServerEntityFactory(pk=1, project=project, location=location) + entity2 = ServerEntityFactory(pk=2, project=project, location=location) + communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22") + + # Get the dependant objects. + dependant_objects = project.get_dependant_objects() + + # Create a small local function for traversing the recursive list. + def traverse(data): + # If data is iterable, verify it is a list, and process its members + # as well. If data is not iterable, make sure it is descendant of + # Django Model class. + if isinstance(data, collections.Iterable): + self.assertIsInstance(data, list) + for element in data: + traverse(element) + else: + self.assertIsInstance(data, Model) + + # Traverse the obtained dependant objects. + traverse(dependant_objects) + + @mock.patch.object(NestedObjects, "collect") + @mock.patch.object(NestedObjects, "nested") + def test_get_dependant_objects_representation_method_calls(self, nested_mock, collect_mock): + """ + Tests if correct methods are being called with correct arguments during + the invocation of get_dependant_objects method. + """ + + # Set-up some test data. + project = ProjectFactory() + + # Call the method. + project.get_dependant_objects_representation() + + # Check if correct collector methods were called. + collect_mock.assert_called_with([project]) + nested_mock.assert_called_with(list_formatter_callback) + + def test_get_dependant_objects_representation_return_value(self): + """ + Tests the return value of get_dependant_objects_representation method. + """ + + # Set-up some test data. + project = ProjectFactory() + location = LocationFactory() + entity1 = ServerEntityFactory(pk=1, project=project, location=location) + entity2 = ServerEntityFactory(pk=2, project=project, location=location) + communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22") + + # Get the dependant objects. + dependant_objects = project.get_dependant_objects_representation() + + # Create a small local function for traversing the recursive list. + def traverse(data): + # If data is iterable, verify it is a list, and process its members + # as well. If data is not iterable, make sure it is descendant of + # Django Model class. + if isinstance(data, collections.Iterable) and not isinstance(data, str): + self.assertIsInstance(data, list) + for element in data: + traverse(element) + else: + self.assertIsInstance(data, str) + + # Traverse the obtained dependant objects. + traverse(dependant_objects) + + class ProjectTest(TestCase): def test_unique_name(self):