Changeset - 2d83b9633ce7
[Not reviewed]
default
0 6 0
Branko Majic (branko) - 10 years ago 2013-11-08 23:55:00
branko@majic.rs
CONNT-20: Moved out the formatter function to utils. Updated docs. Styling fixes. Added tests for the new mixin. Added tests for the formatter function. Switched to different import of models in utils in order to avoid circular imports.
6 files changed with 182 insertions and 39 deletions:
0 comments (0 inline, 0 general)
conntrackt/models.py
Show inline comments
 
@@ -22,14 +22,15 @@
 
# Django imports.
 
from django.contrib.admin.util import NestedObjects
 
from django.core.exceptions import ValidationError
 
from django.core.urlresolvers import reverse
 
from django.db import models
 
from django.db.models.query_utils import Q
 
from django.utils.html import format_html
 
from django.utils.text import capfirst
 

	
 
# Application imports.
 
from .utils import list_formatter_callback
 

	
 

	
 
class SearchManager(models.Manager):
 
    """
 
    Custom model manager that implements search for model instances that contain
 
    a specific string (search term) in fields "name" or "description".
 
@@ -91,53 +92,25 @@ class RelatedCollectorMixin(object):
 
        objects that would get deleted in case the calling model object gets
 
        deleted.
 

	
 
        The resulting nested list can be shown to the user for
 
        warning/notification purposes using the unordered_list template tag.
 

	
 
        Each non-list element will be a string of format:
 

	
 
        MODEL_NAME: OBJECT_REPRESENTATION
 

	
 
        If object has a callable get_absolute_url method, the object
 
        representation will be surrouned by HTML anchor tag (<a></a>) where
 
        target (href) is set to the value of get_absolute_url() method call.
 
        Each non-list element will be a string generated using the
 
        conntrackt.utils.list_formatter_callback function.
 

	
 
        Returns:
 
          Nested list of representations of model objects that depend
 
          (reference) calling model object.
 
        """
 

	
 
        collector = NestedObjects(using='default')
 

	
 
        collector.collect([self])
 

	
 
        def formatter_callback(obj):
 
            """
 
            Creates model object representation in format:
 

	
 
            MODEL_NAME: OBJECT_REPRESENTATION
 
            
 
            If passed object has a callable get_absolute_url method, the
 
            instance representation will be surrouned by an HTML anchor
 
            (<a></a>) where target is set to value of the get_absolute_url()
 
            method call.
 

	
 
            Arguments:
 
              obj - Model object whose representation should be returned.
 

	
 
            Returns:
 
              String represenation of passed model object.
 
            """
 

	
 
            try:
 
                return format_html('<strong>{0}</strong>: <a href="{1}">{2}</a>', capfirst(obj._meta.verbose_name), obj.get_absolute_url(), str(obj))
 
            except AttributeError:
 
                return format_html('<strong>{0}</strong>: {1}', capfirst(obj._meta.verbose_name), str(obj))
 

	
 
        return collector.nested(formatter_callback)
 
        return collector.nested(list_formatter_callback)
 

	
 

	
 
