socket.ss
;; mzsocket: BSD/POSIX sockets library for mzscheme
;; Copyright (C) 2007 Dimitris Vyzovitis <[email protected]>
;;
;; This library is free software; you can redistribute it and/or
;; modify it under the terms of the GNU Lesser General Public
;; License as published by the Free Software Foundation; either
;; version 2.1 of the License, or (at your option) any later version.
;;
;; This library is distributed in the hope that it will be useful,
;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;; Lesser General Public License for more details.
;;
;; You should have received a copy of the GNU Lesser General Public
;; License along with this library; if not, write to the Free Software
;; Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
;; USA

(module socket mzscheme
  (define-syntax (require/prefix stx)
    (define (prefix mod pre sym)
      (with-syntax
          ((lsym (datum->syntax-object stx
                   (string->symbol 
                    (format "~a~a" pre (syntax-object->datum sym))))))
        #`(require (rename #,mod lsym #,sym))))
    (syntax-case stx ()
      ((_ mod pre symbol ...)
       #`(begin #,@(map (lambda (s) 
                          (prefix #'mod (syntax-object->datum #'pre) s))
                        (syntax->list #'(symbol ...)))))))
  
  (require (lib "kw.ss")
           (lib "port.ss"))
  (require "_constants.ss"
           (all-except "_socket.ss"
                       make-exn:socket make-socket make-inet-address
                       socket-connect socket-accept 
                       socket-send socket-recv
                       socket-sendto socket-recvfrom
                       socket-sendmsg socket-recvmsg))
  (require/prefix "_socket.ss" _
                  make-exn:socket make-socket make-inet-address
                  socket-connect socket-accept 
                  socket-send socket-recv
                  socket-sendto socket-recvfrom
                  socket-sendmsg socket-recvmsg)

  (define (exn:socket:blocking? e)
    (and (exn:socket? e) (= (exn:socket-errno e) errno:EWOULDBLOCK)))
  (define (inet4-address bs port)
    (_make-inet-address 
     AF_INET (if (bytes? bs) bs (string->bytes/utf-8 bs)) port))
  (define (inet4-address? o)
    (and (inet-address? o) (= (inet-address-family o) AF_INET)))
  (define-values (inet6-address inet6-address?)
    (if SOCKET_HAVE_IPV6
        (values
          (lambda (bs . rest)
            (apply _make-inet-address 
                   AF_INET6 (if (bytes? bs) bs (string->bytes/utf-8 bs)) rest))
          (lambda (o)
            (and (inet-address? o) (= (inet-address-family o) AF_INET6))))
        (values
          (lambda x (error 'inet6-address "IPv6 not supported"))
          (lambda x #f))))
  (define inet-address _make-inet-address)

  (define/kw (socket #:optional (domain PF_INET) (type SOCK_STREAM) (proto 0))
    (_make-socket domain type proto))
  
  (define socket-evt:connect (bitwise-ior socket-evt:write socket-evt:except))
  (define socket-evt:accept socket-evt:read)

  (define socket-connect* _socket-connect)
  (define (socket-connect sock sa)
    (unless (socket-connect* sock sa)
      (sync 
        (handle-evt (socket-evt sock socket-evt:connect)
          (lambda (e)
            (let ((err (socket-getsockopt sock SOL_SOCKET SO_ERROR)))
              (unless (= err 0)
                (raise (_make-exn:socket "socket-connect" err)))))))))

  (define-syntax blocking
    (syntax-rules ()
      ((_ sock evt expr)
       (let lp ((s sock))
         (sync
           (handle-evt (socket-evt s evt)
             (lambda (e)
               (with-handlers*
                   ((exn:socket:blocking? ; suprious select
                     (lambda (e) (lp s))))
                 expr))))))))


  (define socket-accept* _socket-accept)
  (define (socket-accept sock)
    (blocking sock socket-evt:accept
      (_socket-accept sock)))
  
  (define/kw (socket-send* sock bs 
                           #:optional (start 0) (end (bytes-length bs))
                           #:key (flags 0))
    (_socket-send sock bs start end flags))
  
  (define/kw (socket-send sock bs 
                          #:optional (start 0) (end (bytes-length bs))
                          #:key (flags 0))
    (blocking sock socket-evt:write
      (_socket-send sock bs start end flags)))

  (define/kw (socket-send-all sock bs
                              #:optional (start 0) (end (bytes-length bs)))
    (let lp ((off start))
      (if (< off end)
          (let ((olen (blocking sock socket-evt:write
                        (socket-send* sock bs off end))))
            (lp (+ off olen)))
          (- end start))))

  (define/kw (socket-send/port sock inp 
                               #:optional (bufsz 4096))
    (let ((buf (make-bytes bufsz)))
      (let lp ((count 0))
        (let ((len (read-bytes-avail! buf inp)))
          (if (eof-object? len)
              count
              (begin
                (socket-send-all sock buf 0 len)
                (lp (+ len count))))))))
  
  (define/kw (socket-sendto* sock dest bs 
                             #:optional (start 0) (end (bytes-length bs))
                             #:key (flags 0))
      (_socket-sendto sock bs start end flags dest))

  (define/kw (socket-sendto sock dest bs 
                            #:optional (start 0) (end (bytes-length bs))
                            #:key (flags 0))
    (blocking sock socket-evt:write
      (_socket-sendto sock bs start end flags dest)))


  (define/kw (socket-recv* sock bs 
                           #:optional (start 0) (end (bytes-length bs))
                           #:key (flags 0))
    (_socket-recv sock bs start end flags))

  (define/kw (socket-recv sock bs 
                          #:optional (start 0) (end (bytes-length bs))
                          #:key (flags 0))
    (blocking sock socket-evt:read
      (_socket-recv sock bs start end flags)))

  (define/kw (socket-recv-all sock bs
                              #:optional (start 0) (end (bytes-length bs)))
    (let lp ((off start) (count 0))
      (if (< off end)
          (let ((len (blocking sock socket-evt:read 
                       (socket-recv* sock bs off end))))
            (if (= len 0)
                count
                (lp (+ off len) (+ count len))))
          count)))

  (define/kw (socket-recv/port sock outp 
                               #:optional (bufsz 4096))
    (let ((buf (make-bytes bufsz)))
      (let lp ((count 0))
        (let ((len (blocking sock socket-evt:read (socket-recv* sock buf))))
          (if (> len 0)
              (begin
                (write-bytes buf outp 0 len)
                (lp (+ count len)))
              count)))))

  (define/kw (socket-recvfrom* sock bs 
                               #:optional (start 0) (end (bytes-length bs))
                               #:key (flags 0))
    (_socket-recvfrom sock bs start end flags))

  (define/kw (socket-recvfrom sock bs 
                              #:optional (start 0) (end (bytes-length bs))
                              #:key (flags 0))
    (blocking sock socket-evt:read
      (_socket-recvfrom sock bs start end flags)))
  
  ;; bytes or false
  (define/kw (socket-sendmsg* sock 
                              #:key (name #f) (data #f) (control #f) (flags 0))
    (_socket-sendmsg sock name data control flags))

  (define/kw (socket-sendmsg sock 
                             #:key (name #f) (data #f) (control #f) (flags 0))
    (blocking sock socket-evt:write
      (_socket-sendmsg sock name data control flags)))

  (define/kw (socket-recvmsg* sock 
                              #:key (name #f) (data #f) (control #f) (flags 0))
    (_socket-recvmsg sock name data control flags))

  (define/kw (socket-recvmsg sock 
                             #:key (name #f) (data #f) (control #f) (flags 0))
    (blocking sock socket-evt:read
      (_socket-recvmsg sock name data control flags)))
  
  ;; ports
  (define (stream-socket? sock)
    (and (socket? sock)
         (= (socket-getsockopt sock SOL_SOCKET SO_TYPE) SOCK_STREAM)))

  (define (socket->input-port sock)
    (define (->read buf)
      (with-handlers*
          ((exn:socket:blocking?
            (lambda x (wrap-evt (socket-evt sock socket-evt:read) 
                        (lambda x 0)))))
        (let ((ilen (socket-recv* sock buf)))
          (if (> ilen 0) ilen eof))))
    
    (if (stream-socket? sock)
        (make-input-port/read-to-peek 
         'socket 
         ->read
         #f 
         (lambda () (socket-shutdown sock SHUT_RD)))
        (error 'socket->input-port "not a stream socket")))
  
  (define (socket->output-port sock)
    (define (->write-avail buf b e)
      (if (> e b)
          (wrap-evt (socket-evt sock socket-evt:write)
            (lambda x 
              (with-handlers*
                  ((exn:socket:blocking? (lambda x (->write-avail buf b e))))
                (socket-send* sock buf b e))))
          (wrap-evt always-evt (lambda x 0)))) ; flush

    (define (->write buf b e . rest)
      (->write-avail buf b e))
    
    (if (stream-socket? sock)
        (make-output-port 
         'socket
         (socket-evt sock socket-evt:write)
         ->write
         (lambda () (socket-shutdown sock SHUT_WR))
         #f
         ->write-avail)
        (error 'socket->output-port "not a stream socket")))

  (define (socket->ports sock)
    (values (socket->input-port sock) (socket->output-port sock)))
  
  (define (open-socket-stream addr)
    (let* ((pf (cond
                ((inet4-address? addr) PF_INET)
                ((inet6-address? addr) PF_INET6)
                ((path? addr) PF_UNIX)
                (else (error 'open-socket-stream "bad address"))))
           (sock (socket pf SOCK_STREAM)))
      (socket-connect sock addr)
      (socket->ports sock)))

  (provide (all-defined-except require/prefix blocking)
           (all-from "_constants.ss")
           (all-from-except "_socket.ss"
                            _make-exn:socket
                            _make-inet-address
                            _make-socket 
                            _socket-connect _socket-accept 
                            _socket-send _socket-recv
                            _socket-sendto _socket-recvfrom
                            _socket-sendmsg _socket-recvmsg))

)