xref: /openssl/ssl/quic/uint_set.c (revision da1c088f)
1 /*
2  * Copyright 2022-2023 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include "internal/uint_set.h"
11 #include "internal/common.h"
12 #include <assert.h>
13 
14 /*
15  * uint64_t Integer Sets
16  * =====================
17  *
18  * This data structure supports the following operations:
19  *
20  *   Insert Range: Adds an inclusive range of integers [start, end]
21  *                 to the set. Equivalent to Insert for each number
22  *                 in the range.
23  *
24  *   Remove Range: Removes an inclusive range of integers [start, end]
25  *                 from the set. Not all of the range need already be in
26  *                 the set, but any part of the range in the set is removed.
27  *
28  *   Query:        Is an integer in the data structure?
29  *
30  * The data structure can be iterated.
31  *
32  * For greater efficiency in tracking large numbers of contiguous integers, we
33  * track integer ranges rather than individual integers. The data structure
34  * manages a list of integer ranges [[start, end]...]. Internally this is
35  * implemented as a doubly linked sorted list of range structures, which are
36  * automatically split and merged as necessary.
37  *
38  * This data structure requires O(n) traversal of the list for insertion,
39  * removal and query when we are not adding/removing ranges which are near the
40  * beginning or end of the set of ranges. For the applications for which this
41  * data structure is used (e.g. QUIC PN tracking for ACK generation), it is
42  * expected that the number of integer ranges needed at any given time will
43  * generally be small and that most operations will be close to the beginning or
44  * end of the range.
45  *
46  * Invariant: The data structure is always sorted in ascending order by value.
47  *
48  * Invariant: No two adjacent ranges ever 'border' one another (have no
49  *            numerical gap between them) as the data structure always ensures
50  *            such ranges are merged.
51  *
52  * Invariant: No two ranges ever overlap.
53  *
54  * Invariant: No range [a, b] ever has a > b.
55  *
56  * Invariant: Since ranges are represented using inclusive bounds, no range
57  *            item inside the data structure can represent a span of zero
58  *            integers.
59  */
ossl_uint_set_init(UINT_SET * s)60 void ossl_uint_set_init(UINT_SET *s)
61 {
62     ossl_list_uint_set_init(s);
63 }
64 
ossl_uint_set_destroy(UINT_SET * s)65 void ossl_uint_set_destroy(UINT_SET *s)
66 {
67     UINT_SET_ITEM *x, *xnext;
68 
69     for (x = ossl_list_uint_set_head(s); x != NULL; x = xnext) {
70         xnext = ossl_list_uint_set_next(x);
71         OPENSSL_free(x);
72     }
73 }
74 
75 /* Possible merge of x, prev(x) */
uint_set_merge_adjacent(UINT_SET * s,UINT_SET_ITEM * x)76 static void uint_set_merge_adjacent(UINT_SET *s, UINT_SET_ITEM *x)
77 {
78     UINT_SET_ITEM *xprev = ossl_list_uint_set_prev(x);
79 
80     if (xprev == NULL)
81         return;
82 
83     if (x->range.start - 1 != xprev->range.end)
84         return;
85 
86     x->range.start = xprev->range.start;
87     ossl_list_uint_set_remove(s, xprev);
88     OPENSSL_free(xprev);
89 }
90 
u64_min(uint64_t x,uint64_t y)91 static uint64_t u64_min(uint64_t x, uint64_t y)
92 {
93     return x < y ? x : y;
94 }
95 
u64_max(uint64_t x,uint64_t y)96 static uint64_t u64_max(uint64_t x, uint64_t y)
97 {
98     return x > y ? x : y;
99 }
100 
101 /*
102  * Returns 1 if there exists an integer x which falls within both ranges a and
103  * b.
104  */
uint_range_overlaps(const UINT_RANGE * a,const UINT_RANGE * b)105 static int uint_range_overlaps(const UINT_RANGE *a,
106                                const UINT_RANGE *b)
107 {
108     return u64_min(a->end, b->end)
109         >= u64_max(a->start, b->start);
110 }
111 
create_set_item(uint64_t start,uint64_t end)112 static UINT_SET_ITEM *create_set_item(uint64_t start, uint64_t end)
113 {
114     UINT_SET_ITEM *x = OPENSSL_malloc(sizeof(UINT_SET_ITEM));
115 
116     if (x == NULL)
117         return NULL;
118 
119     ossl_list_uint_set_init_elem(x);
120     x->range.start = start;
121     x->range.end   = end;
122     return x;
123 }
124 
ossl_uint_set_insert(UINT_SET * s,const UINT_RANGE * range)125 int ossl_uint_set_insert(UINT_SET *s, const UINT_RANGE *range)
126 {
127     UINT_SET_ITEM *x, *xnext, *z, *zprev, *f;
128     uint64_t start = range->start, end = range->end;
129 
130     if (!ossl_assert(start <= end))
131         return 0;
132 
133     if (ossl_list_uint_set_is_empty(s)) {
134         /* Nothing in the set yet, so just add this range. */
135         x = create_set_item(start, end);
136         if (x == NULL)
137             return 0;
138         ossl_list_uint_set_insert_head(s, x);
139         return 1;
140     }
141 
142     z = ossl_list_uint_set_tail(s);
143     if (start > z->range.end) {
144         /*
145          * Range is after the latest range in the set, so append.
146          *
147          * Note: The case where the range is before the earliest range in the
148          * set is handled as a degenerate case of the final case below. See
149          * optimization note (*) below.
150          */
151         if (z->range.end + 1 == start) {
152             z->range.end = end;
153             return 1;
154         }
155 
156         x = create_set_item(start, end);
157         if (x == NULL)
158             return 0;
159         ossl_list_uint_set_insert_tail(s, x);
160         return 1;
161     }
162 
163     f = ossl_list_uint_set_head(s);
164     if (start <= f->range.start && end >= z->range.end) {
165         /*
166          * New range dwarfs all ranges in our set.
167          *
168          * Free everything except the first range in the set, which we scavenge
169          * and reuse.
170          */
171         x = ossl_list_uint_set_head(s);
172         x->range.start = start;
173         x->range.end = end;
174         for (x = ossl_list_uint_set_next(x); x != NULL; x = xnext) {
175             xnext = ossl_list_uint_set_next(x);
176             ossl_list_uint_set_remove(s, x);
177         }
178         return 1;
179     }
180 
181     /*
182      * Walk backwards since we will most often be inserting at the end. As an
183      * optimization, test the head node first and skip iterating over the
184      * entire list if we are inserting at the start. The assumption is that
185      * insertion at the start and end of the space will be the most common
186      * operations. (*)
187      */
188     z = end < f->range.start ? f : z;
189 
190     for (; z != NULL; z = zprev) {
191         zprev = ossl_list_uint_set_prev(z);
192 
193         /* An existing range dwarfs our new range (optimisation). */
194         if (z->range.start <= start && z->range.end >= end)
195             return 1;
196 
197         if (uint_range_overlaps(&z->range, range)) {
198             /*
199              * Our new range overlaps an existing range, or possibly several
200              * existing ranges.
201              */
202             UINT_SET_ITEM *ovend = z;
203 
204             ovend->range.end = u64_max(end, z->range.end);
205 
206             /* Get earliest overlapping range. */
207             while (zprev != NULL && uint_range_overlaps(&zprev->range, range)) {
208                 z = zprev;
209                 zprev = ossl_list_uint_set_prev(z);
210             }
211 
212             ovend->range.start = u64_min(start, z->range.start);
213 
214             /* Replace sequence of nodes z..ovend with updated ovend only. */
215             while (z != ovend) {
216                 z = ossl_list_uint_set_next(x = z);
217                 ossl_list_uint_set_remove(s, x);
218                 OPENSSL_free(x);
219             }
220             break;
221         } else if (end < z->range.start
222                     && (zprev == NULL || start > zprev->range.end)) {
223             if (z->range.start == end + 1) {
224                 /* We can extend the following range backwards. */
225                 z->range.start = start;
226 
227                 /*
228                  * If this closes a gap we now need to merge
229                  * consecutive nodes.
230                  */
231                 uint_set_merge_adjacent(s, z);
232             } else if (zprev != NULL && zprev->range.end + 1 == start) {
233                 /* We can extend the preceding range forwards. */
234                 zprev->range.end = end;
235 
236                 /*
237                  * If this closes a gap we now need to merge
238                  * consecutive nodes.
239                  */
240                 uint_set_merge_adjacent(s, z);
241             } else {
242                 /*
243                  * The new interval is between intervals without overlapping or
244                  * touching them, so insert between, preserving sort.
245                  */
246                 x = create_set_item(start, end);
247                 if (x == NULL)
248                     return 0;
249                 ossl_list_uint_set_insert_before(s, z, x);
250             }
251             break;
252         }
253     }
254 
255     return 1;
256 }
257 
ossl_uint_set_remove(UINT_SET * s,const UINT_RANGE * range)258 int ossl_uint_set_remove(UINT_SET *s, const UINT_RANGE *range)
259 {
260     UINT_SET_ITEM *z, *zprev, *y;
261     uint64_t start = range->start, end = range->end;
262 
263     if (!ossl_assert(start <= end))
264         return 0;
265 
266     /* Walk backwards since we will most often be removing at the end. */
267     for (z = ossl_list_uint_set_tail(s); z != NULL; z = zprev) {
268         zprev = ossl_list_uint_set_prev(z);
269 
270         if (start > z->range.end)
271             /* No overlapping ranges can exist beyond this point, so stop. */
272             break;
273 
274         if (start <= z->range.start && end >= z->range.end) {
275             /*
276              * The range being removed dwarfs this range, so it should be
277              * removed.
278              */
279             ossl_list_uint_set_remove(s, z);
280             OPENSSL_free(z);
281         } else if (start <= z->range.start && end >= z->range.start) {
282             /*
283              * The range being removed includes start of this range, but does
284              * not cover the entire range (as this would be caught by the case
285              * above). Shorten the range.
286              */
287             assert(end < z->range.end);
288             z->range.start = end + 1;
289         } else if (end >= z->range.end) {
290             /*
291              * The range being removed includes the end of this range, but does
292              * not cover the entire range (as this would be caught by the case
293              * above). Shorten the range. We can also stop iterating.
294              */
295             assert(start > z->range.start);
296             assert(start > 0);
297             z->range.end = start - 1;
298             break;
299         } else if (start > z->range.start && end < z->range.end) {
300             /*
301              * The range being removed falls entirely in this range, so cut it
302              * into two. Cases where a zero-length range would be created are
303              * handled by the above cases.
304              */
305             y = create_set_item(end + 1, z->range.end);
306             ossl_list_uint_set_insert_after(s, z, y);
307             z->range.end = start - 1;
308             break;
309         } else {
310             /* Assert no partial overlap; all cases should be covered above. */
311             assert(!uint_range_overlaps(&z->range, range));
312         }
313     }
314 
315     return 1;
316 }
317 
ossl_uint_set_query(const UINT_SET * s,uint64_t v)318 int ossl_uint_set_query(const UINT_SET *s, uint64_t v)
319 {
320     UINT_SET_ITEM *x;
321 
322     if (ossl_list_uint_set_is_empty(s))
323         return 0;
324 
325     for (x = ossl_list_uint_set_tail(s); x != NULL; x = ossl_list_uint_set_prev(x))
326         if (x->range.start <= v && x->range.end >= v)
327             return 1;
328         else if (x->range.end < v)
329             return 0;
330 
331     return 0;
332 }
333