/* $Id: socket.c,v 1.13 1998/08/15 13:02:38 tonyg Exp $ */

#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#include "memory.h"
#include "class.h"
#include "prim.h"
#include "pair.h"
#include "symbol.h"
#include "string.h"
#include "buffer.h"
#include "scan.h"
#include "parse.h"
#include "stream.h"
#include "socket.h"
#include "vector.h"

#if WANT_SOCKETS

#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>

OBJECT socket_class;

void init_socket(void) {
  socket_class = newclass(stream_class, SOCKET_SIZE, NULL);
  SET(newsym("<socket>"), SYM_VALUE, socket_class);
}

/* Methods */

PRIVATE OBJECT socket_printstr(OBJECT sock, OBJECT w) {
  char buf[128];
  sprintf(buf, "#<socket %d>", (int) NUM(GET(sock, SOCKET_NUMBER)));
  return newstring(buf);
}

PRIVATE OBJECT socket_initialise(OBJECT self, OBJECT kind, OBJECT proto) {
  int num;
  int type;
  struct protoent *protocol;
  int protonum;

  if (kind == newsym("datagram"))
    type = SOCK_DGRAM;
  else /* kind is probably "stream" */
    type = SOCK_STREAM;

  if (!instance(proto, symbol_class))
    protonum = 0;
  else {
    protocol = getprotobyname(BIDX(proto, 0));
    if (protocol == NULL)
      protonum = 0;
    else
      protonum = protocol->p_proto;
  }

  num = socket(AF_INET, type, protonum);

  SET(self, SOCKET_NUMBER, MKNUM(num));

  return self;
}

PRIVATE OBJECT socket_close(OBJECT self) {
  int num = NUM(GET(self, SOCKET_NUMBER));

  if (num != -1) {
    shutdown(num, 2);
    close(num);
    SET(self, SOCKET_NUMBER, MKNUM(-1));
  }

  return true;
}

PRIVATE OBJECT socket_read_chars(OBJECT self, OBJECT count) {
  int num = NUM(GET(self, SOCKET_NUMBER));
  char *buf = getmem((word) NUM(count) + 1);
  OBJECT result;
  int nread;

  nread = recv(num, buf, (word) NUM(count), 0);

  if (nread > 0) {
    result = NewObject(string_class, 0, nread + 1);
    memcpy(BIDX(result, 0), buf, nread);
    BSET(result, nread, '\0');
  } else
    result = NULL;

  freemem(buf);

  return result;
}

PRIVATE OBJECT socket_write(OBJECT self, OBJECT count, OBJECT chars) {
  int n = (word) NUM(count);
  int sock = NUM(GET(self, SOCKET_NUMBER));

  if (NUMBIN(chars) < n)
    n = NUMBIN(chars);

  n = send(sock, BIDX(chars, 0), n, 0);

  return MKNUM(n);
}

typedef struct SockScan {
  int socket;
  int cache;
} SockScan, *SOCKSCAN;

PRIVATE char socket_peek(SCANSTATE s) {
  SOCKSCAN so = (SOCKSCAN) s->source;
  int nread;

  if (so->cache == -1) {
    so->cache = 0;
    nread = recv(so->socket, (char *) &so->cache, 1, 0);

    if (nread <= 0)
      so->cache = EOF;
  }

  return so->cache;
}

PRIVATE void socket_drop(SCANSTATE s) {
  SOCKSCAN so = (SOCKSCAN) s->source;

  if (so->cache == -1)
    socket_peek(s);

  if (so->cache != EOF)
    so->cache = -1;
}

PRIVATE OBJECT socket_read(OBJECT self) {
  SCANSTATE s = NULL;
  PARSER p = NULL;
  OBJECT obj = NULL;
  SockScan so;

  so.socket = NUM(GET(self, SOCKET_NUMBER));
  so.cache = -1;

  temp_register(&obj, 1);
  s = newscanner(&so, socket_peek, socket_drop);
  p = newparser(scan, s);

  obj = parse(p);

  if (obj == undefined)
    obj = GET(newsym("%%the-eof-object"), SYM_VALUE);

  killparser(p);
  killscanner(s);
  deregister_root(1);

  return obj;
}

