1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3#*************************************************************************** 4# _ _ ____ _ 5# Project ___| | | | _ \| | 6# / __| | | | |_) | | 7# | (__| |_| | _ <| |___ 8# \___|\___/|_| \_\_____| 9# 10# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al. 11# 12# This software is licensed as described in the file COPYING, which 13# you should have received as part of this distribution. The terms 14# are also available at https://curl.se/docs/copyright.html. 15# 16# You may opt to use, copy, modify, merge, publish, distribute and/or sell 17# copies of the Software, and permit persons to whom the Software is 18# furnished to do so, under the terms of the COPYING file. 19# 20# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY 21# KIND, either express or implied. 22# 23# SPDX-License-Identifier: curl 24# 25########################################################################### 26# 27import ipaddress 28import os 29import re 30from datetime import timedelta, datetime, timezone 31from typing import List, Any, Optional 32 33from cryptography import x509 34from cryptography.hazmat.backends import default_backend 35from cryptography.hazmat.primitives import hashes 36from cryptography.hazmat.primitives.asymmetric import ec, rsa 37from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey 38from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey 39from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key 40from cryptography.x509 import ExtendedKeyUsageOID, NameOID 41 42 43EC_SUPPORTED = {} 44EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [ 45 ec.SECP192R1, 46 ec.SECP224R1, 47 ec.SECP256R1, 48 ec.SECP384R1, 49]]) 50 51 52def _private_key(key_type): 53 if isinstance(key_type, str): 54 key_type = key_type.upper() 55 m = re.match(r'^(RSA)?(\d+)$', key_type) 56 if m: 57 key_type = int(m.group(2)) 58 59 if isinstance(key_type, int): 60 return rsa.generate_private_key( 61 public_exponent=65537, 62 key_size=key_type, 63 backend=default_backend() 64 ) 65 if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED: 66 key_type = EC_SUPPORTED[key_type] 67 return ec.generate_private_key( 68 curve=key_type, 69 backend=default_backend() 70 ) 71 72 73class CertificateSpec: 74 75 def __init__(self, name: Optional[str] = None, 76 domains: Optional[List[str]] = None, 77 email: Optional[str] = None, 78 key_type: Optional[str] = None, 79 single_file: bool = False, 80 valid_from: timedelta = timedelta(days=-1), 81 valid_to: timedelta = timedelta(days=89), 82 client: bool = False, 83 check_valid: bool = True, 84 sub_specs: Optional[List['CertificateSpec']] = None): 85 self._name = name 86 self.domains = domains 87 self.client = client 88 self.email = email 89 self.key_type = key_type 90 self.single_file = single_file 91 self.valid_from = valid_from 92 self.valid_to = valid_to 93 self.sub_specs = sub_specs 94 self.check_valid = check_valid 95 96 @property 97 def name(self) -> Optional[str]: 98 if self._name: 99 return self._name 100 elif self.domains: 101 return self.domains[0] 102 return None 103 104 @property 105 def type(self) -> Optional[str]: 106 if self.domains and len(self.domains): 107 return "server" 108 elif self.client: 109 return "client" 110 elif self.name: 111 return "ca" 112 return None 113 114 115class Credentials: 116 117 def __init__(self, 118 name: str, 119 cert: Any, 120 pkey: Any, 121 issuer: Optional['Credentials'] = None): 122 self._name = name 123 self._cert = cert 124 self._pkey = pkey 125 self._issuer = issuer 126 self._cert_file = None 127 self._pkey_file = None 128 self._store = None 129 130 @property 131 def name(self) -> str: 132 return self._name 133 134 @property 135 def subject(self) -> x509.Name: 136 return self._cert.subject 137 138 @property 139 def key_type(self): 140 if isinstance(self._pkey, RSAPrivateKey): 141 return f"rsa{self._pkey.key_size}" 142 elif isinstance(self._pkey, EllipticCurvePrivateKey): 143 return f"{self._pkey.curve.name}" 144 else: 145 raise Exception(f"unknown key type: {self._pkey}") 146 147 @property 148 def private_key(self) -> Any: 149 return self._pkey 150 151 @property 152 def certificate(self) -> Any: 153 return self._cert 154 155 @property 156 def cert_pem(self) -> bytes: 157 return self._cert.public_bytes(Encoding.PEM) 158 159 @property 160 def pkey_pem(self) -> bytes: 161 return self._pkey.private_bytes( 162 Encoding.PEM, 163 PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8, 164 NoEncryption()) 165 166 @property 167 def issuer(self) -> Optional['Credentials']: 168 return self._issuer 169 170 def set_store(self, store: 'CertStore'): 171 self._store = store 172 173 def set_files(self, cert_file: str, pkey_file: Optional[str] = None, 174 combined_file: Optional[str] = None): 175 self._cert_file = cert_file 176 self._pkey_file = pkey_file 177 self._combined_file = combined_file 178 179 @property 180 def cert_file(self) -> str: 181 return self._cert_file 182 183 @property 184 def pkey_file(self) -> Optional[str]: 185 return self._pkey_file 186 187 @property 188 def combined_file(self) -> Optional[str]: 189 return self._combined_file 190 191 def get_first(self, name) -> Optional['Credentials']: 192 creds = self._store.get_credentials_for_name(name) if self._store else [] 193 return creds[0] if len(creds) else None 194 195 def get_credentials_for_name(self, name) -> List['Credentials']: 196 return self._store.get_credentials_for_name(name) if self._store else [] 197 198 def issue_certs(self, specs: List[CertificateSpec], 199 chain: Optional[List['Credentials']] = None) -> List['Credentials']: 200 return [self.issue_cert(spec=spec, chain=chain) for spec in specs] 201 202 def issue_cert(self, spec: CertificateSpec, 203 chain: Optional[List['Credentials']] = None) -> 'Credentials': 204 key_type = spec.key_type if spec.key_type else self.key_type 205 creds = None 206 if self._store: 207 creds = self._store.load_credentials( 208 name=spec.name, key_type=key_type, single_file=spec.single_file, 209 issuer=self, check_valid=spec.check_valid) 210 if creds is None: 211 creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type, 212 valid_from=spec.valid_from, valid_to=spec.valid_to) 213 if self._store: 214 self._store.save(creds, single_file=spec.single_file) 215 if spec.type == "ca": 216 self._store.save_chain(creds, "ca", with_root=True) 217 218 if spec.sub_specs: 219 if self._store: 220 sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name)) 221 creds.set_store(sub_store) 222 subchain = chain.copy() if chain else [] 223 subchain.append(self) 224 creds.issue_certs(spec.sub_specs, chain=subchain) 225 return creds 226 227 228class CertStore: 229 230 def __init__(self, fpath: str): 231 self._store_dir = fpath 232 if not os.path.exists(self._store_dir): 233 os.makedirs(self._store_dir) 234 self._creds_by_name = {} 235 236 @property 237 def path(self) -> str: 238 return self._store_dir 239 240 def save(self, creds: Credentials, name: Optional[str] = None, 241 chain: Optional[List[Credentials]] = None, 242 single_file: bool = False) -> None: 243 name = name if name is not None else creds.name 244 cert_file = self.get_cert_file(name=name, key_type=creds.key_type) 245 pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type) 246 comb_file = self.get_combined_file(name=name, key_type=creds.key_type) 247 if single_file: 248 pkey_file = None 249 with open(cert_file, "wb") as fd: 250 fd.write(creds.cert_pem) 251 if chain: 252 for c in chain: 253 fd.write(c.cert_pem) 254 if pkey_file is None: 255 fd.write(creds.pkey_pem) 256 if pkey_file is not None: 257 with open(pkey_file, "wb") as fd: 258 fd.write(creds.pkey_pem) 259 with open(comb_file, "wb") as fd: 260 fd.write(creds.cert_pem) 261 if chain: 262 for c in chain: 263 fd.write(c.cert_pem) 264 fd.write(creds.pkey_pem) 265 creds.set_files(cert_file, pkey_file, comb_file) 266 self._add_credentials(name, creds) 267 268 def save_chain(self, creds: Credentials, infix: str, with_root=False): 269 name = creds.name 270 chain = [creds] 271 while creds.issuer is not None: 272 creds = creds.issuer 273 chain.append(creds) 274 if not with_root and len(chain) > 1: 275 chain = chain[:-1] 276 chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem') 277 with open(chain_file, "wb") as fd: 278 for c in chain: 279 fd.write(c.cert_pem) 280 281 def _add_credentials(self, name: str, creds: Credentials): 282 if name not in self._creds_by_name: 283 self._creds_by_name[name] = [] 284 self._creds_by_name[name].append(creds) 285 286 def get_credentials_for_name(self, name) -> List[Credentials]: 287 return self._creds_by_name[name] if name in self._creds_by_name else [] 288 289 def get_cert_file(self, name: str, key_type=None) -> str: 290 key_infix = ".{0}".format(key_type) if key_type is not None else "" 291 return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem') 292 293 def get_pkey_file(self, name: str, key_type=None) -> str: 294 key_infix = ".{0}".format(key_type) if key_type is not None else "" 295 return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem') 296 297 def get_combined_file(self, name: str, key_type=None) -> str: 298 return os.path.join(self._store_dir, f'{name}.pem') 299 300 def load_pem_cert(self, fpath: str) -> x509.Certificate: 301 with open(fpath) as fd: 302 return x509.load_pem_x509_certificate("".join(fd.readlines()).encode()) 303 304 def load_pem_pkey(self, fpath: str): 305 with open(fpath) as fd: 306 return load_pem_private_key("".join(fd.readlines()).encode(), password=None) 307 308 def load_credentials(self, name: str, key_type=None, 309 single_file: bool = False, 310 issuer: Optional[Credentials] = None, 311 check_valid: bool = False): 312 cert_file = self.get_cert_file(name=name, key_type=key_type) 313 pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type) 314 comb_file = self.get_combined_file(name=name, key_type=key_type) 315 if os.path.isfile(cert_file) and os.path.isfile(pkey_file): 316 cert = self.load_pem_cert(cert_file) 317 pkey = self.load_pem_pkey(pkey_file) 318 try: 319 now = datetime.now(tz=timezone.utc) 320 if check_valid and \ 321 ((cert.not_valid_after_utc < now) or 322 (cert.not_valid_before_utc > now)): 323 return None 324 except AttributeError: # older python 325 now = datetime.now() 326 if check_valid and \ 327 ((cert.not_valid_after < now) or 328 (cert.not_valid_before > now)): 329 return None 330 creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 331 creds.set_store(self) 332 creds.set_files(cert_file, pkey_file, comb_file) 333 self._add_credentials(name, creds) 334 return creds 335 return None 336 337 338class TestCA: 339 340 @classmethod 341 def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials: 342 store = CertStore(fpath=store_dir) 343 creds = store.load_credentials(name="ca", key_type=key_type, issuer=None) 344 if creds is None: 345 creds = TestCA._make_ca_credentials(name=name, key_type=key_type) 346 store.save(creds, name="ca") 347 creds.set_store(store) 348 return creds 349 350 @staticmethod 351 def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any, 352 valid_from: timedelta = timedelta(days=-1), 353 valid_to: timedelta = timedelta(days=89), 354 ) -> Credentials: 355 """Create a certificate signed by this CA for the given domains. 356 :returns: the certificate and private key PEM file paths 357 """ 358 if spec.domains and len(spec.domains): 359 creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains, 360 issuer=issuer, valid_from=valid_from, 361 valid_to=valid_to, key_type=key_type) 362 elif spec.client: 363 creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer, 364 email=spec.email, valid_from=valid_from, 365 valid_to=valid_to, key_type=key_type) 366 elif spec.name: 367 creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer, 368 valid_from=valid_from, valid_to=valid_to, 369 key_type=key_type) 370 else: 371 raise Exception(f"unrecognized certificate specification: {spec}") 372 return creds 373 374 @staticmethod 375 def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name: 376 name_pieces = [] 377 if org_name: 378 oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME 379 name_pieces.append(x509.NameAttribute(oid, org_name)) 380 elif common_name: 381 name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name)) 382 if parent: 383 name_pieces.extend([rdn for rdn in parent]) 384 return x509.Name(name_pieces) 385 386 @staticmethod 387 def _make_csr( 388 subject: x509.Name, 389 pkey: Any, 390 issuer_subject: Optional[Credentials], 391 valid_from_delta: timedelta = None, 392 valid_until_delta: timedelta = None 393 ): 394 pubkey = pkey.public_key() 395 issuer_subject = issuer_subject if issuer_subject is not None else subject 396 397 valid_from = datetime.now() 398 if valid_until_delta is not None: 399 valid_from += valid_from_delta 400 valid_until = datetime.now() 401 if valid_until_delta is not None: 402 valid_until += valid_until_delta 403 404 return ( 405 x509.CertificateBuilder() 406 .subject_name(subject) 407 .issuer_name(issuer_subject) 408 .public_key(pubkey) 409 .not_valid_before(valid_from) 410 .not_valid_after(valid_until) 411 .serial_number(x509.random_serial_number()) 412 .add_extension( 413 x509.SubjectKeyIdentifier.from_public_key(pubkey), 414 critical=False, 415 ) 416 ) 417 418 @staticmethod 419 def _add_ca_usages(csr: Any) -> Any: 420 return csr.add_extension( 421 x509.BasicConstraints(ca=True, path_length=9), 422 critical=True, 423 ).add_extension( 424 x509.KeyUsage( 425 digital_signature=True, 426 content_commitment=False, 427 key_encipherment=False, 428 data_encipherment=False, 429 key_agreement=False, 430 key_cert_sign=True, 431 crl_sign=True, 432 encipher_only=False, 433 decipher_only=False), 434 critical=True 435 ).add_extension( 436 x509.ExtendedKeyUsage([ 437 ExtendedKeyUsageOID.CLIENT_AUTH, 438 ExtendedKeyUsageOID.SERVER_AUTH, 439 ExtendedKeyUsageOID.CODE_SIGNING, 440 ]), 441 critical=True 442 ) 443 444 @staticmethod 445 def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any: 446 names = [] 447 for name in domains: 448 try: 449 names.append(x509.IPAddress(ipaddress.ip_address(name))) 450 except: 451 names.append(x509.DNSName(name)) 452 453 return csr.add_extension( 454 x509.BasicConstraints(ca=False, path_length=None), 455 critical=True, 456 ).add_extension( 457 x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( 458 issuer.certificate.extensions.get_extension_for_class( 459 x509.SubjectKeyIdentifier).value), 460 critical=False 461 ).add_extension( 462 x509.SubjectAlternativeName(names), critical=True, 463 ).add_extension( 464 x509.ExtendedKeyUsage([ 465 ExtendedKeyUsageOID.SERVER_AUTH, 466 ]), 467 critical=False 468 ) 469 470 @staticmethod 471 def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any: 472 cert = csr.add_extension( 473 x509.BasicConstraints(ca=False, path_length=None), 474 critical=True, 475 ).add_extension( 476 x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( 477 issuer.certificate.extensions.get_extension_for_class( 478 x509.SubjectKeyIdentifier).value), 479 critical=False 480 ) 481 if rfc82name: 482 cert.add_extension( 483 x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]), 484 critical=True, 485 ) 486 cert.add_extension( 487 x509.ExtendedKeyUsage([ 488 ExtendedKeyUsageOID.CLIENT_AUTH, 489 ]), 490 critical=True 491 ) 492 return cert 493 494 @staticmethod 495 def _make_ca_credentials(name, key_type: Any, 496 issuer: Credentials = None, 497 valid_from: timedelta = timedelta(days=-1), 498 valid_to: timedelta = timedelta(days=89), 499 ) -> Credentials: 500 pkey = _private_key(key_type=key_type) 501 if issuer is not None: 502 issuer_subject = issuer.certificate.subject 503 issuer_key = issuer.private_key 504 else: 505 issuer_subject = None 506 issuer_key = pkey 507 subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None) 508 csr = TestCA._make_csr(subject=subject, 509 issuer_subject=issuer_subject, pkey=pkey, 510 valid_from_delta=valid_from, valid_until_delta=valid_to) 511 csr = TestCA._add_ca_usages(csr) 512 cert = csr.sign(private_key=issuer_key, 513 algorithm=hashes.SHA256(), 514 backend=default_backend()) 515 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 516 517 @staticmethod 518 def _make_server_credentials(name: str, domains: List[str], issuer: Credentials, 519 key_type: Any, 520 valid_from: timedelta = timedelta(days=-1), 521 valid_to: timedelta = timedelta(days=89), 522 ) -> Credentials: 523 name = name 524 pkey = _private_key(key_type=key_type) 525 subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject) 526 csr = TestCA._make_csr(subject=subject, 527 issuer_subject=issuer.certificate.subject, pkey=pkey, 528 valid_from_delta=valid_from, valid_until_delta=valid_to) 529 csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer) 530 cert = csr.sign(private_key=issuer.private_key, 531 algorithm=hashes.SHA256(), 532 backend=default_backend()) 533 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 534 535 @staticmethod 536 def _make_client_credentials(name: str, 537 issuer: Credentials, email: Optional[str], 538 key_type: Any, 539 valid_from: timedelta = timedelta(days=-1), 540 valid_to: timedelta = timedelta(days=89), 541 ) -> Credentials: 542 pkey = _private_key(key_type=key_type) 543 subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject) 544 csr = TestCA._make_csr(subject=subject, 545 issuer_subject=issuer.certificate.subject, pkey=pkey, 546 valid_from_delta=valid_from, valid_until_delta=valid_to) 547 csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email) 548 cert = csr.sign(private_key=issuer.private_key, 549 algorithm=hashes.SHA256(), 550 backend=default_backend()) 551 return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) 552