xref: /curl/tests/http/testenv/certs.py (revision c177e194)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#***************************************************************************
4#                                  _   _ ____  _
5#  Project                     ___| | | |  _ \| |
6#                             / __| | | | |_) | |
7#                            | (__| |_| |  _ <| |___
8#                             \___|\___/|_| \_\_____|
9#
10# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
11#
12# This software is licensed as described in the file COPYING, which
13# you should have received as part of this distribution. The terms
14# are also available at https://curl.se/docs/copyright.html.
15#
16# You may opt to use, copy, modify, merge, publish, distribute and/or sell
17# copies of the Software, and permit persons to whom the Software is
18# furnished to do so, under the terms of the COPYING file.
19#
20# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
21# KIND, either express or implied.
22#
23# SPDX-License-Identifier: curl
24#
25###########################################################################
26#
27import ipaddress
28import os
29import re
30from datetime import timedelta, datetime
31from typing import List, Any, Optional
32
33from cryptography import x509
34from cryptography.hazmat.backends import default_backend
35from cryptography.hazmat.primitives import hashes
36from cryptography.hazmat.primitives.asymmetric import ec, rsa
37from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
38from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
39from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
40from cryptography.x509 import ExtendedKeyUsageOID, NameOID
41
42
43EC_SUPPORTED = {}
44EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
45    ec.SECP192R1,
46    ec.SECP224R1,
47    ec.SECP256R1,
48    ec.SECP384R1,
49]])
50
51
52def _private_key(key_type):
53    if isinstance(key_type, str):
54        key_type = key_type.upper()
55        m = re.match(r'^(RSA)?(\d+)$', key_type)
56        if m:
57            key_type = int(m.group(2))
58
59    if isinstance(key_type, int):
60        return rsa.generate_private_key(
61            public_exponent=65537,
62            key_size=key_type,
63            backend=default_backend()
64        )
65    if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
66        key_type = EC_SUPPORTED[key_type]
67    return ec.generate_private_key(
68        curve=key_type,
69        backend=default_backend()
70    )
71
72
73class CertificateSpec:
74
75    def __init__(self, name: Optional[str] = None,
76                 domains: Optional[List[str]] = None,
77                 email: Optional[str] = None,
78                 key_type: Optional[str] = None,
79                 single_file: bool = False,
80                 valid_from: timedelta = timedelta(days=-1),
81                 valid_to: timedelta = timedelta(days=89),
82                 client: bool = False,
83                 check_valid: bool = True,
84                 sub_specs: Optional[List['CertificateSpec']] = None):
85        self._name = name
86        self.domains = domains
87        self.client = client
88        self.email = email
89        self.key_type = key_type
90        self.single_file = single_file
91        self.valid_from = valid_from
92        self.valid_to = valid_to
93        self.sub_specs = sub_specs
94        self.check_valid = check_valid
95
96    @property
97    def name(self) -> Optional[str]:
98        if self._name:
99            return self._name
100        elif self.domains:
101            return self.domains[0]
102        return None
103
104    @property
105    def type(self) -> Optional[str]:
106        if self.domains and len(self.domains):
107            return "server"
108        elif self.client:
109            return "client"
110        elif self.name:
111            return "ca"
112        return None
113
114
115class Credentials:
116
117    def __init__(self,
118                 name: str,
119                 cert: Any,
120                 pkey: Any,
121                 issuer: Optional['Credentials'] = None):
122        self._name = name
123        self._cert = cert
124        self._pkey = pkey
125        self._issuer = issuer
126        self._cert_file = None
127        self._pkey_file = None
128        self._store = None
129
130    @property
131    def name(self) -> str:
132        return self._name
133
134    @property
135    def subject(self) -> x509.Name:
136        return self._cert.subject
137
138    @property
139    def key_type(self):
140        if isinstance(self._pkey, RSAPrivateKey):
141            return f"rsa{self._pkey.key_size}"
142        elif isinstance(self._pkey, EllipticCurvePrivateKey):
143            return f"{self._pkey.curve.name}"
144        else:
145            raise Exception(f"unknown key type: {self._pkey}")
146
147    @property
148    def private_key(self) -> Any:
149        return self._pkey
150
151    @property
152    def certificate(self) -> Any:
153        return self._cert
154
155    @property
156    def cert_pem(self) -> bytes:
157        return self._cert.public_bytes(Encoding.PEM)
158
159    @property
160    def pkey_pem(self) -> bytes:
161        return self._pkey.private_bytes(
162            Encoding.PEM,
163            PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
164            NoEncryption())
165
166    @property
167    def issuer(self) -> Optional['Credentials']:
168        return self._issuer
169
170    def set_store(self, store: 'CertStore'):
171        self._store = store
172
173    def set_files(self, cert_file: str, pkey_file: Optional[str] = None,
174                  combined_file: Optional[str] = None):
175        self._cert_file = cert_file
176        self._pkey_file = pkey_file
177        self._combined_file = combined_file
178
179    @property
180    def cert_file(self) -> str:
181        return self._cert_file
182
183    @property
184    def pkey_file(self) -> Optional[str]:
185        return self._pkey_file
186
187    @property
188    def combined_file(self) -> Optional[str]:
189        return self._combined_file
190
191    def get_first(self, name) -> Optional['Credentials']:
192        creds = self._store.get_credentials_for_name(name) if self._store else []
193        return creds[0] if len(creds) else None
194
195    def get_credentials_for_name(self, name) -> List['Credentials']:
196        return self._store.get_credentials_for_name(name) if self._store else []
197
198    def issue_certs(self, specs: List[CertificateSpec],
199                    chain: Optional[List['Credentials']] = None) -> List['Credentials']:
200        return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
201
202    def issue_cert(self, spec: CertificateSpec,
203                   chain: Optional[List['Credentials']] = None) -> 'Credentials':
204        key_type = spec.key_type if spec.key_type else self.key_type
205        creds = None
206        if self._store:
207            creds = self._store.load_credentials(
208                name=spec.name, key_type=key_type, single_file=spec.single_file,
209                issuer=self, check_valid=spec.check_valid)
210        if creds is None:
211            creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
212                                              valid_from=spec.valid_from, valid_to=spec.valid_to)
213            if self._store:
214                self._store.save(creds, single_file=spec.single_file)
215                if spec.type == "ca":
216                    self._store.save_chain(creds, "ca", with_root=True)
217
218        if spec.sub_specs:
219            if self._store:
220                sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
221                creds.set_store(sub_store)
222            subchain = chain.copy() if chain else []
223            subchain.append(self)
224            creds.issue_certs(spec.sub_specs, chain=subchain)
225        return creds
226
227
228class CertStore:
229
230    def __init__(self, fpath: str):
231        self._store_dir = fpath
232        if not os.path.exists(self._store_dir):
233            os.makedirs(self._store_dir)
234        self._creds_by_name = {}
235
236    @property
237    def path(self) -> str:
238        return self._store_dir
239
240    def save(self, creds: Credentials, name: Optional[str] = None,
241             chain: Optional[List[Credentials]] = None,
242             single_file: bool = False) -> None:
243        name = name if name is not None else creds.name
244        cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
245        pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
246        comb_file = self.get_combined_file(name=name, key_type=creds.key_type)
247        if single_file:
248            pkey_file = None
249        with open(cert_file, "wb") as fd:
250            fd.write(creds.cert_pem)
251            if chain:
252                for c in chain:
253                    fd.write(c.cert_pem)
254            if pkey_file is None:
255                fd.write(creds.pkey_pem)
256        if pkey_file is not None:
257            with open(pkey_file, "wb") as fd:
258                fd.write(creds.pkey_pem)
259        with open(comb_file, "wb") as fd:
260            fd.write(creds.cert_pem)
261            if chain:
262                for c in chain:
263                    fd.write(c.cert_pem)
264            fd.write(creds.pkey_pem)
265        creds.set_files(cert_file, pkey_file, comb_file)
266        self._add_credentials(name, creds)
267
268    def save_chain(self, creds: Credentials, infix: str, with_root=False):
269        name = creds.name
270        chain = [creds]
271        while creds.issuer is not None:
272            creds = creds.issuer
273            chain.append(creds)
274        if not with_root and len(chain) > 1:
275            chain = chain[:-1]
276        chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem')
277        with open(chain_file, "wb") as fd:
278            for c in chain:
279                fd.write(c.cert_pem)
280
281    def _add_credentials(self, name: str, creds: Credentials):
282        if name not in self._creds_by_name:
283            self._creds_by_name[name] = []
284        self._creds_by_name[name].append(creds)
285
286    def get_credentials_for_name(self, name) -> List[Credentials]:
287        return self._creds_by_name[name] if name in self._creds_by_name else []
288
289    def get_cert_file(self, name: str, key_type=None) -> str:
290        key_infix = ".{0}".format(key_type) if key_type is not None else ""
291        return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
292
293    def get_pkey_file(self, name: str, key_type=None) -> str:
294        key_infix = ".{0}".format(key_type) if key_type is not None else ""
295        return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
296
297    def get_combined_file(self, name: str, key_type=None) -> str:
298        return os.path.join(self._store_dir, f'{name}.pem')
299
300    def load_pem_cert(self, fpath: str) -> x509.Certificate:
301        with open(fpath) as fd:
302            return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
303
304    def load_pem_pkey(self, fpath: str):
305        with open(fpath) as fd:
306            return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
307
308    def load_credentials(self, name: str, key_type=None,
309                         single_file: bool = False,
310                         issuer: Optional[Credentials] = None,
311                         check_valid: bool = False):
312        cert_file = self.get_cert_file(name=name, key_type=key_type)
313        pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
314        comb_file = self.get_combined_file(name=name, key_type=key_type)
315        if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
316            cert = self.load_pem_cert(cert_file)
317            pkey = self.load_pem_pkey(pkey_file)
318            if check_valid and \
319                ((cert.not_valid_after < datetime.now()) or
320                 (cert.not_valid_before > datetime.now())):
321                return None
322            creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
323            creds.set_store(self)
324            creds.set_files(cert_file, pkey_file, comb_file)
325            self._add_credentials(name, creds)
326            return creds
327        return None
328
329
330class TestCA:
331
332    @classmethod
333    def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
334        store = CertStore(fpath=store_dir)
335        creds = store.load_credentials(name="ca", key_type=key_type, issuer=None)
336        if creds is None:
337            creds = TestCA._make_ca_credentials(name=name, key_type=key_type)
338            store.save(creds, name="ca")
339            creds.set_store(store)
340        return creds
341
342    @staticmethod
343    def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
344                           valid_from: timedelta = timedelta(days=-1),
345                           valid_to: timedelta = timedelta(days=89),
346                           ) -> Credentials:
347        """Create a certificate signed by this CA for the given domains.
348        :returns: the certificate and private key PEM file paths
349        """
350        if spec.domains and len(spec.domains):
351            creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains,
352                                                    issuer=issuer, valid_from=valid_from,
353                                                    valid_to=valid_to, key_type=key_type)
354        elif spec.client:
355            creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer,
356                                                    email=spec.email, valid_from=valid_from,
357                                                    valid_to=valid_to, key_type=key_type)
358        elif spec.name:
359            creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer,
360                                                valid_from=valid_from, valid_to=valid_to,
361                                                key_type=key_type)
362        else:
363            raise Exception(f"unrecognized certificate specification: {spec}")
364        return creds
365
366    @staticmethod
367    def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name:
368        name_pieces = []
369        if org_name:
370            oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
371            name_pieces.append(x509.NameAttribute(oid, org_name))
372        elif common_name:
373            name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
374        if parent:
375            name_pieces.extend([rdn for rdn in parent])
376        return x509.Name(name_pieces)
377
378    @staticmethod
379    def _make_csr(
380            subject: x509.Name,
381            pkey: Any,
382            issuer_subject: Optional[Credentials],
383            valid_from_delta: timedelta = None,
384            valid_until_delta: timedelta = None
385    ):
386        pubkey = pkey.public_key()
387        issuer_subject = issuer_subject if issuer_subject is not None else subject
388
389        valid_from = datetime.now()
390        if valid_until_delta is not None:
391            valid_from += valid_from_delta
392        valid_until = datetime.now()
393        if valid_until_delta is not None:
394            valid_until += valid_until_delta
395
396        return (
397            x509.CertificateBuilder()
398            .subject_name(subject)
399            .issuer_name(issuer_subject)
400            .public_key(pubkey)
401            .not_valid_before(valid_from)
402            .not_valid_after(valid_until)
403            .serial_number(x509.random_serial_number())
404            .add_extension(
405                x509.SubjectKeyIdentifier.from_public_key(pubkey),
406                critical=False,
407            )
408        )
409
410    @staticmethod
411    def _add_ca_usages(csr: Any) -> Any:
412        return csr.add_extension(
413            x509.BasicConstraints(ca=True, path_length=9),
414            critical=True,
415        ).add_extension(
416            x509.KeyUsage(
417                digital_signature=True,
418                content_commitment=False,
419                key_encipherment=False,
420                data_encipherment=False,
421                key_agreement=False,
422                key_cert_sign=True,
423                crl_sign=True,
424                encipher_only=False,
425                decipher_only=False),
426            critical=True
427        ).add_extension(
428            x509.ExtendedKeyUsage([
429                ExtendedKeyUsageOID.CLIENT_AUTH,
430                ExtendedKeyUsageOID.SERVER_AUTH,
431                ExtendedKeyUsageOID.CODE_SIGNING,
432            ]),
433            critical=True
434        )
435
436    @staticmethod
437    def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
438        names = []
439        for name in domains:
440            try:
441                names.append(x509.IPAddress(ipaddress.ip_address(name)))
442            except:
443                names.append(x509.DNSName(name))
444
445        return csr.add_extension(
446            x509.BasicConstraints(ca=False, path_length=None),
447            critical=True,
448        ).add_extension(
449            x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
450                issuer.certificate.extensions.get_extension_for_class(
451                    x509.SubjectKeyIdentifier).value),
452            critical=False
453        ).add_extension(
454            x509.SubjectAlternativeName(names), critical=True,
455        ).add_extension(
456            x509.ExtendedKeyUsage([
457                ExtendedKeyUsageOID.SERVER_AUTH,
458            ]),
459            critical=False
460        )
461
462    @staticmethod
463    def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any:
464        cert = csr.add_extension(
465            x509.BasicConstraints(ca=False, path_length=None),
466            critical=True,
467        ).add_extension(
468            x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
469                issuer.certificate.extensions.get_extension_for_class(
470                    x509.SubjectKeyIdentifier).value),
471            critical=False
472        )
473        if rfc82name:
474            cert.add_extension(
475                x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
476                critical=True,
477            )
478        cert.add_extension(
479            x509.ExtendedKeyUsage([
480                ExtendedKeyUsageOID.CLIENT_AUTH,
481            ]),
482            critical=True
483        )
484        return cert
485
486    @staticmethod
487    def _make_ca_credentials(name, key_type: Any,
488                             issuer: Credentials = None,
489                             valid_from: timedelta = timedelta(days=-1),
490                             valid_to: timedelta = timedelta(days=89),
491                             ) -> Credentials:
492        pkey = _private_key(key_type=key_type)
493        if issuer is not None:
494            issuer_subject = issuer.certificate.subject
495            issuer_key = issuer.private_key
496        else:
497            issuer_subject = None
498            issuer_key = pkey
499        subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
500        csr = TestCA._make_csr(subject=subject,
501                               issuer_subject=issuer_subject, pkey=pkey,
502                               valid_from_delta=valid_from, valid_until_delta=valid_to)
503        csr = TestCA._add_ca_usages(csr)
504        cert = csr.sign(private_key=issuer_key,
505                        algorithm=hashes.SHA256(),
506                        backend=default_backend())
507        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
508
509    @staticmethod
510    def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
511                                 key_type: Any,
512                                 valid_from: timedelta = timedelta(days=-1),
513                                 valid_to: timedelta = timedelta(days=89),
514                                 ) -> Credentials:
515        name = name
516        pkey = _private_key(key_type=key_type)
517        subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
518        csr = TestCA._make_csr(subject=subject,
519                               issuer_subject=issuer.certificate.subject, pkey=pkey,
520                               valid_from_delta=valid_from, valid_until_delta=valid_to)
521        csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
522        cert = csr.sign(private_key=issuer.private_key,
523                        algorithm=hashes.SHA256(),
524                        backend=default_backend())
525        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
526
527    @staticmethod
528    def _make_client_credentials(name: str,
529                                 issuer: Credentials, email: Optional[str],
530                                 key_type: Any,
531                                 valid_from: timedelta = timedelta(days=-1),
532                                 valid_to: timedelta = timedelta(days=89),
533                                 ) -> Credentials:
534        pkey = _private_key(key_type=key_type)
535        subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
536        csr = TestCA._make_csr(subject=subject,
537                               issuer_subject=issuer.certificate.subject, pkey=pkey,
538                               valid_from_delta=valid_from, valid_until_delta=valid_to)
539        csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
540        cert = csr.sign(private_key=issuer.private_key,
541                        algorithm=hashes.SHA256(),
542                        backend=default_backend())
543        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
544