1 /* 2 * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved. 3 * 4 * Licensed under the Apache License 2.0 (the "License"). You may not use 5 * this file except in compliance with the License. You can obtain a copy 6 * in the file LICENSE in the source distribution or at 7 * https://www.openssl.org/source/license.html 8 */ 9 10 #ifndef OSSL_INTERNAL_SAFE_MATH_H 11 # define OSSL_INTERNAL_SAFE_MATH_H 12 # pragma once 13 14 # include <openssl/e_os2.h> /* For 'ossl_inline' */ 15 16 # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING 17 # ifdef __has_builtin 18 # define has(func) __has_builtin(func) 19 # elif defined(__GNUC__) 20 # if __GNUC__ > 5 21 # define has(func) 1 22 # endif 23 # endif 24 # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */ 25 26 # ifndef has 27 # define has(func) 0 28 # endif 29 30 /* 31 * Safe addition helpers 32 */ 33 # if has(__builtin_add_overflow) 34 # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \ 35 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ 36 type b, \ 37 int *err) \ 38 { \ 39 type r; \ 40 \ 41 if (!__builtin_add_overflow(a, b, &r)) \ 42 return r; \ 43 *err |= 1; \ 44 return a < 0 ? min : max; \ 45 } 46 47 # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \ 48 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ 49 type b, \ 50 int *err) \ 51 { \ 52 type r; \ 53 \ 54 if (!__builtin_add_overflow(a, b, &r)) \ 55 return r; \ 56 *err |= 1; \ 57 return a + b; \ 58 } 59 60 # else /* has(__builtin_add_overflow) */ 61 # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \ 62 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ 63 type b, \ 64 int *err) \ 65 { \ 66 if ((a < 0) ^ (b < 0) \ 67 || (a > 0 && b <= max - a) \ 68 || (a < 0 && b >= min - a) \ 69 || a == 0) \ 70 return a + b; \ 71 *err |= 1; \ 72 return a < 0 ? min : max; \ 73 } 74 75 # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \ 76 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ 77 type b, \ 78 int *err) \ 79 { \ 80 if (b > max - a) \ 81 *err |= 1; \ 82 return a + b; \ 83 } 84 # endif /* has(__builtin_add_overflow) */ 85 86 /* 87 * Safe subtraction helpers 88 */ 89 # if has(__builtin_sub_overflow) 90 # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \ 91 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ 92 type b, \ 93 int *err) \ 94 { \ 95 type r; \ 96 \ 97 if (!__builtin_sub_overflow(a, b, &r)) \ 98 return r; \ 99 *err |= 1; \ 100 return a < 0 ? min : max; \ 101 } 102 103 # else /* has(__builtin_sub_overflow) */ 104 # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \ 105 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ 106 type b, \ 107 int *err) \ 108 { \ 109 if (!((a < 0) ^ (b < 0)) \ 110 || (b > 0 && a >= min + b) \ 111 || (b < 0 && a <= max + b) \ 112 || b == 0) \ 113 return a - b; \ 114 *err |= 1; \ 115 return a < 0 ? min : max; \ 116 } 117 118 # endif /* has(__builtin_sub_overflow) */ 119 120 # define OSSL_SAFE_MATH_SUBU(type_name, type) \ 121 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ 122 type b, \ 123 int *err) \ 124 { \ 125 if (b > a) \ 126 *err |= 1; \ 127 return a - b; \ 128 } 129 130 /* 131 * Safe multiplication helpers 132 */ 133 # if has(__builtin_mul_overflow) 134 # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \ 135 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ 136 type b, \ 137 int *err) \ 138 { \ 139 type r; \ 140 \ 141 if (!__builtin_mul_overflow(a, b, &r)) \ 142 return r; \ 143 *err |= 1; \ 144 return (a < 0) ^ (b < 0) ? min : max; \ 145 } 146 147 # define OSSL_SAFE_MATH_MULU(type_name, type, max) \ 148 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ 149 type b, \ 150 int *err) \ 151 { \ 152 type r; \ 153 \ 154 if (!__builtin_mul_overflow(a, b, &r)) \ 155 return r; \ 156 *err |= 1; \ 157 return a * b; \ 158 } 159 160 # else /* has(__builtin_mul_overflow) */ 161 # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \ 162 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ 163 type b, \ 164 int *err) \ 165 { \ 166 if (a == 0 || b == 0) \ 167 return 0; \ 168 if (a == 1) \ 169 return b; \ 170 if (b == 1) \ 171 return a; \ 172 if (a != min && b != min) { \ 173 const type x = a < 0 ? -a : a; \ 174 const type y = b < 0 ? -b : b; \ 175 \ 176 if (x <= max / y) \ 177 return a * b; \ 178 } \ 179 *err |= 1; \ 180 return (a < 0) ^ (b < 0) ? min : max; \ 181 } 182 183 # define OSSL_SAFE_MATH_MULU(type_name, type, max) \ 184 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ 185 type b, \ 186 int *err) \ 187 { \ 188 if (b != 0 && a > max / b) \ 189 *err |= 1; \ 190 return a * b; \ 191 } 192 # endif /* has(__builtin_mul_overflow) */ 193 194 /* 195 * Safe division helpers 196 */ 197 # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \ 198 static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \ 199 type b, \ 200 int *err) \ 201 { \ 202 if (b == 0) { \ 203 *err |= 1; \ 204 return a < 0 ? min : max; \ 205 } \ 206 if (b == -1 && a == min) { \ 207 *err |= 1; \ 208 return max; \ 209 } \ 210 return a / b; \ 211 } 212 213 # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \ 214 static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \ 215 type b, \ 216 int *err) \ 217 { \ 218 if (b != 0) \ 219 return a / b; \ 220 *err |= 1; \ 221 return max; \ 222 } 223 224 /* 225 * Safe modulus helpers 226 */ 227 # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \ 228 static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \ 229 type b, \ 230 int *err) \ 231 { \ 232 if (b == 0) { \ 233 *err |= 1; \ 234 return 0; \ 235 } \ 236 if (b == -1 && a == min) { \ 237 *err |= 1; \ 238 return max; \ 239 } \ 240 return a % b; \ 241 } 242 243 # define OSSL_SAFE_MATH_MODU(type_name, type) \ 244 static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \ 245 type b, \ 246 int *err) \ 247 { \ 248 if (b != 0) \ 249 return a % b; \ 250 *err |= 1; \ 251 return 0; \ 252 } 253 254 /* 255 * Safe negation helpers 256 */ 257 # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \ 258 static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \ 259 int *err) \ 260 { \ 261 if (a != min) \ 262 return -a; \ 263 *err |= 1; \ 264 return min; \ 265 } 266 267 # define OSSL_SAFE_MATH_NEGU(type_name, type) \ 268 static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \ 269 int *err) \ 270 { \ 271 if (a == 0) \ 272 return a; \ 273 *err |= 1; \ 274 return 1 + ~a; \ 275 } 276 277 /* 278 * Safe absolute value helpers 279 */ 280 # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \ 281 static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \ 282 int *err) \ 283 { \ 284 if (a != min) \ 285 return a < 0 ? -a : a; \ 286 *err |= 1; \ 287 return min; \ 288 } 289 290 # define OSSL_SAFE_MATH_ABSU(type_name, type) \ 291 static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \ 292 int *err) \ 293 { \ 294 return a; \ 295 } 296 297 /* 298 * Safe fused multiply divide helpers 299 * 300 * These are a bit obscure: 301 * . They begin by checking the denominator for zero and getting rid of this 302 * corner case. 303 * 304 * . Second is an attempt to do the multiplication directly, if it doesn't 305 * overflow, the quotient is returned (for signed values there is a 306 * potential problem here which isn't present for unsigned). 307 * 308 * . Finally, the multiplication/division is transformed so that the larger 309 * of the numerators is divided first. This requires a remainder 310 * correction: 311 * 312 * a b / c = (a / c) b + (a mod c) b / c, where a > b 313 * 314 * The individual operations need to be overflow checked (again signed 315 * being more problematic). 316 * 317 * The algorithm used is not perfect but it should be "good enough". 318 */ 319 # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \ 320 static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \ 321 type b, \ 322 type c, \ 323 int *err) \ 324 { \ 325 int e2 = 0; \ 326 type q, r, x, y; \ 327 \ 328 if (c == 0) { \ 329 *err |= 1; \ 330 return a == 0 || b == 0 ? 0 : max; \ 331 } \ 332 x = safe_mul_ ## type_name(a, b, &e2); \ 333 if (!e2) \ 334 return safe_div_ ## type_name(x, c, err); \ 335 if (b > a) { \ 336 x = b; \ 337 b = a; \ 338 a = x; \ 339 } \ 340 q = safe_div_ ## type_name(a, c, err); \ 341 r = safe_mod_ ## type_name(a, c, err); \ 342 x = safe_mul_ ## type_name(r, b, err); \ 343 y = safe_mul_ ## type_name(q, b, err); \ 344 q = safe_div_ ## type_name(x, c, err); \ 345 return safe_add_ ## type_name(y, q, err); \ 346 } 347 348 # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \ 349 static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \ 350 type b, \ 351 type c, \ 352 int *err) \ 353 { \ 354 int e2 = 0; \ 355 type x, y; \ 356 \ 357 if (c == 0) { \ 358 *err |= 1; \ 359 return a == 0 || b == 0 ? 0 : max; \ 360 } \ 361 x = safe_mul_ ## type_name(a, b, &e2); \ 362 if (!e2) \ 363 return x / c; \ 364 if (b > a) { \ 365 x = b; \ 366 b = a; \ 367 a = x; \ 368 } \ 369 x = safe_mul_ ## type_name(a % c, b, err); \ 370 y = safe_mul_ ## type_name(a / c, b, err); \ 371 return safe_add_ ## type_name(y, x / c, err); \ 372 } 373 374 /* 375 * Calculate a / b rounding up: 376 * i.e. a / b + (a % b != 0) 377 * Which is usually (less safely) converted to (a + b - 1) / b 378 * If you *know* that b != 0, then it's safe to ignore err. 379 */ 380 #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \ 381 static ossl_inline ossl_unused type safe_div_round_up_ ## type_name \ 382 (type a, type b, int *errp) \ 383 { \ 384 type x; \ 385 int *err, err_local = 0; \ 386 \ 387 /* Allow errors to be ignored by callers */ \ 388 err = errp != NULL ? errp : &err_local; \ 389 /* Fast path, both positive */ \ 390 if (b > 0 && a > 0) { \ 391 /* Faster path: no overflow concerns */ \ 392 if (a < max - b) \ 393 return (a + b - 1) / b; \ 394 return a / b + (a % b != 0); \ 395 } \ 396 if (b == 0) { \ 397 *err |= 1; \ 398 return a == 0 ? 0 : max; \ 399 } \ 400 if (a == 0) \ 401 return 0; \ 402 /* Rather slow path because there are negatives involved */ \ 403 x = safe_mod_ ## type_name(a, b, err); \ 404 return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err), \ 405 x != 0, err); \ 406 } 407 408 /* Calculate ranges of types */ 409 # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1)) 410 # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type)) 411 # define OSSL_SAFE_MATH_MAXU(type) (~(type)0) 412 413 /* 414 * Wrapper macros to create all the functions of a given type 415 */ 416 # define OSSL_SAFE_MATH_SIGNED(type_name, type) \ 417 OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 418 OSSL_SAFE_MATH_MAXS(type)) \ 419 OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 420 OSSL_SAFE_MATH_MAXS(type)) \ 421 OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 422 OSSL_SAFE_MATH_MAXS(type)) \ 423 OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 424 OSSL_SAFE_MATH_MAXS(type)) \ 425 OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 426 OSSL_SAFE_MATH_MAXS(type)) \ 427 OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \ 428 OSSL_SAFE_MATH_MAXS(type)) \ 429 OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \ 430 OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \ 431 OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type)) 432 433 # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \ 434 OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ 435 OSSL_SAFE_MATH_SUBU(type_name, type) \ 436 OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ 437 OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ 438 OSSL_SAFE_MATH_MODU(type_name, type) \ 439 OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \ 440 OSSL_SAFE_MATH_MAXU(type)) \ 441 OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ 442 OSSL_SAFE_MATH_NEGU(type_name, type) \ 443 OSSL_SAFE_MATH_ABSU(type_name, type) 444 445 #endif /* OSSL_INTERNAL_SAFE_MATH_H */ 446