From de1cc2505a565c368cfaa7ee29cc0d1b0386823d 2020-06-07 22:17:44 From: Branko Majic Date: 2020-06-07 22:17:44 Subject: [PATCH] GC-37: Refactor key specification handling: - Perform the key specification parsing within CLI module itself, don't do it via crypto module. - Pass-in tuple consisting out of algorithm and associated parameters into the init command instead of key generator. - Updated all tests to accomodate the change in init function signature. - Simplify the KeyGenerator class. - Do not test if KeyGenerator class sets the properties via constructor - it is sufficient to test string represenation and key generation. --- diff --git a/gimmecert/cli.py b/gimmecert/cli.py index 413edf5f08fbbfd2ea90c3487d78529ff76cb1c0..483d11dced41929312891326902c2727044331ea 100644 --- a/gimmecert/cli.py +++ b/gimmecert/cli.py @@ -25,7 +25,6 @@ import sys from .decorators import subcommand_parser, get_subcommand_parser_setup_functions from .commands import client, help_, init, renew, server, status, usage, ExitCode -from .crypto import KeyGenerator ERROR_ARGUMENTS = 2 @@ -78,12 +77,40 @@ Examples: """ +def key_specification(specification): + """ + Verifies and parses the passed-in key specification. This is a + small utility function for use with the Python argument parser. + + :param specification: Key specification. Currently supported formats are: "rsa:KEY_SIZE". + :type specification: str + + :returns: Parsed key algorithm and parameter(s) for the algorithm. For RSA, parameter is the RSA key size. + :rtype: tuple(str, int) + + :raises ValueError: If passed-in specification is invalid. + """ + + try: + algorithm, parameters = specification.split(":", 2) + + if algorithm == "rsa": + parameters = int(parameters) + else: + raise ValueError() + + except ValueError: + raise ValueError("Invalid key specification: '%s'" % specification) + + return algorithm, parameters + + @subcommand_parser def setup_init_subcommand_parser(parser, subparsers): subparser = subparsers.add_parser('init', description='Initialise CA hierarchy.') subparser.add_argument('--ca-base-name', '-b', help="Base name to use for CA naming. Default is to use the working directory base name.") subparser.add_argument('--ca-hierarchy-depth', '-d', type=int, help="Depth of CA hierarchy to generate. Default is 1", default=1) - subparser.add_argument('--key-specification', '-k', type=KeyGenerator, + subparser.add_argument('--key-specification', '-k', type=key_specification, help='''Default specification/parameters to use for private key generation. \ For RSA keys, use format rsa:BIT_LENGTH. Default is rsa:2048.''', default="rsa:2048") diff --git a/gimmecert/commands.py b/gimmecert/commands.py index 319fed20c5f9cd7505690df37239945416d8813d..495274720ffc054528022ceaaf2646f6e5b623b9 100644 --- a/gimmecert/commands.py +++ b/gimmecert/commands.py @@ -47,7 +47,7 @@ class InvalidCommandInvocation(Exception): pass -def init(stdout, stderr, project_directory, ca_base_name, ca_hierarchy_depth, key_generator): +def init(stdout, stderr, project_directory, ca_base_name, ca_hierarchy_depth, key_specification): """ Initialises the necessary directory and CA hierarchies for use in the specified directory. @@ -67,8 +67,8 @@ def init(stdout, stderr, project_directory, ca_base_name, ca_hierarchy_depth, ke :param ca_hierarchy_depth: Length/depths of CA hierarchy that should be initialised. E.g. total number of CAs in chain. :type ca_hierarchy_depth: int - :param key_generator: Callable for generating private keys. - :type key_generator: callable[[], cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey] + :param key_specification: Key specification to use when generating private keys for the hierarchy. + :type key_specification: tuple(str, int) :returns: Status code, one from gimmecert.commands.ExitCode. :rtype: int @@ -86,6 +86,7 @@ def init(stdout, stderr, project_directory, ca_base_name, ca_hierarchy_depth, ke gimmecert.storage.initialise_storage(project_directory) # Generate the CA hierarchy. + key_generator = gimmecert.crypto.KeyGenerator(key_specification[0], key_specification[1]) ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy(ca_base_name, ca_hierarchy_depth, key_generator) # Output the CA private keys and certificates. diff --git a/gimmecert/crypto.py b/gimmecert/crypto.py index d630119bef6cf940ca7c337081e80b23a48e573c..71343f81e7ad58aafcf0bc4d492dabad800a0530 100644 --- a/gimmecert/crypto.py +++ b/gimmecert/crypto.py @@ -36,35 +36,24 @@ class KeyGenerator: instance initialisation. """ - def __init__(self, specification): + def __init__(self, algorithm, parameters): """ Initialises an instance. - :param specification: Specification describing the private keys that that instance should be generating. - For RSA keys, use syntax "rsa:BIT_LENGTH". - :type specification: str + :param algorithm: Algorithm to use. Supported algorithms: 'rsa'. + :type algorithm: str - :raises ValueError: If passed-in specification is invalid. + :param parameters: Parameters for generating the keys using the specified algorithm. For RSA keys this is key size. + :type parameters: int """ - try: - # This will throw ValueError if we can't get two values - # assigned via split. - key_type, key_parameters = specification.split(":", 2) - - if key_type == "rsa" and key_parameters.isnumeric(): - self._algorithm = "rsa" - self._parameters = int(key_parameters) - else: - raise ValueError() - - except ValueError: - raise ValueError("Invalid key specification: '%s'" % specification) + self._algorithm = algorithm + self._parameters = parameters def __str__(self): """ - Returns string (human-readable) representation of stored key - algorithm and parameters. + Returns string (human-readable) representation of stored algorithm + and parameters. :returns: String representation of object. :rtype: str diff --git a/tests/conftest.py b/tests/conftest.py index 618868ca1ccb594bdc55284ab2d4fecc9fb21b81..c8e8f8b431496906d129cf1093a847d37f5fe82d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -143,7 +143,7 @@ def sample_project_directory(tmpdir): gimmecert.storage.write_csr(csr, custom_csr_dir.join("%s.csr.pem" % name).strpath) # Initialise one-level deep hierarchy. - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) # Issue a bunch of certificates. for i in range(1, per_type_count + 1): @@ -186,6 +186,6 @@ def gctmpdir(tmpdir): """ # Initialise one-level deep hierarchy. - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) return tmpdir diff --git a/tests/test_cli.py b/tests/test_cli.py index 44ce1fa5c4cc011d6a5ba6b2ce1d042808a0d825..d97903205f9e8dd5e1c6b92c69b0bb148e54d127 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -376,8 +376,7 @@ def test_command_exists_and_accepts_help_flag(tmpdir, command, help_option): @mock.patch('sys.argv', ['gimmecert', 'init']) @mock.patch('gimmecert.cli.init') -@mock.patch('gimmecert.cli.KeyGenerator') -def test_init_command_invoked_with_correct_parameters_no_options(mock_key_generator, mock_init, tmpdir): +def test_init_command_invoked_with_correct_parameters_no_options(mock_init, tmpdir): # This should ensure we don't accidentally create artifacts # outside of test directory. tmpdir.chdir() @@ -386,18 +385,14 @@ def test_init_command_invoked_with_correct_parameters_no_options(mock_key_genera default_depth = 1 - mock_key_generator.return_value = mock.Mock() - gimmecert.cli.main() - mock_key_generator.assert_called_once_with("rsa:2048") - mock_init.assert_called_once_with(sys.stdout, sys.stderr, tmpdir.strpath, tmpdir.basename, default_depth, mock_key_generator.return_value) + mock_init.assert_called_once_with(sys.stdout, sys.stderr, tmpdir.strpath, tmpdir.basename, default_depth, ('rsa', 2048)) -@mock.patch('sys.argv', ['gimmecert', 'init', '-b', 'My Project']) +@mock.patch('sys.argv', ['gimmecert', 'init', '-b', 'My Project', '-k', 'rsa:4096']) @mock.patch('gimmecert.cli.init') -@mock.patch('gimmecert.cli.KeyGenerator') -def test_init_command_invoked_with_correct_parameters_with_options(mock_key_generator, mock_init, tmpdir): +def test_init_command_invoked_with_correct_parameters_with_options(mock_init, tmpdir): # This should ensure we don't accidentally create artifacts # outside of test directory. tmpdir.chdir() @@ -406,11 +401,9 @@ def test_init_command_invoked_with_correct_parameters_with_options(mock_key_gene default_depth = 1 - mock_key_generator.return_value = mock.Mock() - gimmecert.cli.main() - mock_init.assert_called_once_with(sys.stdout, sys.stderr, tmpdir.strpath, 'My Project', default_depth, mock_key_generator.return_value) + mock_init.assert_called_once_with(sys.stdout, sys.stderr, tmpdir.strpath, 'My Project', default_depth, ('rsa', 4096)) @mock.patch('sys.argv', ['gimmecert', 'server']) @@ -709,3 +702,29 @@ def test_renew_command_fails_if_both_new_private_key_and_csr_options_are_specifi assert mock_renew.called is False assert e_info.value.code != 0 + + +@pytest.mark.parametrize("key_specification", [ + "", + "rsa", + "rsa:not_a_number", + "unsupported:algorithm", +]) +def test_key_specification_raises_exception_for_invalid_specification(key_specification): + + with pytest.raises(ValueError) as e_info: + gimmecert.cli.key_specification(key_specification) + + assert str(e_info.value) == "Invalid key specification: '%s'" % key_specification + + +@pytest.mark.parametrize("key_specification, expected_return_value", [ + ("rsa:1024", ("rsa", 1024)), + ("rsa:2048", ("rsa", 2048)), + ("rsa:4096", ("rsa", 4096)), +]) +def test_key_specification_returns_algorithm_and_parameters_for_valid_specification(key_specification, expected_return_value): + + algorithm, parameters = gimmecert.cli.key_specification(key_specification) # should not raise + + assert (algorithm, parameters) == expected_return_value diff --git a/tests/test_commands.py b/tests/test_commands.py index 737367ca03d4d3f392697a2d389e3d5bac18ea58..2adcb34e0a42ffaadd56fe08675c77ddbf36fae4 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -38,7 +38,7 @@ def test_init_sets_up_directory_structure(tmpdir): ca_dir = tmpdir.join('.gimmecert', 'ca') server_dir = tmpdir.join('.gimmecert', 'server') - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) assert os.path.exists(base_dir.strpath) assert os.path.exists(ca_dir.strpath) @@ -46,7 +46,7 @@ def test_init_sets_up_directory_structure(tmpdir): def test_init_generates_single_ca_artifact_for_depth_1(tmpdir): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) assert os.path.exists(tmpdir.join('.gimmecert', 'ca', 'level1.key.pem').strpath) assert os.path.exists(tmpdir.join('.gimmecert', 'ca', 'level1.cert.pem').strpath) @@ -54,7 +54,7 @@ def test_init_generates_single_ca_artifact_for_depth_1(tmpdir): def test_init_generates_three_ca_artifacts_for_depth_3(tmpdir): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, ("rsa", 2048)) assert os.path.exists(tmpdir.join('.gimmecert', 'ca', 'level1.key.pem').strpath) assert os.path.exists(tmpdir.join('.gimmecert', 'ca', 'level1.cert.pem').strpath) @@ -66,7 +66,7 @@ def test_init_generates_three_ca_artifacts_for_depth_3(tmpdir): def test_init_outputs_full_chain_for_depth_1(tmpdir): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) level1_certificate = tmpdir.join('.gimmecert', 'ca', 'level1.cert.pem').read() full_chain = tmpdir.join('.gimmecert', 'ca', 'chain-full.cert.pem').read() @@ -75,7 +75,7 @@ def test_init_outputs_full_chain_for_depth_1(tmpdir): def test_init_outputs_full_chain_for_depth_3(tmpdir): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, ("rsa", 2048)) level1_certificate = tmpdir.join('.gimmecert', 'ca', 'level1.cert.pem').read() level2_certificate = tmpdir.join('.gimmecert', 'ca', 'level2.cert.pem').read() @@ -88,26 +88,26 @@ def test_init_outputs_full_chain_for_depth_3(tmpdir): def test_init_returns_success_if_directory_has_not_been_previously_initialised(tmpdir): - status_code = gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + status_code = gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) assert status_code == gimmecert.commands.ExitCode.SUCCESS def test_init_returns_error_code_if_directory_has_been_previously_initialised(tmpdir): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) - status_code = gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) + status_code = gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) assert status_code == gimmecert.commands.ExitCode.ERROR_ALREADY_INITIALISED def test_init_does_not_overwrite_artifcats_if_already_initialised(tmpdir): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) level1_private_key_before = tmpdir.join('.gimmecert', 'ca', 'level1.key.pem').read() level1_certificate_before = tmpdir.join('.gimmecert', 'ca', 'level1.cert.pem').read() full_chain_before = tmpdir.join('.gimmecert', 'ca', 'chain-full.cert.pem').read() - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) level1_private_key_after = tmpdir.join('.gimmecert', 'ca', 'level1.key.pem').read() level1_certificate_after = tmpdir.join('.gimmecert', 'ca', 'level1.cert.pem').read() @@ -250,7 +250,7 @@ def test_init_command_stdout_and_stderr_for_single_ca(tmpdir): stdout_stream = io.StringIO() stderr_stream = io.StringIO() - gimmecert.commands.init(stdout_stream, stderr_stream, tmpdir.strpath, "myproject", 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(stdout_stream, stderr_stream, tmpdir.strpath, "myproject", 1, ("rsa", 2048)) stdout = stdout_stream.getvalue() stderr = stderr_stream.getvalue() @@ -266,7 +266,7 @@ def test_init_command_stdout_and_stderr_for_multiple_cas_with_rsa_1024(tmpdir): stdout_stream = io.StringIO() stderr_stream = io.StringIO() - gimmecert.commands.init(stdout_stream, stderr_stream, tmpdir.strpath, "myproject", 3, gimmecert.crypto.KeyGenerator("rsa:1024")) + gimmecert.commands.init(stdout_stream, stderr_stream, tmpdir.strpath, "myproject", 3, ("rsa", 1024)) stdout = stdout_stream.getvalue() stderr = stderr_stream.getvalue() @@ -286,9 +286,9 @@ def test_init_command_stdout_and_stderr_if_hierarchy_already_initialised(tmpdir) stdout_stream = io.StringIO() stderr_stream = io.StringIO() - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, "myproject", 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, "myproject", 1, ("rsa", 2048)) - gimmecert.commands.init(stdout_stream, stderr_stream, tmpdir.strpath, "myproject", 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(stdout_stream, stderr_stream, tmpdir.strpath, "myproject", 1, ("rsa", 2048)) stdout = stdout_stream.getvalue() stderr = stderr_stream.getvalue() @@ -628,7 +628,7 @@ def test_status_reports_ca_hierarchy_information(tmpdir): stderr_stream = io.StringIO() with freeze_time('2018-01-01 00:15:00'): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, ("rsa", 2048)) with freeze_time('2018-06-01 00:15:00'): status_code = gimmecert.commands.status(stdout_stream, stderr_stream, tmpdir.strpath) @@ -677,7 +677,7 @@ def test_status_reports_server_certificate_information(tmpdir): gimmecert.storage.write_csr(myserver3_csr, myserver3_csr_file.strpath) with freeze_time('2018-01-01 00:15:00'): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, ("rsa", 2048)) with freeze_time('2018-02-01 00:15:00'): gimmecert.commands.server(io.StringIO(), io.StringIO(), tmpdir.strpath, 'myserver1', None, None) @@ -745,7 +745,7 @@ def test_status_reports_client_certificate_information(tmpdir): gimmecert.storage.write_csr(myclient3_csr, myclient3_csr_file.strpath) with freeze_time('2018-01-01 00:15:00'): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 3, ("rsa", 2048)) with freeze_time('2018-02-01 00:15:00'): gimmecert.commands.client(io.StringIO(), io.StringIO(), tmpdir.strpath, 'myclient1', None) @@ -803,7 +803,7 @@ def test_status_reports_no_server_certificates_were_issued(tmpdir): # Just create some sample data, but no server certificates. with freeze_time('2018-01-01 00:15:00'): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) gimmecert.commands.client(io.StringIO(), io.StringIO(), tmpdir.strpath, 'myclient1', None) gimmecert.commands.client(io.StringIO(), io.StringIO(), tmpdir.strpath, 'myclient2', None) @@ -824,7 +824,7 @@ def test_status_reports_no_client_certificates_were_issued(tmpdir): # Just create some sample data, but no client certificates. with freeze_time('2018-01-01 00:15:00'): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, tmpdir.basename, 1, ("rsa", 2048)) gimmecert.commands.server(io.StringIO(), io.StringIO(), tmpdir.strpath, 'myserver1', None, None) gimmecert.commands.server(io.StringIO(), io.StringIO(), tmpdir.strpath, 'myserver2', None, None) @@ -863,7 +863,7 @@ def test_certificate_marked_as_not_valid_or_expired_as_appropriate(tmpdir, subje # Perform action on our fixed issuance date. with freeze_time(issuance_date): - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, "My Project", 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, "My Project", 1, ("rsa", 2048)) gimmecert.commands.server(io.StringIO(), io.StringIO(), tmpdir.strpath, 'myserver', None, None) gimmecert.commands.client(io.StringIO(), io.StringIO(), tmpdir.strpath, 'myclient', None) diff --git a/tests/test_crypto.py b/tests/test_crypto.py index f5062aaa319ac66d9020c0453c4b3dfb755cf98a..28eeeaec21772e09c26a8d5c18ea9e1aed43bd72 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -108,7 +108,7 @@ def test_generate_ca_hierarchy_returns_list_with_3_elements_for_depth_3(): base_name = 'My Project' depth = 3 - hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa:2048")) + hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa", 2048)) assert isinstance(hierarchy, list) assert len(hierarchy) == depth @@ -118,7 +118,7 @@ def test_generate_ca_hierarchy_returns_list_with_1_element_for_depth_1(): base_name = 'My Project' depth = 1 - hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa:2048")) + hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa", 2048)) assert isinstance(hierarchy, list) assert len(hierarchy) == depth @@ -128,7 +128,7 @@ def test_generate_ca_hierarchy_returns_list_of_private_key_certificate_pairs(): base_name = 'My Project' depth = 3 - hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa:2048")) + hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa", 2048)) for private_key, certificate in hierarchy: assert isinstance(private_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey) @@ -138,7 +138,7 @@ def test_generate_ca_hierarchy_returns_list_of_private_key_certificate_pairs(): def test_generate_ca_hierarchy_subject_dns_have_correct_value(): base_name = 'My Project' depth = 3 - key_generator = gimmecert.crypto.KeyGenerator("rsa:2048") + key_generator = gimmecert.crypto.KeyGenerator("rsa", 2048) level1, level2, level3 = [certificate for _, certificate in gimmecert.crypto.generate_ca_hierarchy(base_name, depth, key_generator)] @@ -151,7 +151,7 @@ def test_generate_ca_hierarchy_issuer_dns_have_correct_value(): base_name = 'My Project' depth = 3 - hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa:2048")) + hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa", 2048)) level1_key, level1_certificate = hierarchy[0] level2_key, level2_certificate = hierarchy[1] @@ -166,7 +166,7 @@ def test_generate_ca_hierarchy_private_keys_match_with_public_keys_in_certificat base_name = 'My Project' depth = 3 - hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa:2048")) + hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa", 2048)) level1_private_key, level1_certificate = hierarchy[0] level2_private_key, level2_certificate = hierarchy[1] @@ -181,7 +181,7 @@ def test_generate_ca_hierarchy_cas_have_differing_keys(): base_name = 'My Project' depth = 3 - hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa:2048")) + hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa", 2048)) level1_private_key, _ = hierarchy[0] level2_private_key, _ = hierarchy[1] @@ -200,7 +200,7 @@ def test_generate_ca_hierarchy_certificates_have_same_validity(): base_name = 'My Project' depth = 3 - hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa:2048")) + hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa", 2048)) _, level1_certificate = hierarchy[0] _, level2_certificate = hierarchy[1] @@ -250,7 +250,7 @@ def test_generate_ca_hierarchy_produces_certificates_with_ca_basic_constraints() base_name = 'My test' depth = 3 - hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa:2048")) + hierarchy = gimmecert.crypto.generate_ca_hierarchy(base_name, depth, gimmecert.crypto.KeyGenerator("rsa", 2048)) for _, certificate in hierarchy: stored_extension = certificate.extensions.get_extension_for_class(cryptography.x509.BasicConstraints) @@ -263,7 +263,7 @@ def test_generate_ca_hierarchy_produces_certificates_with_ca_basic_constraints() def test_issue_server_certificate_returns_certificate(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -274,7 +274,7 @@ def test_issue_server_certificate_returns_certificate(): def test_issue_server_certificate_sets_correct_extensions(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -319,7 +319,7 @@ def test_issue_server_certificate_sets_correct_extensions(): def test_issue_server_certificate_has_correct_issuer_and_subject(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 4, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 4, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[3] private_key = gimmecert.crypto.generate_private_key() @@ -331,7 +331,7 @@ def test_issue_server_certificate_has_correct_issuer_and_subject(): def test_issue_server_certificate_has_correct_public_key(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -343,7 +343,7 @@ def test_issue_server_certificate_has_correct_public_key(): @freeze_time('2018-01-01 00:15:00') def test_issue_server_certificate_not_before_is_15_minutes_in_past(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -355,7 +355,7 @@ def test_issue_server_certificate_not_before_is_15_minutes_in_past(): def test_issue_server_certificate_not_before_does_not_exceed_ca_validity(): with freeze_time('2018-01-01 00:15:00'): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] @@ -369,7 +369,7 @@ def test_issue_server_certificate_not_before_does_not_exceed_ca_validity(): def test_issue_server_certificate_not_after_does_not_exceed_ca_validity(): with freeze_time('2018-01-01 00:15:00'): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] @@ -382,7 +382,7 @@ def test_issue_server_certificate_not_after_does_not_exceed_ca_validity(): def test_issue_server_certificate_incorporates_additional_dns_subject_alternative_names(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -403,7 +403,7 @@ def test_issue_server_certificate_incorporates_additional_dns_subject_alternativ def test_issue_client_certificate_returns_certificate(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -414,7 +414,7 @@ def test_issue_client_certificate_returns_certificate(): def test_issue_client_certificate_has_correct_issuer_and_subject(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 4, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 4, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[3] private_key = gimmecert.crypto.generate_private_key() @@ -426,7 +426,7 @@ def test_issue_client_certificate_has_correct_issuer_and_subject(): def test_issue_client_certificate_sets_correct_extensions(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -463,7 +463,7 @@ def test_issue_client_certificate_sets_correct_extensions(): def test_issue_client_certificate_has_correct_public_key(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -475,7 +475,7 @@ def test_issue_client_certificate_has_correct_public_key(): @freeze_time('2018-01-01 00:15:00') def test_issue_client_certificate_not_before_is_15_minutes_in_past(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -487,7 +487,7 @@ def test_issue_client_certificate_not_before_is_15_minutes_in_past(): def test_issue_client_certificate_not_before_does_not_exceed_ca_validity(): with freeze_time('2018-01-01 00:15:00'): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] @@ -501,7 +501,7 @@ def test_issue_client_certificate_not_before_does_not_exceed_ca_validity(): def test_issue_client_certificate_not_after_does_not_exceed_ca_validity(): with freeze_time('2018-01-01 00:15:00'): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] @@ -514,7 +514,7 @@ def test_issue_client_certificate_not_after_does_not_exceed_ca_validity(): def test_renew_certificate_returns_certificate(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -526,7 +526,7 @@ def test_renew_certificate_returns_certificate(): def test_renew_certificate_has_correct_content(): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -546,7 +546,7 @@ def test_renew_certificate_not_before_is_15_minutes_in_past(): # Initial server certificate. with freeze_time('2018-01-01 00:15:00'): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -563,7 +563,7 @@ def test_renew_certificate_not_before_does_not_exceed_ca_validity(): # Initial server certificate. with freeze_time('2018-01-01 00:15:00'): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -580,7 +580,7 @@ def test_renew_certificate_not_after_does_not_exceed_ca_validity(): # Initial server certificate. with freeze_time('2018-01-01 00:15:00'): - ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + ca_hierarchy = gimmecert.crypto.generate_ca_hierarchy('My Project', 1, gimmecert.crypto.KeyGenerator("rsa", 2048)) issuer_private_key, issuer_certificate = ca_hierarchy[0] private_key = gimmecert.crypto.generate_private_key() @@ -618,67 +618,31 @@ def test_generate_csr_returns_csr_with_passed_in_name(): assert csr.subject == expected_subject_dn -@pytest.mark.parametrize("key_specification", [ - "", - "rsa", - "rsa:not_a_number", - "unsupported:algorithm", +@pytest.mark.parametrize("algorithm, parameters, string_representation", [ + ("rsa", 1024, "1024-bit RSA"), + ("rsa", 2048, "2048-bit RSA"), + ("rsa", 4096, "4096-bit RSA"), ]) -def test_KeyGenerator_raises_exception_for_invalid_specification(key_specification): +def test_KeyGenerator_string_representation(algorithm, parameters, string_representation): - with pytest.raises(ValueError) as e_info: - gimmecert.crypto.KeyGenerator(key_specification) - - assert str(e_info.value) == "Invalid key specification: '%s'" % key_specification - - -@pytest.mark.parametrize("key_specification", [ - "rsa:1024", - "rsa:2048", - "rsa:4096", -]) -def test_KeyGenerator_accepts_valid_specifications(key_specification): - - gimmecert.crypto.KeyGenerator(key_specification) # should not raise - - -def test_KeyGenerator_stores_specification(): - - key_generator = gimmecert.crypto.KeyGenerator("rsa:2048") - - assert key_generator._algorithm == "rsa" - assert key_generator._parameters == 2048 - - -@pytest.mark.parametrize("key_specification, string_representation", [ - ("rsa:1024", "1024-bit RSA"), - ("rsa:2048", "2048-bit RSA"), - ("rsa:4096", "4096-bit RSA"), -]) -def test_KeyGenerator_string_representation(key_specification, string_representation): - - key_generator = gimmecert.crypto.KeyGenerator(key_specification) + key_generator = gimmecert.crypto.KeyGenerator(algorithm, parameters) assert str(key_generator) == string_representation -def test_KeyGenerator_instance_returns_rsa_private_key(): - - key_generator_1 = gimmecert.crypto.KeyGenerator("rsa:1024") - key_generator_2 = gimmecert.crypto.KeyGenerator("rsa:2048") +@pytest.mark.parametrize("key_size", [1024, 2048, 4096]) +def test_KeyGenerator_instance_returns_rsa_private_key_of_correct_size(key_size): - private_key_1 = key_generator_1() - private_key_2 = key_generator_2() + key_generator = gimmecert.crypto.KeyGenerator("rsa", key_size) - assert isinstance(private_key_1, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey) - assert isinstance(private_key_2, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey) + private_key = key_generator() - assert private_key_1.key_size == 1024 - assert private_key_2.key_size == 2048 + assert isinstance(private_key, cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey) + assert private_key.key_size == key_size @pytest.mark.parametrize("key_generator, expected_bit_size", [ - (gimmecert.crypto.KeyGenerator("rsa:1024"), 1024), - (gimmecert.crypto.KeyGenerator("rsa:2048"), 2048), + (gimmecert.crypto.KeyGenerator("rsa", 1024), 1024), + (gimmecert.crypto.KeyGenerator("rsa", 2048), 2048), ]) def test_generate_ca_hierarchy_uses_correct_rsa_bit_size(key_generator, expected_bit_size): base_name = "My Test" diff --git a/tests/test_storage.py b/tests/test_storage.py index cec4aeb16bb9e8df2d46e26e18dd159f6be9a13d..beea3a5942a8f4c69103d058f41c06433355114d 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -80,7 +80,7 @@ def test_write_certificate(tmpdir): def test_write_certificate_chain(tmpdir): output_file = tmpdir.join('chain.cert.pem') - certificate_chain = [certificate for _, certificate in gimmecert.crypto.generate_ca_hierarchy('My Project', 3, gimmecert.crypto.KeyGenerator("rsa:2048"))] + certificate_chain = [certificate for _, certificate in gimmecert.crypto.generate_ca_hierarchy('My Project', 3, gimmecert.crypto.KeyGenerator("rsa", 2048))] level1_pem, level2_pem, level3_pem = [gimmecert.utils.certificate_to_pem(certificate) for certificate in certificate_chain] gimmecert.storage.write_certificate_chain(certificate_chain, output_file.strpath) @@ -106,7 +106,7 @@ def test_is_initialised_returns_false_if_directory_is_not_initialised(tmpdir): def test_read_ca_hierarchy_returns_list_of_ca_private_key_and_certificate_pairs_for_single_ca(tmpdir): tmpdir.chdir() - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, 'My Project', 1, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, 'My Project', 1, ("rsa", 2048)) ca_hierarchy = gimmecert.storage.read_ca_hierarchy(tmpdir.join('.gimmecert', 'ca').strpath) @@ -146,7 +146,7 @@ def test_read_certificate_returns_certificate(tmpdir): def test_read_ca_hierarchy_returns_list_of_ca_private_key_and_certificate_pairs_in_hierarchy_order_for_multiple_cas(tmpdir): tmpdir.chdir() - gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, 'My Project', 4, gimmecert.crypto.KeyGenerator("rsa:2048")) + gimmecert.commands.init(io.StringIO(), io.StringIO(), tmpdir.strpath, 'My Project', 4, ("rsa", 2048)) ca_hierarchy = gimmecert.storage.read_ca_hierarchy(tmpdir.join('.gimmecert', 'ca').strpath) diff --git a/tests/test_utils.py b/tests/test_utils.py index 76a2b03b2e3e06d11af6465a88eb969be38c1cdb..13f3e0ab265e8f668fa9ab848965a3949e3c3968 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -75,7 +75,7 @@ def test_date_range_to_str(): def test_get_dns_names_returns_empty_list_if_no_dns_names(): - issuer_private_key, issuer_certificate = gimmecert.crypto.generate_ca_hierarchy('My Test', 1, gimmecert.crypto.KeyGenerator("rsa:2048"))[0] + issuer_private_key, issuer_certificate = gimmecert.crypto.generate_ca_hierarchy('My Test', 1, gimmecert.crypto.KeyGenerator("rsa", 2048))[0] private_key = gimmecert.crypto.generate_private_key() certificate = gimmecert.crypto.issue_client_certificate( @@ -91,7 +91,7 @@ def test_get_dns_names_returns_empty_list_if_no_dns_names(): def test_get_dns_names_returns_list_of_dns_names(): - issuer_private_key, issuer_certificate = gimmecert.crypto.generate_ca_hierarchy('My Test', 1, gimmecert.crypto.KeyGenerator("rsa:2048"))[0] + issuer_private_key, issuer_certificate = gimmecert.crypto.generate_ca_hierarchy('My Test', 1, gimmecert.crypto.KeyGenerator("rsa", 2048))[0] private_key = gimmecert.crypto.generate_private_key() certificate = gimmecert.crypto.issue_server_certificate(