xref: /curl/tests/http/testenv/certs.py (revision 57cc5233)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#***************************************************************************
4#                                  _   _ ____  _
5#  Project                     ___| | | |  _ \| |
6#                             / __| | | | |_) | |
7#                            | (__| |_| |  _ <| |___
8#                             \___|\___/|_| \_\_____|
9#
10# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
11#
12# This software is licensed as described in the file COPYING, which
13# you should have received as part of this distribution. The terms
14# are also available at https://curl.se/docs/copyright.html.
15#
16# You may opt to use, copy, modify, merge, publish, distribute and/or sell
17# copies of the Software, and permit persons to whom the Software is
18# furnished to do so, under the terms of the COPYING file.
19#
20# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
21# KIND, either express or implied.
22#
23# SPDX-License-Identifier: curl
24#
25###########################################################################
26#
27import ipaddress
28import os
29import re
30from datetime import timedelta, datetime, timezone
31from typing import List, Any, Optional
32
33from cryptography import x509
34from cryptography.hazmat.backends import default_backend
35from cryptography.hazmat.primitives import hashes
36from cryptography.hazmat.primitives.asymmetric import ec, rsa
37from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
38from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
39from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
40from cryptography.x509 import ExtendedKeyUsageOID, NameOID
41
42
43EC_SUPPORTED = {}
44EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
45    ec.SECP192R1,
46    ec.SECP224R1,
47    ec.SECP256R1,
48    ec.SECP384R1,
49]])
50
51
52def _private_key(key_type):
53    if isinstance(key_type, str):
54        key_type = key_type.upper()
55        m = re.match(r'^(RSA)?(\d+)$', key_type)
56        if m:
57            key_type = int(m.group(2))
58
59    if isinstance(key_type, int):
60        return rsa.generate_private_key(
61            public_exponent=65537,
62            key_size=key_type,
63            backend=default_backend()
64        )
65    if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
66        key_type = EC_SUPPORTED[key_type]
67    return ec.generate_private_key(
68        curve=key_type,
69        backend=default_backend()
70    )
71
72
73class CertificateSpec:
74
75    def __init__(self, name: Optional[str] = None,
76                 domains: Optional[List[str]] = None,
77                 email: Optional[str] = None,
78                 key_type: Optional[str] = None,
79                 single_file: bool = False,
80                 valid_from: timedelta = timedelta(days=-1),
81                 valid_to: timedelta = timedelta(days=89),
82                 client: bool = False,
83                 check_valid: bool = True,
84                 sub_specs: Optional[List['CertificateSpec']] = None):
85        self._name = name
86        self.domains = domains
87        self.client = client
88        self.email = email
89        self.key_type = key_type
90        self.single_file = single_file
91        self.valid_from = valid_from
92        self.valid_to = valid_to
93        self.sub_specs = sub_specs
94        self.check_valid = check_valid
95
96    @property
97    def name(self) -> Optional[str]:
98        if self._name:
99            return self._name
100        elif self.domains:
101            return self.domains[0]
102        return None
103
104    @property
105    def type(self) -> Optional[str]:
106        if self.domains and len(self.domains):
107            return "server"
108        elif self.client:
109            return "client"
110        elif self.name:
111            return "ca"
112        return None
113
114
115class Credentials:
116
117    def __init__(self,
118                 name: str,
119                 cert: Any,
120                 pkey: Any,
121                 issuer: Optional['Credentials'] = None):
122        self._name = name
123        self._cert = cert
124        self._pkey = pkey
125        self._issuer = issuer
126        self._cert_file = None
127        self._pkey_file = None
128        self._store = None
129        self._combined_file = None
130
131    @property
132    def name(self) -> str:
133        return self._name
134
135    @property
136    def subject(self) -> x509.Name:
137        return self._cert.subject
138
139    @property
140    def key_type(self):
141        if isinstance(self._pkey, RSAPrivateKey):
142            return f"rsa{self._pkey.key_size}"
143        elif isinstance(self._pkey, EllipticCurvePrivateKey):
144            return f"{self._pkey.curve.name}"
145        else:
146            raise Exception(f"unknown key type: {self._pkey}")
147
148    @property
149    def private_key(self) -> Any:
150        return self._pkey
151
152    @property
153    def certificate(self) -> Any:
154        return self._cert
155
156    @property
157    def cert_pem(self) -> bytes:
158        return self._cert.public_bytes(Encoding.PEM)
159
160    @property
161    def pkey_pem(self) -> bytes:
162        return self._pkey.private_bytes(
163            Encoding.PEM,
164            PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
165            NoEncryption())
166
167    @property
168    def issuer(self) -> Optional['Credentials']:
169        return self._issuer
170
171    def set_store(self, store: 'CertStore'):
172        self._store = store
173
174    def set_files(self, cert_file: str, pkey_file: Optional[str] = None,
175                  combined_file: Optional[str] = None):
176        self._cert_file = cert_file
177        self._pkey_file = pkey_file
178        self._combined_file = combined_file
179
180    @property
181    def cert_file(self) -> str:
182        return self._cert_file
183
184    @property
185    def pkey_file(self) -> Optional[str]:
186        return self._pkey_file
187
188    @property
189    def combined_file(self) -> Optional[str]:
190        return self._combined_file
191
192    def get_first(self, name) -> Optional['Credentials']:
193        creds = self._store.get_credentials_for_name(name) if self._store else []
194        return creds[0] if len(creds) else None
195
196    def get_credentials_for_name(self, name) -> List['Credentials']:
197        return self._store.get_credentials_for_name(name) if self._store else []
198
199    def issue_certs(self, specs: List[CertificateSpec],
200                    chain: Optional[List['Credentials']] = None) -> List['Credentials']:
201        return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
202
203    def issue_cert(self, spec: CertificateSpec,
204                   chain: Optional[List['Credentials']] = None) -> 'Credentials':
205        key_type = spec.key_type if spec.key_type else self.key_type
206        creds = None
207        if self._store:
208            creds = self._store.load_credentials(
209                name=spec.name, key_type=key_type, single_file=spec.single_file,
210                issuer=self, check_valid=spec.check_valid)
211        if creds is None:
212            creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
213                                              valid_from=spec.valid_from, valid_to=spec.valid_to)
214            if self._store:
215                self._store.save(creds, single_file=spec.single_file)
216                if spec.type == "ca":
217                    self._store.save_chain(creds, "ca", with_root=True)
218
219        if spec.sub_specs:
220            if self._store:
221                sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
222                creds.set_store(sub_store)
223            subchain = chain.copy() if chain else []
224            subchain.append(self)
225            creds.issue_certs(spec.sub_specs, chain=subchain)
226        return creds
227
228
229class CertStore:
230
231    def __init__(self, fpath: str):
232        self._store_dir = fpath
233        if not os.path.exists(self._store_dir):
234            os.makedirs(self._store_dir)
235        self._creds_by_name = {}
236
237    @property
238    def path(self) -> str:
239        return self._store_dir
240
241    def save(self, creds: Credentials, name: Optional[str] = None,
242             chain: Optional[List[Credentials]] = None,
243             single_file: bool = False) -> None:
244        name = name if name is not None else creds.name
245        cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
246        pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
247        comb_file = self.get_combined_file(name=name, key_type=creds.key_type)
248        if single_file:
249            pkey_file = None
250        with open(cert_file, "wb") as fd:
251            fd.write(creds.cert_pem)
252            if chain:
253                for c in chain:
254                    fd.write(c.cert_pem)
255            if pkey_file is None:
256                fd.write(creds.pkey_pem)
257        if pkey_file is not None:
258            with open(pkey_file, "wb") as fd:
259                fd.write(creds.pkey_pem)
260        with open(comb_file, "wb") as fd:
261            fd.write(creds.cert_pem)
262            if chain:
263                for c in chain:
264                    fd.write(c.cert_pem)
265            fd.write(creds.pkey_pem)
266        creds.set_files(cert_file, pkey_file, comb_file)
267        self._add_credentials(name, creds)
268
269    def save_chain(self, creds: Credentials, infix: str, with_root=False):
270        name = creds.name
271        chain = [creds]
272        while creds.issuer is not None:
273            creds = creds.issuer
274            chain.append(creds)
275        if not with_root and len(chain) > 1:
276            chain = chain[:-1]
277        chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem')
278        with open(chain_file, "wb") as fd:
279            for c in chain:
280                fd.write(c.cert_pem)
281
282    def _add_credentials(self, name: str, creds: Credentials):
283        if name not in self._creds_by_name:
284            self._creds_by_name[name] = []
285        self._creds_by_name[name].append(creds)
286
287    def get_credentials_for_name(self, name) -> List[Credentials]:
288        return self._creds_by_name[name] if name in self._creds_by_name else []
289
290    def get_cert_file(self, name: str, key_type=None) -> str:
291        key_infix = ".{0}".format(key_type) if key_type is not None else ""
292        return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
293
294    def get_pkey_file(self, name: str, key_type=None) -> str:
295        key_infix = ".{0}".format(key_type) if key_type is not None else ""
296        return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
297
298    def get_combined_file(self, name: str, key_type=None) -> str:
299        return os.path.join(self._store_dir, f'{name}.pem')
300
301    def load_pem_cert(self, fpath: str) -> x509.Certificate:
302        with open(fpath) as fd:
303            return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
304
305    def load_pem_pkey(self, fpath: str):
306        with open(fpath) as fd:
307            return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
308
309    def load_credentials(self, name: str, key_type=None,
310                         single_file: bool = False,
311                         issuer: Optional[Credentials] = None,
312                         check_valid: bool = False):
313        cert_file = self.get_cert_file(name=name, key_type=key_type)
314        pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
315        comb_file = self.get_combined_file(name=name, key_type=key_type)
316        if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
317            cert = self.load_pem_cert(cert_file)
318            pkey = self.load_pem_pkey(pkey_file)
319            try:
320                now = datetime.now(tz=timezone.utc)
321                if check_valid and \
322                    ((cert.not_valid_after_utc < now) or
323                     (cert.not_valid_before_utc > now)):
324                    return None
325            except AttributeError:  # older python
326                now = datetime.now()
327                if check_valid and \
328                        ((cert.not_valid_after < now) or
329                         (cert.not_valid_before > now)):
330                    return None
331            creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
332            creds.set_store(self)
333            creds.set_files(cert_file, pkey_file, comb_file)
334            self._add_credentials(name, creds)
335            return creds
336        return None
337
338
339class TestCA:
340
341    @classmethod
342    def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
343        store = CertStore(fpath=store_dir)
344        creds = store.load_credentials(name="ca", key_type=key_type, issuer=None)
345        if creds is None:
346            creds = TestCA._make_ca_credentials(name=name, key_type=key_type)
347            store.save(creds, name="ca")
348            creds.set_store(store)
349        return creds
350
351    @staticmethod
352    def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
353                           valid_from: timedelta = timedelta(days=-1),
354                           valid_to: timedelta = timedelta(days=89),
355                           ) -> Credentials:
356        """
357        Create a certificate signed by this CA for the given domains.
358
359        :returns: the certificate and private key PEM file paths
360        """
361        if spec.domains and len(spec.domains):
362            creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains,
363                                                    issuer=issuer, valid_from=valid_from,
364                                                    valid_to=valid_to, key_type=key_type)
365        elif spec.client:
366            creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer,
367                                                    email=spec.email, valid_from=valid_from,
368                                                    valid_to=valid_to, key_type=key_type)
369        elif spec.name:
370            creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer,
371                                                valid_from=valid_from, valid_to=valid_to,
372                                                key_type=key_type)
373        else:
374            raise Exception(f"unrecognized certificate specification: {spec}")
375        return creds
376
377    @staticmethod
378    def _make_x509_name(org_name: Optional[str] = None, common_name: Optional[str] = None, parent: x509.Name = None) -> x509.Name:
379        name_pieces = []
380        if org_name:
381            oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
382            name_pieces.append(x509.NameAttribute(oid, org_name))
383        elif common_name:
384            name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
385        if parent:
386            name_pieces.extend(list(parent))
387        return x509.Name(name_pieces)
388
389    @staticmethod
390    def _make_csr(
391            subject: x509.Name,
392            pkey: Any,
393            issuer_subject: Optional[Credentials],
394            valid_from_delta: Optional[timedelta] = None,
395            valid_until_delta: Optional[timedelta] = None
396    ):
397        pubkey = pkey.public_key()
398        issuer_subject = issuer_subject if issuer_subject is not None else subject
399
400        valid_from = datetime.now()
401        if valid_until_delta is not None:
402            valid_from += valid_from_delta
403        valid_until = datetime.now()
404        if valid_until_delta is not None:
405            valid_until += valid_until_delta
406
407        return (
408            x509.CertificateBuilder()
409            .subject_name(subject)
410            .issuer_name(issuer_subject)
411            .public_key(pubkey)
412            .not_valid_before(valid_from)
413            .not_valid_after(valid_until)
414            .serial_number(x509.random_serial_number())
415            .add_extension(
416                x509.SubjectKeyIdentifier.from_public_key(pubkey),
417                critical=False,
418            )
419        )
420
421    @staticmethod
422    def _add_ca_usages(csr: Any) -> Any:
423        return csr.add_extension(
424            x509.BasicConstraints(ca=True, path_length=9),
425            critical=True,
426        ).add_extension(
427            x509.KeyUsage(
428                digital_signature=True,
429                content_commitment=False,
430                key_encipherment=False,
431                data_encipherment=False,
432                key_agreement=False,
433                key_cert_sign=True,
434                crl_sign=True,
435                encipher_only=False,
436                decipher_only=False),
437            critical=True
438        ).add_extension(
439            x509.ExtendedKeyUsage([
440                ExtendedKeyUsageOID.CLIENT_AUTH,
441                ExtendedKeyUsageOID.SERVER_AUTH,
442                ExtendedKeyUsageOID.CODE_SIGNING,
443            ]),
444            critical=True
445        )
446
447    @staticmethod
448    def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
449        names = []
450        for name in domains:
451            try:
452                names.append(x509.IPAddress(ipaddress.ip_address(name)))
453            # TODO: specify specific exceptions here
454            except:  # noqa: E722
455                names.append(x509.DNSName(name))
456
457        return csr.add_extension(
458            x509.BasicConstraints(ca=False, path_length=None),
459            critical=True,
460        ).add_extension(
461            x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
462                issuer.certificate.extensions.get_extension_for_class(
463                    x509.SubjectKeyIdentifier).value),
464            critical=False
465        ).add_extension(
466            x509.SubjectAlternativeName(names), critical=True,
467        ).add_extension(
468            x509.ExtendedKeyUsage([
469                ExtendedKeyUsageOID.SERVER_AUTH,
470            ]),
471            critical=False
472        )
473
474    @staticmethod
475    def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: Optional[str] = None) -> Any:
476        cert = csr.add_extension(
477            x509.BasicConstraints(ca=False, path_length=None),
478            critical=True,
479        ).add_extension(
480            x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
481                issuer.certificate.extensions.get_extension_for_class(
482                    x509.SubjectKeyIdentifier).value),
483            critical=False
484        )
485        if rfc82name:
486            cert.add_extension(
487                x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
488                critical=True,
489            )
490        cert.add_extension(
491            x509.ExtendedKeyUsage([
492                ExtendedKeyUsageOID.CLIENT_AUTH,
493            ]),
494            critical=True
495        )
496        return cert
497
498    @staticmethod
499    def _make_ca_credentials(name, key_type: Any,
500                             issuer: Optional[Credentials] = None,
501                             valid_from: timedelta = timedelta(days=-1),
502                             valid_to: timedelta = timedelta(days=89),
503                             ) -> Credentials:
504        pkey = _private_key(key_type=key_type)
505        if issuer is not None:
506            issuer_subject = issuer.certificate.subject
507            issuer_key = issuer.private_key
508        else:
509            issuer_subject = None
510            issuer_key = pkey
511        subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
512        csr = TestCA._make_csr(subject=subject,
513                               issuer_subject=issuer_subject, pkey=pkey,
514                               valid_from_delta=valid_from, valid_until_delta=valid_to)
515        csr = TestCA._add_ca_usages(csr)
516        cert = csr.sign(private_key=issuer_key,
517                        algorithm=hashes.SHA256(),
518                        backend=default_backend())
519        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
520
521    @staticmethod
522    def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
523                                 key_type: Any,
524                                 valid_from: timedelta = timedelta(days=-1),
525                                 valid_to: timedelta = timedelta(days=89),
526                                 ) -> Credentials:
527        pkey = _private_key(key_type=key_type)
528        subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
529        csr = TestCA._make_csr(subject=subject,
530                               issuer_subject=issuer.certificate.subject, pkey=pkey,
531                               valid_from_delta=valid_from, valid_until_delta=valid_to)
532        csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
533        cert = csr.sign(private_key=issuer.private_key,
534                        algorithm=hashes.SHA256(),
535                        backend=default_backend())
536        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
537
538    @staticmethod
539    def _make_client_credentials(name: str,
540                                 issuer: Credentials, email: Optional[str],
541                                 key_type: Any,
542                                 valid_from: timedelta = timedelta(days=-1),
543                                 valid_to: timedelta = timedelta(days=89),
544                                 ) -> Credentials:
545        pkey = _private_key(key_type=key_type)
546        subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
547        csr = TestCA._make_csr(subject=subject,
548                               issuer_subject=issuer.certificate.subject, pkey=pkey,
549                               valid_from_delta=valid_from, valid_until_delta=valid_to)
550        csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
551        cert = csr.sign(private_key=issuer.private_key,
552                        algorithm=hashes.SHA256(),
553                        backend=default_backend())
554        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
555