ea35ba041c31bafd9ecff834f576961ca9f82af9
[gpgme.git] / assuan / assuan-domain-connect.c
1 /* assuan-domain-connect.c - Assuan unix domain socket based client
2  *      Copyright (C) 2002, 2003 Free Software Foundation, Inc.
3  *
4  * This file is part of Assuan.
5  *
6  * Assuan is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU Lesser General Public License as
8  * published by the Free Software Foundation; either version 2.1 of
9  * the License, or (at your option) any later version.
10  *
11  * Assuan is distributed in the hope that it will be useful, but
12  * WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this program; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA 
19  */
20
21 #ifdef HAVE_CONFIG_H
22 #include <config.h>
23 #endif
24
25 #include <stdlib.h>
26 #include <stddef.h>
27 #include <stdio.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <string.h>
35 #include <assert.h>
36
37 #include "assuan-defs.h"
38
39 #define LOG(format, args...) \
40         fprintf (assuan_get_assuan_log_stream (), \
41                  assuan_get_assuan_log_prefix (), \
42                  "%s" format , ## args)
43
44
45 static void
46 do_deinit (ASSUAN_CONTEXT ctx)
47 {
48   if (ctx->inbound.fd != -1)
49     close (ctx->inbound.fd);
50   ctx->inbound.fd = -1;
51   ctx->outbound.fd = -1;
52
53   if (ctx->domainbuffer)
54     {
55       assert (ctx->domainbufferallocated);
56       free (ctx->domainbuffer);
57     }
58
59   if (ctx->pendingfds)
60     {
61       int i;
62
63       assert (ctx->pendingfdscount > 0);
64       for (i = 0; i < ctx->pendingfdscount; i ++)
65         close (ctx->pendingfds[i]);
66
67       free (ctx->pendingfds);
68     }
69
70   unlink (ctx->myaddr.sun_path);
71 }
72
73
74 /* Read from the socket server.  */
75 static ssize_t
76 domain_reader (ASSUAN_CONTEXT ctx, void *buf, size_t buflen)
77 {
78   int len = ctx->domainbuffersize;
79
80  start:
81   if (len == 0)
82     /* No data is buffered.  */
83     {
84       struct msghdr msg;
85       struct iovec iovec;
86       struct sockaddr_un sender;
87       struct
88       {
89         struct cmsghdr hdr;
90         int fd;
91       }
92       cmsg;
93
94       memset (&msg, 0, sizeof (msg));
95
96       for (;;)
97         {
98           msg.msg_name = &sender;
99           msg.msg_namelen = sizeof (struct sockaddr_un);
100           msg.msg_iov = &iovec;
101           msg.msg_iovlen = 1;
102           iovec.iov_base = ctx->domainbuffer;
103           iovec.iov_len = ctx->domainbufferallocated;
104           msg.msg_control = &cmsg;
105           msg.msg_controllen = sizeof cmsg;
106
107           /* Peek first: if the buffer we have is too small then it
108              will be truncated.  */
109           len = recvmsg (ctx->inbound.fd, &msg, MSG_PEEK);
110           if (len < 0)
111             {
112               printf ("domain_reader: %m\n");
113               return -1;
114             }
115
116           if (strcmp (ctx->serveraddr.sun_path,
117                       ((struct sockaddr_un *) msg.msg_name)->sun_path) != 0)
118             {
119               /* XXX: Arg.  Not from whom we expected!  What do we
120                  want to do?  Should we just ignore it?  Either way,
121                  we still need to consume the message.  */
122               break;
123             }
124
125           if (msg.msg_flags & MSG_TRUNC)
126             /* Enlarge the buffer and try again.  */
127             {
128               int size = ctx->domainbufferallocated;
129               void *tmp;
130
131               if (size == 0)
132                 size = 4 * 1024;
133               else
134                 size *= 2;
135
136               tmp = malloc (size);
137               if (! tmp)
138                 return -1;
139
140               free (ctx->domainbuffer);
141               ctx->domainbuffer = tmp;
142               ctx->domainbufferallocated = size;
143             }
144           else
145             /* We have enough space!  */
146             break;
147         }
148
149       /* Now we have to actually consume it (remember, we only
150          peeked).  */
151       msg.msg_name = &sender;
152       msg.msg_namelen = sizeof (struct sockaddr_un);
153       msg.msg_iov = &iovec;
154       msg.msg_iovlen = 1;
155       iovec.iov_base = ctx->domainbuffer;
156       iovec.iov_len = ctx->domainbufferallocated;
157       msg.msg_control = &cmsg;
158       msg.msg_controllen = sizeof cmsg;
159
160       if (strcmp (ctx->serveraddr.sun_path,
161                   ((struct sockaddr_un *) msg.msg_name)->sun_path) != 0)
162         {
163           /* XXX: Arg.  Not from whom we expected!  What do we want to
164              do?  Should we just ignore it?  We shall do the latter
165              for the moment.  */
166           LOG ("Not setup to receive messages from: `%s'.",
167                ((struct sockaddr_un *) msg.msg_name)->sun_path);
168           goto start;
169         }
170
171       len = recvmsg (ctx->inbound.fd, &msg, 0);
172       if (len < 0)
173         {
174           LOG ("domain_reader: %s\n", strerror (errno));
175           return -1;
176         }
177
178       ctx->domainbuffersize = len;
179       ctx->domainbufferoffset = 0;
180
181       if (sizeof (cmsg) == msg.msg_controllen)
182         /* We received a file descriptor.  */
183         {
184           void *tmp;
185
186           tmp = realloc (ctx->pendingfds,
187                          sizeof (int) * (ctx->pendingfdscount + 1));
188           if (! tmp)
189             {
190               LOG ("domain_reader: %s\n", strerror (errno));
191               return -1;
192             }
193
194           ctx->pendingfds = tmp;
195           ctx->pendingfds[ctx->pendingfdscount++]
196             = *(int *) CMSG_DATA (&cmsg.hdr);
197
198           LOG ("Received file descriptor %d from peer.\n",
199                ctx->pendingfds[ctx->pendingfdscount - 1]);
200         }
201
202       if (len == 0)
203         goto start;
204     }
205
206   /* Return some data to the user.  */
207
208   if (len > buflen)
209     /* We have more than the user requested.  */
210     len = buflen;
211
212   memcpy (buf, ctx->domainbuffer + ctx->domainbufferoffset, len);
213   ctx->domainbuffersize -= len;
214   assert (ctx->domainbuffersize >= 0);
215   ctx->domainbufferoffset += len;
216   assert (ctx->domainbufferoffset <= ctx->domainbufferallocated);
217
218   return len;
219 }
220
221 /* Write to the domain server.  */
222 static ssize_t
223 domain_writer (ASSUAN_CONTEXT ctx, const void *buf, size_t buflen)
224 {
225   struct msghdr msg;
226   struct iovec iovec;
227   ssize_t len;
228
229   memset (&msg, 0, sizeof (msg));
230
231   msg.msg_name = &ctx->serveraddr;
232   msg.msg_namelen = offsetof (struct sockaddr_un, sun_path)
233     + strlen (ctx->serveraddr.sun_path) + 1;
234
235   msg.msg_iovlen = 1;
236   msg.msg_iov = &iovec;
237   iovec.iov_base = (void *) buf;
238   iovec.iov_len = buflen;
239   msg.msg_control = 0;
240   msg.msg_controllen = 0;
241
242   len = sendmsg (ctx->outbound.fd, &msg, 0);
243   if (len < 0)
244     LOG ("domain_writer: %s\n", strerror (errno));
245
246   return len;
247 }
248
249 static AssuanError
250 domain_sendfd (ASSUAN_CONTEXT ctx, int fd)
251 {
252   struct msghdr msg;
253   struct
254   {
255     struct cmsghdr hdr;
256     int fd;
257   }
258   cmsg;
259   int len;
260
261   memset (&msg, 0, sizeof (msg));
262
263   msg.msg_name = &ctx->serveraddr;
264   msg.msg_namelen = offsetof (struct sockaddr_un, sun_path)
265     + strlen (ctx->serveraddr.sun_path) + 1;
266
267   msg.msg_iovlen = 0;
268   msg.msg_iov = 0;
269
270   cmsg.hdr.cmsg_level = SOL_SOCKET;
271   cmsg.hdr.cmsg_type = SCM_RIGHTS;
272   cmsg.hdr.cmsg_len = sizeof (cmsg);
273
274   msg.msg_control = &cmsg;
275   msg.msg_controllen = sizeof (cmsg);
276
277   *(int *) CMSG_DATA (&cmsg.hdr) = fd;
278
279   len = sendmsg (ctx->outbound.fd, &msg, 0);
280   if (len < 0)
281     {
282       LOG ("domain_sendfd: %s\n", strerror (errno));
283       return ASSUAN_General_Error;
284     }
285   else
286     return 0;
287 }
288
289 static AssuanError
290 domain_receivefd (ASSUAN_CONTEXT ctx, int *fd)
291 {
292   if (ctx->pendingfds == 0)
293     {
294       LOG ("No pending file descriptors!\n");
295       return ASSUAN_General_Error;
296     }
297
298   *fd = ctx->pendingfds[0];
299   if (-- ctx->pendingfdscount == 0)
300     {
301       free (ctx->pendingfds);
302       ctx->pendingfds = 0;
303     }
304   else
305     /* Fix the array.  */
306     {
307       memmove (ctx->pendingfds, ctx->pendingfds + 1,
308                ctx->pendingfdscount * sizeof (int));
309       ctx->pendingfds = realloc (ctx->pendingfds,
310                                  ctx->pendingfdscount * sizeof (int));
311     }
312
313   return 0;
314 }
315
316
317
318 /* Make a connection to the Unix domain socket NAME and return a new
319    Assuan context in CTX.  SERVER_PID is currently not used but may
320    become handy in the future.  */
321 AssuanError
322 _assuan_domain_init (ASSUAN_CONTEXT *r_ctx, int rendezvousfd, pid_t peer)
323 {
324   static struct assuan_io io = { domain_reader, domain_writer,
325                                  domain_sendfd, domain_receivefd };
326
327   AssuanError err;
328   ASSUAN_CONTEXT ctx;
329   int fd;
330   size_t len;
331   int tries;
332
333   if (!r_ctx)
334     return ASSUAN_Invalid_Value;
335   *r_ctx = NULL;
336
337   err = _assuan_new_context (&ctx); 
338   if (err)
339     return err;
340
341   /* Save it in case we need it later.  */
342   ctx->pid = peer;
343
344   /* Override the default (NOP) handlers.  */
345   ctx->deinit_handler = do_deinit;
346
347   /* Setup the socket.  */
348
349   fd = socket (PF_LOCAL, SOCK_DGRAM, 0);
350   if (fd == -1)
351     {
352       LOG ("can't create socket: %s\n", strerror (errno));
353       _assuan_release_context (ctx);
354       return ASSUAN_General_Error;
355     }
356
357   ctx->inbound.fd = fd;
358   ctx->outbound.fd = fd;
359
360   /* And the io buffers.  */
361
362   ctx->io = &io;
363   ctx->domainbuffer = 0;
364   ctx->domainbufferoffset = 0;
365   ctx->domainbuffersize = 0;
366   ctx->domainbufferallocated = 0;
367   ctx->pendingfds = 0;
368   ctx->pendingfdscount = 0;
369
370   /* Get usable name and bind to it.  */
371
372   for (tries = 0; tries < TMP_MAX; tries ++)
373     {
374       char *p;
375       char buf[L_tmpnam];
376
377       /* XXX: L_tmpnam must be shorter than sizeof (sun_path)!  */
378       assert (L_tmpnam < sizeof (ctx->myaddr.sun_path));
379
380       p = tmpnam (buf);
381       if (! p)
382         {
383           LOG ("cannot determine an appropriate temporary file "
384                "name.  DOS in progress?\n");
385           _assuan_release_context (ctx);
386           close (fd);
387           return ASSUAN_General_Error;
388         }
389
390       memset (&ctx->myaddr, 0, sizeof ctx->myaddr);
391       ctx->myaddr.sun_family = AF_LOCAL;
392       len = strlen (buf) + 1;
393       memcpy (ctx->myaddr.sun_path, buf, len);
394       len += offsetof (struct sockaddr_un, sun_path);
395
396       err = bind (fd, (struct sockaddr *) &ctx->myaddr, len);
397       if (! err)
398         break;
399     }
400
401   if (err)
402     {
403       LOG ("can't bind to `%s': %s\n", ctx->myaddr.sun_path,
404            strerror (errno));
405       _assuan_release_context (ctx);
406       close (fd);
407       return ASSUAN_Connect_Failed;
408     }
409
410   /* Rendezvous with our peer.  */
411   {
412     FILE *fp;
413     char *p;
414
415     fp = fdopen (rendezvousfd, "w+");
416     if (! fp)
417       {
418         LOG ("can't open rendezvous port: %s\n", strerror (errno));
419         return ASSUAN_Connect_Failed;
420       }
421
422     /* Send our address.  */
423     fprintf (fp, "%s\n", ctx->myaddr.sun_path);
424     fflush (fp);
425
426     /* And receive our peer's.  */
427     memset (&ctx->serveraddr, 0, sizeof ctx->serveraddr);
428     for (p = ctx->serveraddr.sun_path;
429          p < (ctx->serveraddr.sun_path
430               + sizeof ctx->serveraddr.sun_path - 1);
431          p ++)
432       {
433         *p = fgetc (fp);
434         if (*p == '\n')
435           break;
436       }
437     *p = '\0';
438     fclose (fp);
439
440     ctx->serveraddr.sun_family = AF_LOCAL;
441   }
442
443   *r_ctx = ctx;
444   return 0;
445 }
446
447 AssuanError
448 assuan_domain_connect (ASSUAN_CONTEXT * r_ctx, int rendezvousfd, pid_t peer)
449 {
450   AssuanError aerr;
451   int okay, off;
452
453   aerr = _assuan_domain_init (r_ctx, rendezvousfd, peer);
454   if (aerr)
455     return aerr;
456
457   /* Initial handshake.  */
458   aerr = _assuan_read_from_server (*r_ctx, &okay, &off);
459   if (aerr)
460     LOG ("can't connect to server: %s\n", assuan_strerror (aerr));
461   else if (okay != 1)
462     {
463       LOG ("can't connect to server: `");
464       _assuan_log_sanitized_string ((*r_ctx)->inbound.line);
465       fprintf (assuan_get_assuan_log_stream (), "'\n");
466       aerr = ASSUAN_Connect_Failed;
467     }
468
469   if (aerr)
470     assuan_disconnect (*r_ctx);
471
472   return aerr;
473 }