xref: /openssl/ssl/quic/quic_lcidm.c (revision 29fbdfaf)
1 /*
2  * Copyright 2023 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include "internal/quic_lcidm.h"
11 #include "internal/quic_types.h"
12 #include "internal/quic_vlint.h"
13 #include "internal/common.h"
14 #include <openssl/lhash.h>
15 #include <openssl/rand.h>
16 #include <openssl/err.h>
17 
18 /*
19  * QUIC Local Connection ID Manager
20  * ================================
21  */
22 
23 typedef struct quic_lcidm_conn_st QUIC_LCIDM_CONN;
24 
25 enum {
26     LCID_TYPE_ODCID,        /* This LCID is the ODCID from the peer */
27     LCID_TYPE_INITIAL,      /* This is our Initial SCID */
28     LCID_TYPE_NCID          /* This LCID was issued via a NCID frame */
29 };
30 
31 typedef struct quic_lcid_st {
32     QUIC_CONN_ID                cid;
33     uint64_t                    seq_num;
34 
35     /* Back-pointer to the owning QUIC_LCIDM_CONN structure. */
36     QUIC_LCIDM_CONN             *conn;
37 
38     /* LCID_TYPE_* */
39     unsigned int                type                : 2;
40 } QUIC_LCID;
41 
42 DEFINE_LHASH_OF_EX(QUIC_LCID);
43 DEFINE_LHASH_OF_EX(QUIC_LCIDM_CONN);
44 
45 struct quic_lcidm_conn_st {
46     size_t              num_active_lcid;
47     LHASH_OF(QUIC_LCID) *lcids;
48     void                *opaque;
49     QUIC_LCID           *odcid_lcid_obj;
50     uint64_t            next_seq_num;
51 
52     /* Have we enrolled an ODCID? */
53     unsigned int        done_odcid          : 1;
54 };
55 
56 struct quic_lcidm_st {
57     OSSL_LIB_CTX                *libctx;
58     LHASH_OF(QUIC_LCID)         *lcids; /* (QUIC_CONN_ID) -> (QUIC_LCID *)  */
59     LHASH_OF(QUIC_LCIDM_CONN)   *conns; /* (void *opaque) -> (QUIC_LCIDM_CONN *) */
60     size_t                      lcid_len; /* Length in bytes for all LCIDs */
61 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
62     QUIC_CONN_ID                next_lcid;
63 #endif
64 };
65 
bin_hash(const unsigned char * buf,size_t buf_len)66 static unsigned long bin_hash(const unsigned char *buf, size_t buf_len)
67 {
68     unsigned long hash = 0;
69     size_t i;
70 
71     for (i = 0; i < buf_len; ++i)
72         hash ^= ((unsigned long)buf[i]) << (8 * (i % sizeof(unsigned long)));
73 
74     return hash;
75 }
76 
lcid_hash(const QUIC_LCID * lcid_obj)77 static unsigned long lcid_hash(const QUIC_LCID *lcid_obj)
78 {
79     return bin_hash(lcid_obj->cid.id, lcid_obj->cid.id_len);
80 }
81 
lcid_comp(const QUIC_LCID * a,const QUIC_LCID * b)82 static int lcid_comp(const QUIC_LCID *a, const QUIC_LCID *b)
83 {
84     return !ossl_quic_conn_id_eq(&a->cid, &b->cid);
85 }
86 
lcidm_conn_hash(const QUIC_LCIDM_CONN * conn)87 static unsigned long lcidm_conn_hash(const QUIC_LCIDM_CONN *conn)
88 {
89     return (unsigned long)(uintptr_t)conn->opaque;
90 }
91 
lcidm_conn_comp(const QUIC_LCIDM_CONN * a,const QUIC_LCIDM_CONN * b)92 static int lcidm_conn_comp(const QUIC_LCIDM_CONN *a, const QUIC_LCIDM_CONN *b)
93 {
94     return a->opaque != b->opaque;
95 }
96 
ossl_quic_lcidm_new(OSSL_LIB_CTX * libctx,size_t lcid_len)97 QUIC_LCIDM *ossl_quic_lcidm_new(OSSL_LIB_CTX *libctx, size_t lcid_len)
98 {
99     QUIC_LCIDM *lcidm = NULL;
100 
101     if (lcid_len > QUIC_MAX_CONN_ID_LEN)
102         goto err;
103 
104     if ((lcidm = OPENSSL_zalloc(sizeof(*lcidm))) == NULL)
105         goto err;
106 
107     if ((lcidm->lcids = lh_QUIC_LCID_new(lcid_hash, lcid_comp)) == NULL)
108         goto err;
109 
110     if ((lcidm->conns = lh_QUIC_LCIDM_CONN_new(lcidm_conn_hash,
111                                                lcidm_conn_comp)) == NULL)
112         goto err;
113 
114     lcidm->libctx   = libctx;
115     lcidm->lcid_len = lcid_len;
116     return lcidm;
117 
118 err:
119     if (lcidm != NULL) {
120         lh_QUIC_LCID_free(lcidm->lcids);
121         lh_QUIC_LCIDM_CONN_free(lcidm->conns);
122         OPENSSL_free(lcidm);
123     }
124     return NULL;
125 }
126 
127 static void lcidm_delete_conn(QUIC_LCIDM *lcidm, QUIC_LCIDM_CONN *conn);
128 
lcidm_delete_conn_(QUIC_LCIDM_CONN * conn,void * arg)129 static void lcidm_delete_conn_(QUIC_LCIDM_CONN *conn, void *arg)
130 {
131     lcidm_delete_conn((QUIC_LCIDM *)arg, conn);
132 }
133 
ossl_quic_lcidm_free(QUIC_LCIDM * lcidm)134 void ossl_quic_lcidm_free(QUIC_LCIDM *lcidm)
135 {
136     if (lcidm == NULL)
137         return;
138 
139     /*
140      * Calling OPENSSL_lh_delete during a doall call is unsafe with our
141      * current LHASH implementation for several reasons:
142      *
143      * - firstly, because deletes can cause the hashtable to be contracted,
144      *   resulting in rehashing which might cause items in later buckets to
145      *   move to earlier buckets, which might cause doall to skip an item,
146      *   resulting in a memory leak;
147      *
148      * - secondly, because doall in general is not safe across hashtable
149      *   size changes, as it caches hashtable size and pointer values
150      *   while operating.
151      *
152      * The fix for this is to disable hashtable contraction using the following
153      * call, which guarantees that no rehashing will occur so long as we only
154      * call delete and not insert.
155      */
156     lh_QUIC_LCIDM_CONN_set_down_load(lcidm->conns, 0);
157 
158     lh_QUIC_LCIDM_CONN_doall_arg(lcidm->conns, lcidm_delete_conn_, lcidm);
159 
160     lh_QUIC_LCID_free(lcidm->lcids);
161     lh_QUIC_LCIDM_CONN_free(lcidm->conns);
162     OPENSSL_free(lcidm);
163 }
164 
lcidm_get0_lcid(const QUIC_LCIDM * lcidm,const QUIC_CONN_ID * lcid)165 static QUIC_LCID *lcidm_get0_lcid(const QUIC_LCIDM *lcidm, const QUIC_CONN_ID *lcid)
166 {
167     QUIC_LCID key;
168 
169     key.cid = *lcid;
170 
171     if (key.cid.id_len > QUIC_MAX_CONN_ID_LEN)
172         return NULL;
173 
174     return lh_QUIC_LCID_retrieve(lcidm->lcids, &key);
175 }
176 
lcidm_get0_conn(const QUIC_LCIDM * lcidm,void * opaque)177 static QUIC_LCIDM_CONN *lcidm_get0_conn(const QUIC_LCIDM *lcidm, void *opaque)
178 {
179     QUIC_LCIDM_CONN key;
180 
181     key.opaque = opaque;
182 
183     return lh_QUIC_LCIDM_CONN_retrieve(lcidm->conns, &key);
184 }
185 
lcidm_upsert_conn(const QUIC_LCIDM * lcidm,void * opaque)186 static QUIC_LCIDM_CONN *lcidm_upsert_conn(const QUIC_LCIDM *lcidm, void *opaque)
187 {
188     QUIC_LCIDM_CONN *conn = lcidm_get0_conn(lcidm, opaque);
189 
190     if (conn != NULL)
191         return conn;
192 
193     if ((conn = OPENSSL_zalloc(sizeof(*conn))) == NULL)
194         goto err;
195 
196     if ((conn->lcids = lh_QUIC_LCID_new(lcid_hash, lcid_comp)) == NULL)
197         goto err;
198 
199     conn->opaque = opaque;
200 
201     lh_QUIC_LCIDM_CONN_insert(lcidm->conns, conn);
202     if (lh_QUIC_LCIDM_CONN_error(lcidm->conns))
203         goto err;
204 
205     return conn;
206 
207 err:
208     if (conn != NULL) {
209         lh_QUIC_LCID_free(conn->lcids);
210         OPENSSL_free(conn);
211     }
212     return NULL;
213 }
214 
lcidm_delete_conn_lcid(QUIC_LCIDM * lcidm,QUIC_LCID * lcid_obj)215 static void lcidm_delete_conn_lcid(QUIC_LCIDM *lcidm, QUIC_LCID *lcid_obj)
216 {
217     lh_QUIC_LCID_delete(lcidm->lcids, lcid_obj);
218     lh_QUIC_LCID_delete(lcid_obj->conn->lcids, lcid_obj);
219     assert(lcid_obj->conn->num_active_lcid > 0);
220     --lcid_obj->conn->num_active_lcid;
221     OPENSSL_free(lcid_obj);
222 }
223 
224 /* doall_arg wrapper */
lcidm_delete_conn_lcid_(QUIC_LCID * lcid_obj,void * arg)225 static void lcidm_delete_conn_lcid_(QUIC_LCID *lcid_obj, void *arg)
226 {
227     lcidm_delete_conn_lcid((QUIC_LCIDM *)arg, lcid_obj);
228 }
229 
lcidm_delete_conn(QUIC_LCIDM * lcidm,QUIC_LCIDM_CONN * conn)230 static void lcidm_delete_conn(QUIC_LCIDM *lcidm, QUIC_LCIDM_CONN *conn)
231 {
232     /* See comment in ossl_quic_lcidm_free */
233     lh_QUIC_LCID_set_down_load(conn->lcids, 0);
234 
235     lh_QUIC_LCID_doall_arg(conn->lcids, lcidm_delete_conn_lcid_, lcidm);
236     lh_QUIC_LCIDM_CONN_delete(lcidm->conns, conn);
237     lh_QUIC_LCID_free(conn->lcids);
238     OPENSSL_free(conn);
239 }
240 
lcidm_conn_new_lcid(QUIC_LCIDM * lcidm,QUIC_LCIDM_CONN * conn,const QUIC_CONN_ID * lcid)241 static QUIC_LCID *lcidm_conn_new_lcid(QUIC_LCIDM *lcidm, QUIC_LCIDM_CONN *conn,
242                                       const QUIC_CONN_ID *lcid)
243 {
244     QUIC_LCID *lcid_obj = NULL;
245 
246     if (lcid->id_len > QUIC_MAX_CONN_ID_LEN)
247         return NULL;
248 
249     if ((lcid_obj = OPENSSL_zalloc(sizeof(*lcid_obj))) == NULL)
250         goto err;
251 
252     lcid_obj->cid = *lcid;
253     lcid_obj->conn = conn;
254 
255     lh_QUIC_LCID_insert(conn->lcids, lcid_obj);
256     if (lh_QUIC_LCID_error(conn->lcids))
257         goto err;
258 
259     lh_QUIC_LCID_insert(lcidm->lcids, lcid_obj);
260     if (lh_QUIC_LCID_error(lcidm->lcids)) {
261         lh_QUIC_LCID_delete(conn->lcids, lcid_obj);
262         goto err;
263     }
264 
265     ++conn->num_active_lcid;
266     return lcid_obj;
267 
268 err:
269     OPENSSL_free(lcid_obj);
270     return NULL;
271 }
272 
ossl_quic_lcidm_get_lcid_len(const QUIC_LCIDM * lcidm)273 size_t ossl_quic_lcidm_get_lcid_len(const QUIC_LCIDM *lcidm)
274 {
275     return lcidm->lcid_len;
276 }
277 
ossl_quic_lcidm_get_num_active_lcid(const QUIC_LCIDM * lcidm,void * opaque)278 size_t ossl_quic_lcidm_get_num_active_lcid(const QUIC_LCIDM *lcidm,
279                                            void *opaque)
280 {
281     QUIC_LCIDM_CONN *conn;
282 
283     conn = lcidm_get0_conn(lcidm, opaque);
284     if (conn == NULL)
285         return 0;
286 
287     return conn->num_active_lcid;
288 }
289 
lcidm_generate_cid(QUIC_LCIDM * lcidm,QUIC_CONN_ID * cid)290 static int lcidm_generate_cid(QUIC_LCIDM *lcidm,
291                               QUIC_CONN_ID *cid)
292 {
293 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
294     int i;
295 
296     lcidm->next_lcid.id_len = (unsigned char)lcidm->lcid_len;
297     *cid = lcidm->next_lcid;
298 
299     for (i = lcidm->lcid_len - 1; i >= 0; --i)
300         if (++lcidm->next_lcid.id[i] != 0)
301             break;
302 
303     return 1;
304 #else
305     return ossl_quic_gen_rand_conn_id(lcidm->libctx, lcidm->lcid_len, cid);
306 #endif
307 }
308 
lcidm_generate(QUIC_LCIDM * lcidm,void * opaque,unsigned int type,QUIC_CONN_ID * lcid_out,uint64_t * seq_num)309 static int lcidm_generate(QUIC_LCIDM *lcidm,
310                           void *opaque,
311                           unsigned int type,
312                           QUIC_CONN_ID *lcid_out,
313                           uint64_t *seq_num)
314 {
315     QUIC_LCIDM_CONN *conn;
316     QUIC_LCID key, *lcid_obj;
317     size_t i;
318 #define MAX_RETRIES 8
319 
320     if ((conn = lcidm_upsert_conn(lcidm, opaque)) == NULL)
321         return 0;
322 
323     if ((type == LCID_TYPE_INITIAL && conn->next_seq_num > 0)
324         || conn->next_seq_num > OSSL_QUIC_VLINT_MAX)
325         return 0;
326 
327     i = 0;
328     do {
329         if (i++ >= MAX_RETRIES)
330             /*
331              * Too many retries; should not happen but if it does, don't loop
332              * endlessly.
333              */
334             return 0;
335 
336         if (!lcidm_generate_cid(lcidm, lcid_out))
337             return 0;
338 
339         key.cid = *lcid_out;
340         /* If a collision occurs, retry. */
341     } while (lh_QUIC_LCID_retrieve(lcidm->lcids, &key) != NULL);
342 
343     if ((lcid_obj = lcidm_conn_new_lcid(lcidm, conn, lcid_out)) == NULL)
344         return 0;
345 
346     lcid_obj->seq_num   = conn->next_seq_num;
347     lcid_obj->type      = type;
348 
349     if (seq_num != NULL)
350         *seq_num = lcid_obj->seq_num;
351 
352     ++conn->next_seq_num;
353     return 1;
354 }
355 
ossl_quic_lcidm_enrol_odcid(QUIC_LCIDM * lcidm,void * opaque,const QUIC_CONN_ID * initial_odcid)356 int ossl_quic_lcidm_enrol_odcid(QUIC_LCIDM *lcidm,
357                                 void *opaque,
358                                 const QUIC_CONN_ID *initial_odcid)
359 {
360     QUIC_LCIDM_CONN *conn;
361     QUIC_LCID key, *lcid_obj;
362 
363     if (initial_odcid == NULL || initial_odcid->id_len < QUIC_MIN_ODCID_LEN
364         || initial_odcid->id_len > QUIC_MAX_CONN_ID_LEN)
365         return 0;
366 
367     if ((conn = lcidm_upsert_conn(lcidm, opaque)) == NULL)
368         return 0;
369 
370     if (conn->done_odcid)
371         return 0;
372 
373     key.cid = *initial_odcid;
374     if (lh_QUIC_LCID_retrieve(lcidm->lcids, &key) != NULL)
375         return 0;
376 
377     if ((lcid_obj = lcidm_conn_new_lcid(lcidm, conn, initial_odcid)) == NULL)
378         return 0;
379 
380     lcid_obj->seq_num       = LCIDM_ODCID_SEQ_NUM;
381     lcid_obj->type          = LCID_TYPE_ODCID;
382 
383     conn->odcid_lcid_obj    = lcid_obj;
384     conn->done_odcid        = 1;
385     return 1;
386 }
387 
ossl_quic_lcidm_generate_initial(QUIC_LCIDM * lcidm,void * opaque,QUIC_CONN_ID * initial_lcid)388 int ossl_quic_lcidm_generate_initial(QUIC_LCIDM *lcidm,
389                                      void *opaque,
390                                      QUIC_CONN_ID *initial_lcid)
391 {
392     return lcidm_generate(lcidm, opaque, LCID_TYPE_INITIAL,
393                           initial_lcid, NULL);
394 }
395 
ossl_quic_lcidm_generate(QUIC_LCIDM * lcidm,void * opaque,OSSL_QUIC_FRAME_NEW_CONN_ID * ncid_frame)396 int ossl_quic_lcidm_generate(QUIC_LCIDM *lcidm,
397                              void *opaque,
398                              OSSL_QUIC_FRAME_NEW_CONN_ID *ncid_frame)
399 {
400     ncid_frame->seq_num         = 0;
401     ncid_frame->retire_prior_to = 0;
402 
403     return lcidm_generate(lcidm, opaque, LCID_TYPE_NCID,
404                           &ncid_frame->conn_id,
405                           &ncid_frame->seq_num);
406 }
407 
ossl_quic_lcidm_retire_odcid(QUIC_LCIDM * lcidm,void * opaque)408 int ossl_quic_lcidm_retire_odcid(QUIC_LCIDM *lcidm, void *opaque)
409 {
410     QUIC_LCIDM_CONN *conn;
411 
412     if ((conn = lcidm_upsert_conn(lcidm, opaque)) == NULL)
413         return 0;
414 
415     if (conn->odcid_lcid_obj == NULL)
416         return 0;
417 
418     lcidm_delete_conn_lcid(lcidm, conn->odcid_lcid_obj);
419     conn->odcid_lcid_obj = NULL;
420     return 1;
421 }
422 
423 struct retire_args {
424     QUIC_LCID           *earliest_seq_num_lcid_obj;
425     uint64_t            earliest_seq_num, retire_prior_to;
426 };
427 
retire_for_conn(QUIC_LCID * lcid_obj,void * arg)428 static void retire_for_conn(QUIC_LCID *lcid_obj, void *arg)
429 {
430     struct retire_args *args = arg;
431 
432     /* ODCID LCID cannot be retired via this API */
433     if (lcid_obj->type == LCID_TYPE_ODCID
434         || lcid_obj->seq_num >= args->retire_prior_to)
435         return;
436 
437     if (lcid_obj->seq_num < args->earliest_seq_num) {
438         args->earliest_seq_num          = lcid_obj->seq_num;
439         args->earliest_seq_num_lcid_obj = lcid_obj;
440     }
441 }
442 
ossl_quic_lcidm_retire(QUIC_LCIDM * lcidm,void * opaque,uint64_t retire_prior_to,const QUIC_CONN_ID * containing_pkt_dcid,QUIC_CONN_ID * retired_lcid,uint64_t * retired_seq_num,int * did_retire)443 int ossl_quic_lcidm_retire(QUIC_LCIDM *lcidm,
444                            void *opaque,
445                            uint64_t retire_prior_to,
446                            const QUIC_CONN_ID *containing_pkt_dcid,
447                            QUIC_CONN_ID *retired_lcid,
448                            uint64_t *retired_seq_num,
449                            int *did_retire)
450 {
451     QUIC_LCIDM_CONN key, *conn;
452     struct retire_args args = {0};
453 
454     key.opaque = opaque;
455 
456     if (did_retire == NULL)
457         return 0;
458 
459     *did_retire = 0;
460     if ((conn = lh_QUIC_LCIDM_CONN_retrieve(lcidm->conns, &key)) == NULL)
461         return 1;
462 
463     args.retire_prior_to    = retire_prior_to;
464     args.earliest_seq_num   = UINT64_MAX;
465 
466     lh_QUIC_LCID_doall_arg(conn->lcids, retire_for_conn, &args);
467     if (args.earliest_seq_num_lcid_obj == NULL)
468         return 1;
469 
470     if (containing_pkt_dcid != NULL
471         && ossl_quic_conn_id_eq(&args.earliest_seq_num_lcid_obj->cid,
472                                 containing_pkt_dcid))
473         return 0;
474 
475     *did_retire = 1;
476     if (retired_lcid != NULL)
477         *retired_lcid = args.earliest_seq_num_lcid_obj->cid;
478     if (retired_seq_num != NULL)
479         *retired_seq_num = args.earliest_seq_num_lcid_obj->seq_num;
480 
481     lcidm_delete_conn_lcid(lcidm, args.earliest_seq_num_lcid_obj);
482     return 1;
483 }
484 
ossl_quic_lcidm_cull(QUIC_LCIDM * lcidm,void * opaque)485 int ossl_quic_lcidm_cull(QUIC_LCIDM *lcidm, void *opaque)
486 {
487     QUIC_LCIDM_CONN key, *conn;
488 
489     key.opaque = opaque;
490 
491     if ((conn = lh_QUIC_LCIDM_CONN_retrieve(lcidm->conns, &key)) == NULL)
492         return 0;
493 
494     lcidm_delete_conn(lcidm, conn);
495     return 1;
496 }
497 
ossl_quic_lcidm_lookup(QUIC_LCIDM * lcidm,const QUIC_CONN_ID * lcid,uint64_t * seq_num,void ** opaque)498 int ossl_quic_lcidm_lookup(QUIC_LCIDM *lcidm,
499                            const QUIC_CONN_ID *lcid,
500                            uint64_t *seq_num,
501                            void **opaque)
502 {
503     QUIC_LCID *lcid_obj;
504 
505     if (lcid == NULL)
506         return 0;
507 
508     if ((lcid_obj = lcidm_get0_lcid(lcidm, lcid)) == NULL)
509         return 0;
510 
511     if (seq_num != NULL)
512         *seq_num        = lcid_obj->seq_num;
513 
514     if (opaque != NULL)
515         *opaque         = lcid_obj->conn->opaque;
516 
517     return 1;
518 }
519 
ossl_quic_lcidm_debug_remove(QUIC_LCIDM * lcidm,const QUIC_CONN_ID * lcid)520 int ossl_quic_lcidm_debug_remove(QUIC_LCIDM *lcidm,
521                                  const QUIC_CONN_ID *lcid)
522 {
523     QUIC_LCID key, *lcid_obj;
524 
525     key.cid = *lcid;
526     if ((lcid_obj = lh_QUIC_LCID_retrieve(lcidm->lcids, &key)) == NULL)
527         return 0;
528 
529     lcidm_delete_conn_lcid(lcidm, lcid_obj);
530     return 1;
531 }
532 
ossl_quic_lcidm_debug_add(QUIC_LCIDM * lcidm,void * opaque,const QUIC_CONN_ID * lcid,uint64_t seq_num)533 int ossl_quic_lcidm_debug_add(QUIC_LCIDM *lcidm, void *opaque,
534                               const QUIC_CONN_ID *lcid,
535                               uint64_t seq_num)
536 {
537     QUIC_LCIDM_CONN *conn;
538     QUIC_LCID key, *lcid_obj;
539 
540     if (lcid == NULL || lcid->id_len > QUIC_MAX_CONN_ID_LEN)
541         return 0;
542 
543     if ((conn = lcidm_upsert_conn(lcidm, opaque)) == NULL)
544         return 0;
545 
546     key.cid = *lcid;
547     if (lh_QUIC_LCID_retrieve(lcidm->lcids, &key) != NULL)
548         return 0;
549 
550     if ((lcid_obj = lcidm_conn_new_lcid(lcidm, conn, lcid)) == NULL)
551         return 0;
552 
553     lcid_obj->seq_num   = seq_num;
554     lcid_obj->type      = LCID_TYPE_NCID;
555     return 1;
556 }
557