diff --git a/conntrackt/templates/conntrackt/base.html b/conntrackt/templates/conntrackt/base.html
--- a/conntrackt/templates/conntrackt/base.html
+++ b/conntrackt/templates/conntrackt/base.html
@@ -35,7 +35,7 @@
{% block header %}
-
diff --git a/conntrackt/templatetags/conntrackt_tags.py b/conntrackt/templatetags/conntrackt_tags.py
--- a/conntrackt/templatetags/conntrackt_tags.py
+++ b/conntrackt/templatetags/conntrackt_tags.py
@@ -23,6 +23,7 @@
from django import template
from django import urls
from django.utils.html import format_html
+from django.urls import reverse
# Get an instance of Django's template library.
@@ -104,36 +105,34 @@ def active_link(context, url_name, retur
return return_value if matches else ''
-def current_url_equals(context, url_name, **kwargs):
+def current_url_equals(context, view_name, *args, **kwargs):
"""
- Helper function for checking if the specified URL corresponds to the current
- request path in the context.
+ Helper function for checking if the specified view with provided
+ arguments corresponds to the current request path in the context.
+
+ Passed-in positional and keyword arguments are used to resolve URL
+ for views that use them.
Arguments:
- - context - Context of the view being rendered.
+ context
+ Context of the view being rendered.
- - url_name - Name of the URL against which the context request path is
- being checked.
+ view_name
+ Name of the view against which the context request path is
+ being checked.
+
+ args
+ Positional parametrs for the view.
+
+ kwargs
+ Keyword arguments for the view.
"""
- # Assume that we have not been able to resolve the request path to an URL.
- resolved = False
- try:
- # Use the request path, and resolve it to a URL name.
- resolved = urls.resolve(context.get('request').path)
- except urls.Resolver404:
- # This means we haven't been able to resolve the path from request.
- pass
+ request = context.get('request')
+ reversed_url = reverse(view_name, args=args, kwargs=kwargs)
- # If the request was resolved and URL names match, verify that the kwargs
- # match as well.
- matches = resolved and resolved.url_name == url_name
- if matches and kwargs:
- for key in kwargs:
- kwarg = kwargs.get(key)
- resolved_kwarg = resolved.kwargs.get(key)
- if not resolved_kwarg or kwarg != resolved_kwarg:
- return False
+ if request.path == reversed_url:
+ return True
- return matches
+ return False
diff --git a/conntrackt/tests/test_tags.py b/conntrackt/tests/test_tags.py
--- a/conntrackt/tests/test_tags.py
+++ b/conntrackt/tests/test_tags.py
@@ -30,10 +30,14 @@ import mock
# Django imports.
from django.template import Context, Template, TemplateSyntaxError
from django.test import TestCase
+from django.urls import reverse
# Application imports
from conntrackt.templatetags.conntrackt_tags import html_link, active_link, current_url_equals
+# Test imports.
+from .helpers import create_get_request
+
@mock.patch('conntrackt.templatetags.conntrackt_tags.urls.reverse')
class HtmlLinkTest(TestCase):
@@ -187,3 +191,59 @@ class HtmlLinkTest(TestCase):
)
self.assertEqual(link, 'My </a> link')
+
+
+class CurrentUrlEqualsTest(TestCase):
+
+ def get_context_for_view(self, view, *args, **kwargs):
+ """
+ Returns a Context instance where the request path has been
+ constructed using the passed-in view (or view name), and view
+ positional/keyword arguments.
+
+ Arguments:
+
+ view
+ View function or name for request object.
+
+ args
+ Positional arguments to pass into the view.
+
+ kwargs
+ Keyword arguments to pass into the view.
+
+ Returns:
+
+ django.template.Context instance with request.
+ """
+ request = create_get_request(reverse(view, args=args, kwargs=kwargs))
+ context = Context({'request': request})
+
+ return context
+
+ def test_non_matching_url_returns_false(self):
+ request = create_get_request("/this/url/does/not/exist")
+ context = Context({'request': request})
+
+ self.assertEqual(current_url_equals(context, 'index'), False)
+
+ def test_matching_url_returns_true(self):
+ context = self.get_context_for_view('project_create')
+
+ self.assertEqual(current_url_equals(context, 'project_create'), True)
+
+ def test_matching_url_with_different_args_returns_false(self):
+ context = self.get_context_for_view('project', 1)
+
+ self.assertEqual(current_url_equals(context, 'project', 2), False)
+
+ def test_matching_url_with_different_kwargs_returns_false(self):
+ context = self.get_context_for_view('project', pk=1)
+
+ self.assertEqual(current_url_equals(context, 'project', pk=2), False)
+
+ def test_matching_url_with_GET_parameters_returns_true(self):
+ request = create_get_request(reverse('project', kwargs={'pk': 1}) + '?my_get_param=10')
+ context = Context({'request': request})
+
+ self.assertEqual(current_url_equals(context, 'project', pk=1), True)