class Project(RelatedCollectorMixin, models.Model):
 
    """
 
    Implements a model with information about a project. A project has some
 
    basic settings, and mainly serves the purpose of grouping entities for
 
@@ -149,13 +122,13 @@ class Project(RelatedCollectorMixin, mod
 
      description - Free-form description of the project.
 
    """
 

	
 
    name = models.CharField(max_length=100, unique=True)
 
    description = models.TextField(blank=True)
 
    objects = SearchManager()
 
    deletion_collect_models = ["Entity", "Interface"] 
 
    deletion_collect_models = ["Entity", "Interface"]
 

	
 
    class Meta:
 
        permissions = (("view", "Can view information"),)
 

	
 
    def __unicode__(self):
 
        """
conntrackt/tests/test_models.py
Show inline comments
 
@@ -15,30 +15,137 @@
 
# details.
 
#
 
# You should have received a copy of the GNU General Public License along with
 
# Django Conntrackt.  If not, see <http://www.gnu.org/licenses/>.
 
#
 

	
 
# Standard Python library import.
 
import collections
 

	
 
# Python third-party library imports.
 
import mock
 

	
 
# Django imports.
 
from django.core.exceptions import ValidationError
 
from django.db import IntegrityError
 
from django.db.models import Model
 
from django.test import TestCase
 

	
 
# Application imports.
 
from conntrackt.models import Project, Location, Entity, Interface, Communication
 
from conntrackt.models import SearchManager
 
from conntrackt.models import NestedObjects
 
from conntrackt.utils import list_formatter_callback
 

	
 
# Test imports.
 
from .factories import ProjectFactory, LocationFactory
 
from .factories import ServerEntityFactory, ServerInterfaceFactory
 
from .factories import SubnetEntityFactory, SubnetInterfaceFactory
 
from .factories import CommunicationFactory
 
from .factories import setup_test_data
 

	
 

	
 
class RelatedCollectorMixinTest(TestCase):
 

	
 
    @mock.patch.object(NestedObjects, "collect")
 
    @mock.patch.object(NestedObjects, "nested")
 
    def test_get_dependant_objects_method_calls(self, nested_mock, collect_mock):
 
        """
 
        Tests if correct methods are being called with correct arguments during
 
        the invocation of get_dependant_objects method.
 
        """
 

	
 
        # Set-up some test data.
 
        project = ProjectFactory()
 

	
 
        # Call the method.
 
        project.get_dependant_objects()
 

	
 
        # Check if correct collector methods were called.
 
        collect_mock.assert_called_with([project])
 
        nested_mock.assert_called_with()
 

	
 
    def test_get_dependant_objects_return_value(self):
 
        """
 
        Tests the return value of get_dependant_objects method.
 
        """
 

	
 
        # Set-up some test data.
 
        project = ProjectFactory()
 
        location = LocationFactory()
 
        entity1 = ServerEntityFactory(pk=1, project=project, location=location)
 
        entity2 = ServerEntityFactory(pk=2, project=project, location=location)
 
        communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22")
 

	
 
        # Get the dependant objects.
 
        dependant_objects = project.get_dependant_objects()
 

	
 
        # Create a small local function for traversing the recursive list.
 
        def traverse(data):
 
            # If data is iterable, verify it is a list, and process its members
 
            # as well. If data is not iterable, make sure it is descendant of
 
            # Django Model class.
 
            if isinstance(data, collections.Iterable):
 
                self.assertIsInstance(data, list)
 
                for element in data:
 
                    traverse(element)
 
            else:
 
                self.assertIsInstance(data, Model)
 

	
 
        # Traverse the obtained dependant objects.
 
        traverse(dependant_objects)
 

	
 
    @mock.patch.object(NestedObjects, "collect")
 
    @mock.patch.object(NestedObjects, "nested")
 
    def test_get_dependant_objects_representation_method_calls(self, nested_mock, collect_mock):
 
        """
 
        Tests if correct methods are being called with correct arguments during
 
        the invocation of get_dependant_objects method.
 
        """
 

	
 
        # Set-up some test data.
 
        project = ProjectFactory()
 

	
 
        # Call the method.
 
        project.get_dependant_objects_representation()
 

	
 
        # Check if correct collector methods were called.
 
        collect_mock.assert_called_with([project])
 
        nested_mock.assert_called_with(list_formatter_callback)
 

	
 
    def test_get_dependant_objects_representation_return_value(self):
 
        """
 
        Tests the return value of get_dependant_objects_representation method.
 
        """
 

	
 
        # Set-up some test data.
 
        project = ProjectFactory()
 
        location = LocationFactory()
 
        entity1 = ServerEntityFactory(pk=1, project=project, location=location)
 
        entity2 = ServerEntityFactory(pk=2, project=project, location=location)
 
        communication1 = CommunicationFactory(pk=1, source_id=1, destination_id=2, protocol="TCP", port="22")
 

	
 
        # Get the dependant objects.
 
        dependant_objects = project.get_dependant_objects_representation()
 

	
 
        # Create a small local function for traversing the recursive list.
 
        def traverse(data):
 
            # If data is iterable, verify it is a list, and process its members
 
            # as well. If data is not iterable, make sure it is descendant of
 
            # Django Model class.
 
            if isinstance(data, collections.Iterable) and not isinstance(data, str):
 
                self.assertIsInstance(data, list)
 
                for element in data:
 
                    traverse(element)
 
            else:
 
                self.assertIsInstance(data, str)
 

	
 
        # Traverse the obtained dependant objects.
 
        traverse(dependant_objects)
 

	
 

	
 
class ProjectTest(TestCase):
 

	
 
    def test_unique_name(self):
 
        """
 
        Test if unique project name is enforced.
 
        """
conntrackt/tests/test_utils.py
Show inline comments
 
@@ -22,12 +22,13 @@
 
# Django imports.
 
from django.test import TestCase
 

	
 
# Third-party Python library imports.
 
import palette
 
import pydot
 
import mock
 

	
 
# Application imports.
 
from conntrackt.models import Entity, Project, Communication
 
from conntrackt import utils
 
from .factories import setup_test_data
 

	
 
@@ -284,6 +285,43 @@ class GenerateProjectDiagramTest(TestCas
 
        diagram = utils.generate_project_diagram(project)
 

	
 
        self.assertEqual("digraph", diagram.get_graph_type())
 
        self.assertEqual("transparent", diagram.get_bgcolor())
 
        self.assertEqual("1.5", diagram.get_nodesep())
 
        self.assertEqual([{"shape": "record"}], diagram.get_node_defaults())
 

	
 

	
 
class ListFormatterCallbackTest(TestCase):
 
    """
 
    Tests the list_formatter_callback function.
 
    """
 

	
 
    def test_get_absolute_url(self):
 
        """
 
        Test the return result in case the get_absolute_url is available on
 
        passed object instance.
 
        """
 

	
 
        # Set-up a minimal object mock.
 
        obj = mock.Mock(spec=["_meta", "__repr__", "get_absolute_url"])
 
        obj._meta = mock.Mock()
 
        obj._meta.verbose_name = "name"
 
        obj.__repr__ = mock.Mock()
 
        obj.__repr__.return_value = "representation"
 
        obj.get_absolute_url.return_value = "url"
 

	
 
        self.assertEqual(utils.list_formatter_callback(obj), '<strong>Name</strong>: <a href="url">representation</a>')
 

	
 
    def test_no_get_absolute_url(self):
 
        """
 
        Test the return result in case the get_absolute_url is not available on
 
        passed object instance.
 
        """
 

	
 
        # Set-up a minimal object mock.
 
        obj = mock.Mock(spec=["_meta", "__repr__"])
 
        obj._meta = mock.Mock()
 
        obj._meta.verbose_name = "name"
 
        obj.__repr__ = mock.Mock()
 
        obj.__repr__.return_value = "representation"
 

	
 
        self.assertEqual(utils.list_formatter_callback(obj), "<strong>Name</strong>: representation")
conntrackt/tests/test_views.py
Show inline comments
 
@@ -1755,7 +1755,6 @@ class APISearchViewTest(PermissionTestMi
 
        view = APISearchView.as_view()
 

	
 
        # Get the response.
 
        response = generate_get_response(view, search_term="test")
 

	
 
        self.assertEqual(response['Content-Type'], "application/json")
 

	
conntrackt/utils.py
Show inline comments
 
@@ -26,16 +26,18 @@ import itertools
 
# Third-party Python library imports.
 
import palette
 
import pydot
 

	
 
# Django imports.
 
from django.template import Context, loader
 
from django.utils.html import format_html
 
from django.utils.text import capfirst
 

	
 
# Application imports.
 
import iptables
 
from .models import Communication
 
import models
 

	
 

	
 
def generate_entity_iptables(entity):
 
    """
 
    Generates full iptables rules for the supplied entity. The generated rules
 
    can be fed directly to the iptables-restore utility.
 
