#!/usr/bin/env python3 # Tests for the samba.key_credential_link module # # Copyright (C) Douglas Bagnall # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # import os import re import sys import time from itertools import permutations sys.path.insert(0, "bin/python") os.environ["PYTHONUNBUFFERED"] = "1" from samba import param from samba.auth import system_session from samba.credentials import Credentials from samba.samdb import SamDB, BinaryDn from samba import key_credential_link as kcl from samba.tests import TestCase, env_get_var_value class KeyCredentialLinkWrapperTests(TestCase): """Testing the samba.key_credential_link module""" maxDiff = 9999 # A bytestring that contains a DER encoded 2048 bit RSA public key # borrowed from the key_credential_link.test_unpack_der_key_material # test. _der_encoded_key = bytes.fromhex( "30 82 01 22" # Sequence 290 bytes, 2 elements "30 0d" # Sequence 13 bytes, 2 elements # OID 9 bytes, 1.2.840.113549.1.1.1 "06 09 2a 86 48 86 f7 0d 01 01 01" "05 00" # Null "03 82 01 0f 00" # Bit string, 2160 bits, 0 unused bits "30 82 01 0a" # Sequence 266 bytes, 2 elements "02 82 01 01" # Integer 2048 bit, 257 bytes # MODULUS is 257 bytes as it's most significant byte # is 0xbd 0b10111101 and has bit 8 set, # which DER Integer encoding uses as the sign bit, # so need the leading 00 byte to prevent the value # being interpreted as a negative integer "00 bd ae 45 8b 17 cd 3e 62 71 66 67 7f a2 46 c4" "47 78 79 f2 8c d4 2e 0c a0 90 1c f6 33 e1 94 89" "b9 44 15 e3 29 e7 b6 91 ca ab 7e c6 25 60 e3 7a" "c4 09 97 8a 4e 79 cb a6 1f f8 29 3f 8a 0d 45 58" "9b 0e bf a5 fa 1c a2 5e 31 a1 e7 ba 7e 17 62 03" "79 c0 07 48 11 8b fa 58 17 56 1a a1 62 d2 02 02" "2a 64 8d 8c 53 fa 28 7c 89 18 34 70 64 a7 08 10" "c9 3b 1b 2c 23 88 9c 35 50 78 d1 89 33 ce 82 b2" "84 f4 99 d8 3e 67 11 a1 5c 1a 64 b8 6a 3e e6 95" "2e 47 33 51 7e b7 62 b4 08 2c c4 87 52 00 9e 28" "f2 16 9f 1b c1 3a 93 6d a3 38 9b 34 39 88 85 ea" "38 ad c2 2b c3 7c 15 cb 8f 15 37 ed 88 62 5c 34" "75 6f b0 eb 5c 42 6a cd 03 cc 49 bc b4 78 14 e1" "5e 98 83 6f e7 19 a8 43 cb ca 07 b2 4e a4 36 60" "95 ac 6f e2 1d 3a 33 f6 0e 94 ae fb d2 ac 9f c2" "9f 5b 77 8f 46 3c ee 13 27 19 8e 68 71 27 3f 50" "59" "02 03 01 00 01" # Integer 3 bytes EXPONENT ) _kcl_prefix = ("B:772:000200002000012548213C16B0B6CEEB2A5C67B30744" "01BCBA6394A7C310713AF7314FEBCDF082200002B12C51275D" "E353EBD3117BA405F3A00131740B8938C572127DD6F045D63D" "43F326010330820122300D06092A864886F70D010101050003" "82010F003082010A0282010100BDAE458B17CD3E627166677F" "A246C4477879F28CD42E0CA0901CF633E19489B94415E329E7" "B691CAAB7EC62560E37AC409978A4E79CBA61FF8293F8A0D45" "589B0EBFA5FA1CA25E31A1E7BA7E17620379C00748118BFA58" "17561AA162D202022A648D8C53FA287C8918347064A70810C9" "3B1B2C23889C355078D18933CE82B284F499D83E6711A15C1A" "64B86A3EE6952E4733517EB762B4082CC48752009E28F2169F" "1BC13A936DA3389B34398885EA38ADC22BC37C15CB8F1537ED" "88625C34756FB0EB5C426ACD03CC49BCB47814E15E98836FE7" "19A843CBCA07B24EA4366095AC6FE21D3A33F60E94AEFBD2AC" "9FC29F5B778F463CEE1327198E6871273F5059020301000101" "000401080009805811AFEF07DC01:") _kcl_fingerprint = ("25:48:21:3C:16:B0:B6:CE:EB:2A:5C:67:B3:07:44:01:" "BC:BA:63:94:A7:C3:10:71:3A:F7:31:4F:EB:CD:F0:82") @classmethod def setUpClass(cls): super().setUpClass() server = os.environ['DC_SERVER'] host = f'ldap://{server}' lp = param.LoadParm() lp.load(os.environ['SMB_CONF_PATH']) creds = Credentials() creds.guess(lp) creds.set_username(env_get_var_value('DC_USERNAME')) creds.set_password(env_get_var_value('DC_PASSWORD')) cls.ldb = SamDB(host, credentials=creds, session_info=system_session(lp), lp=lp) cls.base_dn = cls.ldb.domain_dn() cls.schema_dn = cls.ldb.get_schema_basedn().get_linearized() cls.domain_sid = cls.ldb.get_domain_sid() def test_key_credential_link_description(self): before = time.strftime('%Y-%m-%d %H:%M:%S') k = kcl.create_key_credential_link(self.ldb, self.base_dn, self._der_encoded_key) after = time.strftime('%Y-%m-%d %H:%M:%S') level_0 = k.description(0) level_1 = k.description(1) level_2 = k.description(2) level_3 = k.description(3) level_4 = k.description(4) # we should get more text with increasing verbosity self.assertLessEqual(len(level_0), len(level_1)) self.assertLessEqual(len(level_1), len(level_2)) self.assertLessEqual(len(level_2), len(level_3)) # but verbosity maxes out at 3, so 4 and 3 are identical text self.assertEqual(level_4, level_3) lines = level_4.split('\n') self.assertEqual(lines[0], f'Link target: {self.base_dn}') self.assertRegex(lines[1], r'^Binary Dn: B:772:0002000020000125482[\dA-F]+:' f'{self.base_dn}$') key_entries = lines.index("Key entries:") key_properties = lines.index("RSA public key properties:") self.assertGreater(key_entries, 2) self.assertGreater(key_properties, key_entries + 1) entries = {} for line in lines[key_entries + 1 : key_properties]: m = re.match(r'^ ([^:]+):\s+(.+)$', line) self.assertIsNotNone(m) k, v = m.groups() entries[k] = v self.assertEqual(entries["Device GUID (DeviceId)"], "not found") self.assertEqual(entries["last logon (KeyApproximateLastLogonTimeStamp)"], "not found") self.assertLessEqual(before, entries["creation time (KeyCreationTime)"]) self.assertLessEqual(entries["creation time (KeyCreationTime)"], after) properties = {} for line in lines[key_properties + 1:]: m = re.match(r'^ ([^:]+):\s+(.+)$', line) self.assertIsNotNone(m) k, v = m.groups() properties[k] = v self.assertEqual(properties["key size"], "2048") # fingerprint should be the known constant, and the same as the entry. self.assertEqual(properties["fingerprint"], self._kcl_fingerprint) self.assertEqual(properties["fingerprint"], entries["key material fingerprint (KeyID)"]) def test_key_credential_link_fingerprint(self): k = kcl.create_key_credential_link(self.ldb, self.base_dn, self._der_encoded_key) self.assertEqual(k.fingerprint(), self._kcl_fingerprint) def test_key_credential_link_as_pem(self): k1 = kcl.create_key_credential_link(self.ldb, self.base_dn, self._der_encoded_key) pem1 = k1.as_pem() self.assertTrue(pem1.startswith('-----BEGIN PUBLIC KEY-----')) k2 = kcl.create_key_credential_link(self.ldb, self.base_dn, pem1.encode()) pem2 = k2.as_pem() self.assertEqual(pem1, pem2) # we can't quite assert that the binary part of k1 and k2 is # the same, because the creation date could have changed, but # we can exclude just the bits that might be affected by that. dnstr1 = str(k1) dnstr2 = str(k2) self.assertEqual(dnstr1[:90], dnstr2[:90]) # 90 to 154 is the hash of various fields, including time self.assertEqual(dnstr1[154:762], dnstr2[154:762]) # 762-778 is creation time in nttime self.assertEqual(dnstr1[778:], dnstr2[778:]) def test_create_key_credential_link_damaged(self): """self._der_encoded_key is a valid key, but if we make slightly altered versions we should see failures from kcl.create_key_credential_link. Not all changes we cause trouble (e.g. we could just be changing the modulus to match a different private key) but we try some that should. """ orig = self._der_encoded_key for start, end, replacement in [ (0, 1, b'a'), # bad start (-2, -1, b''), # eat a byte (7, 8, b'x'), # change OID interpretation (31, 32, b'\x88'), # make the modulus negative (0, len(orig), b' ' * len(orig)) # what we don't catch is invalid values qua key, # like a zero exponent # (-3, 9999, b'\x00\x00\x00'), # or adding extra bytes at the end # (-1, -1, b'xxxx') ]: der = orig[:start] + replacement + orig[end:] with self.assertRaises(ValueError): kcl.create_key_credential_link(self.ldb, self.base_dn, der) def _good_kcl(self): return f"{self._kcl_prefix}{self.base_dn}" def _bad_kcl(self, start, replacement, end=None): """We don't use create_key_credential_link(), because it tries to check the key is properly encoded. """ good = self._good_kcl() if end is None: end = start + len(replacement) bdn = BinaryDn(self.ldb, good) bdn.binary = bdn.binary[:start] + replacement + bdn.binary[end:] return str(bdn) def test_bad_key_credential_links(self): for start, replacement, end in [ (0, b'1234', None), # bad version 0x34333231 (100, b'', 1000), # truncated (10, b'', 1000), # truncated (1000, b'\x01' * 20, None), # extra bytes ]: bad_dn = self._bad_kcl(start, replacement, end) self.assertRaises(ValueError, kcl.KeyCredentialLinkDn, self.ldb, bad_dn) def test_bad_key_credential_links_one_byte_damage(self): """If we poke the wrong byte in certain places, the ndr pull should fail.""" for i in [3, 4, 5, 6, 39, 40, 41, 75, 76, 371, 372, 373, 375, 376, 377]: bad = self._bad_kcl(i, b'*') self.assertRaises(ValueError, kcl.KeyCredentialLinkDn, self.ldb, bad) def test_bad_key_credential_link_keys(self): """Parsing as a KeyCredentialLink is OK, but the resultant RSA key is broken.""" for i in [77, 78, 79, 80, 81, 82, 83, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 98, 99, 101, 102, 105, 106, 107, 108, 366, 367]: bad = self._bad_kcl(i, b'*') k = kcl.KeyCredentialLinkDn(self.ldb, bad) self.assertRaises(ValueError, k.as_pem) def test_good_key_credential_link_case_sensitivity(self): """Do kcl DNs compare and normalise as expected, in the same way as binary DNs?.""" mixed = self._good_kcl() lc = mixed[0] + mixed[1:].lower() # we need to keep initial 'B:' uc = mixed.upper() kcldn_mixed = kcl.KeyCredentialLinkDn(self.ldb, mixed) kcldn_lower = kcl.KeyCredentialLinkDn(self.ldb, lc) kcldn_upper = kcl.KeyCredentialLinkDn(self.ldb, uc) bdn_mixed = BinaryDn(self.ldb, mixed) bdn_lower = BinaryDn(self.ldb, lc) bdn_upper = BinaryDn(self.ldb, uc) dns = [kcldn_mixed, kcldn_lower, kcldn_upper, bdn_mixed, bdn_lower, bdn_upper] for a, b in permutations(dns, 2): self.assertEqual(a, b) self.assertEqual(a.binary, b.binary) self.assertEqual(a.prefix, b.prefix) self.assertEqual(a.dn, b.dn) strings = [str(x).upper() for x in dns] for s in strings: self.assertEqual(uc, s) # prefixes are normalised by str() prefixes = [str(x).rsplit(':', 1)[0] for x in dns] op = mixed.rsplit(':', 1)[0] for p in prefixes: self.assertEqual(p, op) if __name__ == "__main__": import unittest unittest.main()