2005-10-01 Marcus Brinkmann <marcus@g10code.de>
[gpgme.git] / assuan / assuan-domain-connect.c
1 /* assuan-domain-connect.c - Assuan unix domain socket based client
2  *      Copyright (C) 2002, 2003 Free Software Foundation, Inc.
3  *
4  * This file is part of Assuan.
5  *
6  * Assuan is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU Lesser General Public License as
8  * published by the Free Software Foundation; either version 2.1 of
9  * the License, or (at your option) any later version.
10  *
11  * Assuan is distributed in the hope that it will be useful, but
12  * WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this program; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA 
19  */
20
21 #ifdef HAVE_CONFIG_H
22 #include <config.h>
23 #endif
24
25 #include <stdlib.h>
26 #include <stddef.h>
27 #include <stdio.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #ifndef HAVE_W32_SYSTEM
31 #include <sys/socket.h>
32 #include <sys/un.h>
33 #else
34 #include <windows.h>
35 #endif
36 #if HAVE_SYS_UIO_H
37 #include <sys/uio.h>
38 #endif
39 #include <unistd.h>
40 #include <fcntl.h>
41 #include <string.h>
42 #include <assert.h>
43
44 #include "assuan-defs.h"
45
46 #ifndef PF_LOCAL
47 # ifdef PF_UNIX
48 #  define PF_LOCAL PF_UNIX
49 # else
50 #  define PF_LOCAL AF_UNIX
51 # endif
52 # ifndef AF_LOCAL
53 #  define AF_LOCAL AF_UNIX
54 # endif
55 #endif
56
57
58 static void
59 do_deinit (assuan_context_t ctx)
60 {
61   if (ctx->inbound.fd != -1)
62     _assuan_close (ctx->inbound.fd);
63   ctx->inbound.fd = -1;
64   ctx->outbound.fd = -1;
65
66   if (ctx->domainbuffer)
67     {
68       assert (ctx->domainbufferallocated);
69       free (ctx->domainbuffer);
70     }
71
72   if (ctx->pendingfds)
73     {
74       int i;
75
76       assert (ctx->pendingfdscount > 0);
77       for (i = 0; i < ctx->pendingfdscount; i ++)
78         _assuan_close (ctx->pendingfds[i]);
79
80       free (ctx->pendingfds);
81     }
82
83   unlink (ctx->myaddr.sun_path);
84 }
85
86
87 /* Read from the socket server.  */
88 static ssize_t
89 domain_reader (assuan_context_t ctx, void *buf, size_t buflen)
90 {
91   int len = ctx->domainbuffersize;
92
93 #ifndef HAVE_W32_SYSTEM
94  start:
95   if (len == 0)
96     /* No data is buffered.  */
97     {
98       struct msghdr msg;
99       struct iovec iovec;
100       struct sockaddr_un sender;
101       struct
102       {
103         struct cmsghdr hdr;
104         int fd;
105       }
106       cmsg;
107
108       memset (&msg, 0, sizeof (msg));
109
110       for (;;)
111         {
112           msg.msg_name = &sender;
113           msg.msg_namelen = sizeof (struct sockaddr_un);
114           msg.msg_iov = &iovec;
115           msg.msg_iovlen = 1;
116           iovec.iov_base = ctx->domainbuffer;
117           iovec.iov_len = ctx->domainbufferallocated;
118           msg.msg_control = &cmsg;
119           msg.msg_controllen = sizeof cmsg;
120
121           /* Peek first: if the buffer we have is too small then it
122              will be truncated.  */
123           len = recvmsg (ctx->inbound.fd, &msg, MSG_PEEK);
124           if (len < 0)
125             {
126               printf ("domain_reader: %m\n");
127               return -1;
128             }
129
130           if (strcmp (ctx->serveraddr.sun_path,
131                       ((struct sockaddr_un *) msg.msg_name)->sun_path) != 0)
132             {
133               /* XXX: Arg.  Not from whom we expected!  What do we
134                  want to do?  Should we just ignore it?  Either way,
135                  we still need to consume the message.  */
136               break;
137             }
138
139           if (msg.msg_flags & MSG_TRUNC)
140             /* Enlarge the buffer and try again.  */
141             {
142               int size = ctx->domainbufferallocated;
143               void *tmp;
144
145               if (size == 0)
146                 size = 4 * 1024;
147               else
148                 size *= 2;
149
150               tmp = malloc (size);
151               if (! tmp)
152                 return -1;
153
154               free (ctx->domainbuffer);
155               ctx->domainbuffer = tmp;
156               ctx->domainbufferallocated = size;
157             }
158           else
159             /* We have enough space!  */
160             break;
161         }
162
163       /* Now we have to actually consume it (remember, we only
164          peeked).  */
165       msg.msg_name = &sender;
166       msg.msg_namelen = sizeof (struct sockaddr_un);
167       msg.msg_iov = &iovec;
168       msg.msg_iovlen = 1;
169       iovec.iov_base = ctx->domainbuffer;
170       iovec.iov_len = ctx->domainbufferallocated;
171       msg.msg_control = &cmsg;
172       msg.msg_controllen = sizeof cmsg;
173
174       if (strcmp (ctx->serveraddr.sun_path,
175                   ((struct sockaddr_un *) msg.msg_name)->sun_path) != 0)
176         {
177           /* XXX: Arg.  Not from whom we expected!  What do we want to
178              do?  Should we just ignore it?  We shall do the latter
179              for the moment.  */
180           _assuan_log_printf ("not setup to receive messages from `%s'\n",
181                               ((struct sockaddr_un *) msg.msg_name)->sun_path);
182           goto start;
183         }
184
185       len = recvmsg (ctx->inbound.fd, &msg, 0);
186       if (len < 0)
187         {
188           _assuan_log_printf ("domain_reader: %s\n", strerror (errno));
189           return -1;
190         }
191
192       ctx->domainbuffersize = len;
193       ctx->domainbufferoffset = 0;
194
195       if (sizeof (cmsg) == msg.msg_controllen)
196         /* We received a file descriptor.  */
197         {
198           void *tmp;
199
200           tmp = realloc (ctx->pendingfds,
201                          sizeof (int) * (ctx->pendingfdscount + 1));
202           if (! tmp)
203             {
204               _assuan_log_printf ("domain_reader: %s\n", strerror (errno));
205               return -1;
206             }
207
208           ctx->pendingfds = tmp;
209           ctx->pendingfds[ctx->pendingfdscount++]
210             = *(int *) CMSG_DATA (&cmsg.hdr);
211
212           _assuan_log_printf ("received file descriptor %d from peer\n",
213                ctx->pendingfds[ctx->pendingfdscount - 1]);
214         }
215
216       if (len == 0)
217         goto start;
218     }
219 #else
220   len = recvfrom (ctx->inbound.fd, buf, buflen, 0, NULL, NULL);
221 #endif
222
223   /* Return some data to the user.  */
224
225   if (len > buflen)
226     /* We have more than the user requested.  */
227     len = buflen;
228
229   memcpy (buf, ctx->domainbuffer + ctx->domainbufferoffset, len);
230   ctx->domainbuffersize -= len;
231   assert (ctx->domainbuffersize >= 0);
232   ctx->domainbufferoffset += len;
233   assert (ctx->domainbufferoffset <= ctx->domainbufferallocated);
234
235   return len;
236 }
237
238 /* Write to the domain server.  */
239 static ssize_t
240 domain_writer (assuan_context_t ctx, const void *buf, size_t buflen)
241 {
242 #ifndef HAVE_W32_SYSTEM
243   struct msghdr msg;
244   struct iovec iovec;
245   ssize_t len;
246
247   memset (&msg, 0, sizeof (msg));
248
249   msg.msg_name = &ctx->serveraddr;
250   msg.msg_namelen = offsetof (struct sockaddr_un, sun_path)
251     + strlen (ctx->serveraddr.sun_path) + 1;
252
253   msg.msg_iovlen = 1;
254   msg.msg_iov = &iovec;
255   iovec.iov_base = (void *) buf;
256   iovec.iov_len = buflen;
257   msg.msg_control = 0;
258   msg.msg_controllen = 0;
259
260   len = sendmsg (ctx->outbound.fd, &msg, 0);
261   if (len < 0)
262     _assuan_log_printf ("domain_writer: %s\n", strerror (errno));
263 #else
264   int len;
265   
266   len = sendto (ctx->outbound.fd, buf, buflen, 0,
267                 (struct sockaddr *)&ctx->serveraddr,
268                 sizeof (struct sockaddr_in));
269 #endif  
270   return len;
271 }
272
273 static assuan_error_t
274 domain_sendfd (assuan_context_t ctx, int fd)
275 {
276 #ifndef HAVE_W32_SYSTEM
277   struct msghdr msg;
278   struct
279   {
280     struct cmsghdr hdr;
281     int fd;
282   }
283   cmsg;
284   int len;
285
286   memset (&msg, 0, sizeof (msg));
287
288   msg.msg_name = &ctx->serveraddr;
289   msg.msg_namelen = offsetof (struct sockaddr_un, sun_path)
290     + strlen (ctx->serveraddr.sun_path) + 1;
291
292   msg.msg_iovlen = 0;
293   msg.msg_iov = 0;
294
295   cmsg.hdr.cmsg_level = SOL_SOCKET;
296   cmsg.hdr.cmsg_type = SCM_RIGHTS;
297   cmsg.hdr.cmsg_len = sizeof (cmsg);
298
299   msg.msg_control = &cmsg;
300   msg.msg_controllen = sizeof (cmsg);
301
302   *(int *) CMSG_DATA (&cmsg.hdr) = fd;
303
304   len = sendmsg (ctx->outbound.fd, &msg, 0);
305   if (len < 0)
306     {
307       _assuan_log_printf ("domain_sendfd: %s\n", strerror (errno));
308       return ASSUAN_General_Error;
309     }
310   else
311     return 0;
312 #else
313   return 0;
314 #endif
315 }
316
317 static assuan_error_t
318 domain_receivefd (assuan_context_t ctx, int *fd)
319 {
320 #ifndef HAVE_W32_SYSTEM
321   if (ctx->pendingfds == 0)
322     {
323       _assuan_log_printf ("no pending file descriptors!\n");
324       return ASSUAN_General_Error;
325     }
326
327   *fd = ctx->pendingfds[0];
328   if (-- ctx->pendingfdscount == 0)
329     {
330       free (ctx->pendingfds);
331       ctx->pendingfds = 0;
332     }
333   else
334     /* Fix the array.  */
335     {
336       memmove (ctx->pendingfds, ctx->pendingfds + 1,
337                ctx->pendingfdscount * sizeof (int));
338       ctx->pendingfds = realloc (ctx->pendingfds,
339                                  ctx->pendingfdscount * sizeof (int));
340     }
341 #endif
342   return 0;
343 }
344
345
346
347 /* Make a connection to the Unix domain socket NAME and return a new
348    Assuan context in CTX.  SERVER_PID is currently not used but may
349    become handy in the future.  */
350 assuan_error_t
351 _assuan_domain_init (assuan_context_t *r_ctx, int rendezvousfd, pid_t peer)
352 {
353   static struct assuan_io io = { domain_reader, domain_writer,
354                                  domain_sendfd, domain_receivefd };
355
356   assuan_error_t err;
357   assuan_context_t ctx;
358   int fd;
359   size_t len;
360   int tries;
361
362   if (!r_ctx)
363     return ASSUAN_Invalid_Value;
364   *r_ctx = NULL;
365
366   err = _assuan_new_context (&ctx); 
367   if (err)
368     return err;
369
370   /* Save it in case we need it later.  */
371   ctx->pid = peer;
372
373   /* Override the default (NOP) handlers.  */
374   ctx->deinit_handler = do_deinit;
375
376   /* Setup the socket.  */
377
378   fd = _assuan_sock_new (PF_LOCAL, SOCK_DGRAM, 0);
379   if (fd == -1)
380     {
381       _assuan_log_printf ("can't create socket: %s\n", strerror (errno));
382       _assuan_release_context (ctx);
383       return ASSUAN_General_Error;
384     }
385   
386   ctx->inbound.fd = fd;
387   ctx->outbound.fd = fd;
388   
389   /* And the io buffers.  */
390
391   ctx->io = &io;
392   ctx->domainbuffer = 0;
393   ctx->domainbufferoffset = 0;
394   ctx->domainbuffersize = 0;
395   ctx->domainbufferallocated = 0;
396   ctx->pendingfds = 0;
397   ctx->pendingfdscount = 0;
398
399   /* Get usable name and bind to it.  */
400
401   for (tries = 0; tries < TMP_MAX; tries ++)
402     {
403       char *p;
404       char buf[L_tmpnam];
405
406       /* XXX: L_tmpnam must be shorter than sizeof (sun_path)!  */
407       assert (L_tmpnam < sizeof (ctx->myaddr.sun_path));
408
409       /* XXX: W32 tmpnam is broken */
410       p = tmpnam (buf);
411       if (! p)
412         {
413           _assuan_log_printf ("cannot determine an appropriate temporary file "
414             "name.  DoS in progress?\n");
415           _assuan_release_context (ctx);
416           _assuan_close (fd);
417           return ASSUAN_General_Error;
418         }
419
420       memset (&ctx->myaddr, 0, sizeof ctx->myaddr);
421       ctx->myaddr.sun_family = AF_LOCAL;
422       len = strlen (buf) + 1;
423       memcpy (ctx->myaddr.sun_path, buf, len);
424       len += offsetof (struct sockaddr_un, sun_path);
425
426       err = _assuan_sock_bind (fd, (struct sockaddr *) &ctx->myaddr, len);
427       if (! err)
428         break;
429     }
430
431   if (err)
432     {
433       _assuan_log_printf ("can't bind to `%s': %s\n", ctx->myaddr.sun_path,
434            strerror (errno));
435       _assuan_release_context (ctx);
436       _assuan_close (fd);
437       return ASSUAN_Connect_Failed;
438     }
439
440   /* Rendezvous with our peer.  */
441   {
442     FILE *fp;
443     char *p;
444
445     fp = fdopen (rendezvousfd, "w+");
446     if (! fp)
447       {
448         _assuan_log_printf ("can't open rendezvous port: %s\n", strerror (errno));
449         return ASSUAN_Connect_Failed;
450       }
451
452     /* Send our address.  */
453     fprintf (fp, "%s\n", ctx->myaddr.sun_path);
454     fflush (fp);
455
456     /* And receive our peer's.  */
457     memset (&ctx->serveraddr, 0, sizeof ctx->serveraddr);
458     for (p = ctx->serveraddr.sun_path;
459          p < (ctx->serveraddr.sun_path
460               + sizeof ctx->serveraddr.sun_path - 1);
461          p ++)
462       {
463         *p = fgetc (fp);
464         if (*p == '\n')
465           break;
466       }
467     *p = '\0';
468     fclose (fp);
469
470     ctx->serveraddr.sun_family = AF_LOCAL;
471   }
472
473   *r_ctx = ctx;
474   return 0;
475 }
476
477 assuan_error_t
478 assuan_domain_connect (assuan_context_t * r_ctx, int rendezvousfd, pid_t peer)
479 {
480   assuan_error_t aerr;
481   int okay, off;
482
483   aerr = _assuan_domain_init (r_ctx, rendezvousfd, peer);
484   if (aerr)
485     return aerr;
486
487   /* Initial handshake.  */
488   aerr = _assuan_read_from_server (*r_ctx, &okay, &off);
489   if (aerr)
490     _assuan_log_printf ("can't connect to server: %s\n",
491                         assuan_strerror (aerr));
492   else if (okay != 1)
493     {
494       _assuan_log_printf ("can't connect to server: `");
495       _assuan_log_sanitized_string ((*r_ctx)->inbound.line);
496       fprintf (assuan_get_assuan_log_stream (), "'\n");
497       aerr = ASSUAN_Connect_Failed;
498     }
499
500   if (aerr)
501     assuan_disconnect (*r_ctx);
502
503   return aerr;
504 }