xref: /openssl/include/internal/safe_math.h (revision 3189e127)
1 /*
2  * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #ifndef OSSL_INTERNAL_SAFE_MATH_H
11 # define OSSL_INTERNAL_SAFE_MATH_H
12 # pragma once
13 
14 # include <openssl/e_os2.h>              /* For 'ossl_inline' */
15 
16 # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
17 #  ifdef __has_builtin
18 #   define has(func) __has_builtin(func)
19 #  elif __GNUC__ > 5
20 #   define has(func) 1
21 #  endif
22 # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */
23 
24 # ifndef has
25 #  define has(func) 0
26 # endif
27 
28 /*
29  * Safe addition helpers
30  */
31 # if has(__builtin_add_overflow)
32 #  define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
33     static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \
34                                                                type b,       \
35                                                                int *err)     \
36     {                                                                        \
37         type r;                                                              \
38                                                                              \
39         if (!__builtin_add_overflow(a, b, &r))                               \
40             return r;                                                        \
41         *err |= 1;                                                           \
42         return a < 0 ? min : max;                                            \
43     }
44 
45 #  define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
46     static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \
47                                                                type b,       \
48                                                                int *err)     \
49     {                                                                        \
50         type r;                                                              \
51                                                                              \
52         if (!__builtin_add_overflow(a, b, &r))                               \
53             return r;                                                        \
54         *err |= 1;                                                           \
55         return a + b;                                                            \
56     }
57 
58 # else  /* has(__builtin_add_overflow) */
59 #  define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
60     static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \
61                                                                type b,       \
62                                                                int *err)     \
63     {                                                                        \
64         if ((a < 0) ^ (b < 0)                                                \
65                 || (a > 0 && b <= max - a)                                   \
66                 || (a < 0 && b >= min - a)                                   \
67                 || a == 0)                                                   \
68             return a + b;                                                    \
69         *err |= 1;                                                           \
70         return a < 0 ? min : max;                                            \
71     }
72 
73 #  define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
74     static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \
75                                                                type b,       \
76                                                                int *err)     \
77     {                                                                        \
78         if (b > max - a)                                                     \
79             *err |= 1;                                                       \
80         return a + b;                                                        \
81     }
82 # endif /* has(__builtin_add_overflow) */
83 
84 /*
85  * Safe subtraction helpers
86  */
87 # if has(__builtin_sub_overflow)
88 #  define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
89     static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \
90                                                                type b,       \
91                                                                int *err)     \
92     {                                                                        \
93         type r;                                                              \
94                                                                              \
95         if (!__builtin_sub_overflow(a, b, &r))                               \
96             return r;                                                        \
97         *err |= 1;                                                           \
98         return a < 0 ? min : max;                                            \
99     }
100 
101 # else  /* has(__builtin_sub_overflow) */
102 #  define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
103     static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \
104                                                                type b,       \
105                                                                int *err)     \
106     {                                                                        \
107         if (!((a < 0) ^ (b < 0))                                             \
108                 || (b > 0 && a >= min + b)                                   \
109                 || (b < 0 && a <= max + b)                                   \
110                 || b == 0)                                                   \
111             return a - b;                                                    \
112         *err |= 1;                                                           \
113         return a < 0 ? min : max;                                            \
114     }
115 
116 # endif /* has(__builtin_sub_overflow) */
117 
118 # define OSSL_SAFE_MATH_SUBU(type_name, type) \
119     static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \
120                                                                type b,       \
121                                                                int *err)     \
122     {                                                                        \
123         if (b > a)                                                           \
124             *err |= 1;                                                       \
125         return a - b;                                                        \
126     }
127 
128 /*
129  * Safe multiplication helpers
130  */
131 # if has(__builtin_mul_overflow)
132 #  define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
133     static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \
134                                                                type b,       \
135                                                                int *err)     \
136     {                                                                        \
137         type r;                                                              \
138                                                                              \
139         if (!__builtin_mul_overflow(a, b, &r))                               \
140             return r;                                                        \
141         *err |= 1;                                                           \
142         return (a < 0) ^ (b < 0) ? min : max;                                \
143     }
144 
145 #  define OSSL_SAFE_MATH_MULU(type_name, type, max) \
146     static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \
147                                                                type b,       \
148                                                                int *err)     \
149     {                                                                        \
150         type r;                                                              \
151                                                                              \
152         if (!__builtin_mul_overflow(a, b, &r))                               \
153             return r;                                                        \
154         *err |= 1;                                                           \
155         return a * b;                                                          \
156     }
157 
158 # else  /* has(__builtin_mul_overflow) */
159 #  define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
160     static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \
161                                                                type b,       \
162                                                                int *err)     \
163     {                                                                        \
164         if (a == 0 || b == 0)                                                \
165             return 0;                                                        \
166         if (a == 1)                                                          \
167             return b;                                                        \
168         if (b == 1)                                                          \
169             return a;                                                        \
170         if (a != min && b != min) {                                          \
171             const type x = a < 0 ? -a : a;                                   \
172             const type y = b < 0 ? -b : b;                                   \
173                                                                              \
174             if (x <= max / y)                                                \
175                 return a * b;                                                \
176         }                                                                    \
177         *err |= 1;                                                           \
178         return (a < 0) ^ (b < 0) ? min : max;                                \
179     }
180 
181 #  define OSSL_SAFE_MATH_MULU(type_name, type, max) \
182     static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \
183                                                                type b,       \
184                                                                int *err)     \
185     {                                                                        \
186         if (b != 0 && a > max / b)                                           \
187             *err |= 1;                                                       \
188         return a * b;                                                        \
189     }
190 # endif /* has(__builtin_mul_overflow) */
191 
192 /*
193  * Safe division helpers
194  */
195 # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \
196     static ossl_inline ossl_unused type safe_div_ ## type_name(type a,       \
197                                                                type b,       \
198                                                                int *err)     \
199     {                                                                        \
200         if (b == 0) {                                                        \
201             *err |= 1;                                                       \
202             return a < 0 ? min : max;                                        \
203         }                                                                    \
204         if (b == -1 && a == min) {                                           \
205             *err |= 1;                                                       \
206             return max;                                                      \
207         }                                                                    \
208         return a / b;                                                        \
209     }
210 
211 # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \
212     static ossl_inline ossl_unused type safe_div_ ## type_name(type a,       \
213                                                                type b,       \
214                                                                int *err)     \
215     {                                                                        \
216         if (b != 0)                                                          \
217             return a / b;                                                    \
218         *err |= 1;                                                           \
219         return max;                                                        \
220     }
221 
222 /*
223  * Safe modulus helpers
224  */
225 # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \
226     static ossl_inline ossl_unused type safe_mod_ ## type_name(type a,       \
227                                                                type b,       \
228                                                                int *err)     \
229     {                                                                        \
230         if (b == 0) {                                                        \
231             *err |= 1;                                                       \
232             return 0;                                                        \
233         }                                                                    \
234         if (b == -1 && a == min) {                                           \
235             *err |= 1;                                                       \
236             return max;                                                      \
237         }                                                                    \
238         return a % b;                                                        \
239     }
240 
241 # define OSSL_SAFE_MATH_MODU(type_name, type) \
242     static ossl_inline ossl_unused type safe_mod_ ## type_name(type a,       \
243                                                                type b,       \
244                                                                int *err)     \
245     {                                                                        \
246         if (b != 0)                                                          \
247             return a % b;                                                    \
248         *err |= 1;                                                           \
249         return 0;                                                            \
250     }
251 
252 /*
253  * Safe negation helpers
254  */
255 # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \
256     static ossl_inline ossl_unused type safe_neg_ ## type_name(type a,       \
257                                                                int *err)     \
258     {                                                                        \
259         if (a != min)                                                        \
260             return -a;                                                       \
261         *err |= 1;                                                           \
262         return min;                                                          \
263     }
264 
265 # define OSSL_SAFE_MATH_NEGU(type_name, type) \
266     static ossl_inline ossl_unused type safe_neg_ ## type_name(type a,       \
267                                                                int *err)     \
268     {                                                                        \
269         if (a == 0)                                                          \
270             return a;                                                        \
271         *err |= 1;                                                           \
272         return 1 + ~a;                                                       \
273     }
274 
275 /*
276  * Safe absolute value helpers
277  */
278 # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \
279     static ossl_inline ossl_unused type safe_abs_ ## type_name(type a,       \
280                                                                int *err)     \
281     {                                                                        \
282         if (a != min)                                                        \
283             return a < 0 ? -a : a;                                           \
284         *err |= 1;                                                           \
285         return min;                                                          \
286     }
287 
288 # define OSSL_SAFE_MATH_ABSU(type_name, type) \
289     static ossl_inline ossl_unused type safe_abs_ ## type_name(type a,       \
290                                                                int *err)     \
291     {                                                                        \
292         return a;                                                            \
293     }
294 
295 /*
296  * Safe fused multiply divide helpers
297  *
298  * These are a bit obscure:
299  *    . They begin by checking the denominator for zero and getting rid of this
300  *      corner case.
301  *
302  *    . Second is an attempt to do the multiplication directly, if it doesn't
303  *      overflow, the quotient is returned (for signed values there is a
304  *      potential problem here which isn't present for unsigned).
305  *
306  *    . Finally, the multiplication/division is transformed so that the larger
307  *      of the numerators is divided first.  This requires a remainder
308  *      correction:
309  *
310  *          a b / c = (a / c) b + (a mod c) b / c, where a > b
311  *
312  *      The individual operations need to be overflow checked (again signed
313  *      being more problematic).
314  *
315  * The algorithm used is not perfect but it should be "good enough".
316  */
317 # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \
318     static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a,    \
319                                                                   type b,    \
320                                                                   type c,    \
321                                                                   int *err)  \
322     {                                                                        \
323         int e2 = 0;                                                          \
324         type q, r, x, y;                                                     \
325                                                                              \
326         if (c == 0) {                                                        \
327             *err |= 1;                                                       \
328             return a == 0 || b == 0 ? 0 : max;                               \
329         }                                                                    \
330         x = safe_mul_ ## type_name(a, b, &e2);                               \
331         if (!e2)                                                             \
332             return safe_div_ ## type_name(x, c, err);                        \
333         if (b > a) {                                                         \
334             x = b;                                                           \
335             b = a;                                                           \
336             a = x;                                                           \
337         }                                                                    \
338         q = safe_div_ ## type_name(a, c, err);                               \
339         r = safe_mod_ ## type_name(a, c, err);                               \
340         x = safe_mul_ ## type_name(r, b, err);                               \
341         y = safe_mul_ ## type_name(q, b, err);                               \
342         q = safe_div_ ## type_name(x, c, err);                               \
343         return safe_add_ ## type_name(y, q, err);                            \
344     }
345 
346 # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \
347     static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a,    \
348                                                                   type b,    \
349                                                                   type c,    \
350                                                                   int *err)  \
351     {                                                                        \
352         int e2 = 0;                                                          \
353         type x, y;                                                           \
354                                                                              \
355         if (c == 0) {                                                        \
356             *err |= 1;                                                       \
357             return a == 0 || b == 0 ? 0 : max;                               \
358         }                                                                    \
359         x = safe_mul_ ## type_name(a, b, &e2);                               \
360         if (!e2)                                                             \
361             return x / c;                                                    \
362         if (b > a) {                                                         \
363             x = b;                                                           \
364             b = a;                                                           \
365             a = x;                                                           \
366         }                                                                    \
367         x = safe_mul_ ## type_name(a % c, b, err);                           \
368         y = safe_mul_ ## type_name(a / c, b, err);                           \
369         return safe_add_ ## type_name(y, x / c, err);                        \
370     }
371 
372 /*
373  * Calculate a / b rounding up:
374  *     i.e. a / b + (a % b != 0)
375  * Which is usually (less safely) converted to (a + b - 1) / b
376  * If you *know* that b != 0, then it's safe to ignore err.
377  */
378 #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \
379     static ossl_inline ossl_unused type safe_div_round_up_ ## type_name      \
380         (type a, type b, int *errp)                                          \
381     {                                                                        \
382         type x;                                                              \
383         int *err, err_local = 0;                                             \
384                                                                              \
385         /* Allow errors to be ignored by callers */                          \
386         err = errp != NULL ? errp : &err_local;                              \
387         /* Fast path, both positive */                                       \
388         if (b > 0 && a > 0) {                                                \
389             /* Faster path: no overflow concerns */                          \
390             if (a < max - b)                                                 \
391                 return (a + b - 1) / b;                                      \
392             return a / b + (a % b != 0);                                     \
393         }                                                                    \
394         if (b == 0) {                                                        \
395             *err |= 1;                                                       \
396             return a == 0 ? 0 : max;                                         \
397         }                                                                    \
398         if (a == 0)                                                          \
399             return 0;                                                        \
400         /* Rather slow path because there are negatives involved */          \
401         x = safe_mod_ ## type_name(a, b, err);                               \
402         return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err),     \
403                                       x != 0, err);                          \
404     }
405 
406 /* Calculate ranges of types */
407 # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
408 # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
409 # define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
410 
411 /*
412  * Wrapper macros to create all the functions of a given type
413  */
414 # define OSSL_SAFE_MATH_SIGNED(type_name, type)                         \
415     OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
416                         OSSL_SAFE_MATH_MAXS(type))                      \
417     OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
418                         OSSL_SAFE_MATH_MAXS(type))                      \
419     OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
420                         OSSL_SAFE_MATH_MAXS(type))                      \
421     OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
422                         OSSL_SAFE_MATH_MAXS(type))                      \
423     OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
424                         OSSL_SAFE_MATH_MAXS(type))                      \
425     OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                        \
426                                 OSSL_SAFE_MATH_MAXS(type))              \
427     OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type))  \
428     OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type))     \
429     OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
430 
431 # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \
432     OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \
433     OSSL_SAFE_MATH_SUBU(type_name, type)                                \
434     OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \
435     OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \
436     OSSL_SAFE_MATH_MODU(type_name, type)                                \
437     OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                        \
438                                 OSSL_SAFE_MATH_MAXU(type))              \
439     OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))  \
440     OSSL_SAFE_MATH_NEGU(type_name, type)                                \
441     OSSL_SAFE_MATH_ABSU(type_name, type)
442 
443 #endif                          /* OSSL_INTERNAL_SAFE_MATH_H */
444