File diff 5193ae7fc0e1 → 2d83b9633ce7
conntrackt/tests/test_models.py
Show inline comments
 
@@ -18,15 +18,23 @@
 
# Django Conntrackt.  If not, see <http://www.gnu.org/licenses/>.
 
#
 

	
 
# Standard Python library import.
 
import collections
 

	
 
# Python third-party library imports.
 
import mock
 

	
 
# Django imports.
 
from django.core.exceptions import ValidationError
 
from django.db import IntegrityError
 
from django.db.models import Model
 
from django.test import TestCase
 

	
 
# Application imports.
 
from conntrackt.models import Project, Location, Entity, Interface, Communication
 
from conntrackt.models import SearchManager
 
from conntrackt.models import NestedObjects
 
from conntrackt.utils import list_formatter_callback
 

	
 
# Test imports.
 
from .factories import ProjectFactory, LocationFactory
 
@@ -36,6 +44,105 @@ from .factories import CommunicationFact
 
from .factories import setup_test_data
 

	
 

	
 
class RelatedCollectorMixinTest(TestCase):
 

	
 
    @mock.patch.object(NestedObjects, "collect")
 
    @mock.patch.object(NestedObjects, "nested")
 
    def test_get_dependant_objects_method_calls(self, nested_mock, collect_mock):
 
        """
 
        Tests if correct methods are being called with correct arguments during
 
        the invocation of get_dependant_objects method.
 
        """
 

	
 
        # Set-up some test data.
 
        project = ProjectFactory()
 

	
 
        # Call the method.
 
        project.get_dependant_objects()
 

	
 
        # Check if correct collector methods were called.
 
        collect_mock.assert_called_with([project])
 
        nested_mock.assert_called_with()
 

	
 
    def test_get_dependant_objects_return_value(self):
 
        """
 
        Tests the return value of get_dependant_objects method.
 
        """
 

	
 
        # Set-up some test data.
 
        project = ProjectFactory()
 
        location = LocationFactory()
 
        entity1 = ServerEntityFactory(pk=1, project=project, location=location)
 
        entity2 = ServerEntityFactory(pk=2, project=project, location=location)
 
        communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22")
 

	
 
        # Get the dependant objects.
 
        dependant_objects = project.get_dependant_objects()
 

	
 
        # Create a small local function for traversing the recursive list.
 
        def traverse(data):
 
            # If data is iterable, verify it is a list, and process its members
 
            # as well. If data is not iterable, make sure it is descendant of
 
            # Django Model class.
 
            if isinstance(data, collections.Iterable):
 
                self.assertIsInstance(data, list)
 
                for element in data:
 
                    traverse(element)
 
            else:
 
                self.assertIsInstance(data, Model)
 

	
 
        # Traverse the obtained dependant objects.
 
        traverse(dependant_objects)
 

	
 
    @mock.patch.object(NestedObjects, "collect")
 
    @mock.patch.object(NestedObjects, "nested")
 
    def test_get_dependant_objects_representation_method_calls(self, nested_mock, collect_mock):
 
        """
 
        Tests if correct methods are being called with correct arguments during
 
        the invocation of get_dependant_objects method.
 
        """
 

	
 
        # Set-up some test data.
 
        project = ProjectFactory()
 

	
 
        # Call the method.
 
        project.get_dependant_objects_representation()
 

	
 
        # Check if correct collector methods were called.
 
        collect_mock.assert_called_with([project])
 
        nested_mock.assert_called_with(list_formatter_callback)
 

	
 
    def test_get_dependant_objects_representation_return_value(self):
 
        """
 
        Tests the return value of get_dependant_objects_representation method.
 
        """
 

	
 
        # Set-up some test data.
 
        project = ProjectFactory()
 
        location = LocationFactory()
 
        entity1 = ServerEntityFactory(pk=1, project=project, location=location)
 
        entity2 = ServerEntityFactory(pk=2, project=project, location=location)
 
        communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22")
 

	
 
        # Get the dependant objects.
 
        dependant_objects = project.get_dependant_objects_representation()
 

	
 
        # Create a small local function for traversing the recursive list.
 
        def traverse(data):
 
            # If data is iterable, verify it is a list, and process its members
 
            # as well. If data is not iterable, make sure it is descendant of
 
            # Django Model class.
 
            if isinstance(data, collections.Iterable) and not isinstance(data, str):
 
                self.assertIsInstance(data, list)
 
                for element in data:
 
                    traverse(element)
 
            else:
 
                self.assertIsInstance(data, str)
 

	
 
        # Traverse the obtained dependant objects.
 
        traverse(dependant_objects)
 

	
 

	
 
class ProjectTest(TestCase):
 

	
 
    def test_unique_name(self):