(module socket
    (call/ssl-socket-ports
     call/plain-socket-ports)

  (import s2j)
  
  (define-generic-java-methods
    get-default
    create-socket
    set-so-timeout
    get-input-stream
    get-output-stream
    close)

  (define-java-classes
    (<socket> |java.net.Socket|)
    (<socket-factory> |javax.net.ssl.SSLSocketFactory|)
    (<stream-input-port> |sisc.io.StreamInputPort|)
    (<stream-output-port> |sisc.io.StreamOutputPort|))

  (define jtrue (->jboolean #t))

  (define (call/raw-socket raw-socket timeout-ms k)
    (when timeout-ms
      (set-so-timeout raw-socket (->jint timeout-ms)))
    (let* ((ins (get-input-stream raw-socket))
	   (inw (java-new <stream-input-port> ins))
	   (in (java-unwrap inw))
	   (outs (get-output-stream raw-socket))
	   (outw (java-new <stream-output-port> outs jtrue))
	   (out (java-unwrap outw)))
      (dynamic-wind
	  (lambda () 'nothing)
	  (lambda () (k in out))
	  (lambda ()
	    (close ins)
	    (close outs)
	    (close raw-socket)))))

  (define (call/ssl-socket-ports host port timeout-ms k)
    (let* ((factory (get-default (java-null <socket-factory>)))
	   (raw-socket (create-socket factory (->jstring host) (->jint port))))
      (call/raw-socket raw-socket timeout-ms k)))

  (define (call/plain-socket-ports host port timeout-ms k)
    (let* ((raw-socket (java-new <socket> (->jstring host) (->jint port))))
      (call/raw-socket raw-socket timeout-ms k)))
  )