PRIVATE OBJECT encode_addr(long inet_addr) {
  long host_addr = ntohl(inet_addr);
  int a, b, c, d;
  
  a = (host_addr >> 24) & 0xff;
  b = (host_addr >> 16) & 0xff;
  c = (host_addr >> 8) & 0xff;
  d = host_addr & 0xff;

  return newvector(4, 1, MKNUM(a), MKNUM(b), MKNUM(c), MKNUM(d));
}

PRIVATE void decode_addr(OBJECT addr, long *inet_addr) {
  long host_addr;

  host_addr =	(NUM(IGET(addr, 0)) << 24) |
		(NUM(IGET(addr, 1)) << 16) |
		(NUM(IGET(addr, 2)) << 8) |
		NUM(IGET(addr, 3));
  
  *inet_addr = htonl(host_addr);
}

PRIVATE OBJECT socket_gethost_n(OBJECT name) {
  struct hostent *h = gethostbyname(BIDX(name, 0));
  OBJECT result = NULL;
  OBJECT item = NULL;

  temp_register(&result, 1);
  temp_register(&item, 1);

  if (h == NULL)
    result = false;
  else {
    int i = 0;
    OBJECT prev = NULL;

    while (h->h_addr_list[i] != NULL) {
      item = newvector(4, 1, MKNUM((unsigned char) h->h_addr_list[i][0]),
		             MKNUM((unsigned char) h->h_addr_list[i][1]),
		             MKNUM((unsigned char) h->h_addr_list[i][2]),
		             MKNUM((unsigned char) h->h_addr_list[i][3]));

      if (result == NULL)
	result = prev = cons(item, NULL);
      else {
	SETCDR(prev, cons(item, NULL));
	prev = CDR(prev);
      }

      i++;
    }

    item = newstring((char *) h->h_name);
    result = cons(item, result);
  }

  deregister_root(2);

  return result;
}

PRIVATE OBJECT socket_gethost_a(OBJECT addr) {
  long haddr;
  struct hostent *h;
  OBJECT result = NULL;
  OBJECT item = NULL;

  decode_addr(addr, &haddr);
  h = gethostbyaddr((char *) &haddr, sizeof(haddr), AF_INET);

  temp_register(&result, 1);
  temp_register(&item, 1);

  if (h == NULL)
    result = false;
  else {
    int i = 0;
    OBJECT prev = NULL;

    while (h->h_addr_list[i] != NULL) {
      item = newvector(4, 1, MKNUM((unsigned char) h->h_addr_list[i][0]),
		             MKNUM((unsigned char) h->h_addr_list[i][1]),
		             MKNUM((unsigned char) h->h_addr_list[i][2]),
		             MKNUM((unsigned char) h->h_addr_list[i][3]));

      if (result == NULL)
	result = prev = cons(item, NULL);
      else {
	SETCDR(prev, cons(item, NULL));
	prev = CDR(prev);
      }

      i++;
    }

    item = newstring((char *) h->h_name);
    result = cons(item, result);
  }

  deregister_root(2);

  return result;
}

PRIVATE OBJECT socket_getaddr(OBJECT sym) {
  if (sym == newsym("any"))
    return encode_addr(htonl(INADDR_ANY));

  if (sym == newsym("broadcast"))
    return encode_addr(htonl(INADDR_BROADCAST));

  if (sym == newsym("loopback") ||
      sym == newsym("localhost"))
    return encode_addr(htonl(INADDR_LOOPBACK));

  return false;
}

PRIVATE OBJECT socket_connect(OBJECT self, OBJECT addr, OBJECT port) {
  struct sockaddr_in s;

  s.sin_family = AF_INET;
  decode_addr(addr, (long *) &s.sin_addr.s_addr);
  s.sin_port = htons(NUM(port));

  if (connect(NUM(GET(self, SOCKET_NUMBER)), (struct sockaddr *) &s, sizeof(s)) == -1)
    return false;
  else
    return true;
}

