xref: /openssl/doc/designs/ddd/ddd-06-mem-uv.c (revision ec36534c)
1 #include <sys/poll.h>
2 #include <openssl/ssl.h>
3 #include <uv.h>
4 #include <assert.h>
5 
6 typedef struct app_conn_st APP_CONN;
7 typedef struct upper_write_op_st UPPER_WRITE_OP;
8 typedef struct lower_write_op_st LOWER_WRITE_OP;
9 
10 typedef void (app_connect_cb)(APP_CONN *conn, int status, void *arg);
11 typedef void (app_write_cb)(APP_CONN *conn, int status, void *arg);
12 typedef void (app_read_cb)(APP_CONN *conn, void *buf, size_t buf_len, void *arg);
13 
14 static void tcp_connect_done(uv_connect_t *tcp_connect, int status);
15 static void net_connect_fail_close_done(uv_handle_t *handle);
16 static int handshake_ssl(APP_CONN *conn);
17 static void flush_write_buf(APP_CONN *conn);
18 static void set_rx(APP_CONN *conn);
19 static int try_write(APP_CONN *conn, UPPER_WRITE_OP *op);
20 static void handle_pending_writes(APP_CONN *conn);
21 static int write_deferred(APP_CONN *conn, const void *buf, size_t buf_len, app_write_cb *cb, void *arg);
22 static void teardown_continued(uv_handle_t *handle);
23 static int setup_ssl(APP_CONN *conn, const char *hostname);
24 
25 /*
26  * Structure to track an application-level write request. Only created
27  * if SSL_write does not accept the data immediately, typically because
28  * it is in WANT_READ.
29  */
30 struct upper_write_op_st {
31     struct upper_write_op_st   *prev, *next;
32     const uint8_t              *buf;
33     size_t                      buf_len, written;
34     APP_CONN                   *conn;
35     app_write_cb               *cb;
36     void                       *cb_arg;
37 };
38 
39 /*
40  * Structure to track a network-level write request.
41  */
42 struct lower_write_op_st {
43     uv_write_t      w;
44     uv_buf_t        b;
45     uint8_t        *buf;
46     APP_CONN       *conn;
47 };
48 
49 /*
50  * Application connection object.
51  */
52 struct app_conn_st {
53     SSL_CTX        *ctx;
54     SSL            *ssl;
55     BIO            *net_bio;
56     uv_stream_t    *stream;
57     uv_tcp_t        tcp;
58     uv_connect_t    tcp_connect;
59     app_connect_cb *app_connect_cb;   /* called once handshake is done */
60     void           *app_connect_arg;
61     app_read_cb    *app_read_cb;      /* application's on-RX callback */
62     void           *app_read_arg;
63     const char     *hostname;
64     char            init_handshake, done_handshake, closed;
65     char           *teardown_done;
66 
67     UPPER_WRITE_OP *pending_upper_write_head, *pending_upper_write_tail;
68 };
69 
70 /*
71  * The application is initializing and wants an SSL_CTX which it will use for
72  * some number of outgoing connections, which it creates in subsequent calls to
73  * new_conn. The application may also call this function multiple times to
74  * create multiple SSL_CTX.
75  */
create_ssl_ctx(void)76 SSL_CTX *create_ssl_ctx(void)
77 {
78     SSL_CTX *ctx;
79 
80     ctx = SSL_CTX_new(TLS_client_method());
81     if (ctx == NULL)
82         return NULL;
83 
84     /* Enable trust chain verification. */
85     SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
86 
87     /* Load default root CA store. */
88     if (SSL_CTX_set_default_verify_paths(ctx) == 0) {
89         SSL_CTX_free(ctx);
90         return NULL;
91     }
92 
93     return ctx;
94 }
95 
96 /*
97  * The application wants to create a new outgoing connection using a given
98  * SSL_CTX. An outgoing TCP connection is started and the callback is called
99  * asynchronously when the TLS handshake is complete.
100  *
101  * hostname is a string like "openssl.org" used for certificate validation.
102  */
103 
new_conn(SSL_CTX * ctx,const char * hostname,struct sockaddr * sa,socklen_t sa_len,app_connect_cb * cb,void * arg)104 APP_CONN *new_conn(SSL_CTX *ctx, const char *hostname,
105                    struct sockaddr *sa, socklen_t sa_len,
106                    app_connect_cb *cb, void *arg)
107 {
108     int rc;
109     APP_CONN *conn = NULL;
110 
111     conn = calloc(1, sizeof(APP_CONN));
112     if (!conn)
113         return NULL;
114 
115     uv_tcp_init(uv_default_loop(), &conn->tcp);
116     conn->tcp.data = conn;
117 
118     conn->stream            = (uv_stream_t *)&conn->tcp;
119     conn->app_connect_cb    = cb;
120     conn->app_connect_arg   = arg;
121     conn->tcp_connect.data  = conn;
122     rc = uv_tcp_connect(&conn->tcp_connect, &conn->tcp, sa, tcp_connect_done);
123     if (rc < 0) {
124         uv_close((uv_handle_t *)&conn->tcp, net_connect_fail_close_done);
125         return NULL;
126     }
127 
128     conn->ctx       = ctx;
129     conn->hostname  = hostname;
130     return conn;
131 }
132 
133 /*
134  * The application wants to start reading from the SSL stream.
135  * The callback is called whenever data is available.
136  */
app_read_start(APP_CONN * conn,app_read_cb * cb,void * arg)137 int app_read_start(APP_CONN *conn, app_read_cb *cb, void *arg)
138 {
139     conn->app_read_cb  = cb;
140     conn->app_read_arg = arg;
141     set_rx(conn);
142     return 0;
143 }
144 
145 /*
146  * The application wants to write. The callback is called once the
147  * write is complete. The callback should free the buffer.
148  */
app_write(APP_CONN * conn,const void * buf,size_t buf_len,app_write_cb * cb,void * arg)149 int app_write(APP_CONN *conn, const void *buf, size_t buf_len, app_write_cb *cb, void *arg)
150 {
151     write_deferred(conn, buf, buf_len, cb, arg);
152     handle_pending_writes(conn);
153     return buf_len;
154 }
155 
156 /*
157  * The application wants to close the connection and free bookkeeping
158  * structures.
159  */
teardown(APP_CONN * conn)160 void teardown(APP_CONN *conn)
161 {
162     char teardown_done = 0;
163 
164     if (conn == NULL)
165         return;
166 
167     BIO_free_all(conn->net_bio);
168     SSL_free(conn->ssl);
169 
170     uv_cancel((uv_req_t *)&conn->tcp_connect);
171 
172     conn->teardown_done = &teardown_done;
173     uv_close((uv_handle_t *)conn->stream, teardown_continued);
174 
175     /* Just wait synchronously until teardown completes. */
176     while (!teardown_done)
177         uv_run(uv_default_loop(), UV_RUN_DEFAULT);
178 }
179 
180 /*
181  * The application is shutting down and wants to free a previously
182  * created SSL_CTX.
183  */
teardown_ctx(SSL_CTX * ctx)184 void teardown_ctx(SSL_CTX *ctx)
185 {
186     SSL_CTX_free(ctx);
187 }
188 
189 /*
190  * ============================================================================
191  * Internal implementation functions.
192  */
enqueue_upper_write_op(APP_CONN * conn,UPPER_WRITE_OP * op)193 static void enqueue_upper_write_op(APP_CONN *conn, UPPER_WRITE_OP *op)
194 {
195     op->prev = conn->pending_upper_write_tail;
196     if (op->prev)
197         op->prev->next = op;
198 
199     conn->pending_upper_write_tail = op;
200     if (conn->pending_upper_write_head == NULL)
201         conn->pending_upper_write_head = op;
202 }
203 
dequeue_upper_write_op(APP_CONN * conn)204 static void dequeue_upper_write_op(APP_CONN *conn)
205 {
206     if (conn->pending_upper_write_head == NULL)
207         return;
208 
209     if (conn->pending_upper_write_head->next == NULL) {
210         conn->pending_upper_write_head = NULL;
211         conn->pending_upper_write_tail = NULL;
212     } else {
213         conn->pending_upper_write_head = conn->pending_upper_write_head->next;
214         conn->pending_upper_write_head->prev = NULL;
215     }
216 }
217 
net_read_alloc(uv_handle_t * handle,size_t suggested_size,uv_buf_t * buf)218 static void net_read_alloc(uv_handle_t *handle,
219                            size_t suggested_size, uv_buf_t *buf)
220 {
221     buf->base = malloc(suggested_size);
222     buf->len  = suggested_size;
223 }
224 
on_rx_push(APP_CONN * conn)225 static void on_rx_push(APP_CONN *conn)
226 {
227     int srd, rc;
228     size_t buf_len = 4096;
229 
230     do {
231         if (!conn->app_read_cb)
232             return;
233 
234         void *buf = malloc(buf_len);
235         if (!buf)
236             return;
237 
238         srd = SSL_read(conn->ssl, buf, buf_len);
239         flush_write_buf(conn);
240         if (srd < 0) {
241             free(buf);
242             rc = SSL_get_error(conn->ssl, srd);
243             if (rc == SSL_ERROR_WANT_READ)
244                 return;
245         }
246 
247         conn->app_read_cb(conn, buf, srd, conn->app_read_arg);
248     } while (srd == buf_len);
249 }
250 
net_error(APP_CONN * conn)251 static void net_error(APP_CONN *conn)
252 {
253     conn->closed = 1;
254     set_rx(conn);
255 
256     if (conn->app_read_cb)
257         conn->app_read_cb(conn, NULL, 0, conn->app_read_arg);
258 }
259 
handle_pending_writes(APP_CONN * conn)260 static void handle_pending_writes(APP_CONN *conn)
261 {
262     int rc;
263 
264     if (conn->pending_upper_write_head == NULL)
265         return;
266 
267     do {
268         UPPER_WRITE_OP *op = conn->pending_upper_write_head;
269         rc = try_write(conn, op);
270         if (rc <= 0)
271             break;
272 
273         dequeue_upper_write_op(conn);
274         free(op);
275     } while (conn->pending_upper_write_head != NULL);
276 
277     set_rx(conn);
278 }
279 
net_read_done(uv_stream_t * stream,ssize_t nr,const uv_buf_t * buf)280 static void net_read_done(uv_stream_t *stream, ssize_t nr, const uv_buf_t *buf)
281 {
282     int rc;
283     APP_CONN *conn = (APP_CONN *)stream->data;
284 
285     if (nr < 0) {
286         free(buf->base);
287         net_error(conn);
288         return;
289     }
290 
291     if (nr > 0) {
292         int wr = BIO_write(conn->net_bio, buf->base, nr);
293         assert(wr == nr);
294     }
295 
296     free(buf->base);
297 
298     if (!conn->done_handshake) {
299         rc = handshake_ssl(conn);
300         if (rc < 0) {
301             fprintf(stderr, "handshake error: %d\n", rc);
302             return;
303         }
304 
305         if (!conn->done_handshake)
306             return;
307     }
308 
309     handle_pending_writes(conn);
310     on_rx_push(conn);
311 }
312 
set_rx(APP_CONN * conn)313 static void set_rx(APP_CONN *conn)
314 {
315     if (!conn->closed && (conn->app_read_cb || (!conn->done_handshake && conn->init_handshake) || conn->pending_upper_write_head != NULL))
316         uv_read_start(conn->stream, net_read_alloc, net_read_done);
317     else
318         uv_read_stop(conn->stream);
319 }
320 
net_write_done(uv_write_t * req,int status)321 static void net_write_done(uv_write_t *req, int status)
322 {
323     LOWER_WRITE_OP *op = (LOWER_WRITE_OP *)req->data;
324     APP_CONN *conn = op->conn;
325 
326     if (status < 0) {
327         fprintf(stderr, "UV write failed %d\n", status);
328         return;
329     }
330 
331     free(op->buf);
332     free(op);
333 
334     flush_write_buf(conn);
335 }
336 
flush_write_buf(APP_CONN * conn)337 static void flush_write_buf(APP_CONN *conn)
338 {
339     int rc, rd;
340     LOWER_WRITE_OP *op;
341     uint8_t *buf;
342 
343     buf = malloc(4096);
344     if (!buf)
345         return;
346 
347     rd = BIO_read(conn->net_bio, buf, 4096);
348     if (rd <= 0) {
349         free(buf);
350         return;
351     }
352 
353     op = calloc(1, sizeof(LOWER_WRITE_OP));
354     if (!op)
355         return;
356 
357     op->buf     = buf;
358     op->conn    = conn;
359     op->w.data  = op;
360     op->b.base  = (char *)buf;
361     op->b.len   = rd;
362 
363     rc = uv_write(&op->w, conn->stream, &op->b, 1, net_write_done);
364     if (rc < 0) {
365         free(buf);
366         free(op);
367         fprintf(stderr, "UV write failed\n");
368         return;
369     }
370 }
371 
handshake_done_ssl(APP_CONN * conn)372 static void handshake_done_ssl(APP_CONN *conn)
373 {
374     conn->app_connect_cb(conn, 0, conn->app_connect_arg);
375 }
376 
handshake_ssl(APP_CONN * conn)377 static int handshake_ssl(APP_CONN *conn)
378 {
379     int rc, rcx;
380 
381     conn->init_handshake = 1;
382 
383     rc = SSL_do_handshake(conn->ssl);
384     if (rc > 0) {
385         conn->done_handshake = 1;
386         handshake_done_ssl(conn);
387         set_rx(conn);
388         return 0;
389     }
390 
391     flush_write_buf(conn);
392     rcx = SSL_get_error(conn->ssl, rc);
393     if (rcx == SSL_ERROR_WANT_READ) {
394         set_rx(conn);
395         return 0;
396     }
397 
398     fprintf(stderr, "Handshake error: %d\n", rcx);
399     return -rcx;
400 }
401 
setup_ssl(APP_CONN * conn,const char * hostname)402 static int setup_ssl(APP_CONN *conn, const char *hostname)
403 {
404     BIO *internal_bio = NULL, *net_bio = NULL;
405     SSL *ssl = NULL;
406 
407     ssl = SSL_new(conn->ctx);
408     if (!ssl)
409         return -1;
410 
411     SSL_set_connect_state(ssl);
412 
413     if (BIO_new_bio_pair(&internal_bio, 0, &net_bio, 0) <= 0) {
414         SSL_free(ssl);
415         return -1;
416     }
417 
418     SSL_set_bio(ssl, internal_bio, internal_bio);
419 
420     if (SSL_set1_host(ssl, hostname) <= 0) {
421         SSL_free(ssl);
422         return -1;
423     }
424 
425     if (SSL_set_tlsext_host_name(ssl, hostname) <= 0) {
426         SSL_free(ssl);
427         return -1;
428     }
429 
430     conn->net_bio             = net_bio;
431     conn->ssl                 = ssl;
432     return handshake_ssl(conn);
433 }
434 
tcp_connect_done(uv_connect_t * tcp_connect,int status)435 static void tcp_connect_done(uv_connect_t *tcp_connect, int status)
436 {
437     int rc;
438     APP_CONN *conn = (APP_CONN *)tcp_connect->data;
439 
440     if (status < 0) {
441         uv_stop(uv_default_loop());
442         return;
443     }
444 
445     rc = setup_ssl(conn, conn->hostname);
446     if (rc < 0) {
447         fprintf(stderr, "cannot init SSL\n");
448         uv_stop(uv_default_loop());
449         return;
450     }
451 }
452 
net_connect_fail_close_done(uv_handle_t * handle)453 static void net_connect_fail_close_done(uv_handle_t *handle)
454 {
455     APP_CONN *conn = (APP_CONN *)handle->data;
456 
457     free(conn);
458 }
459 
try_write(APP_CONN * conn,UPPER_WRITE_OP * op)460 static int try_write(APP_CONN *conn, UPPER_WRITE_OP *op)
461 {
462     int rc, rcx;
463     size_t written = op->written;
464 
465     while (written < op->buf_len) {
466         rc = SSL_write(conn->ssl, op->buf + written, op->buf_len - written);
467         if (rc <= 0) {
468             rcx = SSL_get_error(conn->ssl, rc);
469             if (rcx == SSL_ERROR_WANT_READ) {
470                 op->written = written;
471                 return 0;
472             } else {
473                 if (op->cb != NULL)
474                     op->cb(conn, -rcx, op->cb_arg);
475                 return 1; /* op should be freed */
476             }
477         }
478 
479         written += rc;
480     }
481 
482     if (op->cb != NULL)
483         op->cb(conn, 0, op->cb_arg);
484 
485     flush_write_buf(conn);
486     return 1; /* op should be freed */
487 }
488 
write_deferred(APP_CONN * conn,const void * buf,size_t buf_len,app_write_cb * cb,void * arg)489 static int write_deferred(APP_CONN *conn, const void *buf, size_t buf_len, app_write_cb *cb, void *arg)
490 {
491     UPPER_WRITE_OP *op = calloc(1, sizeof(UPPER_WRITE_OP));
492     if (!op)
493         return -1;
494 
495     op->buf     = buf;
496     op->buf_len = buf_len;
497     op->conn    = conn;
498     op->cb      = cb;
499     op->cb_arg  = arg;
500 
501     enqueue_upper_write_op(conn, op);
502     set_rx(conn);
503     flush_write_buf(conn);
504     return buf_len;
505 }
506 
teardown_continued(uv_handle_t * handle)507 static void teardown_continued(uv_handle_t *handle)
508 {
509     APP_CONN *conn = (APP_CONN *)handle->data;
510     UPPER_WRITE_OP *op, *next_op;
511     char *teardown_done = conn->teardown_done;
512 
513     for (op=conn->pending_upper_write_head; op; op=next_op) {
514         next_op = op->next;
515         free(op);
516     }
517 
518     free(conn);
519     *teardown_done = 1;
520 }
521 
522 /*
523  * ============================================================================
524  * Example driver for the above code. This is just to demonstrate that the code
525  * works and is not intended to be representative of a real application.
526  */
post_read(APP_CONN * conn,void * buf,size_t buf_len,void * arg)527 static void post_read(APP_CONN *conn, void *buf, size_t buf_len, void *arg)
528 {
529     if (!buf_len) {
530         free(buf);
531         uv_stop(uv_default_loop());
532         return;
533     }
534 
535     fwrite(buf, 1, buf_len, stdout);
536     free(buf);
537 }
538 
post_write_get(APP_CONN * conn,int status,void * arg)539 static void post_write_get(APP_CONN *conn, int status, void *arg)
540 {
541     if (status < 0) {
542         fprintf(stderr, "write failed: %d\n", status);
543         return;
544     }
545 
546     app_read_start(conn, post_read, NULL);
547 }
548 
post_connect(APP_CONN * conn,int status,void * arg)549 static void post_connect(APP_CONN *conn, int status, void *arg)
550 {
551     int wr;
552     const char tx_msg[] = "GET / HTTP/1.0\r\nHost: www.openssl.org\r\n\r\n";
553 
554     if (status < 0) {
555         fprintf(stderr, "failed to connect: %d\n", status);
556         uv_stop(uv_default_loop());
557         return;
558     }
559 
560     wr = app_write(conn, tx_msg, sizeof(tx_msg)-1, post_write_get, NULL);
561     if (wr < sizeof(tx_msg)-1) {
562         fprintf(stderr, "error writing request");
563         return;
564     }
565 }
566 
main(int argc,char ** argv)567 int main(int argc, char **argv)
568 {
569     int rc = 1;
570     SSL_CTX *ctx;
571     APP_CONN *conn = NULL;
572     struct addrinfo hints = {0}, *result = NULL;
573 
574     ctx = create_ssl_ctx();
575     if (!ctx)
576         goto fail;
577 
578     hints.ai_family     = AF_INET;
579     hints.ai_socktype   = SOCK_STREAM;
580     hints.ai_flags      = AI_PASSIVE;
581     rc = getaddrinfo("www.openssl.org", "443", &hints, &result);
582     if (rc < 0) {
583         fprintf(stderr, "cannot resolve\n");
584         goto fail;
585     }
586 
587     conn = new_conn(ctx, "www.openssl.org", result->ai_addr, result->ai_addrlen, post_connect, NULL);
588     if (!conn)
589         goto fail;
590 
591     uv_run(uv_default_loop(), UV_RUN_DEFAULT);
592 
593     rc = 0;
594 fail:
595     teardown(conn);
596     freeaddrinfo(result);
597     uv_loop_close(uv_default_loop());
598     teardown_ctx(ctx);
599 }
600