xref: /curl/tests/http/testenv/certs.py (revision 34555724)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#***************************************************************************
4#                                  _   _ ____  _
5#  Project                     ___| | | |  _ \| |
6#                             / __| | | | |_) | |
7#                            | (__| |_| |  _ <| |___
8#                             \___|\___/|_| \_\_____|
9#
10# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
11#
12# This software is licensed as described in the file COPYING, which
13# you should have received as part of this distribution. The terms
14# are also available at https://curl.se/docs/copyright.html.
15#
16# You may opt to use, copy, modify, merge, publish, distribute and/or sell
17# copies of the Software, and permit persons to whom the Software is
18# furnished to do so, under the terms of the COPYING file.
19#
20# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
21# KIND, either express or implied.
22#
23# SPDX-License-Identifier: curl
24#
25###########################################################################
26#
27import ipaddress
28import os
29import re
30from datetime import timedelta, datetime, timezone
31from typing import List, Any, Optional
32
33from cryptography import x509
34from cryptography.hazmat.backends import default_backend
35from cryptography.hazmat.primitives import hashes
36from cryptography.hazmat.primitives.asymmetric import ec, rsa
37from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
38from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
39from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
40from cryptography.x509 import ExtendedKeyUsageOID, NameOID
41
42
43EC_SUPPORTED = {}
44EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
45    ec.SECP192R1,
46    ec.SECP224R1,
47    ec.SECP256R1,
48    ec.SECP384R1,
49]])
50
51
52def _private_key(key_type):
53    if isinstance(key_type, str):
54        key_type = key_type.upper()
55        m = re.match(r'^(RSA)?(\d+)$', key_type)
56        if m:
57            key_type = int(m.group(2))
58
59    if isinstance(key_type, int):
60        return rsa.generate_private_key(
61            public_exponent=65537,
62            key_size=key_type,
63            backend=default_backend()
64        )
65    if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
66        key_type = EC_SUPPORTED[key_type]
67    return ec.generate_private_key(
68        curve=key_type,
69        backend=default_backend()
70    )
71
72
73class CertificateSpec:
74
75    def __init__(self, name: Optional[str] = None,
76                 domains: Optional[List[str]] = None,
77                 email: Optional[str] = None,
78                 key_type: Optional[str] = None,
79                 single_file: bool = False,
80                 valid_from: timedelta = timedelta(days=-1),
81                 valid_to: timedelta = timedelta(days=89),
82                 client: bool = False,
83                 check_valid: bool = True,
84                 sub_specs: Optional[List['CertificateSpec']] = None):
85        self._name = name
86        self.domains = domains
87        self.client = client
88        self.email = email
89        self.key_type = key_type
90        self.single_file = single_file
91        self.valid_from = valid_from
92        self.valid_to = valid_to
93        self.sub_specs = sub_specs
94        self.check_valid = check_valid
95
96    @property
97    def name(self) -> Optional[str]:
98        if self._name:
99            return self._name
100        elif self.domains:
101            return self.domains[0]
102        return None
103
104    @property
105    def type(self) -> Optional[str]:
106        if self.domains and len(self.domains):
107            return "server"
108        elif self.client:
109            return "client"
110        elif self.name:
111            return "ca"
112        return None
113
114
115class Credentials:
116
117    def __init__(self,
118                 name: str,
119                 cert: Any,
120                 pkey: Any,
121                 issuer: Optional['Credentials'] = None):
122        self._name = name
123        self._cert = cert
124        self._pkey = pkey
125        self._issuer = issuer
126        self._cert_file = None
127        self._pkey_file = None
128        self._store = None
129
130    @property
131    def name(self) -> str:
132        return self._name
133
134    @property
135    def subject(self) -> x509.Name:
136        return self._cert.subject
137
138    @property
139    def key_type(self):
140        if isinstance(self._pkey, RSAPrivateKey):
141            return f"rsa{self._pkey.key_size}"
142        elif isinstance(self._pkey, EllipticCurvePrivateKey):
143            return f"{self._pkey.curve.name}"
144        else:
145            raise Exception(f"unknown key type: {self._pkey}")
146
147    @property
148    def private_key(self) -> Any:
149        return self._pkey
150
151    @property
152    def certificate(self) -> Any:
153        return self._cert
154
155    @property
156    def cert_pem(self) -> bytes:
157        return self._cert.public_bytes(Encoding.PEM)
158
159    @property
160    def pkey_pem(self) -> bytes:
161        return self._pkey.private_bytes(
162            Encoding.PEM,
163            PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
164            NoEncryption())
165
166    @property
167    def issuer(self) -> Optional['Credentials']:
168        return self._issuer
169
170    def set_store(self, store: 'CertStore'):
171        self._store = store
172
173    def set_files(self, cert_file: str, pkey_file: Optional[str] = None,
174                  combined_file: Optional[str] = None):
175        self._cert_file = cert_file
176        self._pkey_file = pkey_file
177        self._combined_file = combined_file
178
179    @property
180    def cert_file(self) -> str:
181        return self._cert_file
182
183    @property
184    def pkey_file(self) -> Optional[str]:
185        return self._pkey_file
186
187    @property
188    def combined_file(self) -> Optional[str]:
189        return self._combined_file
190
191    def get_first(self, name) -> Optional['Credentials']:
192        creds = self._store.get_credentials_for_name(name) if self._store else []
193        return creds[0] if len(creds) else None
194
195    def get_credentials_for_name(self, name) -> List['Credentials']:
196        return self._store.get_credentials_for_name(name) if self._store else []
197
198    def issue_certs(self, specs: List[CertificateSpec],
199                    chain: Optional[List['Credentials']] = None) -> List['Credentials']:
200        return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
201
202    def issue_cert(self, spec: CertificateSpec,
203                   chain: Optional[List['Credentials']] = None) -> 'Credentials':
204        key_type = spec.key_type if spec.key_type else self.key_type
205        creds = None
206        if self._store:
207            creds = self._store.load_credentials(
208                name=spec.name, key_type=key_type, single_file=spec.single_file,
209                issuer=self, check_valid=spec.check_valid)
210        if creds is None:
211            creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
212                                              valid_from=spec.valid_from, valid_to=spec.valid_to)
213            if self._store:
214                self._store.save(creds, single_file=spec.single_file)
215                if spec.type == "ca":
216                    self._store.save_chain(creds, "ca", with_root=True)
217
218        if spec.sub_specs:
219            if self._store:
220                sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
221                creds.set_store(sub_store)
222            subchain = chain.copy() if chain else []
223            subchain.append(self)
224            creds.issue_certs(spec.sub_specs, chain=subchain)
225        return creds
226
227
228class CertStore:
229
230    def __init__(self, fpath: str):
231        self._store_dir = fpath
232        if not os.path.exists(self._store_dir):
233            os.makedirs(self._store_dir)
234        self._creds_by_name = {}
235
236    @property
237    def path(self) -> str:
238        return self._store_dir
239
240    def save(self, creds: Credentials, name: Optional[str] = None,
241             chain: Optional[List[Credentials]] = None,
242             single_file: bool = False) -> None:
243        name = name if name is not None else creds.name
244        cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
245        pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
246        comb_file = self.get_combined_file(name=name, key_type=creds.key_type)
247        if single_file:
248            pkey_file = None
249        with open(cert_file, "wb") as fd:
250            fd.write(creds.cert_pem)
251            if chain:
252                for c in chain:
253                    fd.write(c.cert_pem)
254            if pkey_file is None:
255                fd.write(creds.pkey_pem)
256        if pkey_file is not None:
257            with open(pkey_file, "wb") as fd:
258                fd.write(creds.pkey_pem)
259        with open(comb_file, "wb") as fd:
260            fd.write(creds.cert_pem)
261            if chain:
262                for c in chain:
263                    fd.write(c.cert_pem)
264            fd.write(creds.pkey_pem)
265        creds.set_files(cert_file, pkey_file, comb_file)
266        self._add_credentials(name, creds)
267
268    def save_chain(self, creds: Credentials, infix: str, with_root=False):
269        name = creds.name
270        chain = [creds]
271        while creds.issuer is not None:
272            creds = creds.issuer
273            chain.append(creds)
274        if not with_root and len(chain) > 1:
275            chain = chain[:-1]
276        chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem')
277        with open(chain_file, "wb") as fd:
278            for c in chain:
279                fd.write(c.cert_pem)
280
281    def _add_credentials(self, name: str, creds: Credentials):
282        if name not in self._creds_by_name:
283            self._creds_by_name[name] = []
284        self._creds_by_name[name].append(creds)
285
286    def get_credentials_for_name(self, name) -> List[Credentials]:
287        return self._creds_by_name[name] if name in self._creds_by_name else []
288
289    def get_cert_file(self, name: str, key_type=None) -> str:
290        key_infix = ".{0}".format(key_type) if key_type is not None else ""
291        return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
292
293    def get_pkey_file(self, name: str, key_type=None) -> str:
294        key_infix = ".{0}".format(key_type) if key_type is not None else ""
295        return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
296
297    def get_combined_file(self, name: str, key_type=None) -> str:
298        return os.path.join(self._store_dir, f'{name}.pem')
299
300    def load_pem_cert(self, fpath: str) -> x509.Certificate:
301        with open(fpath) as fd:
302            return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
303
304    def load_pem_pkey(self, fpath: str):
305        with open(fpath) as fd:
306            return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
307
308    def load_credentials(self, name: str, key_type=None,
309                         single_file: bool = False,
310                         issuer: Optional[Credentials] = None,
311                         check_valid: bool = False):
312        cert_file = self.get_cert_file(name=name, key_type=key_type)
313        pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
314        comb_file = self.get_combined_file(name=name, key_type=key_type)
315        if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
316            cert = self.load_pem_cert(cert_file)
317            pkey = self.load_pem_pkey(pkey_file)
318            try:
319                now = datetime.now(tz=timezone.utc)
320                if check_valid and \
321                    ((cert.not_valid_after_utc < now) or
322                     (cert.not_valid_before_utc > now)):
323                    return None
324            except AttributeError:  # older python
325                now = datetime.now()
326                if check_valid and \
327                        ((cert.not_valid_after < now) or
328                         (cert.not_valid_before > now)):
329                    return None
330            creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
331            creds.set_store(self)
332            creds.set_files(cert_file, pkey_file, comb_file)
333            self._add_credentials(name, creds)
334            return creds
335        return None
336
337
338class TestCA:
339
340    @classmethod
341    def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
342        store = CertStore(fpath=store_dir)
343        creds = store.load_credentials(name="ca", key_type=key_type, issuer=None)
344        if creds is None:
345            creds = TestCA._make_ca_credentials(name=name, key_type=key_type)
346            store.save(creds, name="ca")
347            creds.set_store(store)
348        return creds
349
350    @staticmethod
351    def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
352                           valid_from: timedelta = timedelta(days=-1),
353                           valid_to: timedelta = timedelta(days=89),
354                           ) -> Credentials:
355        """Create a certificate signed by this CA for the given domains.
356        :returns: the certificate and private key PEM file paths
357        """
358        if spec.domains and len(spec.domains):
359            creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains,
360                                                    issuer=issuer, valid_from=valid_from,
361                                                    valid_to=valid_to, key_type=key_type)
362        elif spec.client:
363            creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer,
364                                                    email=spec.email, valid_from=valid_from,
365                                                    valid_to=valid_to, key_type=key_type)
366        elif spec.name:
367            creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer,
368                                                valid_from=valid_from, valid_to=valid_to,
369                                                key_type=key_type)
370        else:
371            raise Exception(f"unrecognized certificate specification: {spec}")
372        return creds
373
374    @staticmethod
375    def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name:
376        name_pieces = []
377        if org_name:
378            oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
379            name_pieces.append(x509.NameAttribute(oid, org_name))
380        elif common_name:
381            name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
382        if parent:
383            name_pieces.extend([rdn for rdn in parent])
384        return x509.Name(name_pieces)
385
386    @staticmethod
387    def _make_csr(
388            subject: x509.Name,
389            pkey: Any,
390            issuer_subject: Optional[Credentials],
391            valid_from_delta: timedelta = None,
392            valid_until_delta: timedelta = None
393    ):
394        pubkey = pkey.public_key()
395        issuer_subject = issuer_subject if issuer_subject is not None else subject
396
397        valid_from = datetime.now()
398        if valid_until_delta is not None:
399            valid_from += valid_from_delta
400        valid_until = datetime.now()
401        if valid_until_delta is not None:
402            valid_until += valid_until_delta
403
404        return (
405            x509.CertificateBuilder()
406            .subject_name(subject)
407            .issuer_name(issuer_subject)
408            .public_key(pubkey)
409            .not_valid_before(valid_from)
410            .not_valid_after(valid_until)
411            .serial_number(x509.random_serial_number())
412            .add_extension(
413                x509.SubjectKeyIdentifier.from_public_key(pubkey),
414                critical=False,
415            )
416        )
417
418    @staticmethod
419    def _add_ca_usages(csr: Any) -> Any:
420        return csr.add_extension(
421            x509.BasicConstraints(ca=True, path_length=9),
422            critical=True,
423        ).add_extension(
424            x509.KeyUsage(
425                digital_signature=True,
426                content_commitment=False,
427                key_encipherment=False,
428                data_encipherment=False,
429                key_agreement=False,
430                key_cert_sign=True,
431                crl_sign=True,
432                encipher_only=False,
433                decipher_only=False),
434            critical=True
435        ).add_extension(
436            x509.ExtendedKeyUsage([
437                ExtendedKeyUsageOID.CLIENT_AUTH,
438                ExtendedKeyUsageOID.SERVER_AUTH,
439                ExtendedKeyUsageOID.CODE_SIGNING,
440            ]),
441            critical=True
442        )
443
444    @staticmethod
445    def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
446        names = []
447        for name in domains:
448            try:
449                names.append(x509.IPAddress(ipaddress.ip_address(name)))
450            except:
451                names.append(x509.DNSName(name))
452
453        return csr.add_extension(
454            x509.BasicConstraints(ca=False, path_length=None),
455            critical=True,
456        ).add_extension(
457            x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
458                issuer.certificate.extensions.get_extension_for_class(
459                    x509.SubjectKeyIdentifier).value),
460            critical=False
461        ).add_extension(
462            x509.SubjectAlternativeName(names), critical=True,
463        ).add_extension(
464            x509.ExtendedKeyUsage([
465                ExtendedKeyUsageOID.SERVER_AUTH,
466            ]),
467            critical=False
468        )
469
470    @staticmethod
471    def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any:
472        cert = csr.add_extension(
473            x509.BasicConstraints(ca=False, path_length=None),
474            critical=True,
475        ).add_extension(
476            x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
477                issuer.certificate.extensions.get_extension_for_class(
478                    x509.SubjectKeyIdentifier).value),
479            critical=False
480        )
481        if rfc82name:
482            cert.add_extension(
483                x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
484                critical=True,
485            )
486        cert.add_extension(
487            x509.ExtendedKeyUsage([
488                ExtendedKeyUsageOID.CLIENT_AUTH,
489            ]),
490            critical=True
491        )
492        return cert
493
494    @staticmethod
495    def _make_ca_credentials(name, key_type: Any,
496                             issuer: Credentials = None,
497                             valid_from: timedelta = timedelta(days=-1),
498                             valid_to: timedelta = timedelta(days=89),
499                             ) -> Credentials:
500        pkey = _private_key(key_type=key_type)
501        if issuer is not None:
502            issuer_subject = issuer.certificate.subject
503            issuer_key = issuer.private_key
504        else:
505            issuer_subject = None
506            issuer_key = pkey
507        subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
508        csr = TestCA._make_csr(subject=subject,
509                               issuer_subject=issuer_subject, pkey=pkey,
510                               valid_from_delta=valid_from, valid_until_delta=valid_to)
511        csr = TestCA._add_ca_usages(csr)
512        cert = csr.sign(private_key=issuer_key,
513                        algorithm=hashes.SHA256(),
514                        backend=default_backend())
515        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
516
517    @staticmethod
518    def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
519                                 key_type: Any,
520                                 valid_from: timedelta = timedelta(days=-1),
521                                 valid_to: timedelta = timedelta(days=89),
522                                 ) -> Credentials:
523        name = name
524        pkey = _private_key(key_type=key_type)
525        subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
526        csr = TestCA._make_csr(subject=subject,
527                               issuer_subject=issuer.certificate.subject, pkey=pkey,
528                               valid_from_delta=valid_from, valid_until_delta=valid_to)
529        csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
530        cert = csr.sign(private_key=issuer.private_key,
531                        algorithm=hashes.SHA256(),
532                        backend=default_backend())
533        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
534
535    @staticmethod
536    def _make_client_credentials(name: str,
537                                 issuer: Credentials, email: Optional[str],
538                                 key_type: Any,
539                                 valid_from: timedelta = timedelta(days=-1),
540                                 valid_to: timedelta = timedelta(days=89),
541                                 ) -> Credentials:
542        pkey = _private_key(key_type=key_type)
543        subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
544        csr = TestCA._make_csr(subject=subject,
545                               issuer_subject=issuer.certificate.subject, pkey=pkey,
546                               valid_from_delta=valid_from, valid_until_delta=valid_to)
547        csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
548        cert = csr.sign(private_key=issuer.private_key,
549                        algorithm=hashes.SHA256(),
550                        backend=default_backend())
551        return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
552