xref: /openssl/include/internal/safe_math.h (revision 53b34561)
1 /*
2  * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #ifndef OSSL_INTERNAL_SAFE_MATH_H
11 # define OSSL_INTERNAL_SAFE_MATH_H
12 # pragma once
13 
14 # include <openssl/e_os2.h>              /* For 'ossl_inline' */
15 
16 # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
17 #  ifdef __has_builtin
18 #   define has(func) __has_builtin(func)
19 #  elif defined(__GNUC__)
20 #   if __GNUC__ > 5
21 #    define has(func) 1
22 #   endif
23 #  endif
24 # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */
25 
26 # ifndef has
27 #  define has(func) 0
28 # endif
29 
30 /*
31  * Safe addition helpers
32  */
33 # if has(__builtin_add_overflow)
34 #  define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
35     static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \
36                                                                type b,       \
37                                                                int *err)     \
38     {                                                                        \
39         type r;                                                              \
40                                                                              \
41         if (!__builtin_add_overflow(a, b, &r))                               \
42             return r;                                                        \
43         *err |= 1;                                                           \
44         return a < 0 ? min : max;                                            \
45     }
46 
47 #  define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
48     static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \
49                                                                type b,       \
50                                                                int *err)     \
51     {                                                                        \
52         type r;                                                              \
53                                                                              \
54         if (!__builtin_add_overflow(a, b, &r))                               \
55             return r;                                                        \
56         *err |= 1;                                                           \
57         return a + b;                                                            \
58     }
59 
60 # else  /* has(__builtin_add_overflow) */
61 #  define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
62     static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \
63                                                                type b,       \
64                                                                int *err)     \
65     {                                                                        \
66         if ((a < 0) ^ (b < 0)                                                \
67                 || (a > 0 && b <= max - a)                                   \
68                 || (a < 0 && b >= min - a)                                   \
69                 || a == 0)                                                   \
70             return a + b;                                                    \
71         *err |= 1;                                                           \
72         return a < 0 ? min : max;                                            \
73     }
74 
75 #  define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
76     static ossl_inline ossl_unused type safe_add_ ## type_name(type a,       \
77                                                                type b,       \
78                                                                int *err)     \
79     {                                                                        \
80         if (b > max - a)                                                     \
81             *err |= 1;                                                       \
82         return a + b;                                                        \
83     }
84 # endif /* has(__builtin_add_overflow) */
85 
86 /*
87  * Safe subtraction helpers
88  */
89 # if has(__builtin_sub_overflow)
90 #  define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
91     static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \
92                                                                type b,       \
93                                                                int *err)     \
94     {                                                                        \
95         type r;                                                              \
96                                                                              \
97         if (!__builtin_sub_overflow(a, b, &r))                               \
98             return r;                                                        \
99         *err |= 1;                                                           \
100         return a < 0 ? min : max;                                            \
101     }
102 
103 # else  /* has(__builtin_sub_overflow) */
104 #  define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
105     static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \
106                                                                type b,       \
107                                                                int *err)     \
108     {                                                                        \
109         if (!((a < 0) ^ (b < 0))                                             \
110                 || (b > 0 && a >= min + b)                                   \
111                 || (b < 0 && a <= max + b)                                   \
112                 || b == 0)                                                   \
113             return a - b;                                                    \
114         *err |= 1;                                                           \
115         return a < 0 ? min : max;                                            \
116     }
117 
118 # endif /* has(__builtin_sub_overflow) */
119 
120 # define OSSL_SAFE_MATH_SUBU(type_name, type) \
121     static ossl_inline ossl_unused type safe_sub_ ## type_name(type a,       \
122                                                                type b,       \
123                                                                int *err)     \
124     {                                                                        \
125         if (b > a)                                                           \
126             *err |= 1;                                                       \
127         return a - b;                                                        \
128     }
129 
130 /*
131  * Safe multiplication helpers
132  */
133 # if has(__builtin_mul_overflow)
134 #  define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
135     static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \
136                                                                type b,       \
137                                                                int *err)     \
138     {                                                                        \
139         type r;                                                              \
140                                                                              \
141         if (!__builtin_mul_overflow(a, b, &r))                               \
142             return r;                                                        \
143         *err |= 1;                                                           \
144         return (a < 0) ^ (b < 0) ? min : max;                                \
145     }
146 
147 #  define OSSL_SAFE_MATH_MULU(type_name, type, max) \
148     static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \
149                                                                type b,       \
150                                                                int *err)     \
151     {                                                                        \
152         type r;                                                              \
153                                                                              \
154         if (!__builtin_mul_overflow(a, b, &r))                               \
155             return r;                                                        \
156         *err |= 1;                                                           \
157         return a * b;                                                          \
158     }
159 
160 # else  /* has(__builtin_mul_overflow) */
161 #  define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
162     static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \
163                                                                type b,       \
164                                                                int *err)     \
165     {                                                                        \
166         if (a == 0 || b == 0)                                                \
167             return 0;                                                        \
168         if (a == 1)                                                          \
169             return b;                                                        \
170         if (b == 1)                                                          \
171             return a;                                                        \
172         if (a != min && b != min) {                                          \
173             const type x = a < 0 ? -a : a;                                   \
174             const type y = b < 0 ? -b : b;                                   \
175                                                                              \
176             if (x <= max / y)                                                \
177                 return a * b;                                                \
178         }                                                                    \
179         *err |= 1;                                                           \
180         return (a < 0) ^ (b < 0) ? min : max;                                \
181     }
182 
183 #  define OSSL_SAFE_MATH_MULU(type_name, type, max) \
184     static ossl_inline ossl_unused type safe_mul_ ## type_name(type a,       \
185                                                                type b,       \
186                                                                int *err)     \
187     {                                                                        \
188         if (b != 0 && a > max / b)                                           \
189             *err |= 1;                                                       \
190         return a * b;                                                        \
191     }
192 # endif /* has(__builtin_mul_overflow) */
193 
194 /*
195  * Safe division helpers
196  */
197 # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \
198     static ossl_inline ossl_unused type safe_div_ ## type_name(type a,       \
199                                                                type b,       \
200                                                                int *err)     \
201     {                                                                        \
202         if (b == 0) {                                                        \
203             *err |= 1;                                                       \
204             return a < 0 ? min : max;                                        \
205         }                                                                    \
206         if (b == -1 && a == min) {                                           \
207             *err |= 1;                                                       \
208             return max;                                                      \
209         }                                                                    \
210         return a / b;                                                        \
211     }
212 
213 # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \
214     static ossl_inline ossl_unused type safe_div_ ## type_name(type a,       \
215                                                                type b,       \
216                                                                int *err)     \
217     {                                                                        \
218         if (b != 0)                                                          \
219             return a / b;                                                    \
220         *err |= 1;                                                           \
221         return max;                                                        \
222     }
223 
224 /*
225  * Safe modulus helpers
226  */
227 # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \
228     static ossl_inline ossl_unused type safe_mod_ ## type_name(type a,       \
229                                                                type b,       \
230                                                                int *err)     \
231     {                                                                        \
232         if (b == 0) {                                                        \
233             *err |= 1;                                                       \
234             return 0;                                                        \
235         }                                                                    \
236         if (b == -1 && a == min) {                                           \
237             *err |= 1;                                                       \
238             return max;                                                      \
239         }                                                                    \
240         return a % b;                                                        \
241     }
242 
243 # define OSSL_SAFE_MATH_MODU(type_name, type) \
244     static ossl_inline ossl_unused type safe_mod_ ## type_name(type a,       \
245                                                                type b,       \
246                                                                int *err)     \
247     {                                                                        \
248         if (b != 0)                                                          \
249             return a % b;                                                    \
250         *err |= 1;                                                           \
251         return 0;                                                            \
252     }
253 
254 /*
255  * Safe negation helpers
256  */
257 # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \
258     static ossl_inline ossl_unused type safe_neg_ ## type_name(type a,       \
259                                                                int *err)     \
260     {                                                                        \
261         if (a != min)                                                        \
262             return -a;                                                       \
263         *err |= 1;                                                           \
264         return min;                                                          \
265     }
266 
267 # define OSSL_SAFE_MATH_NEGU(type_name, type) \
268     static ossl_inline ossl_unused type safe_neg_ ## type_name(type a,       \
269                                                                int *err)     \
270     {                                                                        \
271         if (a == 0)                                                          \
272             return a;                                                        \
273         *err |= 1;                                                           \
274         return 1 + ~a;                                                       \
275     }
276 
277 /*
278  * Safe absolute value helpers
279  */
280 # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \
281     static ossl_inline ossl_unused type safe_abs_ ## type_name(type a,       \
282                                                                int *err)     \
283     {                                                                        \
284         if (a != min)                                                        \
285             return a < 0 ? -a : a;                                           \
286         *err |= 1;                                                           \
287         return min;                                                          \
288     }
289 
290 # define OSSL_SAFE_MATH_ABSU(type_name, type) \
291     static ossl_inline ossl_unused type safe_abs_ ## type_name(type a,       \
292                                                                int *err)     \
293     {                                                                        \
294         return a;                                                            \
295     }
296 
297 /*
298  * Safe fused multiply divide helpers
299  *
300  * These are a bit obscure:
301  *    . They begin by checking the denominator for zero and getting rid of this
302  *      corner case.
303  *
304  *    . Second is an attempt to do the multiplication directly, if it doesn't
305  *      overflow, the quotient is returned (for signed values there is a
306  *      potential problem here which isn't present for unsigned).
307  *
308  *    . Finally, the multiplication/division is transformed so that the larger
309  *      of the numerators is divided first.  This requires a remainder
310  *      correction:
311  *
312  *          a b / c = (a / c) b + (a mod c) b / c, where a > b
313  *
314  *      The individual operations need to be overflow checked (again signed
315  *      being more problematic).
316  *
317  * The algorithm used is not perfect but it should be "good enough".
318  */
319 # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \
320     static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a,    \
321                                                                   type b,    \
322                                                                   type c,    \
323                                                                   int *err)  \
324     {                                                                        \
325         int e2 = 0;                                                          \
326         type q, r, x, y;                                                     \
327                                                                              \
328         if (c == 0) {                                                        \
329             *err |= 1;                                                       \
330             return a == 0 || b == 0 ? 0 : max;                               \
331         }                                                                    \
332         x = safe_mul_ ## type_name(a, b, &e2);                               \
333         if (!e2)                                                             \
334             return safe_div_ ## type_name(x, c, err);                        \
335         if (b > a) {                                                         \
336             x = b;                                                           \
337             b = a;                                                           \
338             a = x;                                                           \
339         }                                                                    \
340         q = safe_div_ ## type_name(a, c, err);                               \
341         r = safe_mod_ ## type_name(a, c, err);                               \
342         x = safe_mul_ ## type_name(r, b, err);                               \
343         y = safe_mul_ ## type_name(q, b, err);                               \
344         q = safe_div_ ## type_name(x, c, err);                               \
345         return safe_add_ ## type_name(y, q, err);                            \
346     }
347 
348 # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \
349     static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a,    \
350                                                                   type b,    \
351                                                                   type c,    \
352                                                                   int *err)  \
353     {                                                                        \
354         int e2 = 0;                                                          \
355         type x, y;                                                           \
356                                                                              \
357         if (c == 0) {                                                        \
358             *err |= 1;                                                       \
359             return a == 0 || b == 0 ? 0 : max;                               \
360         }                                                                    \
361         x = safe_mul_ ## type_name(a, b, &e2);                               \
362         if (!e2)                                                             \
363             return x / c;                                                    \
364         if (b > a) {                                                         \
365             x = b;                                                           \
366             b = a;                                                           \
367             a = x;                                                           \
368         }                                                                    \
369         x = safe_mul_ ## type_name(a % c, b, err);                           \
370         y = safe_mul_ ## type_name(a / c, b, err);                           \
371         return safe_add_ ## type_name(y, x / c, err);                        \
372     }
373 
374 /*
375  * Calculate a / b rounding up:
376  *     i.e. a / b + (a % b != 0)
377  * Which is usually (less safely) converted to (a + b - 1) / b
378  * If you *know* that b != 0, then it's safe to ignore err.
379  */
380 #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \
381     static ossl_inline ossl_unused type safe_div_round_up_ ## type_name      \
382         (type a, type b, int *errp)                                          \
383     {                                                                        \
384         type x;                                                              \
385         int *err, err_local = 0;                                             \
386                                                                              \
387         /* Allow errors to be ignored by callers */                          \
388         err = errp != NULL ? errp : &err_local;                              \
389         /* Fast path, both positive */                                       \
390         if (b > 0 && a > 0) {                                                \
391             /* Faster path: no overflow concerns */                          \
392             if (a < max - b)                                                 \
393                 return (a + b - 1) / b;                                      \
394             return a / b + (a % b != 0);                                     \
395         }                                                                    \
396         if (b == 0) {                                                        \
397             *err |= 1;                                                       \
398             return a == 0 ? 0 : max;                                         \
399         }                                                                    \
400         if (a == 0)                                                          \
401             return 0;                                                        \
402         /* Rather slow path because there are negatives involved */          \
403         x = safe_mod_ ## type_name(a, b, err);                               \
404         return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err),     \
405                                       x != 0, err);                          \
406     }
407 
408 /* Calculate ranges of types */
409 # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
410 # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
411 # define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
412 
413 /*
414  * Wrapper macros to create all the functions of a given type
415  */
416 # define OSSL_SAFE_MATH_SIGNED(type_name, type)                         \
417     OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
418                         OSSL_SAFE_MATH_MAXS(type))                      \
419     OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
420                         OSSL_SAFE_MATH_MAXS(type))                      \
421     OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
422                         OSSL_SAFE_MATH_MAXS(type))                      \
423     OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
424                         OSSL_SAFE_MATH_MAXS(type))                      \
425     OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
426                         OSSL_SAFE_MATH_MAXS(type))                      \
427     OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                        \
428                                 OSSL_SAFE_MATH_MAXS(type))              \
429     OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type))  \
430     OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type))     \
431     OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
432 
433 # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \
434     OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \
435     OSSL_SAFE_MATH_SUBU(type_name, type)                                \
436     OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \
437     OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \
438     OSSL_SAFE_MATH_MODU(type_name, type)                                \
439     OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                        \
440                                 OSSL_SAFE_MATH_MAXU(type))              \
441     OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))  \
442     OSSL_SAFE_MATH_NEGU(type_name, type)                                \
443     OSSL_SAFE_MATH_ABSU(type_name, type)
444 
445 #endif                          /* OSSL_INTERNAL_SAFE_MATH_H */
446