PRIVATE OBJECT socket_bind(OBJECT self, OBJECT addr, OBJECT port) {
  struct sockaddr_in s;

  s.sin_family = AF_INET;
  decode_addr(addr, (long *) &s.sin_addr.s_addr);
  s.sin_port = htons(NUM(port));

  if (bind(NUM(GET(self, SOCKET_NUMBER)), (struct sockaddr *) &s, sizeof(s)) == -1)
    return false;
  else
    return true;
}

PRIVATE OBJECT socket_listen(OBJECT self, OBJECT backlog) {
  if (listen(NUM(GET(self, SOCKET_NUMBER)), NUM(backlog)) == -1)
    return false;
  else
    return true;
}

PRIVATE OBJECT socket_accept(OBJECT self) {
  struct sockaddr_in s;
  OBJECT retval = NULL;
  int s_size;
  int fd;

  s_size = sizeof(s);
  fd = accept(NUM(GET(self, SOCKET_NUMBER)), (struct sockaddr *) &s, &s_size);

  if (fd < 0)
    return false;

  retval = NewObject(socket_class, 0, 0);

  SET(retval, SOCKET_NUMBER, MKNUM(fd));

  return retval;
}

PRIVATE OBJECT socket_getname(OBJECT self) {
  struct sockaddr_in s;
  int s_size;

  s_size = sizeof(s);
  if (getsockname(NUM(GET(self, SOCKET_NUMBER)), (struct sockaddr *) &s, &s_size) < 0)
    return false;
  else {
    OBJECT retval = NULL;

    temp_register(&retval, 1);

    if (s.sin_family != AF_INET)
      retval = false;
    else {
      retval = encode_addr(s.sin_addr.s_addr);
      retval = cons(retval, MKNUM(ntohs(s.sin_port)));
    }

    deregister_root(1);

    return retval;
  }
}

PRIVATE OBJECT socket_getpeer(OBJECT self) {
  struct sockaddr_in s;
  int s_size;

  s_size = sizeof(s);
  if (getpeername(NUM(GET(self, SOCKET_NUMBER)), (struct sockaddr *) &s, &s_size) < 0)
    return false;
  else {
    OBJECT retval = NULL;

    temp_register(&retval, 1);

    if (s.sin_family != AF_INET)
      retval = false;
    else {
      retval = encode_addr(s.sin_addr.s_addr);
      retval = cons(retval, MKNUM(ntohs(s.sin_port)));
    }

    deregister_root(1);

    return retval;
  }
}

#define AM(n,f,a)	addmeth(n,f,a,cl)

void init_meth_socket(void) {
  OBJECT cl = NULL;

  temp_register(&cl, 1);
  cl = cons(socket_class, NULL);

#if 0
  {
    /* This code is to set up winsock on Win32. */
    WSADATA d;
    WSAStartup(0x0101, &d);
  }
#endif

  AM("print-string", socket_printstr, 2);
  AM("initialize", socket_initialise, 3);
  AM("close", socket_close, 1);
  AM("read-chars-from", socket_read_chars, 2);
  AM("write-chars-to", socket_write, 3);
  AM("read-from", socket_read, 1);

  AM("socket-connect", socket_connect, 3);
  AM("socket-bind", socket_bind, 3);
  AM("socket-listen", socket_listen, 2);
  AM("socket-accept", socket_accept, 1);
  AM("get-sock-name", socket_getname, 1);
  AM("get-peer-name", socket_getpeer, 1);

  addprim("get-host-by-name", socket_gethost_n, 1);
  addprim("get-host-by-addr", socket_gethost_a, 1);
  addprim("get-inaddr", socket_getaddr, 1);

  {
    OBJECT fsym = NULL;

    temp_register(&fsym, 1);

    fsym = newsym("*features*");
    cl = newsym("sockets");
    cl = cons(cl, GET(fsym, SYM_VALUE));
    SET(fsym, SYM_VALUE, cl);

    deregister_root(1);
  }

  deregister_root(1);
}

#else /* don't WANT_SOCKETS */

void init_socket(void) {
}

void init_meth_socket(void) {
}

#endif /* WANT_SOCKETS */
