# -*- coding: utf-8 -*- # # Copyright (C) 2013 Branko Majic # # This file is part of Django Conntrackt. # # Django Conntrackt is free software: you can redistribute it and/or modify it # under the terms of the GNU General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. # # Django Conntrackt is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or # FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more # details. # # You should have received a copy of the GNU General Public License along with # Django Conntrackt. If not, see . # # Standard library imports. import json from StringIO import StringIO from zipfile import ZipFile, ZIP_DEFLATED # Python third-party library imports. import mock # Django imports. from django.core.exceptions import ValidationError from django.core.urlresolvers import reverse from django.http import Http404 from django.test import RequestFactory from django.test import TestCase from django.utils.http import urlquote # Application imports from conntrackt.models import Project, Location, Entity, Interface, Communication from conntrackt.views import IndexView, SearchView, APISearchView from conntrackt.views import entity_iptables, project_iptables, project_diagram from conntrackt.views import ProjectView, ProjectCreateView, ProjectUpdateView, ProjectDeleteView from conntrackt.views import LocationCreateView, LocationUpdateView, LocationDeleteView from conntrackt.views import EntityView, EntityCreateView, EntityUpdateView, EntityDeleteView from conntrackt.views import InterfaceCreateView, InterfaceUpdateView, InterfaceDeleteView from conntrackt.views import CommunicationCreateView, CommunicationUpdateView, CommunicationDeleteView # Test imports. from .forms import FormWithWidgetCSSClassFormMixin, FormWithPlaceholderFormMixin from .helpers import PermissionTestMixin, create_get_request, generate_get_response, FakeMessages from .views import RedirectToNextMixinView from .factories import setup_test_data class IndexViewTest(PermissionTestMixin, TestCase): sufficient_permissions = ("view",) view_class = IndexView def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context_no_projects(self): """ Verifies that the context is properly set-up when the view is called and no projects are available. """ Project.objects.all().delete() # Get the view. view = IndexView.as_view() # Get the response. response = generate_get_response(view) # Validate the response. self.assertQuerysetEqual(response.context_data["projects"], []) def test_context_no_locations(self): """ Verifies that the context is properly set-up when the view is called and no locations are available. """ Location.objects.all().delete() # Get the view. view = IndexView.as_view() # Get the response. response = generate_get_response(view) # Validate the response. self.assertQuerysetEqual(response.context_data["locations"], []) def test_context_projects(self): """ Verifies that the context is properly set-up when the view is called and there's multiple projects available. """ # Get the view. view = IndexView.as_view() # Get the response. response = generate_get_response(view) self.assertQuerysetEqual(response.context_data["projects"], ["", ""]) def test_locations_available(self): """ Verifies that the context is properly set-up when the view is called and there's multiple locationsg available. """ # Get the view. view = IndexView.as_view() # Get the response. response = generate_get_response(view) # Validate the response. self.assertQuerysetEqual(response.context_data["locations"], ["", ""]) class ProjectViewTest(PermissionTestMixin, TestCase): sufficient_permissions = ("view",) permission_test_view_kwargs = {"pk": "1"} view_class = ProjectView def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called. """ # Get the view. view = ProjectView.as_view() # Get the response. response = generate_get_response(view, pk=1) # Fetch context data from response. location, entities = response.context_data["location_entities"][0] # Set-up expected context data values. expected_entities = ["", ""] # Validate context data. self.assertEqual(location.name, "Test Location 1") self.assertQuerysetEqual(entities, expected_entities) # Fetch context data from response. location, entities = response.context_data["location_entities"][1] # Set-up expected context data values. expected_entities = ["", ""] # Validate context data. self.assertEqual(location.name, "Test Location 2") self.assertQuerysetEqual(entities, expected_entities) # Validate context data. self.assertEqual(str(response.context_data["project"]), "Test Project 1") class EntityViewTest(PermissionTestMixin, TestCase): view_class = EntityView sufficient_permissions = ("view",) permission_test_view_kwargs = {"pk": "1"} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Tests if the form comes pre-populated with proper content. """ # Get the view. view = EntityView.as_view() # Get the response. response = generate_get_response(view, pk=1) # Set-up expected context data. expected_entity = Entity.objects.get(pk=1) expected_incoming_communications = [" Test Entity 1 (TCP:22)>", " Test Entity 1 (ICMP:8)>", " Test Entity 1 (TCP:3306)>", " Test Entity 1 (TCP:22)>"] expected_outgoing_communications = [" Test Entity 2 (UDP:123)>", " Test Entity 3 (UDP:53)>"] expected_interfaces = [""] # Validate the response. self.assertQuerysetEqual(response.context_data["interfaces"], expected_interfaces) self.assertQuerysetEqual(response.context_data["incoming_communications"], expected_incoming_communications) self.assertQuerysetEqual(response.context_data["outgoing_communications"], expected_outgoing_communications) self.assertEqual(response.context_data["entity"], expected_entity) self.assertTrue("entity_iptables" in response.context_data) class EntityIptablesTest(PermissionTestMixin, TestCase): view_function = staticmethod(entity_iptables) sufficient_permissions = ("view",) permission_test_view_kwargs = {"pk": "1"} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_invalid_entity(self): """ Tests if a 404 is returned if no entity was found (invalid ID). """ # Set-up a request. request = create_get_request() # Get the view. view = entity_iptables # Validate the response. self.assertRaises(Http404, view, request, pk=200) def test_content_type(self): """ Test if correct content type is being returned by the response. """ # Get the view. view = entity_iptables # Get the response. response = generate_get_response(view, pk=1) self.assertEqual(response['Content-Type'], "text/plain") def test_content_disposition(self): """ Test if the correct content disposition has been set. """ # Get the view. view = entity_iptables # Get the response. response = generate_get_response(view, pk=1) self.assertEqual(response['Content-Disposition'], "attachment; filename=test_entity_1-iptables.conf") def test_content(self): """ Tests content produced by the view. """ # Get the view. view = entity_iptables # Get the response. response = generate_get_response(view, pk=1) self.assertContains(response, ":INPUT") self.assertContains(response, ":OUTPUT") self.assertContains(response, ":FORWARD") class ProjectIptablesTest(PermissionTestMixin, TestCase): view_function = staticmethod(project_iptables) sufficient_permissions = ("view",) permission_test_view_kwargs = {"project_id": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_invalid_project(self): """ Tests if a 404 is returned if no project was found (invalid ID). """ # Set-up a request. request = create_get_request() # Get the view. view = project_iptables # Request iptables for whole project. self.assertRaises(Http404, view, request, 200) # Request iptables for project location self.assertRaises(Http404, view, request, 200, 1) def test_invalid_location(self): """ Tests if a 404 is returned if no location was found (invalid ID). """ # Set-up a request. request = create_get_request() # Get the view. view = project_iptables # Request iptables for project location self.assertRaises(Http404, view, request, 1, 200) def test_content_type(self): """ Test if correct content type is being returned by the response. """ # Get the view. view = project_iptables # Get the response. response = generate_get_response(view, None, 1) # Validate the response. self.assertEqual(response['Content-Type'], "application/zip") def test_content_disposition(self): """ Test if the correct content disposition has been set. """ # Get the view. view = project_iptables # Get the response. response = generate_get_response(view, None, 1) self.assertEqual(response['Content-Disposition'], 'attachment; filename="test_project_1-iptables.zip"') response = generate_get_response(view, None, 1, 1) self.assertEqual(response['Content-Disposition'], 'attachment; filename="test_project_1-test_location_1-iptables.zip"') def test_content_project(self): """ Verifies that the content is properly generated when the view is called for an entire project. """ # Get the view. view = project_iptables # Get the response. response = generate_get_response(project_iptables, None, 1) buff = StringIO(response.content) zipped_iptables = ZipFile(buff, "r", ZIP_DEFLATED) expected_zip_files = ["test_entity_1-iptables.conf", "test_entity_2-iptables.conf", "test_entity_3-iptables.conf", "test_subnet_4-iptables.conf"] self.assertEqual(len(zipped_iptables.namelist()), 4) self.assertEqual(zipped_iptables.namelist(), expected_zip_files) for filename in expected_zip_files: iptables_file = zipped_iptables.read(filename) self.assertIn(":INPUT", iptables_file) self.assertIn(":OUTPUT", iptables_file) self.assertIn(":FORWARD", iptables_file) zipped_iptables.close() def test_content_location(self): """ Verifies that the content is properly generated when the view is called for an entire project. """ # Get the view. view = project_iptables # Get the response. response = generate_get_response(project_iptables, None, 1, 1) buff = StringIO(response.content) zipped_iptables = ZipFile(buff, "r", ZIP_DEFLATED) expected_zip_files = ["test_entity_1-iptables.conf", "test_entity_2-iptables.conf"] self.assertEqual(len(zipped_iptables.namelist()), 2) self.assertEqual(zipped_iptables.namelist(), expected_zip_files) for filename in expected_zip_files: iptables_file = zipped_iptables.read(filename) self.assertIn(":INPUT", iptables_file) self.assertIn(":OUTPUT", iptables_file) self.assertIn(":FORWARD", iptables_file) zipped_iptables.close() class ProjectCreateViewTest(PermissionTestMixin, TestCase): view_class = ProjectCreateView sufficient_permissions = ("add_project",) class ProjectUpdateViewTest(PermissionTestMixin, TestCase): view_class = ProjectUpdateView sufficient_permissions = ("change_project",) permission_test_view_kwargs = {"pk": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific project. """ # Get the view. view = ProjectUpdateView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) self.assertEqual(response.context_data["project"].name, "Test Project 1") self.assertEqual(response.context_data["headline"], "Update project Test Project 1") class ProjectDeleteViewTest(PermissionTestMixin, TestCase): view_class = ProjectDeleteView sufficient_permissions = ("delete_project",) permission_test_view_kwargs = {"pk": "1"} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific project. """ # Get the expected project. project = Project.objects.get(pk=1) # Get the view. view = ProjectDeleteView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) self.assertEqual(response.context_data["project"], project) self.assertEqual(response.context_data["headline"], "Delete project Test Project 1") def test_message(self): """ Tests if the message gets added when the project is deleted. """ # Get the view. view = ProjectDeleteView.as_view() # Generate the request. request = RequestFactory().post("/fake-path/") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertIn("Project Test Project 1 has been removed.", request._messages.messages) class LocationCreateViewTest(PermissionTestMixin, TestCase): view_class = LocationCreateView sufficient_permissions = ("add_location",) class LocationUpdateViewTest(PermissionTestMixin, TestCase): view_class = LocationUpdateView sufficient_permissions = ("change_location",) permission_test_view_kwargs = {"pk": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific location. """ # Get the view. view = LocationUpdateView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) self.assertEqual(response.context_data["location"].name, "Test Location 1") self.assertEqual(response.context_data["headline"], "Update location Test Location 1") class LocationDeleteViewTest(PermissionTestMixin, TestCase): view_class = LocationDeleteView sufficient_permissions = ("delete_location",) permission_test_view_kwargs = {"pk": "1"} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific location. """ # Get the expected location. location = Location.objects.get(pk=1) # Get the view. view = LocationDeleteView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) self.assertEqual(response.context_data["location"], location) self.assertEqual(response.context_data["headline"], "Delete location Test Location 1") def test_message(self): """ Tests if the message gets added when the location is deleted. """ # Get the view. view = LocationDeleteView.as_view() # Generate the request. request = RequestFactory().post("/fake-path/") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertIn("Location Test Location 1 has been removed.", request._messages.messages) class EntityCreateViewTest(PermissionTestMixin, TestCase): view_class = EntityCreateView sufficient_permissions = ("add_entity",) def setUp(self): """ Sets-up some data necessary for testing. """ # Set-up some data for testing. Project.objects.create(name="Test Project 1", description="This is test project 1.") Project.objects.create(name="Test Project 2", description="This is test project 2.") Location.objects.create(name="Test Location 1", description="This is test location 1.") Location.objects.create(name="Test Location 2", description="This is test location 2.") def test_form_project_limit(self): """ Tests if the queryset is properly limitted to specific project if GET parameters is passed. """ # Set-up the view. view = EntityCreateView() view.request = RequestFactory().get("/fake-path?project=1") view.object = None # Get the form. form = view.get_form(view.get_form_class()) self.assertQuerysetEqual(form.fields["project"].queryset, [""]) def test_form_location_limit(self): """ Tests if the queryset is properly limitted to specific location if GET parameters is passed. """ # Set-up the view. view = EntityCreateView() view.request = RequestFactory().get("/fake-path?location=1") view.object = None # Get the form. form = view.get_form(view.get_form_class()) self.assertQuerysetEqual(form.fields["location"].queryset, [""]) def test_initial_project(self): """ Tests if the choice field for project is defaulted to project passed as part of GET parameters. """ view = EntityCreateView() view.request = RequestFactory().get("/fake-path?project=1") view.object = None initial = view.get_initial() self.assertDictContainsSubset({"project": "1"}, initial) def test_initial_location(self): """ Tests if the choice field for location is defaulted to location passed as part of GET parameters. """ view = EntityCreateView() view.request = RequestFactory().get("/fake-path?location=1") view.object = None initial = view.get_initial() self.assertDictContainsSubset({"location": "1"}, initial) class EntityDeleteViewTest(PermissionTestMixin, TestCase): view_class = EntityDeleteView sufficient_permissions = ("delete_entity",) permission_test_view_kwargs = {"pk": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific entity. """ # Get the expected entity. entity = Entity.objects.get(pk=1) # Get the view. view = EntityDeleteView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) self.assertEqual(response.context_data["entity"], entity) self.assertEqual(response.context_data["headline"], "Delete entity Test Entity 1") def test_message(self): """ Tests if the message gets added when the entity is deleted. """ # Get the view. view = EntityDeleteView.as_view() # Generate the request. request = RequestFactory().post("/fake-path/") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertIn("Entity Test Entity 1 has been removed.", request._messages.messages) def test_success_url(self): """ Validate that the success URL is set properly after delete. """ # Get the view. view = EntityDeleteView.as_view() # Generate the request request = RequestFactory().post("/fake-path/") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], reverse("project", args=(1,))) class EntityUpdateViewTest(PermissionTestMixin, TestCase): view_class = EntityUpdateView sufficient_permissions = ("change_entity",) permission_test_view_kwargs = {"pk": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific entity. """ # Get the view. view = EntityUpdateView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) self.assertEqual(response.context_data["entity"].name, "Test Entity 1") self.assertEqual(response.context_data["headline"], "Update entity Test Entity 1") class InterfaceCreateViewTest(PermissionTestMixin, TestCase): view_class = InterfaceCreateView sufficient_permissions = ("add_interface",) def setUp(self): """ Sets-up some data necessary for testing. """ # Set-up some data for testing. project = Project.objects.create(name="Test Project", description="This is test project.") location = Location.objects.create(name="Test Location", description="This is test location.") Entity.objects.create(name="Test Entity 1", description="This is test entity 1.", project=project, location=location) Entity.objects.create(name="Test Entity 2", description="This is test entity 2.", project=project, location=location) def test_form_entity_limit(self): """ Tests if the queryset is properly limitted to specific entity if GET parameter is passed. """ # Set-up the view. view = InterfaceCreateView() view.request = RequestFactory().get("/fake-path?entity=1") view.object = None # Get the form. form = view.get_form(view.get_form_class()) self.assertQuerysetEqual(form.fields["entity"].queryset, [""]) def test_initial_project(self): """ Tests if the choice field for entity is defaulted to entity passed as part of GET parameters. """ view = InterfaceCreateView() view.request = RequestFactory().get("/fake-path?entity=1") view.object = None initial = view.get_initial() self.assertDictContainsSubset({"entity": "1"}, initial) def test_success_url(self): """ Validate that the success URL is set properly after interface is created. """ # Get the view. view = InterfaceCreateView.as_view() # Generate the request. post_data = {"name": "eth0", "description": "Main interface.", "entity": "1", "address": "192.168.1.1", "netmask": "255.255.255.255"} request = RequestFactory().post("/fake-path/", data=post_data) request.user = mock.Mock() request._dont_enforce_csrf_checks = True # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], reverse("entity", args=(1,))) self.assertEqual(response.status_code, 302) class InterfaceUpdateViewTest(PermissionTestMixin, TestCase): view_class = InterfaceUpdateView sufficient_permissions = ("change_interface",) permission_test_view_kwargs = {"pk": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific entity. """ # Get the view. view = InterfaceUpdateView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) # Set-up expected interface. interface = Interface.objects.get(pk=1) self.assertEqual(response.context_data["interface"], interface) self.assertEqual(response.context_data["headline"], "Update interface eth0") def test_form_entity_limit(self): """ Tests if the queryset is properly limitted to specific project's entities. """ # Set-up the view. view = InterfaceUpdateView() view.request = RequestFactory().get("/fake-path/1") view.object = Interface.objects.get(pk=1) # Get the form. form = view.get_form(view.get_form_class()) expected_entities = ["", "", "", ""] self.assertQuerysetEqual(form.fields["entity"].queryset, expected_entities) def test_success_url(self): """ Validate that the success URL is set properly after update. """ # Get the view. view = InterfaceUpdateView.as_view() # Get the interface object. interface = Interface.objects.get(pk=1) # Generate the request. post_data = {"name": interface.name, "description": interface.name, "entity": "1", "address": "192.168.1.1", "netmask": "255.255.255.255"} request = RequestFactory().post("/fake-path/", data=post_data) request.user = mock.Mock() request._dont_enforce_csrf_checks = True # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], reverse("entity", args=(1,))) self.assertEqual(response.status_code, 302) class InterfaceDeleteViewTest(PermissionTestMixin, TestCase): view_class = InterfaceDeleteView sufficient_permissions = ("delete_interface",) permission_test_view_kwargs = {"pk": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific interface. """ # Get the expected entity. interface = Interface.objects.get(pk=1) # Get the view. view = InterfaceDeleteView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) self.assertEqual(response.context_data["interface"], interface) self.assertEqual(response.context_data["headline"], "Delete interface eth0") def test_message(self): """ Tests if the message gets added when the interface is deleted. """ # Get the view. view = InterfaceDeleteView.as_view() # Generate the request. request = RequestFactory().post("/fake-path/") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertIn("Interface eth0 has been removed.", request._messages.messages) def test_success_url(self): """ Validate that the success URL is set properly after delete. """ # Get the view. view = InterfaceDeleteView.as_view() # Generate the request request = RequestFactory().post("/fake-path/") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], reverse("entity", args=(1,))) class CommunicationCreateViewTest(PermissionTestMixin, TestCase): view_class = CommunicationCreateView sufficient_permissions = ("add_communication",) def setUp(self): """ Sets-up some data necessary for testing. """ # Set-up some data for testing. project1 = Project.objects.create(name="Test Project 1", description="This is test project 1.") project2 = Project.objects.create(name="Test Project 2", description="This is test project 2.") location = Location.objects.create(name="Test Location", description="This is test location.") entity1 = Entity.objects.create(name="Test Entity 1", description="This is test entity 1.", project=project1, location=location) entity2 = Entity.objects.create(name="Test Entity 2", description="This is test entity 2.", project=project1, location=location) entity3 = Entity.objects.create(name="Test Entity 3", description="This is test entity 3.", project=project2, location=location) Interface.objects.create(name="eth0", description="Main interface", entity=entity1, address="192.168.1.1", netmask="255.255.255.255") Interface.objects.create(name="eth0", description="Main interface", entity=entity2, address="192.168.1.2", netmask="255.255.255.255") Interface.objects.create(name="eth0", description="Main interface", entity=entity3, address="192.168.1.3", netmask="255.255.255.255") def test_interface_limit_from_entity(self): """ Tests if the queryset is properly limitted if GET parameter is passed. """ # Set-up the view. view = CommunicationCreateView() view.request = RequestFactory().get("/fake-path?from_entity=1") view.object = None # Get the form. form = view.get_form(view.get_form_class()) # Set-up expected interfaces. expected_interfaces = ["", ""] self.assertQuerysetEqual(form.fields["source"].queryset, expected_interfaces) self.assertQuerysetEqual(form.fields["destination"].queryset, expected_interfaces) def test_interface_limit_to_entity(self): """ Tests if the queryset is properly limitted if GET parameter is passed. """ # Set-up the view. view = CommunicationCreateView() view.request = RequestFactory().get("/fake-path?to_entity=1") view.object = None # Get the form. form = view.get_form(view.get_form_class()) # Set-up expected interfaces. expected_interfaces = ["", ""] self.assertQuerysetEqual(form.fields["source"].queryset, expected_interfaces) self.assertQuerysetEqual(form.fields["destination"].queryset, expected_interfaces) def test_interface_limit_project(self): """ Tests if the queryset is properly limitted if GET parameter is passed. """ # Set-up the view. view = CommunicationCreateView() view.request = RequestFactory().get("/fake-path?project=1") view.object = None # Get the form. form = view.get_form(view.get_form_class()) # Set-up expected interfaces. expected_interfaces = ["", ""] self.assertQuerysetEqual(form.fields["source"].queryset, expected_interfaces) self.assertQuerysetEqual(form.fields["destination"].queryset, expected_interfaces) def test_initial_from_entity(self): """ Tests if the choice field for interface is defaulted to first interface of entity passed as part of GET parameters. """ # Set-up the view. view = CommunicationCreateView() view.request = RequestFactory().get("/fake-path?from_entity=1") view.object = None # Get the expected interface ID. interface = Entity.objects.get(pk=1).interface_set.all()[0] # Fetch the initial values. initial = view.get_initial() self.assertDictContainsSubset({"source": interface.pk}, initial) def test_initial_to_entity(self): """ Tests if the choice field for interface is defaulted to first interface of entity passed as part of GET parameters. """ # Set-up the view. view = CommunicationCreateView() view.request = RequestFactory().get("/fake-path?to_entity=1") view.object = None # Get the expected interface ID. interface = Entity.objects.get(pk=1).interface_set.all()[0] # Fetch the initial value. initial = view.get_initial() self.assertDictContainsSubset({"destination": interface.pk}, initial) def test_initial_invalid_from_entity(self): """ Tests if the choice fields for source and destination interfaces are not defaulted in case invalid entity ID is passed as GET parameter. """ # Set-up the view. view = CommunicationCreateView() view.request = RequestFactory().get("/fake-path?from_entity=10") view.object = None # Get the initial values. initial = view.get_initial() self.assertEqual(len(initial), 0) def test_initial_invalid_to_entity(self): """ Tests if the choice fields for source and destination interfaces are not defaulted in case invalid entity ID is passed as GET parameter. """ # Set-up the view. view = CommunicationCreateView() view.request = RequestFactory().get("/fake-path?to_entity=10") view.object = None # Get the initial values. initial = view.get_initial() self.assertEqual(len(initial), 0) def test_initial_invalid_project(self): """ Tests if the choice fields for source and destination interfaces are not defaulted in case invalid project ID is passed as GET parameter. """ # Set-up the view. view = CommunicationCreateView() view.request = RequestFactory().get("/fake-path?project=10") view.object = None # Get the initial values. initial = view.get_initial() self.assertEqual(len(initial), 0) def test_success_url_next(self): """ Validate that the success URL is set properly after communication is created if "next" GET parameter is provided. """ # Get the view. view = CommunicationCreateView.as_view() # Generate the request. source = Interface.objects.get(pk=1) destination = Interface.objects.get(pk=2) post_data = {"source": source.pk, "destination": destination.pk, "protocol": "TCP", "port": "22", "description": "SSH."} request = RequestFactory().post("/fake-path?next=/next-page", data=post_data) request.user = mock.Mock() request._dont_enforce_csrf_checks = True # Get the response. response = view(request) self.assertEqual(response["Location"], "/next-page") self.assertEqual(response.status_code, 302) def test_success_url_no_next(self): """ Validate that the success URL is set properly after communication is created if no "next" GET parameter is provided. """ # Get the view. view = CommunicationCreateView.as_view() # Generate the request. source = Interface.objects.get(pk=1) destination = Interface.objects.get(pk=2) post_data = {"source": source.pk, "destination": destination.pk, "protocol": "TCP", "port": "22", "description": "SSH."} request = RequestFactory().post("/fake-path", data=post_data) request.user = mock.Mock() request._dont_enforce_csrf_checks = True # Get the response. response = view(request) self.assertEqual(response["Location"], reverse("project", args=(1,))) self.assertEqual(response.status_code, 302) class CommunicationUpdateViewTest(PermissionTestMixin, TestCase): view_class = CommunicationUpdateView sufficient_permissions = ("change_communication",) permission_test_view_kwargs = {"pk": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up. """ # Get the view. view = CommunicationUpdateView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) # Set-up expected interface. communication = Communication.objects.get(pk=1) self.assertEqual(response.context_data["communication"], communication) self.assertEqual(response.context_data["headline"], "Update communication Test Entity 2 -> Test Entity 1 (TCP:22)") def test_form_interface_limit(self): """ Tests if the queryset is properly limitted to specific project's entity interfaces. """ # Set-up the view. view = CommunicationUpdateView() view.request = RequestFactory().get("/fake-path/1") view.object = Communication.objects.get(pk=1) # Get the form. form = view.get_form(view.get_form_class()) expected_interfaces = ["", "", "", ""] self.assertQuerysetEqual(form.fields["source"].queryset, expected_interfaces) self.assertQuerysetEqual(form.fields["destination"].queryset, expected_interfaces) def test_success_url_next(self): """ Validate that the success URL is set properly after update if GET parameter is passed. """ # Get the view. view = CommunicationUpdateView.as_view() # Get the communication object. communication = Communication.objects.get(pk=1) # Generate the request. post_data = {"source": communication.source.pk, "destination": communication.destination.pk, "protocol": communication.protocol, "port": communication.port} request = RequestFactory().post("/fake-path?next=/next-page", data=post_data) request.user = mock.Mock() request._dont_enforce_csrf_checks = True # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], "/next-page") self.assertEqual(response.status_code, 302) def test_success_url_no_next(self): """ Validate that the success URL is set properly after communication is created if no "next" GET parameter is provided. """ # Get the view. view = CommunicationUpdateView.as_view() # Get the communication object. communication = Communication.objects.get(pk=1) # Generate the request. post_data = {"source": communication.source.pk, "destination": communication.destination.pk, "protocol": communication.protocol, "port": communication.port} request = RequestFactory().post("/fake-path/", data=post_data) request.user = mock.Mock() request._dont_enforce_csrf_checks = True # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], reverse("project", args=(communication.source.entity.project.id,))) self.assertEqual(response.status_code, 302) class CommunicationDeleteViewTest(PermissionTestMixin, TestCase): view_class = CommunicationDeleteView sufficient_permissions = ("delete_communication",) permission_test_view_kwargs = {"pk": 1} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_context(self): """ Verifies that the context is properly set-up when the view is called for specific communication. """ # Get the expected entity. communication = Communication.objects.get(pk=1) # Get the view. view = CommunicationDeleteView.as_view() # Get the response. response = generate_get_response(view, None, pk=1) self.assertEqual(response.context_data["communication"], communication) self.assertEqual(response.context_data["headline"], "Delete communication Test Entity 2 -> Test Entity 1 (TCP:22)") def test_message(self): """ Tests if the message gets added when the communication is deleted. """ # Get the view. view = CommunicationDeleteView.as_view() # Generate the request. request = RequestFactory().post("/fake-path/") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertIn("Communication Test Entity 2 -> Test Entity 1 (TCP:22) has been removed.", request._messages.messages) def test_success_url_from_entity(self): """ Validate that the success URL is set properly after delete. """ # Get the view. view = CommunicationDeleteView.as_view() # Generate the request request = RequestFactory().post("/fake-path?from_entity=1") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], reverse("entity", args=(1,))) def test_success_url_to_entity(self): """ Validate that the success URL is set properly after delete. """ # Get the view. view = CommunicationDeleteView.as_view() # Generate the request request = RequestFactory().post("/fake-path?to_entity=1") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], reverse("entity", args=(1,))) def test_success_url_no_entity(self): """ Validate that the success URL is set properly after delete. """ # Get the view. view = CommunicationDeleteView.as_view() # Get the communication object. communication = Communication.objects.get(pk=1) # Generate the request request = RequestFactory().post("/fake-path") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request, pk=1) self.assertEqual(response["Location"], reverse("entity", args=(communication.source.entity.pk,))) class ProjectDiagramTest(PermissionTestMixin, TestCase): view_function = staticmethod(project_diagram) sufficient_permissions = ("view",) permission_test_view_kwargs = {"pk": "1"} def setUp(self): """ Set-up some test data. """ setup_test_data() def test_invalid_project(self): """ Tests if a 404 is returned if no project was found (invalid ID). """ # Set-up a request. request = create_get_request() # Get the view. view = project_diagram # Validate the response. self.assertRaises(Http404, view, request, pk=200) def test_content_type(self): """ Test if correct content type is being returned by the response. """ # Get the view. view = project_diagram # Get the response. response = generate_get_response(view, pk=1) self.assertEqual(response['Content-Type'], "image/svg+xml") def test_content(self): """ Tests content produced by the view. """ # Get the view. view = project_diagram # Get the response. response = generate_get_response(view, pk=1) self.assertContains(response, '"-//W3C//DTD SVG 1.1//EN"') self.assertContains(response, "Test Project 1") class RedirectToNextMixinTest(TestCase): def test_request_with_next(self): """ Test if the get_success_url returns correct URL if "next" is present in request's GET parameters. """ # Generate the request. request = RequestFactory().post("/fake-path?next=/next") # Initialise the pseudo-view. view = RedirectToNextMixinView(request) self.assertEqual("/next", view.get_success_url()) def test_request_without_next(self): """ Test if the get_success_url returns correct URL if "next" is not present in request's GET parameters. """ # Generate the request. request = RequestFactory().post("/fake-path") # Initialise the pseudo-view. view = RedirectToNextMixinView(request) self.assertEqual("/STATIC", view.get_success_url()) def test_request_custom_parameter_name(self): """ Test if the mixin honours the custom parameter name. """ # Generate the request. request = RequestFactory().post("/fake-path?custom=/next") # Initialise the pseudo-view. view = RedirectToNextMixinView(request) view.next_parameter = "custom" self.assertEqual("/next", view.get_success_url()) class SearchViewTest(PermissionTestMixin, TestCase): sufficient_permissions = ("view",) view_class = SearchView def setUp(self): """ Set-up some test data. """ setup_test_data() def test_empty_query_error_message(self): """ Verifies that an error is reported to the user in case an empty query is submitted. """ # Get the view. view = SearchView.as_view() # Generate the request request = RequestFactory().get("/fake-path?q=") request.user = mock.Mock() request._dont_enforce_csrf_checks = True request._messages = FakeMessages() # Get the response. response = view(request) self.assertIn("Search query is not allowed to be empty.", request._messages.messages) def test_strip_search_term(self): """ Verifies that the search term is stripped when search is performed. """ # Get the view. view = SearchView.as_view() # Set-up a request. search_term = " \t \t something with lots of tabs \t \t" stripped_search_term = "something with lots of tabs" request = RequestFactory().get("/fake-path?q=%s" % urlquote(search_term)) request.user = mock.Mock() request._dont_enforce_csrf_checks = True # Get the response. response = view(request) # Validate the response. self.assertEqual(stripped_search_term, response.context_data["search_term"]) def test_no_query_context(self): """ Tests that context is not set if no query was sent. """ # Get the view. view = SearchView.as_view() # Set-up a request. response = generate_get_response(view) self.assertNotIn("entities", response.context_data) self.assertNotIn("projects", response.context_data) self.assertNotIn("search_term", response.context_data) # Only the "view" context variable should be present. self.assertEqual(1, len(response.context_data)) class APISearchViewTest(PermissionTestMixin, TestCase): sufficient_permissions = ("view",) view_class = APISearchView def setUp(self): """ Set-up some test data. """ setup_test_data() def test_limit_negative(self): """ Test if an exception is raised in case a negative limit is requested. """ # Get the view. view = APISearchView.as_view() # Generate the request. request = RequestFactory().get("/fake-path?limit=-1") request.user = mock.Mock() request._dont_enforce_csrf_checks = True # Validate the response. self.assertRaisesRegexp(ValidationError, "Limit may not be a negative value.", view, request, search_term="test") def test_empty_query(self): """ Test that the response is empty if empty query was sent. """ # Get the view. view = APISearchView.as_view() # Get the response. response = generate_get_response(view, search_term="") self.assertEqual(response['Content-Type'], "application/json") self.assertEqual(response.content, "[]") def test_strip_search_term(self): """ Verifies that the search term is stripped when search is performed. """ # Get the view. view = APISearchView.as_view() # Get the response. response = generate_get_response(view, search_term="Test Entity 1") # Validate the response. expected_content = """[{"project": "Test Project 1", "url": "/conntrackt/entity/1/", "type": "entity", "name": "Test Entity 1"}]""" self.assertEqual(response.content, expected_content) def test_no_items(self): """ Test the response if no items are found. """ # Get the view. view = APISearchView.as_view() # Get the response. response = generate_get_response(view, search_term="string that does not exist") self.assertEqual(response['Content-Type'], "application/json") self.assertEqual(response.content, "[]") def test_entity_found(self): """ Test the response if a single entity is found. """ # Get the view. view = APISearchView.as_view() # Get the response. response = generate_get_response(view, search_term="Test Entity 1") expected_content = """[{"project": "Test Project 1", "url": "/conntrackt/entity/1/", "type": "entity", "name": "Test Entity 1"}]""" self.assertEqual(response['Content-Type'], "application/json") self.assertEqual(response.content, expected_content) def test_project_found(self): """ Test the response if a single project is found. """ # Get the view. view = APISearchView.as_view() # Get the response. response = generate_get_response(view, search_term="Test Project 1") expected_content = """[{"project": "Test Project 1", "url": "/conntrackt/project/1/", "type": "project", "name": "Test Project 1"}]""" self.assertEqual(response['Content-Type'], "application/json") self.assertEqual(response.content, expected_content) def test_multiple_items_found(self): """ Test the response if multiple items are found. """ # Get the view. view = APISearchView.as_view() # Get the response. response = generate_get_response(view, search_term="Test") # Verify that the JSON reply is valid. try: items = json.loads(response.content) except ValueError: self.fail("Parsing of resulting JSON has failed") # Verify that a list of items was returned. self.assertTrue(isinstance(items, list)) # Verify each item. for item in items: # Every item must be a dictionary. self.assertTrue(isinstance(item, dict)) keys = item.keys() # Verify that 4 specific keys are present in dictionary (project, # url, name, type). self.assertEqual(len(keys), 4) self.assertIn("project", keys) self.assertIn("name", keys) self.assertIn("url", keys) self.assertIn("type", keys) # Verify the type associated with item. self.assertIn(item["type"], ["project", "entity"]) def test_content_type(self): """ Test if correct content type is being returned by the response. """ # Get the view. view = APISearchView.as_view() # Get the response. response = generate_get_response(view, search_term="test") self.assertEqual(response['Content-Type'], "application/json")