1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3#*************************************************************************** 4# _ _ ____ _ 5# Project ___| | | | _ \| | 6# / __| | | | |_) | | 7# | (__| |_| | _ <| |___ 8# \___|\___/|_| \_\_____| 9# 10# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al. 11# 12# This software is licensed as described in the file COPYING, which 13# you should have received as part of this distribution. The terms 14# are also available at https://curl.se/docs/copyright.html. 15# 16# You may opt to use, copy, modify, merge, publish, distribute and/or sell 17# copies of the Software, and permit persons to whom the Software is 18# furnished to do so, under the terms of the COPYING file. 19# 20# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY 21# KIND, either express or implied. 22# 23# SPDX-License-Identifier: curl 24# 25########################################################################### 26# 27import ipaddress 28import os 29import re 30from datetime import timedelta, datetime, timezone 31from typing import List, Any, Optional 32 33from cryptography import x509 34from cryptography.hazmat.backends import default_backend 35from cryptography.hazmat.primitives import hashes 36from cryptography.hazmat.primitives.asymmetric import ec, rsa 37from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey 38from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey 39from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key 40from cryptography.x509 import ExtendedKeyUsageOID, NameOID 41 42 43EC_SUPPORTED = {} 44EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [ 45 ec.SECP192R1, 46 ec.SECP224R1, 47 ec.SECP256R1, 48 ec.SECP384R1, 49]]) 50 51 52def _private_key(key_type): 53 if isinstance(key_type, str): 54 key_type = key_type.upper() 55 m = re.match(r'^(RSA)?(\d+)$', key_type) 56 if m: 57 key_type = int(m.group(2)) 58 59 if isinstance(key_type, int): 60 return rsa.generate_private_key( 61 public_exponent=65537, 62 key_size=key_type, 63 backend=default_backend() 64 ) 65 if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED: 66 key_type = EC_SUPPORTED[key_type] 67 return ec.generate_private_key( 68 curve=key_type, 69 backend=default_backend() 70 ) 71 72 73class CertificateSpec: 74 75 def __init__(self, name: Optional[str] = None, 76 domains: Optional[List[str]] = None, 77 email: Optional[str] = None, 78 key_type: Optional[str] = None, 79 single_file: bool = False, 80 valid_from: timedelta = timedelta(days=-1), 81 valid_to: timedelta = timedelta(days=89), 82 client: bool = False, 83 check_valid: bool = True, 84 sub_specs: Optional[List['CertificateSpec']] = None): 85 self._name = name 86 self.domains = domains 87 self.client = client 88 self.email = email 89 self.key_type = key_type 90 self.single_file = single_file 91 self.valid_from = valid_from 92 self.valid_to = valid_to 93 self.sub_specs = sub_specs 94 self.check_valid = check_valid 95 96 @property 97 def name(self) -> Optional[str]: 98 if self._name: 99 return self._name 100 elif self.domains: 101 return self.domains[0] 102 return None 103 104 @property 105 def type(self) -> Optional[str]: 106 if self.domains and len(self.domains): 107 return "server" 108 elif self.client: 109 return "client" 110 elif self.name: 111 return "ca" 112 return None 113 114 115class Credentials: 116 117 def __init__(self, 118 name: str, 119 cert: Any, 120 pkey: Any, 121 issuer: Optional['Credentials'] = None): 122 self._name = name 123 self._cert = cert 124 self._pkey = pkey 125 self._issuer = issuer 126 self._cert_file = None 127 self._pkey_file = None 128 self._store = None 129 self._combined_file = None 130 131 @property 132 def name(self) -> str: 133 return self._name 134 135 @property 136 def subject(self) -> x509.Name: 137 return self._cert.subject 138 139 @property 140 def key_type(self): 141 if isinstance(self._pkey, RSAPrivateKey): 142 return f"rsa{self._pkey.key_size}" 143 elif isinstance(self._pkey, EllipticCurvePrivateKey): 144 return f"{self._pkey.curve.name}" 145 else: 146 raise Exception(f"unknown key type: {self._pkey}") 147 148 @property 149 def private_key(self) -> Any: 150 return self._pkey 151 152 @property 153 def certificate(self) -> Any: 154 return self._cert 155 156 @property 157 def cert_pem(self) -> bytes: 158 return self._cert.public_bytes(Encoding.PEM) 159 160 @property 161 def pkey_pem(self) -> bytes: 162 return self._pkey.private_bytes( 163 Encoding.PEM, 164 PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8, 165 NoEncryption()) 166 167 @property 168 def issuer(self) -> Optional['Credentials']: 169 return self._issuer 170 171 def set_store(self, store: 'CertStore'): 172 self._store = store 173 174 def set_files(self, cert_file: str, pkey_file: Optional[str] = None, 175 combined_file: Optional[str] = None): 176 self._cert_file = cert_file 177 self._pkey_file = pkey_file 178 self._combined_file = combined_file 179 180 @property 181 def cert_file(self) -> str: 182 return self._cert_file 183 184 @property 185 def pkey_file(self) -> Optional[str]: 186 return self._pkey_file 187 188 @property 189 def combined_file(self) -> Optional[str]: 190 return self._combined_file 191 192 def get_first(self, name) -> Optional['Credentials']: 193 creds = self._store.get_credentials_for_name(name) if self._store else [] 194 return creds[0] if len(creds) else None 195 196 def get_credentials_for_name(self, name) -> List['Credentials']: 197 return self._store.get_credentials_for_name(name) if self._store else [] 198 199 def issue_certs(self, specs: List[CertificateSpec], 200 chain: Optional[List['Credentials']] = None) -> List['Credentials']: 201 return [self.issue_cert(spec=spec, chain=chain) for spec in specs] 202 203 def issue_cert(self, spec: CertificateSpec, 204 chain: Optional[List['Credentials']] = None) -> 'Credentials': 205 key_type = spec.key_type if spec.key_type else self.key_type 206 creds = None 207 if self._store: 208 creds = self._store.load_credentials( 209 name=spec.name, key_type=key_type, single_file=spec.single_file, 210 issuer=self, check_valid=spec.check_valid) 211 if creds is None: 212 creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type, 213 valid_from=spec.valid_from, valid_to=spec.valid_to) 214 if self._store: 215 self._store.save(creds, single_file=spec.single_file) 216 if spec.type == "ca": 217 self._store.save_chain(creds, "ca", with_root=True) 218 219 if spec.sub_specs: 220 if self._store: 221 sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name)) 222 creds.set_store(sub_store) 223 subchain = chain.copy() if chain else [] 224 subchain.append(self) 225 creds.issue_certs(spec.sub_specs, chain=subchain) 226 return creds 227 228 229class CertStore: 230 231 def __init__(self, fpath: str): 232 self._store_dir = fpath 233 if not os.path.exists(self._store_dir): 234 os.makedirs(self._store_dir) 235 self._creds_by_name = {} 236 237 @property 238 def path(self) -> str: 239 return self._store_dir 240 241 def save(self, creds: Credentials, name: Optional[str] = None, 242 chain: Optional[List[Credentials]] = None, 243 single_file: bool = False) -> None: 244 name = name if name is not None else creds.name 245 cert_file = self.get_cert_file(name=name, key_type=creds.key_type) 246 pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type) 247 comb_file = self.get_combined_file(name=name, key_type=creds.key_type) 248 if single_file: 249 pkey_file = None 250 with open(cert_file, "wb") as fd: 251 fd.write(creds.cert_pem) 252 if chain: 253 for c in chain: 254 fd.write(c.cert_pem) 255 if pkey_file is None: 256 fd.write(creds.pkey_pem) 257 if pkey_file is not None: 258 with open(pkey_file, "wb") as fd: 259 fd.write(creds.pkey_pem) 260 with open(comb_file, "wb") as fd: 261 fd.write(creds.cert_pem) 262 if chain: 263 for c in chain: 264 fd.write(c.cert_pem) 265 fd.write(creds.pkey_pem) 266 creds.set_files(cert_file, pkey_file, comb_file) 267 self._add_credentials(name, creds) 268 269 def save_chain(self, creds: Credentials, infix: str, with_root=False): 270 name = creds.name 271 chain = [creds] 272 while creds.issuer is not None: 273 creds = creds.issuer 274 chain.append(creds) 275 if not with_root and len(chain) > 1: 276 chain = chain[:-1] 277 chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem') 278 with open(chain_file, "wb") as fd: 279 for c in chain: 280 fd.write(c.cert_pem) 281 282 def _add_credentials(self, name: str, creds: Credentials): 283 if name not in self._creds_by_name: 284 self._creds_by_name[name] = [] 285 self._creds_by_name[name].append(creds) 286 287 def get_credentials_for_name(self, name) -> List[Credentials]: 288 return self._creds_by_name[name] if name in self._creds_by_name else [] 289 290 def get_cert_file(self, name: str, key_type=None) -> str: 291 key_infix = ".{0}".format(key_type) if key_type is not None else "" 292 return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem') 293 294 def get_pkey_file(self, name: str, key_type=None) -> str: 295 key_infix = ".{0}".format(key_type) if key_type is not None else "" 296 return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem') 297 298 def get_combined_file(self, name: str, key_type=None) -> str: 299 return os.path.join(self._store_dir, f'{name}.pem') 300 301 def load_pem_cert(self, fpath: str) -> x509.Certificate: 302 with open(fpath) as fd: 303 return x509.load_pem_x509_certificate("".join(fd.readlines()).encode()) 304 305 def load_pem_pkey(self, fpath: str): 306 with open(fpath) as fd: 307 return load_pem_private_key("".join(fd.readlines()).encode(), password=None) 308 309 def load_credentials(self, name: str, key_type=None, 310 single_file: bool = False, 311 issuer: Optional[Credentials] = None, 312 check_valid: bool = False): 313 cert_file = self.get_cert_file(name=name, key_type=key_type) 314 pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type) 315 comb_file = self.get_combined_file(name=name, key_type=key_type) 316 if os.path.isfile(cert_file) and os.path.isfile(pkey_file): 317 cert = self.load_pem_cert(cert_file) 318 pkey = self.load_pem_pkey(pkey_file) 319 try: 320 now = datetime.now(tz=timezone.utc) 321 if check_valid and \ 322 ((cert.not_valid_after_utc < now) or 323 (cert.not_valid_before_utc > now)): 324 return None 325 except AttributeError: # older python 326 now = datetime.now() 327 if check_valid and \ 328 ((cert.not_valid_after < now) or 329 (cert.not_valid_before > now)): 330 return None 331 creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 332 creds.set_store(self) 333 creds.set_files(cert_file, pkey_file, comb_file) 334 self._add_credentials(name, creds) 335 return creds 336 return None 337 338 339class TestCA: 340 341 @classmethod 342 def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials: 343 store = CertStore(fpath=store_dir) 344 creds = store.load_credentials(name="ca", key_type=key_type, issuer=None) 345 if creds is None: 346 creds = TestCA._make_ca_credentials(name=name, key_type=key_type) 347 store.save(creds, name="ca") 348 creds.set_store(store) 349 return creds 350 351 @staticmethod 352 def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any, 353 valid_from: timedelta = timedelta(days=-1), 354 valid_to: timedelta = timedelta(days=89), 355 ) -> Credentials: 356 """ 357 Create a certificate signed by this CA for the given domains. 358 359 :returns: the certificate and private key PEM file paths 360 """ 361 if spec.domains and len(spec.domains): 362 creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains, 363 issuer=issuer, valid_from=valid_from, 364 valid_to=valid_to, key_type=key_type) 365 elif spec.client: 366 creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer, 367 email=spec.email, valid_from=valid_from, 368 valid_to=valid_to, key_type=key_type) 369 elif spec.name: 370 creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer, 371 valid_from=valid_from, valid_to=valid_to, 372 key_type=key_type) 373 else: 374 raise Exception(f"unrecognized certificate specification: {spec}") 375 return creds 376 377 @staticmethod 378 def _make_x509_name(org_name: Optional[str] = None, common_name: Optional[str] = None, parent: x509.Name = None) -> x509.Name: 379 name_pieces = [] 380 if org_name: 381 oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME 382 name_pieces.append(x509.NameAttribute(oid, org_name)) 383 elif common_name: 384 name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name)) 385 if parent: 386 name_pieces.extend(list(parent)) 387 return x509.Name(name_pieces) 388 389 @staticmethod 390 def _make_csr( 391 subject: x509.Name, 392 pkey: Any, 393 issuer_subject: Optional[Credentials], 394 valid_from_delta: Optional[timedelta] = None, 395 valid_until_delta: Optional[timedelta] = None 396 ): 397 pubkey = pkey.public_key() 398 issuer_subject = issuer_subject if issuer_subject is not None else subject 399 400 valid_from = datetime.now() 401 if valid_until_delta is not None: 402 valid_from += valid_from_delta 403 valid_until = datetime.now() 404 if valid_until_delta is not None: 405 valid_until += valid_until_delta 406 407 return ( 408 x509.CertificateBuilder() 409 .subject_name(subject) 410 .issuer_name(issuer_subject) 411 .public_key(pubkey) 412 .not_valid_before(valid_from) 413 .not_valid_after(valid_until) 414 .serial_number(x509.random_serial_number()) 415 .add_extension( 416 x509.SubjectKeyIdentifier.from_public_key(pubkey), 417 critical=False, 418 ) 419 ) 420 421 @staticmethod 422 def _add_ca_usages(csr: Any) -> Any: 423 return csr.add_extension( 424 x509.BasicConstraints(ca=True, path_length=9), 425 critical=True, 426 ).add_extension( 427 x509.KeyUsage( 428 digital_signature=True, 429 content_commitment=False, 430 key_encipherment=False, 431 data_encipherment=False, 432 key_agreement=False, 433 key_cert_sign=True, 434 crl_sign=True, 435 encipher_only=False, 436 decipher_only=False), 437 critical=True 438 ).add_extension( 439 x509.ExtendedKeyUsage([ 440 ExtendedKeyUsageOID.CLIENT_AUTH, 441 ExtendedKeyUsageOID.SERVER_AUTH, 442 ExtendedKeyUsageOID.CODE_SIGNING, 443 ]), 444 critical=True 445 ) 446 447 @staticmethod 448 def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any: 449 names = [] 450 for name in domains: 451 try: 452 names.append(x509.IPAddress(ipaddress.ip_address(name))) 453 # TODO: specify specific exceptions here 454 except: # noqa: E722 455 names.append(x509.DNSName(name)) 456 457 return csr.add_extension( 458 x509.BasicConstraints(ca=False, path_length=None), 459 critical=True, 460 ).add_extension( 461 x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( 462 issuer.certificate.extensions.get_extension_for_class( 463 x509.SubjectKeyIdentifier).value), 464 critical=False 465 ).add_extension( 466 x509.SubjectAlternativeName(names), critical=True, 467 ).add_extension( 468 x509.ExtendedKeyUsage([ 469 ExtendedKeyUsageOID.SERVER_AUTH, 470 ]), 471 critical=False 472 ) 473 474 @staticmethod 475 def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: Optional[str] = None) -> Any: 476 cert = csr.add_extension( 477 x509.BasicConstraints(ca=False, path_length=None), 478 critical=True, 479 ).add_extension( 480 x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( 481 issuer.certificate.extensions.get_extension_for_class( 482 x509.SubjectKeyIdentifier).value), 483 critical=False 484 ) 485 if rfc82name: 486 cert.add_extension( 487 x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]), 488 critical=True, 489 ) 490 cert.add_extension( 491 x509.ExtendedKeyUsage([ 492 ExtendedKeyUsageOID.CLIENT_AUTH, 493 ]), 494 critical=True 495 ) 496 return cert 497 498 @staticmethod 499 def _make_ca_credentials(name, key_type: Any, 500 issuer: Optional[Credentials] = None, 501 valid_from: timedelta = timedelta(days=-1), 502 valid_to: timedelta = timedelta(days=89), 503 ) -> Credentials: 504 pkey = _private_key(key_type=key_type) 505 if issuer is not None: 506 issuer_subject = issuer.certificate.subject 507 issuer_key = issuer.private_key 508 else: 509 issuer_subject = None 510 issuer_key = pkey 511 subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None) 512 csr = TestCA._make_csr(subject=subject, 513 issuer_subject=issuer_subject, pkey=pkey, 514 valid_from_delta=valid_from, valid_until_delta=valid_to) 515 csr = TestCA._add_ca_usages(csr) 516 cert = csr.sign(private_key=issuer_key, 517 algorithm=hashes.SHA256(), 518 backend=default_backend()) 519 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 520 521 @staticmethod 522 def _make_server_credentials(name: str, domains: List[str], issuer: Credentials, 523 key_type: Any, 524 valid_from: timedelta = timedelta(days=-1), 525 valid_to: timedelta = timedelta(days=89), 526 ) -> Credentials: 527 pkey = _private_key(key_type=key_type) 528 subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject) 529 csr = TestCA._make_csr(subject=subject, 530 issuer_subject=issuer.certificate.subject, pkey=pkey, 531 valid_from_delta=valid_from, valid_until_delta=valid_to) 532 csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer) 533 cert = csr.sign(private_key=issuer.private_key, 534 algorithm=hashes.SHA256(), 535 backend=default_backend()) 536 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 537 538 @staticmethod 539 def _make_client_credentials(name: str, 540 issuer: Credentials, email: Optional[str], 541 key_type: Any, 542 valid_from: timedelta = timedelta(days=-1), 543 valid_to: timedelta = timedelta(days=89), 544 ) -> Credentials: 545 pkey = _private_key(key_type=key_type) 546 subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject) 547 csr = TestCA._make_csr(subject=subject, 548 issuer_subject=issuer.certificate.subject, pkey=pkey, 549 valid_from_delta=valid_from, valid_until_delta=valid_to) 550 csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email) 551 cert = csr.sign(private_key=issuer.private_key, 552 algorithm=hashes.SHA256(), 553 backend=default_backend()) 554 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 555