Handle Proxy Protocol v2 safely as well.
[users/heiko/exim.git] / test / src / server.c
index 0d6e5fe907f7fc44ac3b254749cc1449265a559b..ce55c5c377ac716be3f875b7c683eed0fb56d9b5 100644 (file)
@@ -28,7 +28,7 @@ on all interfaces, unless the option -noipv6 is given. */
 #include <netinet/ip.h>
 
 #ifdef HAVE_NETINET_IP_VAR_H
-#include <netinet/ip_var.h>
+# include <netinet/ip_var.h>
 #endif
 
 #include <netdb.h>
@@ -52,25 +52,32 @@ on all interfaces, unless the option -noipv6 is given. */
 
 #ifndef CS
 # define CS (char *)
+# define CCS (const char *)
 #endif
 
 
 typedef struct line {
   struct line *next;
+  unsigned len;
   char line[1];
 } line;
 
+typedef unsigned BOOL;
+#define FALSE 0
+#define TRUE  1
+
 
 /*************************************************
 *            SIGALRM handler - crash out         *
 *************************************************/
+int tmo_noerror = 0;
 
 static void
 sigalrm_handler(int sig)
 {
 sig = sig;    /* Keep picky compilers happy */
 printf("\nServer timed out\n");
-exit(99);
+exit(tmo_noerror ? 0 : 99);
 }
 
 
@@ -123,6 +130,25 @@ return buffer;
 }
 
 
+
+static void
+printit(char * s, int n)
+{
+while(n--)
+  {
+  unsigned char c = *s++;
+  if (c == '\\')
+    printf("\\\\");
+  else if (c >= ' ' && c <= '~')       /* assumes ascii */
+    putchar(c);
+  else
+    printf("\\x%02x", c);
+  }
+putchar('\n');
+}
+
+
+
 /*************************************************
 *                 Main Program                   *
 *************************************************/
@@ -143,6 +169,7 @@ int connection_count = 1;
 int count;
 int on = 1;
 int timeout = 5;
+int initial_pause = 0;
 int use_ipv4 = 1;
 int use_ipv6 = 1;
 int debug = 0;
@@ -151,6 +178,8 @@ line *script = NULL;
 line *last = NULL;
 line *s;
 FILE *in, *out;
+int linebuf = 1;
+char *pidfile = NULL;
 
 char *sockname = NULL;
 unsigned char buffer[10240];
@@ -175,16 +204,34 @@ int len = sizeof(accepted);
 
 
 /* Sort out the arguments */
