1 /* 2 * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved. 3 * 4 * Licensed under the Apache License 2.0 (the "License"). You may not use 5 * this file except in compliance with the License. You can obtain a copy 6 * in the file LICENSE in the source distribution or at 7 * https://www.openssl.org/source/license.html 8 */ 9 10 #ifndef OSSL_INTERNAL_SAFE_MATH_H 11 # define OSSL_INTERNAL_SAFE_MATH_H 12 # pragma once 13 14 # include <openssl/e_os2.h> /* For 'ossl_inline' */ 15 16 # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING 17 # ifdef __has_builtin 18 # define has(func) __has_builtin(func) 19 # elif __GNUC__ > 5 20 # define has(func) 1 21 # endif 22 # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */ 23 24 # ifndef has 25 # define has(func) 0 26 # endif 27 28 /* 29 * Safe addition helpers 30 */ 31 # if has(__builtin_add_overflow) 32 # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \ 33 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ 34 type b, \ 35 int *err) \ 36 { \ 37 type r; \ 38 \ 39 if (!__builtin_add_overflow(a, b, &r)) \ 40 return r; \ 41 *err |= 1; \ 42 return a < 0 ? min : max; \ 43 } 44 45 # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \ 46 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ 47 type b, \ 48 int *err) \ 49 { \ 50 type r; \ 51 \ 52 if (!__builtin_add_overflow(a, b, &r)) \ 53 return r; \ 54 *err |= 1; \ 55 return a + b; \ 56 } 57 58 # else /* has(__builtin_add_overflow) */ 59 # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \ 60 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ 61 type b, \ 62 int *err) \ 63 { \ 64 if ((a < 0) ^ (b < 0) \ 65 || (a > 0 && b <= max - a) \ 66 || (a < 0 && b >= min - a) \ 67 || a == 0) \ 68 return a + b; \ 69 *err |= 1; \ 70 return a < 0 ? min : max; \ 71 } 72 73 # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \ 74 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ 75 type b, \ 76 int *err) \ 77 { \ 78 if (b > max - a) \ 79 *err |= 1; \ 80 return a + b; \ 81 } 82 # endif /* has(__builtin_add_overflow) */ 83 84 /* 85 * Safe subtraction helpers 86 */ 87 # if has(__builtin_sub_overflow) 88 # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \ 89 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ 90 type b, \ 91 int *err) \ 92 { \ 93 type r; \ 94 \ 95 if (!__builtin_sub_overflow(a, b, &r)) \ 96 return r; \ 97 *err |= 1; \ 98 return a < 0 ? min : max; \ 99 } 100 101 # else /* has(__builtin_sub_overflow) */ 102 # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \ 103 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ 104 type b, \ 105 int *err) \ 106 { \ 107 if (!((a < 0) ^ (b < 0)) \ 108 || (b > 0 && a >= min + b) \ 109 || (b < 0 && a <= max + b) \ 110 || b == 0) \ 111 return a - b; \ 112 *err |= 1; \ 113 return a < 0 ? min : max; \ 114 } 115 116 # endif /* has(__builtin_sub_overflow) */ 117 118 # define OSSL_SAFE_MATH_SUBU(type_name, type) \ 119 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ 120 type b, \ 121 int *err) \ 122 { \ 123 if (b > a) \ 124 *err |= 1; \ 125 return a - b; \ 126 } 127 128 /* 129 * Safe multiplication helpers 130 */ 131 # if has(__builtin_mul_overflow) 132 # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \ 133 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ 134 type b, \ 135 int *err) \ 136 { \ 137 type r; \ 138 \ 139 if (!__builtin_mul_overflow(a, b, &r)) \ 140 return r; \ 141 *err |= 1; \ 142 return (a < 0) ^ (b < 0) ? min : max; \ 143 } 144 145 # define OSSL_SAFE_MATH_MULU(type_name, type, max) \ 146 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ 147 type b, \ 148 int *err) \ 149 { \ 150 type r; \ 151 \ 152 if (!__builtin_mul_overflow(a, b, &r)) \ 153 return r; \ 154 *err |= 1; \ 155 return a * b; \ 156 } 157 158 # else /* has(__builtin_mul_overflow) */ 159 # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \ 160 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ 161 type b, \ 162 int *err) \ 163 { \ 164 if (a == 0 || b == 0) \ 165 return 0; \ 166 if (a == 1) \ 167 return b; \ 168 if (b == 1) \ 169 return a; \ 170 if (a != min && b != min) { \ 171 const type x = a < 0 ? -a : a; \ 172 const type y = b < 0 ? -b : b; \ 173 \ 174 if (x <= max / y) \ 175 return a * b; \ 176 } \ 177 *err |= 1; \ 178 return (a < 0) ^ (b < 0) ? min : max; \ 179 } 180 181 # define OSSL_SAFE_MATH_MULU(type_name, type, max) \ 182 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ 183 type b, \ 184 int *err) \ 185 { \ 186 if (b != 0 && a > max / b) \ 187 *err |= 1; \ 188 return a * b; \ 189 } 190 # endif /* has(__builtin_mul_overflow) */ 191 192 /* 193 * Safe division helpers 194 */ 195 # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \ 196 static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \ 197 type b, \ 198 int *err) \ 199 { \ 200 if (b == 0) { \ 201 *err |= 1; \ 202 return a < 0 ? min : max; \ 203 } \ 204 if (b == -1 && a == min) { \ 205 *err |= 1; \ 206 return max; \ 207 } \ 208 return a / b; \ 209 } 210 211 # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \ 212 static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \ 213 type b, \ 214 int *err) \ 215 { \ 216 if (b != 0) \ 217 return a / b; \ 218 *err |= 1; \ 219 return max; \ 220 } 221 222 /* 223 * Safe modulus helpers 224 */ 225 # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \ 226 static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \ 227 type b, \ 228 int *err) \ 229 { \ 230 if (b == 0) { \ 231 *err |= 1; \ 232 return 0; \ 233 } \ 234 if (b == -1 && a == min) { \ 235 *err |= 1; \ 236 return max; \ 237 } \ 238 return a % b; \ 239 } 240 241 # define OSSL_SAFE_MATH_MODU(type_name, type) \ 242 static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \ 243 type b, \ 244 int *err) \ 245 { \ 246 if (b != 0) \ 247 return a % b; \ 248 *err |= 1; \ 249 return 0; \ 250 } 251 252 /* 253 * Safe negation helpers 254 */ 255 # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \ 256 static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \ 257 int *err) \ 258 { \ 259 if (a != min) \ 260 return -a; \ 261 *err |= 1; \ 262 return min; \ 263 } 264 265 # define OSSL_SAFE_MATH_NEGU(type_name, type) \ 266 static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \ 267 int *err) \ 268 { \ 269 if (a == 0) \ 270 return a; \ 271 *err |= 1; \ 272 return 1 + ~a; \ 273 } 274 275 /* 276 * Safe absolute value helpers 277 */ 278 # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \ 279 static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \ 280 int *err) \ 281 { \ 282 if (a != min) \ 283 return a < 0 ? -a : a; \ 284 *err |= 1; \ 285 return min; \ 286 } 287 288 # define OSSL_SAFE_MATH_ABSU(type_name, type) \ 289 static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \ 290 int *err) \ 291 { \ 292 return a; \ 293 } 294 295 /* 296 * Safe fused multiply divide helpers 297 * 298 * These are a bit obscure: 299 * . They begin by checking the denominator for zero and getting rid of this 300 * corner case. 301 * 302 * . Second is an attempt to do the multiplication directly, if it doesn't 303 * overflow, the quotient is returned (for signed values there is a 304 * potential problem here which isn't present for unsigned). 305 * 306 * . Finally, the multiplication/division is transformed so that the larger 307 * of the numerators is divided first. This requires a remainder 308 * correction: 309 * 310 * a b / c = (a / c) b + (a mod c) b / c, where a > b 311 * 312 * The individual operations need to be overflow checked (again signed 313 * being more problematic). 314 * 315 * The algorithm used is not perfect but it should be "good enough". 316 */ 317 # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \ 318 static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \ 319 type b, \ 320 type c, \ 321 int *err) \ 322 { \ 323 int e2 = 0; \ 324 type q, r, x, y; \ 325 \ 326 if (c == 0) { \ 327 *err |= 1; \ 328 return a == 0 || b == 0 ? 0 : max; \ 329 } \ 330 x = safe_mul_ ## type_name(a, b, &e2); \ 331 if (!e2) \ 332 return safe_div_ ## type_name(x, c, err); \ 333 if (b > a) { \ 334 x = b; \ 335 b = a; \ 336 a = x; \ 337 } \ 338 q = safe_div_ ## type_name(a, c, err); \ 339 r = safe_mod_ ## type_name(a, c, err); \ 340 x = safe_mul_ ## type_name(r, b, err); \ 341 y = safe_mul_ ## type_name(q, b, err); \ 342 q = safe_div_ ## type_name(x, c, err); \ 343 return safe_add_ ## type_name(y, q, err); \ 344 } 345 346 # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \ 347 static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \ 348 type b, \ 349 type c, \ 350 int *err) \ 351 { \ 352 int e2 = 0; \ 353 type x, y; \ 354 \ 355 if (c == 0) { \ 356 *err |= 1; \ 357 return a == 0 || b == 0 ? 0 : max; \ 358 } \ 359 x = safe_mul_ ## type_name(a, b, &e2); \ 360 if (!e2) \ 361 return x / c; \ 362 if (b > a) { \ 363 x = b; \ 364 b = a; \ 365 a = x; \ 366 } \ 367 x = safe_mul_ ## type_name(a % c, b, err); \ 368 y = safe_mul_ ## type_name(a / c, b, err); \ 369 return safe_add_ ## type_name(y, x / c, err); \ 370 } 371 372 /* 373 * Calculate a / b rounding up: 374 * i.e. a / b + (a % b != 0) 375 * Which is usually (less safely) converted to (a + b - 1) / b 376 * If you *know* that b != 0, then it's safe to ignore err. 377 */ 378 #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \ 379 static ossl_inline ossl_unused type safe_div_round_up_ ## type_name \ 380 (type a, type b, int *errp) \ 381 { \ 382 type x; \ 383 int *err, err_local = 0; \ 384 \ 385 /* Allow errors to be ignored by callers */ \ 386 err = errp != NULL ? errp : &err_local; \ 387 /* Fast path, both positive */ \ 388 if (b > 0 && a > 0) { \ 389 /* Faster path: no overflow concerns */ \ 390 if (a < max - b) \ 391 return (a + b - 1) / b; \ 392 return a / b + (a % b != 0); \ 393 } \ 394 if (b == 0) { \ 395 *err |= 1; \ 396 return a == 0 ? 0 : max; \ 397 } \ 398 if (a == 0) \ 399 return 0; \ 400 /* Rather slow path because there are negatives involved */ \ 401 x = safe_mod_ ## type_name(a, b, err); \ 402 return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err), \ 403 x != 0, err); \ 404 } 405 406 /* Calculate ranges of types */ 407 # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1)) 408 # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type)) 409 # define OSSL_SAFE_MATH_MAXU(type) (~(type)0) 410 411 /* 412 * Wrapper macros to create all the functions of a given type 413 */ 414 # define OSSL_SAFE_MATH_SIGNED(type_name, type) \ 415 OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 416 OSSL_SAFE_MATH_MAXS(type)) \ 417 OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 418 OSSL_SAFE_MATH_MAXS(type)) \ 419 OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 420 OSSL_SAFE_MATH_MAXS(type)) \ 421 OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 422 OSSL_SAFE_MATH_MAXS(type)) \ 423 OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ 424 OSSL_SAFE_MATH_MAXS(type)) \ 425 OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \ 426 OSSL_SAFE_MATH_MAXS(type)) \ 427 OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \ 428 OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \ 429 OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type)) 430 431 # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \ 432 OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ 433 OSSL_SAFE_MATH_SUBU(type_name, type) \ 434 OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ 435 OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ 436 OSSL_SAFE_MATH_MODU(type_name, type) \ 437 OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \ 438 OSSL_SAFE_MATH_MAXU(type)) \ 439 OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ 440 OSSL_SAFE_MATH_NEGU(type_name, type) \ 441 OSSL_SAFE_MATH_ABSU(type_name, type) 442 443 #endif /* OSSL_INTERNAL_SAFE_MATH_H */ 444