main.ss
;; mzsocket: BSD/POSIX sockets library for PLT-scheme
;; main library module
;;
;; (C) Copyright 2007,2008 Dimitris Vyzovitis <vyzo at media.mit.edu>
;;
;; mzsocket is free software: you can redistribute it and/or modify
;; it under the terms of the GNU Lesser General Public License as published
;; by the Free Software Foundation, either version 3 of the License, or
;; (at your option) any later version.
;;
;; mzsocket is distributed in the hope that it will be useful,
;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
;; GNU Lesser General Public License for more details.
;;
;; You should have received a copy of the GNU Lesser General Public License
;; along with mzsocket.  If not, see <http://www.gnu.org/licenses/>.

#lang scheme/base

(require 
  scheme/port
  "_constants.ss"
  (except-in "_socket.ss"
    socket-accept
    socket-connect
    socket-recv
    socket-recvfrom
    socket-send
    socket-sendto
    socket-sendmsg
    socket-recvmsg)
  (prefix-in _
    (only-in "_socket.ss"
      socket-accept
      socket-connect
      socket-recv
      socket-recvfrom
      socket-send
      socket-sendto
      socket-sendmsg
      socket-recvmsg)))

(provide
  (all-from-out "_constants.ss")
  (except-out (all-from-out "_socket.ss")
    make-exn:fail:socket
    _socket-accept
    _socket-connect
    _socket-recv
    _socket-recvfrom
    _socket-send
    _socket-sendto
    _socket-recvmsg
    _socket-sendmsg)
  (rename-out
    (_socket-accept socket-accept*)
    (_socket-connect socket-connect*)
    (_socket-recv socket-recv*)
    (_socket-recvfrom socket-recvfrom*)
    (_socket-send socket-send*)
    (_socket-sendto socket-sendto*)
    (_socket-recvmsg socket-recvmsg*)
    (_socket-sendmsg socket-sendmsg*))
  ;; wrappers
  inet4-address
  inet4-address?
  inet6-address
  inet6-address?
  socket-accept
  socket-connect
  socket-recv
  socket-recvfrom
  socket-send
  socket-sendto
  socket-sendmsg
  socket-recvmsg
  socket-send-all
  socket-recv-all
  socket-send/port
  socket-recv/port
  socket-recv-evt
  socket-send-evt
  socket-connect-evt
  socket-accept-evt
  stream-socket?
  socket-input-port
  socket-output-port
  open-socket-stream
  exn:fail:socket:blocking?)

(define (host->bytes host)
  (if (bytes? host) host (string->bytes/utf-8 host)))

(define (inet4-address host port)
  (inet-address AF_INET (host->bytes host) port))
(define (inet4-address? o)
  (and (inet-address? o)
       (eq? AF_INET (inet-address-family o))))