+if (argc > 1 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-h")))
+  {
+  printf("Usage: %s [options] port|socket [connection count]\n", argv[0]);
+  puts("Options"
+       "\n\t-d       debug"
+       "\n\t-i n     n seconds initial delay"
+       "\n\t-noipv4  disable ipv4"
+       "\n\t-noipv6  disable ipv6"
+       "\n\t-oP file write PID to file"
+       "\n\t-t n     n seconds timeout"
+  );
+  exit(0);
+  }
 
 while (na < argc && argv[na][0] == '-')
   {
   if (strcmp(argv[na], "-d") == 0) debug = 1;
-  else if (strcmp(argv[na], "-t") == 0) timeout = atoi(argv[++na]);
+  else if (strcmp(argv[na], "-t") == 0)
+    {
+    if (tmo_noerror = ((timeout = atoi(argv[++na])) < 0)) timeout = -timeout;
+    }
+  else if (strcmp(argv[na], "-i") == 0) initial_pause = atoi(argv[++na]);
   else if (strcmp(argv[na], "-noipv4") == 0) use_ipv4 = 0;
   else if (strcmp(argv[na], "-noipv6") == 0) use_ipv6 = 0;
+  else if (strcmp(argv[na], "-oP") == 0) pidfile = argv[++na];
   else
     {
-    printf("server: unknown option %s\n", argv[na]);
+    printf("server: unknown option %s, try -h or --help\n", argv[na]);
     exit(1);
     }
   na++;
@@ -213,11 +260,22 @@ na++;
 if (na < argc) connection_count = atoi(argv[na]);
 
 
+/* Initial pause (before creating listen sockets */
+if (initial_pause > 0)
+  {
+  if (debug)
+    printf("%d: Inital pause of %d seconds\n", time(NULL), initial_pause);
+  else
+    printf("Inital pause of %d seconds\n", initial_pause);
+  while (initial_pause > 0)
+    initial_pause = sleep(initial_pause);
+  }
+
 /* Create sockets */
 
 if (port == 0)  /* Unix domain */
   {
-  if (debug) printf("Creating Unix domain socket\n");
+  if (debug) printf("%d: Creating Unix domain socket\n", time(NULL));
   listen_socket[udn] = socket(PF_UNIX, SOCK_STREAM, 0);
   if (listen_socket[udn] < 0)
     {
@@ -340,7 +398,6 @@ else
       sin4.sin_addr.s_addr = (S_ADDR_TYPE)INADDR_ANY;
       sin4.sin_port = htons(port);
       if (bind(listen_socket[i], (struct sockaddr *)&sin4, sizeof(sin4)) < 0)
-        {
         if (listen_socket[v6n] < 0 || errno != EADDRINUSE)
           {
           printf("IPv4 socket bind() failed: %s\n", strerror(errno));
@@ -351,7 +408,6 @@ else
           close(listen_socket[i]);
           listen_socket[i] = -1;
           }
-        }
       }
     }
   }
@@ -375,24 +431,57 @@ for (i = 0; i <= skn; i++)
   }
 
 