@@ -184,13 +186,13 @@ def generate_project_diagram(project):
 

	
 
    # Add clusters to the graph.
 
    for cluster in clusters.values():
 
        graph.add_subgraph(cluster)
 

	
 
    # Get all project communications.
 
    communications = Communication.objects.filter(source__entity__project=project)
 
    communications = models.Communication.objects.filter(source__entity__project=project)
 

	
 
    # Add the edges (lines) representing communications, drawing them with same
 
    # colour as the source node/entity.
 
    for comm in communications:
 
        edge_color = node_colors[comm.source.entity.id]
 

	
 
@@ -198,6 +200,30 @@ def generate_project_diagram(project):
 

	
 
        edge = pydot.Edge(comm.source.entity.name, comm.destination.entity.name, label=label, color=edge_color)
 

	
 
        graph.add_edge(edge)
 

	
 
    return graph
 

	
 

	
 
def list_formatter_callback(obj):
 
    """
 
    Creates model object representation in format:
 

	
 
    MODEL_NAME: OBJECT_REPRESENTATION
 

	
 
    If passed object has a callable get_absolute_url method, the
 
    instance representation will be surrouned by an HTML anchor
 
    (<a></a>) where target is set to value of the get_absolute_url()
 
    method call.
 

	
 
    Arguments:
 
      obj - Model object whose representation should be returned.
 

	
 
    Returns:
 
      String represenation of passed model object.
 
    """
 

	
 
    try:
 
        return format_html('<strong>{0}</strong>: <a href="{1}">{2}</a>', capfirst(obj._meta.verbose_name), obj.get_absolute_url(), str(obj))
 
    except AttributeError:
 
        return format_html('<strong>{0}</strong>: {1}', capfirst(obj._meta.verbose_name), str(obj))
conntrackt/views.py
Show inline comments
 
@@ -1112,20 +1112,20 @@ class APISearchView(MultiplePermissionsR
 

	
 
            # Add found entities.
 
            for entity in entities:
 
                items.append({"name": entity.name,
 
                              "project": entity.project.name,
 
                              "type": "entity",
 
                              "url": entity.get_absolute_url(),})
 
                              "url": entity.get_absolute_url()})
 

	
 
            # Add found projects.
 
            for project in projects:
 
                items.append({"name": project.name,
 
                              "project": project.name,
 
                              "type": "project",
 
                              "url": project.get_absolute_url(),})
 
                              "url": project.get_absolute_url()})
 

	
 
        # Generate the JSON response.
 
        content = json.dumps(items)
 
        response = HttpResponse(content, mimetype="application/json")
 

	
 
        # Return the response.
0 comments (0 inline, 0 general)