(define inet6-address
  (if SOCKET_HAVE_IPV6
    (lambda (host port . rest)
      (apply inet-address AF_INET6 (host->bytes host) port rest))
    (lambda xs
      (error 'inet6-address "IPv6 not supported in this system"))))
(define (inet6-address? o)
  (and (inet-address? o)
       (eq? AF_INET6 (inet-address-family o))))

(define (exn:fail:socket:blocking? e)
  (and (exn:fail:socket? e) 
       (eq? (exn:fail:socket-errno e) EWOULDBLOCK)))

;; events
(define socket-evt:complete 
  (bitwise-ior socket-evt:write socket-evt:except))

(define-syntax-rule (defevt evt id)
  (define (evt sock) (socket-evt sock id)))

(defevt socket-recv-evt socket-evt:read)
(defevt socket-send-evt socket-evt:write)
(defevt socket-complete-evt socket-evt:complete)

(define (socket-connect-evt sock)
  (handle-evt (socket-complete-evt sock)
    (lambda (x)
      (let ((err (socket-getsockopt sock SOL_SOCKET SO_ERROR)))
        (if (zero? err) sock (make-exn:fail:socket "socket-connect" err))))))

(define (socket-accept-evt sock)
  (handle-evt (socket-recv-evt sock)
    (lambda (x)
      (with-handlers* ((exn:fail:socket? values))
        (let-values (((cli peer) (_socket-accept sock)))
          cli)))))

;; blocking wrappers
(define-syntax-rule (blocking evt body ...)
  (begin (sync evt) body ...))
(define-syntax-rule (defwrapper (wrapper sock arg ... . tl) evt raw)
  (define (wrapper sock arg ... . tl)
    (blocking (evt sock) (apply raw sock arg ... tl))))

(defwrapper (socket-send sock bs . rest)
  socket-send-evt _socket-send)
(defwrapper (socket-recv sock bs . rest)
  socket-recv-evt _socket-recv)
(defwrapper (socket-sendto sock peer bs . rest)
  socket-send-evt _socket-sendto)
(defwrapper (socket-recvfrom sock bs . rest)
  socket-recv-evt _socket-recvfrom)
(defwrapper (socket-sendmsg sock name data ctl . rest)
  socket-send-evt _socket-sendmsg)
(defwrapper (socket-recvmsg sock name data ctl . rest)
  socket-recv-evt _socket-recvmsg)

(define (socket-connect sock peer)
  (unless (_socket-connect sock peer)
    (sync 
      (handle-evt (socket-complete-evt sock)
        (lambda (x)
          (let ((err (socket-getsockopt sock SOL_SOCKET SO_ERROR)))
            (unless (zero? err)
              (raise (make-exn:fail:socket "socket-connect" err)))))))))

(define-syntax one-of?
  (syntax-rules ()
    ((_ e v ...) 
     (let ((ev e)) (or (eq? ev v) ...)))))

(define (socket-accept sock)
  (blocking (socket-recv-evt sock) 
    (with-handlers*
        ((exn:fail:socket?
          (lambda (e)
            (if (one-of? (exn:fail:socket-errno e) 
                         EWOULDBLOCK ;; fortifuy spurious wake-ups
                         ECONNABORTED)
              (socket-accept sock)
              (raise e)))))
      (_socket-accept sock))))

(define (socket-send-all sock bs (start 0) (end (bytes-length bs)))
  (let ((evt (socket-send-evt sock)))
    (let lp ((off start))
      (if (< off end)
        (lp (+ off (blocking evt (_socket-send sock bs off end))))
        (- end start)))))

(define (socket-recv-all sock bs (start 0) (end (bytes-length bs)))
  (let ((evt (socket-recv-evt sock)))
    (let lp ((off start) (cnt 0))
      (if (< off end)
        (let ((n (blocking evt (_socket-recv sock bs off end))))
          (if (zero? n)
            cnt
            (lp (+ off n) (+ cnt n))))
        cnt))))

(define (socket-send/port sock inp (bufsz 4096))
  (let ((buf (make-bytes bufsz)))
    (let lp ((cnt 0))
      (let ((n (read-bytes-avail! buf inp)))
        (if (eof-object? n)
          cnt
          (begin
            (socket-send-all sock buf 0 n)
            (lp (+ n cnt))))))))

(define (socket-recv/port sock outp (bufsz 4096))
  (let ((buf (make-bytes bufsz)))
    (let lp ((cnt 0))
      (let ((n (socket-recv sock buf)))
        (if (zero? n)
          cnt
          (begin
            (write-bytes buf outp 0 n)
            (lp (+ cnt n))))))))

;; socket ports
(define (stream-socket? sock)
  (and (socket? sock)
       (eq? (socket-getsockopt sock SOL_SOCKET SO_TYPE) SOCK_STREAM)))

(define (socket-input-port sock)
  (if (stream-socket? sock)
    (let ((evt (handle-evt (socket-recv-evt sock) (lambda (x) 0))))
      (define (read-some buf)
        (with-handlers* 
            ((exn:fail:socket:blocking? (lambda (x) evt)))
          (let ((n (_socket-recv sock buf)))
            (if (zero? n) eof n))))
      (define (close)
        (socket-shutdown sock SHUT_RD))
      (make-input-port/read-to-peek 'socket read-some #f close))
    (raise-mismatch-error 'socket-input-port
      "not a stream socket: " sock)))

(define (socket-output-port sock)
  (if (stream-socket? sock)
    (let ((evt (socket-send-evt sock)))
      (define (write-some buf b e . xs)
        (if (< b e)
          (with-handlers* 
              ((exn:fail:socket:blocking? (lambda (x) (retry buf b e))))
            (_socket-send sock buf b e))
          0))
      (define (retry buf b e)
        (handle-evt evt
          (lambda (x) (write-some buf b e))))
      (define (close)
        (socket-shutdown sock SHUT_WR))
      (make-output-port 'socket evt write-some close #f retry))
    (raise-mismatch-error 'socket-output-port
      "not a stream socket: " sock)))

(define (open-socket-stream addr)
  (let ((sock (socket (cond
                        ((inet4-address? addr) 
                         PF_INET)
                        ((and SOCKET_HAVE_IPV6 (inet6-address? addr)) 
                         PF_INET6)
                        ((and SOCKET_HAVE_UNIX (path? addr)) 
                         PF_UNIX)
                        (else
                          (raise-mismatch-error 'open-socket-stream 
                            "bad address: " addr)))
                      SOCK_STREAM)))
    (socket-connect sock addr)
    (values (socket-input-port sock)
            (socket-output-port sock))))