+if (pidfile)
+  {
+  FILE * p;
+  if (!(p = fopen(pidfile, "w")))
+    {
+    fprintf(stderr, "pidfile create failed: %s\n", strerror(errno));
+    exit(1);
+    }
+  fprintf(p, "%ld\n", (long)getpid());
+  fclose(p);
+  }
+
 /* This program handles only a fixed number of connections, in sequence. Before
 waiting for the first connection, read the standard input, which contains the
 script of things to do. A line containing "++++" is treated as end of file.
 This is so that the Perl driving script doesn't have to close the pipe -
 because that would cause it to wait for this process, which it doesn't yet want
 to do. The driving script adds the "++++" automatically - it doesn't actually
-appear in the test script. */
+appear in the test script. Within lines we interpret \xNN and \\ groups */
 
 while (fgets(CS buffer, sizeof(buffer), stdin) != NULL)
   {
   line *next;
+  char * d;
   int n = (int)strlen(CS buffer);
+
+  if (n > 1 && buffer[0] == '>' && buffer[1] == '>')
+    linebuf = 0;
   while (n > 0 && isspace(buffer[n-1])) n--;
   buffer[n] = 0;
   if (strcmp(CS buffer, "++++") == 0) break;
   next = malloc(sizeof(line) + n);
   next->next = NULL;
-  strcpy(next->line, CS buffer);
+  d = next->line;
+    {
+    char * s = CS buffer;
+    do
+      {
+      char ch;
+      char cl = *s;
+      if (cl == '\\' && (cl = *++s) == 'x')
+       {
+       if ((ch = *++s - '0') > 9 && (ch -= 'A'-'9'-1) > 15) ch -= 'a'-'A';
+       if ((cl = *++s - '0') > 9 && (cl -= 'A'-'9'-1) > 15) cl -= 'a'-'A';
+       cl |= ch << 4;
+       }
+      *d++ = cl;
+      }
+    while (*s++);
+    }
+  next->len = d - next->line - 1;
   if (last == NULL) script = last = next;
     else last->next = next;
   last = next;
@@ -410,6 +499,11 @@ s = script;
 
 for (count = 0; count < connection_count; count++)
   {
+  struct {
+    int left;
+    BOOL in_use;
+  } content_length = { 0, FALSE };
+
   alarm(timeout);
   if (port <= 0)
     {
@@ -435,8 +529,7 @@ for (count = 0; count < connection_count; count++)
       if (listen_socket[i] > max_socket) max_socket = listen_socket[i];
       }
 
-    lcount = select(max_socket + 1, &select_listen, NULL, NULL, NULL);
-    if (lcount < 0)
+    if ((lcount = select(max_socket + 1, &select_listen, NULL, NULL, NULL)) < 0)
       {
       printf("Select failed\n");
       fflush(stdout);
@@ -445,7 +538,6 @@ for (count = 0; count < connection_count; count++)
 
     accept_socket = -1;
     for (i = 0; i < skn; i++)
-      {
       if (listen_socket[i] > 0 && FD_ISSET(listen_socket[i], &select_listen))
         {
         accept_socket = accept(listen_socket[i],
@@ -453,7 +545,6 @@ for (count = 0; count < connection_count; count++)
         FD_CLR(listen_socket[i], &select_listen);
         break;
         }
-      }
     }
   alarm(0);
 
@@ -486,6 +577,7 @@ for (count = 0; count < connection_count; count++)
               cr.pid, cr.uid, cr.gid);
     --------------*****************/
     }
+  fflush(stdout);
 
   if (dup_accept_socket < 0)
     {
@@ -504,7 +596,7 @@ for (count = 0; count < connection_count; count++)
   doesn't work for other tests (e.g. ident tests) so we have explicit '<' and
   '>' flags for input and output as well as the defaults. */
 
-  for (; s != NULL; s = s->next)
+  for (; s; s = s->next)
     {
     char *ss = s->line;
 
@@ -516,7 +608,8 @@ for (count = 0; count < connection_count; count++)
     if (ss[0] == '>')
       {
       char *end = "\r\n";
-      printf("%s\n", ss++);
+      unsigned len = s->len;
+      printit(ss++, len--);
 
       if (strncmp(ss, "*eof", 4) == 0)
         {
@@ -525,13 +618,14 @@ for (count = 0; count < connection_count; count++)
         }
 
       if (*ss == '>')
-        { end = ""; ss++; }
+        { end = ""; ss++; len--; }
       else if (strncmp(ss, "CR>", 3) == 0)
-        { end = "\r"; ss += 3; }
+        { end = "\r"; ss += 3; len -= 3; }
       else if (strncmp(ss, "LF>", 3) == 0)
-        { end = "\n"; ss += 3; }
+        { end = "\n"; ss += 3; len -= 3; }
 
-      fprintf(out, "%s%s", ss, end);
+      fwrite(ss, 1, len, out);
+      if (*end) fprintf(out, end);
       }
 
     else if (isdigit((unsigned char)ss[0]))
@@ -551,52 +645,138 @@ for (count = 0; count < connection_count; count++)
       sleep(sleepfor);
       }
 
+    /* If the script line starts with "*data " we expect a numeric argument,
+    and we expect to read (and discard) that many data bytes from the input. */
+
+    else if (strncmp(ss, "*data ", 6) == 0)
+      {
+      int dlen = atoi(ss+6);
+      int n;
+
+      alarm(timeout);
+
+      if (!linebuf)
+       while (dlen > 0)
+         {
+         n = dlen < sizeof(buffer) ? dlen : sizeof(buffer);
+         if ((n = read(dup_accept_socket, CS buffer, n)) == 0)
+           {
+           printf("Unexpected EOF read from client\n");
+           s = s->next;
+           goto END_OFF;
+           }
+         dlen -= n;
+         }
+      else
+       while (dlen-- > 0)
+         if (fgetc(in) == EOF)
+           {
+           printf("Unexpected EOF read from client\n");
+           s = s->next;
+           goto END_OFF;
+           }
+      }
+
     /* Otherwise the script line is the start of an input line we are expecting
     from the client, or "*eof" indicating we expect the client to close the
     connection. Read command line or data lines; the latter are indicated
     by the expected line being just ".". If the line starts with '<', that
     doesn't form part of the expected input. (This allows for incoming data
-    starting with a digit.) */
+    starting with a digit.) If the line starts with '<<' we operate in
+    unbuffered rather than line mode and assume that a single read gets the
+    entire message. */
 
     else
       {
       int offset;
       int data = strcmp(ss, ".") == 0;
 
-      if (ss[0] == '<')
+      if (ss[0] != '<')
+       offset = 0;
+      else
         {
         buffer[0] = '<';
-        offset = 1;
+       if (ss[1] != '<')
+         offset = 1;
+       else
+         {
+         buffer[1] = '<';
+         offset = 2;
+         }
         }
-      else offset = 0;
 
       fflush(out);
 
-      for (;;)
-        {
-        int n;
-        alarm(timeout);
-        if (fgets(CS buffer+offset, sizeof(buffer)-offset, in) == NULL)
-          {
-          printf("%sxpected EOF read from client\n",
-            (strncmp(ss, "*eof", 4) == 0)? "E" : "Une");
-          s = s->next;
-          goto END_OFF;
-          }
-        alarm(0);
-        n = (int)strlen(CS buffer);
-        while (n > 0 && isspace(buffer[n-1])) n--;
-        buffer[n] = 0;
-        printf("%s\n", buffer);
-        if (!data || strcmp(CS buffer, ".") == 0) break;
-        }
-
-      if (strncmp(ss, CS buffer, (int)strlen(ss)) != 0)
-        {
-        printf("Comparison failed - bailing out\n");
-        printf("Expected: %s\n", ss);
-        break;
-        }
+      if (!linebuf)
+       {
+       int n;
+       char c;
+
+       alarm(timeout);
+       n = read(dup_accept_socket, CS buffer+offset, s->len - offset);
+       if (content_length.in_use) content_length.left -= n;
+       if (n == 0)
+         {
+         printf("%sxpected EOF read from client\n",
+           (strncmp(ss, "*eof", 4) == 0)? "E" : "Une");
+         s = s->next;
+         goto END_OFF;
+         }
+       if (offset != 2)
+         while (read(dup_accept_socket, &c, 1) == 1 && c != '\n') ;
+       alarm(0);
+       n += offset;
+
+       printit(CS buffer, n);
+
+       if (data) do
+         {
+         n = (read(dup_accept_socket, &c, 1) == 1 && c == '.');
+         if (content_length.in_use) content_length.left--;
+         while (c != '\n' && read(dup_accept_socket, &c, 1) == 1)
+            if (content_length.in_use) content_length.left--;
+         } while (!n);
+       else if (memcmp(ss, buffer, n) != 0)
+         {
+         printf("Comparison failed - bailing out\nExpected: ");
+         printit(ss, n);
+         break;
+         }
+       }
+      else
+       {
+       for (;;)
+         {
+         int n;
+         alarm(timeout);
+         if (fgets(CS buffer+offset, sizeof(buffer)-offset, in) == NULL)
+           {
+           printf("%sxpected EOF read from client\n",
+             (strncmp(ss, "*eof", 4) == 0)? "E" : "Une");
+           s = s->next;
+           goto END_OFF;
+           }
+         alarm(0);
+         n = strlen(CS buffer);
+         if (content_length.in_use) content_length.left -= (n - offset);
+         while (n > 0 && isspace(buffer[n-1])) n--;
+         buffer[n] = 0;
+         printf("%s\n", buffer);
+         if (!data || strcmp(CS buffer, ".") == 0) break;
+         }
+
+       if (strncmp(ss, CS buffer, (int)strlen(ss)) != 0)
+         {
+         printf("Comparison failed - bailing out\n");
+         printf("Expected: %s\n", ss);
+         break;
+         }
+       }
+
+       if (sscanf(CCS buffer, "<Content-length: %d", &content_length.left))
+                 content_length.in_use = TRUE;
+       if (content_length.in_use && content_length.left <= 0)
+         shutdown(dup_accept_socket, SHUT_RD);
       }
     }
 
@@ -607,7 +787,7 @@ for (count = 0; count < connection_count; count++)
 
 if (s == NULL) printf("End of script\n");
 
-if (sockname != NULL) unlink(sockname);
+if (sockname) unlink(sockname);
 exit(0);
 }