Make ntbtls-cli work for W32.
[ntbtls.git] / src / ntbtls-cli.c
1 /* ntbtls-cli.h - NTBTLS client test program
2  * Copyright (C) 2014 g10 Code GmbH
3  *
4  * This file is part of NTBTLS
5  *
6  * NTBTLS is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * NTBTLS is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program; if not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #include <config.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 #include <stdarg.h>
25
26 #include <unistd.h>
27 #include <errno.h>
28 #include <sys/types.h>
29 #ifdef HAVE_W32_SYSTEM
30 # define WIN32_LEAN_AND_MEAN
31 # ifdef HAVE_WINSOCK2_H
32 #  include <winsock2.h>
33 # endif
34 # include <windows.h>
35 #else
36 # include <sys/socket.h>
37 # include <netinet/in.h>
38 # include <arpa/inet.h>
39 # include <netdb.h>
40 #endif
41
42 #include "ntbtls.h"
43
44 #define PGMNAME "ntbtls-cli"
45
46 static int verbose;
47 static int errorcount;
48 static char *opt_hostname;
49 static int opt_head;
50
51
52 \f
53 /*
54  * Reporting functions.
55  */
56 static void
57 die (const char *format, ...)
58 {
59   va_list arg_ptr ;
60
61   fflush (stdout);
62 #ifdef HAVE_FLOCKFILE
63   flockfile (stderr);
64 #endif
65   fprintf (stderr, "%s: ", PGMNAME);
66   va_start (arg_ptr, format) ;
67   vfprintf (stderr, format, arg_ptr);
68   va_end (arg_ptr);
69   if (*format && format[strlen(format)-1] != '\n')
70     putc ('\n', stderr);
71 #ifdef HAVE_FLOCKFILE
72   funlockfile (stderr);
73 #endif
74   exit (1);
75 }
76
77
78 static void
79 fail (const char *format, ...)
80 {
81   va_list arg_ptr;
82
83   fflush (stdout);
84 #ifdef HAVE_FLOCKFILE
85   flockfile (stderr);
86 #endif
87   fprintf (stderr, "%s: ", PGMNAME);
88   va_start (arg_ptr, format);
89   vfprintf (stderr, format, arg_ptr);
90   va_end (arg_ptr);
91   if (*format && format[strlen(format)-1] != '\n')
92     putc ('\n', stderr);
93 #ifdef HAVE_FLOCKFILE
94   funlockfile (stderr);
95 #endif
96   errorcount++;
97   if (errorcount >= 50)
98     die ("stopped after 50 errors.");
99 }
100
101
102 static void
103 info (const char *format, ...)
104 {
105   va_list arg_ptr;
106
107   if (!verbose)
108     return;
109 #ifdef HAVE_FLOCKFILE
110   flockfile (stderr);
111 #endif
112   fprintf (stderr, "%s: ", PGMNAME);
113   va_start (arg_ptr, format);
114   vfprintf (stderr, format, arg_ptr);
115   if (*format && format[strlen(format)-1] != '\n')
116     putc ('\n', stderr);
117   va_end (arg_ptr);
118 #ifdef HAVE_FLOCKFILE
119   funlockfile (stderr);
120 #endif
121 }
122
123
124 \f
125 /* Until we support send/recv in estream we need to use es_fopencookie
126  * under Windows.  */
127 #ifdef HAVE_W32_SYSTEM
128 static gpgrt_ssize_t
129 w32_cookie_read (void *cookie, void *buffer, size_t size)
130 {
131   int sock = (int)cookie;
132   int nread;
133
134   do
135     {
136       /* Under Windows we need to use recv for a socket.  */
137       nread = recv (sock, buffer, size, 0);
138     }
139   while (nread == -1 && errno == EINTR);
140
141   return (gpgrt_ssize_t)nread;
142 }
143
144 static gpg_error_t
145 w32_write_server (int sock, const char *data, size_t length)
146 {
147   int nleft;
148   int nwritten;
149
150   nleft = length;
151   while (nleft > 0)
152     {
153       nwritten = send (sock, data, nleft, 0);
154       if ( nwritten == SOCKET_ERROR )
155         {
156           info ("network write failed: ec=%d\n", (int)WSAGetLastError ());
157           return gpg_error (GPG_ERR_NETWORK);
158         }
159       nleft -= nwritten;
160       data += nwritten;
161     }
162
163   return 0;
164 }
165
166 /* Write handler for estream.  */
167 static gpgrt_ssize_t
168 w32_cookie_write (void *cookie, const void *buffer_arg, size_t size)
169 {
170   int sock = (int)cookie;
171   const char *buffer = buffer_arg;
172   int nwritten = 0;
173
174   if (w32_write_server (sock, buffer, size))
175     {
176       gpg_err_set_errno (EIO);
177       nwritten = -1;
178     }
179   else
180     nwritten = size;
181
182   return (gpgrt_ssize_t)nwritten;
183 }
184
185 static es_cookie_io_functions_t w32_cookie_functions =
186   {
187     w32_cookie_read,
188     w32_cookie_write,
189     NULL,
190     NULL
191   };
192 #endif /*HAVE_W32_SYSTEM*/
193
194
195 \f
196 static int
197 connect_server (const char *server, unsigned short port)
198 {
199   gpg_error_t err;
200   int sock = -1;
201   struct sockaddr_in addr;
202   struct hostent *host;
203
204   addr.sin_family = AF_INET;
205   addr.sin_port = htons (port);
206   host = gethostbyname ((char*)server);
207   if (!host)
208     {
209       err = gpg_error_from_syserror ();
210       fail ("host '%s' not found: %s\n", server, gpg_strerror (err));
211       return -1;
212     }
213
214   addr.sin_addr = *(struct in_addr*)host->h_addr;
215
216   sock = socket (AF_INET, SOCK_STREAM, 0);
217   if (sock == -1)
218     {
219       err = gpg_error_from_syserror ();
220       die ("error creating socket: %s\n", gpg_strerror (err));
221       return -1;
222     }
223
224   if (connect (sock, (struct sockaddr *)&addr, sizeof addr) == -1)
225     {
226       err = gpg_error_from_syserror ();
227       fail ("error connecting '%s': %s\n", server, gpg_strerror (err));
228       close (sock);
229       return -1;
230     }
231
232   info ("connected to '%s' port %hu\n", server, port);
233
234   return sock;
235 }
236
237
238 static int
239 connect_estreams (const char *server, int port,
240                   estream_t *r_in, estream_t *r_out)
241 {
242   gpg_error_t err;
243   int sock;
244
245   *r_in = *r_out = NULL;
246
247   sock = connect_server (server, port);
248   if (sock == -1)
249     return gpg_error (GPG_ERR_GENERAL);
250
251 #ifdef HAVE_W32_SYSTEM
252   *r_in = es_fopencookie ((void*)(unsigned int)sock, "rb",
253                           w32_cookie_functions);
254 #else
255   *r_in = es_fdopen (sock, "rb");
256 #endif
257   if (!*r_in)
258     {
259       err = gpg_error_from_syserror ();
260       close (sock);
261       return err;
262     }
263 #ifdef HAVE_W32_SYSTEM
264   *r_out = es_fopencookie ((void*)(unsigned int)sock, "wb",
265                            w32_cookie_functions);
266 #else
267   *r_out = es_fdopen (sock, "wb");
268 #endif
269   if (!*r_out)
270     {
271       err = gpg_error_from_syserror ();
272       es_fclose (*r_in);
273       *r_in = NULL;
274       close (sock);
275       return err;
276     }
277
278   return 0;
279 }
280
281
282 \f
283 static void
284 simple_client (const char *server, int port)
285 {
286   gpg_error_t err;
287   ntbtls_t tls;
288   estream_t inbound, outbound;
289   estream_t readfp, writefp;
290   int c;
291
292   err = ntbtls_new (&tls, NTBTLS_CLIENT);
293   if (err)
294     die ("ntbtls_init failed: %s <%s>\n",
295          gpg_strerror (err), gpg_strsource (err));
296
297   err = connect_estreams (server, port, &inbound, &outbound);
298   if (err)
299     die ("error connecting server: %s <%s>\n",
300          gpg_strerror (err), gpg_strsource (err));
301
302   err = ntbtls_set_transport (tls, inbound, outbound);
303   if (err)
304     die ("ntbtls_set_transport failed: %s <%s>\n",
305          gpg_strerror (err), gpg_strsource (err));
306
307   err = ntbtls_get_stream (tls, &readfp, &writefp);
308   if (err)
309     die ("ntbtls_get_stream failed: %s <%s>\n",
310          gpg_strerror (err), gpg_strsource (err));
311
312   if (opt_hostname)
313     {
314       err = ntbtls_set_hostname (tls, opt_hostname);
315       if (err)
316         die ("ntbtls_set_hostname failed: %s <%s>\n",
317              gpg_strerror (err), gpg_strsource (err));
318     }
319
320   info ("starting handshake");
321   while ((err = ntbtls_handshake (tls)))
322     {
323       info ("handshake error: %s <%s>", gpg_strerror (err),gpg_strsource (err));
324       switch (gpg_err_code (err))
325         {
326         default:
327           break;
328         }
329       die ("handshake failed");
330     }
331   info ("handshake done");
332
333   do
334     {
335       es_fprintf (writefp, "%s / HTTP/1.0\r\n", opt_head? "HEAD":"GET");
336       if (opt_hostname)
337         es_fprintf (writefp, "Host: %s\r\n", opt_hostname);
338       es_fprintf (writefp, "X-ntbtls: %s\r\n",
339                   ntbtls_check_version (PACKAGE_VERSION));
340       es_fputs ("\r\n", writefp);
341       es_fflush (writefp);
342       while (/*es_pending (readfp) &&*/ (c = es_fgetc (readfp)) != EOF)
343         putchar (c);
344     }
345   while (c != EOF);
346
347   ntbtls_release (tls);
348   es_fclose (inbound);
349   es_fclose (outbound);
350 }
351
352
353
354 int
355 main (int argc, char **argv)
356 {
357   int last_argc = -1;
358   int debug_level = 0;
359   int port = 443;
360   char *host;
361
362   if (argc)
363     { argc--; argv++; }
364   while (argc && last_argc != argc )
365     {
366       last_argc = argc;
367       if (!strcmp (*argv, "--"))
368         {
369           argc--; argv++;
370           break;
371         }
372       else if (!strcmp (*argv, "--help"))
373         {
374           fputs ("Usage: " PGMNAME " [OPTIONS] HOST\n"
375                  "Connect via TLS to HOST\n"
376                  "Options:\n"
377                  "  --version       print the library version\n"
378                  "  --verbose       show more diagnostics\n"
379                  "  --debug LEVEL   enable debugging at LEVEL\n"
380                  "  --port N        connect to port N (default is 443)\n"
381                  "  --hostname NAME use NAME instead of HOST for SNI\n"
382                  "  --head          send a HEAD and not a GET request\n"
383                  "\n", stdout);
384           return 0;
385         }
386       else if (!strcmp (*argv, "--version"))
387         {
388           printf ("%s\n", ntbtls_check_version (NULL));
389           if (verbose)
390             printf ("%s", ntbtls_check_version ("\001\001"));
391           return 0;
392         }
393       else if (!strcmp (*argv, "--verbose"))
394         {
395           verbose = 1;
396           argc--; argv++;
397         }
398       else if (!strcmp (*argv, "--debug"))
399         {
400           verbose = 1;
401           argc--; argv++;
402           if (argc)
403             {
404               debug_level = atoi (*argv);
405               argc--; argv++;
406             }
407           else
408             debug_level = 1;
409         }
410       else if (!strcmp (*argv, "--port"))
411         {
412           argc--; argv++;
413           if (argc)
414             {
415               port = atoi (*argv);
416               argc--; argv++;
417             }
418           else
419             port = 8443;
420         }
421       else if (!strcmp (*argv, "--hostname"))
422         {
423           if (argc < 2)
424             die ("argument missing for option '%s'\n", *argv);
425           argc--; argv++;
426           opt_hostname = *argv;
427           argc--; argv++;
428         }
429       else if (!strcmp (*argv, "--head"))
430         {
431           opt_head = 1;
432           argc--; argv++;
433         }
434       else if (!strncmp (*argv, "--", 2) && (*argv)[2])
435         die ("Invalid option '%s'\n", *argv);
436     }
437
438   host = argc? *argv : "localhost";
439   if (!opt_hostname)
440     opt_hostname = host;
441   if (!*opt_hostname)
442     opt_hostname = NULL;
443
444 #ifdef HAVE_W32_SYSTEM
445   {
446     WSADATA wsadat;
447     WSAStartup (0x202, &wsadat);
448   }
449 #endif
450
451   if (!ntbtls_check_version (PACKAGE_VERSION))
452     die ("NTBTLS library too old (need %s, have %s)\n",
453          PACKAGE_VERSION, ntbtls_check_version (NULL));
454
455   if (debug_level)
456     ntbtls_set_debug (debug_level, NULL, NULL);
457
458   simple_client (host, port);
459   return 0;
460 }