/* * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved. * * Licensed under the Apache License 2.0 (the "License"). You may not use * this file except in compliance with the License. You can obtain a copy * in the file LICENSE in the source distribution or at * https://www.openssl.org/source/license.html */ #ifndef OSSL_INTERNAL_SAFE_MATH_H # define OSSL_INTERNAL_SAFE_MATH_H # pragma once # include /* For 'ossl_inline' */ # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING # ifdef __has_builtin # define has(func) __has_builtin(func) # elif __GNUC__ > 5 # define has(func) 1 # endif # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */ # ifndef has # define has(func) 0 # endif /* * Safe addition helpers */ # if has(__builtin_add_overflow) # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \ static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ type b, \ int *err) \ { \ type r; \ \ if (!__builtin_add_overflow(a, b, &r)) \ return r; \ *err |= 1; \ return a < 0 ? min : max; \ } # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \ static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ type b, \ int *err) \ { \ type r; \ \ if (!__builtin_add_overflow(a, b, &r)) \ return r; \ *err |= 1; \ return a + b; \ } # else /* has(__builtin_add_overflow) */ # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \ static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ type b, \ int *err) \ { \ if ((a < 0) ^ (b < 0) \ || (a > 0 && b <= max - a) \ || (a < 0 && b >= min - a) \ || a == 0) \ return a + b; \ *err |= 1; \ return a < 0 ? min : max; \ } # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \ static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \ type b, \ int *err) \ { \ if (b > max - a) \ *err |= 1; \ return a + b; \ } # endif /* has(__builtin_add_overflow) */ /* * Safe subtraction helpers */ # if has(__builtin_sub_overflow) # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \ static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ type b, \ int *err) \ { \ type r; \ \ if (!__builtin_sub_overflow(a, b, &r)) \ return r; \ *err |= 1; \ return a < 0 ? min : max; \ } # else /* has(__builtin_sub_overflow) */ # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \ static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ type b, \ int *err) \ { \ if (!((a < 0) ^ (b < 0)) \ || (b > 0 && a >= min + b) \ || (b < 0 && a <= max + b) \ || b == 0) \ return a - b; \ *err |= 1; \ return a < 0 ? min : max; \ } # endif /* has(__builtin_sub_overflow) */ # define OSSL_SAFE_MATH_SUBU(type_name, type) \ static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \ type b, \ int *err) \ { \ if (b > a) \ *err |= 1; \ return a - b; \ } /* * Safe multiplication helpers */ # if has(__builtin_mul_overflow) # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \ static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ type b, \ int *err) \ { \ type r; \ \ if (!__builtin_mul_overflow(a, b, &r)) \ return r; \ *err |= 1; \ return (a < 0) ^ (b < 0) ? min : max; \ } # define OSSL_SAFE_MATH_MULU(type_name, type, max) \ static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ type b, \ int *err) \ { \ type r; \ \ if (!__builtin_mul_overflow(a, b, &r)) \ return r; \ *err |= 1; \ return a * b; \ } # else /* has(__builtin_mul_overflow) */ # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \ static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ type b, \ int *err) \ { \ if (a == 0 || b == 0) \ return 0; \ if (a == 1) \ return b; \ if (b == 1) \ return a; \ if (a != min && b != min) { \ const type x = a < 0 ? -a : a; \ const type y = b < 0 ? -b : b; \ \ if (x <= max / y) \ return a * b; \ } \ *err |= 1; \ return (a < 0) ^ (b < 0) ? min : max; \ } # define OSSL_SAFE_MATH_MULU(type_name, type, max) \ static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \ type b, \ int *err) \ { \ if (b != 0 && a > max / b) \ *err |= 1; \ return a * b; \ } # endif /* has(__builtin_mul_overflow) */ /* * Safe division helpers */ # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \ static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \ type b, \ int *err) \ { \ if (b == 0) { \ *err |= 1; \ return a < 0 ? min : max; \ } \ if (b == -1 && a == min) { \ *err |= 1; \ return max; \ } \ return a / b; \ } # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \ static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \ type b, \ int *err) \ { \ if (b != 0) \ return a / b; \ *err |= 1; \ return max; \ } /* * Safe modulus helpers */ # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \ static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \ type b, \ int *err) \ { \ if (b == 0) { \ *err |= 1; \ return 0; \ } \ if (b == -1 && a == min) { \ *err |= 1; \ return max; \ } \ return a % b; \ } # define OSSL_SAFE_MATH_MODU(type_name, type) \ static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \ type b, \ int *err) \ { \ if (b != 0) \ return a % b; \ *err |= 1; \ return 0; \ } /* * Safe negation helpers */ # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \ static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \ int *err) \ { \ if (a != min) \ return -a; \ *err |= 1; \ return min; \ } # define OSSL_SAFE_MATH_NEGU(type_name, type) \ static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \ int *err) \ { \ if (a == 0) \ return a; \ *err |= 1; \ return 1 + ~a; \ } /* * Safe absolute value helpers */ # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \ static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \ int *err) \ { \ if (a != min) \ return a < 0 ? -a : a; \ *err |= 1; \ return min; \ } # define OSSL_SAFE_MATH_ABSU(type_name, type) \ static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \ int *err) \ { \ return a; \ } /* * Safe fused multiply divide helpers * * These are a bit obscure: * . They begin by checking the denominator for zero and getting rid of this * corner case. * * . Second is an attempt to do the multiplication directly, if it doesn't * overflow, the quotient is returned (for signed values there is a * potential problem here which isn't present for unsigned). * * . Finally, the multiplication/division is transformed so that the larger * of the numerators is divided first. This requires a remainder * correction: * * a b / c = (a / c) b + (a mod c) b / c, where a > b * * The individual operations need to be overflow checked (again signed * being more problematic). * * The algorithm used is not perfect but it should be "good enough". */ # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \ static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \ type b, \ type c, \ int *err) \ { \ int e2 = 0; \ type q, r, x, y; \ \ if (c == 0) { \ *err |= 1; \ return a == 0 || b == 0 ? 0 : max; \ } \ x = safe_mul_ ## type_name(a, b, &e2); \ if (!e2) \ return safe_div_ ## type_name(x, c, err); \ if (b > a) { \ x = b; \ b = a; \ a = x; \ } \ q = safe_div_ ## type_name(a, c, err); \ r = safe_mod_ ## type_name(a, c, err); \ x = safe_mul_ ## type_name(r, b, err); \ y = safe_mul_ ## type_name(q, b, err); \ q = safe_div_ ## type_name(x, c, err); \ return safe_add_ ## type_name(y, q, err); \ } # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \ static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \ type b, \ type c, \ int *err) \ { \ int e2 = 0; \ type x, y; \ \ if (c == 0) { \ *err |= 1; \ return a == 0 || b == 0 ? 0 : max; \ } \ x = safe_mul_ ## type_name(a, b, &e2); \ if (!e2) \ return x / c; \ if (b > a) { \ x = b; \ b = a; \ a = x; \ } \ x = safe_mul_ ## type_name(a % c, b, err); \ y = safe_mul_ ## type_name(a / c, b, err); \ return safe_add_ ## type_name(y, x / c, err); \ } /* * Calculate a / b rounding up: * i.e. a / b + (a % b != 0) * Which is usually (less safely) converted to (a + b - 1) / b * If you *know* that b != 0, then it's safe to ignore err. */ #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \ static ossl_inline ossl_unused type safe_div_round_up_ ## type_name \ (type a, type b, int *errp) \ { \ type x; \ int *err, err_local = 0; \ \ /* Allow errors to be ignored by callers */ \ err = errp != NULL ? errp : &err_local; \ /* Fast path, both positive */ \ if (b > 0 && a > 0) { \ /* Faster path: no overflow concerns */ \ if (a < max - b) \ return (a + b - 1) / b; \ return a / b + (a % b != 0); \ } \ if (b == 0) { \ *err |= 1; \ return a == 0 ? 0 : max; \ } \ if (a == 0) \ return 0; \ /* Rather slow path because there are negatives involved */ \ x = safe_mod_ ## type_name(a, b, err); \ return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err), \ x != 0, err); \ } /* Calculate ranges of types */ # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1)) # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type)) # define OSSL_SAFE_MATH_MAXU(type) (~(type)0) /* * Wrapper macros to create all the functions of a given type */ # define OSSL_SAFE_MATH_SIGNED(type_name, type) \ OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ OSSL_SAFE_MATH_MAXS(type)) \ OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ OSSL_SAFE_MATH_MAXS(type)) \ OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ OSSL_SAFE_MATH_MAXS(type)) \ OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ OSSL_SAFE_MATH_MAXS(type)) \ OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \ OSSL_SAFE_MATH_MAXS(type)) \ OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \ OSSL_SAFE_MATH_MAXS(type)) \ OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \ OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \ OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type)) # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \ OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ OSSL_SAFE_MATH_SUBU(type_name, type) \ OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ OSSL_SAFE_MATH_MODU(type_name, type) \ OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \ OSSL_SAFE_MATH_MAXU(type)) \ OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \ OSSL_SAFE_MATH_NEGU(type_name, type) \ OSSL_SAFE_MATH_ABSU(type_name, type) #endif /* OSSL_INTERNAL_SAFE_MATH_H */