2004-06-23 Marcus Brinkmann <marcus@g10code.de>
[gpgme.git] / assuan / assuan-domain-connect.c
1 /* assuan-domain-connect.c - Assuan unix domain socket based client
2  *      Copyright (C) 2002, 2003 Free Software Foundation, Inc.
3  *
4  * This file is part of Assuan.
5  *
6  * Assuan is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU Lesser General Public License as
8  * published by the Free Software Foundation; either version 2.1 of
9  * the License, or (at your option) any later version.
10  *
11  * Assuan is distributed in the hope that it will be useful, but
12  * WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this program; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA 
19  */
20
21 #ifdef HAVE_CONFIG_H
22 #include <config.h>
23 #endif
24
25 #include <stdlib.h>
26 #include <stddef.h>
27 #include <stdio.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #if HAVE_SYS_UIO_H
33 #include <sys/uio.h>
34 #endif
35 #include <unistd.h>
36 #include <fcntl.h>
37 #include <string.h>
38 #include <assert.h>
39
40 #include "assuan-defs.h"
41
42 #define LOG(format, args...) \
43         fprintf (assuan_get_assuan_log_stream (), \
44                  assuan_get_assuan_log_prefix (), \
45                  "%s" format , ## args)
46
47
48 static void
49 do_deinit (ASSUAN_CONTEXT ctx)
50 {
51   if (ctx->inbound.fd != -1)
52     close (ctx->inbound.fd);
53   ctx->inbound.fd = -1;
54   ctx->outbound.fd = -1;
55
56   if (ctx->domainbuffer)
57     {
58       assert (ctx->domainbufferallocated);
59       free (ctx->domainbuffer);
60     }
61
62   if (ctx->pendingfds)
63     {
64       int i;
65
66       assert (ctx->pendingfdscount > 0);
67       for (i = 0; i < ctx->pendingfdscount; i ++)
68         close (ctx->pendingfds[i]);
69
70       free (ctx->pendingfds);
71     }
72
73   unlink (ctx->myaddr.sun_path);
74 }
75
76
77 /* Read from the socket server.  */
78 static ssize_t
79 domain_reader (ASSUAN_CONTEXT ctx, void *buf, size_t buflen)
80 {
81   int len = ctx->domainbuffersize;
82
83  start:
84   if (len == 0)
85     /* No data is buffered.  */
86     {
87       struct msghdr msg;
88       struct iovec iovec;
89       struct sockaddr_un sender;
90       struct
91       {
92         struct cmsghdr hdr;
93         int fd;
94       }
95       cmsg;
96
97       memset (&msg, 0, sizeof (msg));
98
99       for (;;)
100         {
101           msg.msg_name = &sender;
102           msg.msg_namelen = sizeof (struct sockaddr_un);
103           msg.msg_iov = &iovec;
104           msg.msg_iovlen = 1;
105           iovec.iov_base = ctx->domainbuffer;
106           iovec.iov_len = ctx->domainbufferallocated;
107           msg.msg_control = &cmsg;
108           msg.msg_controllen = sizeof cmsg;
109
110           /* Peek first: if the buffer we have is too small then it
111              will be truncated.  */
112           len = recvmsg (ctx->inbound.fd, &msg, MSG_PEEK);
113           if (len < 0)
114             {
115               printf ("domain_reader: %m\n");
116               return -1;
117             }
118
119           if (strcmp (ctx->serveraddr.sun_path,
120                       ((struct sockaddr_un *) msg.msg_name)->sun_path) != 0)
121             {
122               /* XXX: Arg.  Not from whom we expected!  What do we
123                  want to do?  Should we just ignore it?  Either way,
124                  we still need to consume the message.  */
125               break;
126             }
127
128           if (msg.msg_flags & MSG_TRUNC)
129             /* Enlarge the buffer and try again.  */
130             {
131               int size = ctx->domainbufferallocated;
132               void *tmp;
133
134               if (size == 0)
135                 size = 4 * 1024;
136               else
137                 size *= 2;
138
139               tmp = malloc (size);
140               if (! tmp)
141                 return -1;
142
143               free (ctx->domainbuffer);
144               ctx->domainbuffer = tmp;
145               ctx->domainbufferallocated = size;
146             }
147           else
148             /* We have enough space!  */
149             break;
150         }
151
152       /* Now we have to actually consume it (remember, we only
153          peeked).  */
154       msg.msg_name = &sender;
155       msg.msg_namelen = sizeof (struct sockaddr_un);
156       msg.msg_iov = &iovec;
157       msg.msg_iovlen = 1;
158       iovec.iov_base = ctx->domainbuffer;
159       iovec.iov_len = ctx->domainbufferallocated;
160       msg.msg_control = &cmsg;
161       msg.msg_controllen = sizeof cmsg;
162
163       if (strcmp (ctx->serveraddr.sun_path,
164                   ((struct sockaddr_un *) msg.msg_name)->sun_path) != 0)
165         {
166           /* XXX: Arg.  Not from whom we expected!  What do we want to
167              do?  Should we just ignore it?  We shall do the latter
168              for the moment.  */
169           LOG ("Not setup to receive messages from: `%s'.",
170                ((struct sockaddr_un *) msg.msg_name)->sun_path);
171           goto start;
172         }
173
174       len = recvmsg (ctx->inbound.fd, &msg, 0);
175       if (len < 0)
176         {
177           LOG ("domain_reader: %s\n", strerror (errno));
178           return -1;
179         }
180
181       ctx->domainbuffersize = len;
182       ctx->domainbufferoffset = 0;
183
184       if (sizeof (cmsg) == msg.msg_controllen)
185         /* We received a file descriptor.  */
186         {
187           void *tmp;
188
189           tmp = realloc (ctx->pendingfds,
190                          sizeof (int) * (ctx->pendingfdscount + 1));
191           if (! tmp)
192             {
193               LOG ("domain_reader: %s\n", strerror (errno));
194               return -1;
195             }
196
197           ctx->pendingfds = tmp;
198           ctx->pendingfds[ctx->pendingfdscount++]
199             = *(int *) CMSG_DATA (&cmsg.hdr);
200
201           LOG ("Received file descriptor %d from peer.\n",
202                ctx->pendingfds[ctx->pendingfdscount - 1]);
203         }
204
205       if (len == 0)
206         goto start;
207     }
208
209   /* Return some data to the user.  */
210
211   if (len > buflen)
212     /* We have more than the user requested.  */
213     len = buflen;
214
215   memcpy (buf, ctx->domainbuffer + ctx->domainbufferoffset, len);
216   ctx->domainbuffersize -= len;
217   assert (ctx->domainbuffersize >= 0);
218   ctx->domainbufferoffset += len;
219   assert (ctx->domainbufferoffset <= ctx->domainbufferallocated);
220
221   return len;
222 }
223
224 /* Write to the domain server.  */
225 static ssize_t
226 domain_writer (ASSUAN_CONTEXT ctx, const void *buf, size_t buflen)
227 {
228   struct msghdr msg;
229   struct iovec iovec;
230   ssize_t len;
231
232   memset (&msg, 0, sizeof (msg));
233
234   msg.msg_name = &ctx->serveraddr;
235   msg.msg_namelen = offsetof (struct sockaddr_un, sun_path)
236     + strlen (ctx->serveraddr.sun_path) + 1;
237
238   msg.msg_iovlen = 1;
239   msg.msg_iov = &iovec;
240   iovec.iov_base = (void *) buf;
241   iovec.iov_len = buflen;
242   msg.msg_control = 0;
243   msg.msg_controllen = 0;
244
245   len = sendmsg (ctx->outbound.fd, &msg, 0);
246   if (len < 0)
247     LOG ("domain_writer: %s\n", strerror (errno));
248
249   return len;
250 }
251
252 static AssuanError
253 domain_sendfd (ASSUAN_CONTEXT ctx, int fd)
254 {
255   struct msghdr msg;
256   struct
257   {
258     struct cmsghdr hdr;
259     int fd;
260   }
261   cmsg;
262   int len;
263
264   memset (&msg, 0, sizeof (msg));
265
266   msg.msg_name = &ctx->serveraddr;
267   msg.msg_namelen = offsetof (struct sockaddr_un, sun_path)
268     + strlen (ctx->serveraddr.sun_path) + 1;
269
270   msg.msg_iovlen = 0;
271   msg.msg_iov = 0;
272
273   cmsg.hdr.cmsg_level = SOL_SOCKET;
274   cmsg.hdr.cmsg_type = SCM_RIGHTS;
275   cmsg.hdr.cmsg_len = sizeof (cmsg);
276
277   msg.msg_control = &cmsg;
278   msg.msg_controllen = sizeof (cmsg);
279
280   *(int *) CMSG_DATA (&cmsg.hdr) = fd;
281
282   len = sendmsg (ctx->outbound.fd, &msg, 0);
283   if (len < 0)
284     {
285       LOG ("domain_sendfd: %s\n", strerror (errno));
286       return ASSUAN_General_Error;
287     }
288   else
289     return 0;
290 }
291
292 static AssuanError
293 domain_receivefd (ASSUAN_CONTEXT ctx, int *fd)
294 {
295   if (ctx->pendingfds == 0)
296     {
297       LOG ("No pending file descriptors!\n");
298       return ASSUAN_General_Error;
299     }
300
301   *fd = ctx->pendingfds[0];
302   if (-- ctx->pendingfdscount == 0)
303     {
304       free (ctx->pendingfds);
305       ctx->pendingfds = 0;
306     }
307   else
308     /* Fix the array.  */
309     {
310       memmove (ctx->pendingfds, ctx->pendingfds + 1,
311                ctx->pendingfdscount * sizeof (int));
312       ctx->pendingfds = realloc (ctx->pendingfds,
313                                  ctx->pendingfdscount * sizeof (int));
314     }
315
316   return 0;
317 }
318
319
320
321 /* Make a connection to the Unix domain socket NAME and return a new
322    Assuan context in CTX.  SERVER_PID is currently not used but may
323    become handy in the future.  */
324 AssuanError
325 _assuan_domain_init (ASSUAN_CONTEXT *r_ctx, int rendezvousfd, pid_t peer)
326 {
327   static struct assuan_io io = { domain_reader, domain_writer,
328                                  domain_sendfd, domain_receivefd };
329
330   AssuanError err;
331   ASSUAN_CONTEXT ctx;
332   int fd;
333   size_t len;
334   int tries;
335
336   if (!r_ctx)
337     return ASSUAN_Invalid_Value;
338   *r_ctx = NULL;
339
340   err = _assuan_new_context (&ctx); 
341   if (err)
342     return err;
343
344   /* Save it in case we need it later.  */
345   ctx->pid = peer;
346
347   /* Override the default (NOP) handlers.  */
348   ctx->deinit_handler = do_deinit;
349
350   /* Setup the socket.  */
351
352   fd = socket (PF_LOCAL, SOCK_DGRAM, 0);
353   if (fd == -1)
354     {
355       LOG ("can't create socket: %s\n", strerror (errno));
356       _assuan_release_context (ctx);
357       return ASSUAN_General_Error;
358     }
359
360   ctx->inbound.fd = fd;
361   ctx->outbound.fd = fd;
362
363   /* And the io buffers.  */
364
365   ctx->io = &io;
366   ctx->domainbuffer = 0;
367   ctx->domainbufferoffset = 0;
368   ctx->domainbuffersize = 0;
369   ctx->domainbufferallocated = 0;
370   ctx->pendingfds = 0;
371   ctx->pendingfdscount = 0;
372
373   /* Get usable name and bind to it.  */
374
375   for (tries = 0; tries < TMP_MAX; tries ++)
376     {
377       char *p;
378       char buf[L_tmpnam];
379
380       /* XXX: L_tmpnam must be shorter than sizeof (sun_path)!  */
381       assert (L_tmpnam < sizeof (ctx->myaddr.sun_path));
382
383       p = tmpnam (buf);
384       if (! p)
385         {
386           LOG ("cannot determine an appropriate temporary file "
387                "name.  DOS in progress?\n");
388           _assuan_release_context (ctx);
389           close (fd);
390           return ASSUAN_General_Error;
391         }
392
393       memset (&ctx->myaddr, 0, sizeof ctx->myaddr);
394       ctx->myaddr.sun_family = AF_LOCAL;
395       len = strlen (buf) + 1;
396       memcpy (ctx->myaddr.sun_path, buf, len);
397       len += offsetof (struct sockaddr_un, sun_path);
398
399       err = bind (fd, (struct sockaddr *) &ctx->myaddr, len);
400       if (! err)
401         break;
402     }
403
404   if (err)
405     {
406       LOG ("can't bind to `%s': %s\n", ctx->myaddr.sun_path,
407            strerror (errno));
408       _assuan_release_context (ctx);
409       close (fd);
410       return ASSUAN_Connect_Failed;
411     }
412
413   /* Rendezvous with our peer.  */
414   {
415     FILE *fp;
416     char *p;
417
418     fp = fdopen (rendezvousfd, "w+");
419     if (! fp)
420       {
421         LOG ("can't open rendezvous port: %s\n", strerror (errno));
422         return ASSUAN_Connect_Failed;
423       }
424
425     /* Send our address.  */
426     fprintf (fp, "%s\n", ctx->myaddr.sun_path);
427     fflush (fp);
428
429     /* And receive our peer's.  */
430     memset (&ctx->serveraddr, 0, sizeof ctx->serveraddr);
431     for (p = ctx->serveraddr.sun_path;
432          p < (ctx->serveraddr.sun_path
433               + sizeof ctx->serveraddr.sun_path - 1);
434          p ++)
435       {
436         *p = fgetc (fp);
437         if (*p == '\n')
438           break;
439       }
440     *p = '\0';
441     fclose (fp);
442
443     ctx->serveraddr.sun_family = AF_LOCAL;
444   }
445
446   *r_ctx = ctx;
447   return 0;
448 }
449
450 AssuanError
451 assuan_domain_connect (ASSUAN_CONTEXT * r_ctx, int rendezvousfd, pid_t peer)
452 {
453   AssuanError aerr;
454   int okay, off;
455
456   aerr = _assuan_domain_init (r_ctx, rendezvousfd, peer);
457   if (aerr)
458     return aerr;
459
460   /* Initial handshake.  */
461   aerr = _assuan_read_from_server (*r_ctx, &okay, &off);
462   if (aerr)
463     LOG ("can't connect to server: %s\n", assuan_strerror (aerr));
464   else if (okay != 1)
465     {
466       LOG ("can't connect to server: `");
467       _assuan_log_sanitized_string ((*r_ctx)->inbound.line);
468       fprintf (assuan_get_assuan_log_stream (), "'\n");
469       aerr = ASSUAN_Connect_Failed;
470     }
471
472   if (aerr)
473     assuan_disconnect (*r_ctx);
474
475   return aerr;
476 }