matrix.ss
#|  matrix.ss: Matrices and matrix operations using BLAS and LAPACK
    Copyright (C) 2007 Will M. Farr <[email protected]>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along
    with this program; if not, write to the Free Software Foundation, Inc.,
    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|#

(module matrix mzscheme
  (require (lib "foreign.ss")
           (lib "etc.ss")
           (all-except (lib "contract.ss") ->)
           (rename (lib "contract.ss") ->/c ->)
           (all-except (lib "42.ss" "srfi") :)
           (all-except (planet "srfi-4-comprehensions.ss" ("wmfarr" "srfi-4-comprehensions.plt")) :)
           "blas-lapack.ss")
  
  (define (list/length/c n)
    (flat-named-contract
     (format "list of length ~a" n)
     (lambda (l) (= (length l) n))))
  
  (define (matrix-multiplication-compatible/c m)
    (flat-named-contract
     (format "compatible for multiplication by a ~a by ~a matrix" (matrix-rows m) (matrix-cols m))
     (lambda (m2)
       (= (matrix-cols m) (matrix-rows m2)))))
  
  (define (matrix-same-dimensions/c m)
    (flat-named-contract
     (format "~a by ~a matrix" (matrix-rows m) (matrix-cols m))
     (lambda (m2)
       (and (= (matrix-rows m) (matrix-rows m2))
            (= (matrix-cols m) (matrix-cols m2))))))
  
  (define (matrix-valid-row-index/c m)
    (let ((r (matrix-rows m)))
      (flat-named-contract
       (format "valid row index for a ~a by ~a matrix" r (matrix-cols m))
       (lambda (i) (and (>= i 0)
                        (< i r))))))
  
  (define (matrix-valid-col-index/c m)
    (let ((c (matrix-cols m)))
      (flat-named-contract
       (format "valid column index for ~a by ~a matrix" (matrix-rows m) c)
       (lambda (j) (and (>= j 0)
                        (< j c))))))
  
  (define (matrix-col-vector-compatible/c m)
    (let ((c (matrix-cols m)))
      (flat-named-contract
       (format "column vector of length ~a" c)
       (lambda (v) (= (f64vector-length v) c)))))
  
  (define (matrix-row-vector-compatible/c m)
    (let ((r (matrix-rows m)))
      (flat-named-contract
       (format "row vector of length ~a" r)
       (lambda (v) (= (f64vector-length v) r)))))
  
  (define matrix-square/c
    (flat-named-contract
     "square matrix"
     (lambda (m) (= (matrix-rows m) (matrix-cols m)))))
  
  (define-struct matrix
    (ptr rows cols) #f)
  
  (provide matrix? matrix-multiplication-compatible/c matrix-valid-row-index/c
           matrix-valid-col-index/c matrix-square/c matrix-same-dimensions/c
           matrix-col-vector-compatible/c matrix-row-vector-compatible/c
           matrix-ec :matrix
           _matrix
           struct:matrix)
  
  (provide/contract
   (rename my-make-matrix make-matrix
           (->/c natural-number/c natural-number/c number? matrix?))
   (rename my-matrix matrix
           (->r ((i natural-number/c)
                 (j natural-number/c))
                elts (and/c (listof number?)
                            (list/length/c (* i j)))
                matrix?))
   (matrix-rows (->/c matrix? natural-number/c))
   (matrix-cols (->/c matrix? natural-number/c))
   (matrix-ref (->r ((m matrix?)
                     (i (and/c natural-number/c
                               (matrix-valid-row-index/c m)))
                     (j (and/c natural-number/c
                               (matrix-valid-col-index/c m))))
                    number?))
   (matrix-set! (->r ((m matrix?)
                      (i (and/c natural-number/c
                                (matrix-valid-row-index/c m)))
                      (j (and/c natural-number/c
                                (matrix-valid-col-index/c m)))
                      (x number?))
                     any))
   (matrix-add (->r ((m1 matrix?)
                     (m2 (and/c matrix?
                                (matrix-same-dimensions/c m1))))
                    matrix?))
   (matrix-sub (->r ((m1 matrix?)
                     (m2 (and/c matrix?
                                (matrix-same-dimensions/c m1))))
                    matrix?))
   (matrix-scale (->/c matrix? number? matrix?))
   (matrix-mul (->r ((m1 matrix?)
                     (m2 (and/c matrix?
                                (matrix-multiplication-compatible/c m1))))
                    matrix?))
   (matrix-f64vector-mul (->r ((m matrix?)
                               (v (and/c f64vector?
                                         (matrix-col-vector-compatible/c m))))
                              f64vector?))
   (f64vector-matrix-mul (->r ((v (and/c f64vector?
                                         (matrix-row-vector-compatible/c m)))
                               (m matrix?))
                              f64vector?))
   (matrix-inverse (->/c (and/c matrix? matrix-square/c) matrix?))
   (matrix-norm (->/c matrix? number?))
   (matrix-identity (->/c natural-number/c matrix?))
   (matrix-transpose (->/c matrix? matrix?))
   (matrix-solve (->r ((m matrix-square/c)
                       (v (and/c f64vector?
                                 (let ((r (matrix-rows m)))
                                   (flat-named-contract
                                    (format "column vector of length ~a" r)
                                    (lambda (v) (= (f64vector-length v) r)))))))
                      f64vector?))
   (matrix-solve-many (->r ((m1 matrix-square/c)
                            (m2 (and/c matrix?
                                       (matrix-multiplication-compatible/c m1))))
                           matrix?)))
  
  (unsafe!)
  
  (define my-make-matrix
    (case-lambda
      ((rows cols)
       (let* ((n (* rows cols))
              (p (malloc n _double 'atomic)))
         (memset p 0 n _double)
         (make-matrix p rows cols)))
      ((rows cols elt)
       (let* ((m (my-make-matrix rows cols))
              (p (matrix-ptr m)))
         (do-ec (:range i (* rows cols))
           (ptr-set! p _double* i elt))
         m))))
  
  (define (my-matrix i j . elts)
    (let* ((m (my-make-matrix i j))
           (p (matrix-ptr m)))
      (do-ec (:parallel (:range k (* i j))
                        (:list elt elts))
        (ptr-set! p _double* k elt))
      m))
  
  (define _matrix*
    (make-ctype _pointer matrix-ptr
                (lambda (x)
                  (error '_matrix
                         "cannot convert C output to _matrix"))))
  
  (define-fun-syntax _matrix
    (syntax-id-rules (i o io)
      ((_matrix i)
       _matrix*)
      ((_matrix o rows cols)
       (type: _pointer
              pre: (let* ((n (* rows cols))
                          (p (malloc n _double 'atomic)))
                     (memset p 0 n _double)
                     p)
              post: (p => (make-matrix p rows cols))))
      ((_matrix io)
       (type: _pointer
              bind: m
              pre: (m => (matrix-ptr m))
              post: m))
      (_matrix _matrix*)))
  
  (define (matrix-ptr-index m i j)
    (+ i (* j (matrix-rows m))))
  
  (define (matrix-ref m i j)
    (ptr-ref (matrix-ptr m) _double (matrix-ptr-index m i j)))
  
  (define (matrix-set! m i j elt)
    (ptr-set! (matrix-ptr m) _double* (matrix-ptr-index m i j) elt))
  
  (define (matrix-length m)
    (* (matrix-cols m)
       (matrix-rows m)))
  
  (define matrix-copy
    (get-ffi-obj 'cblas_dcopy *blas*
                  (_fun (m) ::
                        (_int = (matrix-length m))
                        (m : _matrix) (_int = 1)
                        (m-out : (_matrix o (matrix-rows m) (matrix-cols m))) (_int = 1) ->
                        _void ->
                        m-out)))
  
  (define matrix-add
    (get-ffi-obj 'cblas_daxpy *blas*
                  (_fun (m1 m2) ::
                        (_int = (matrix-length m1))
                        (_double = 1.0)
                       (_matrix = m1) (_int = 1)
                       (m-out : _matrix = (matrix-copy m2)) (_int = 1) ->
                       _void ->
                       m-out)))
  
  (define matrix-sub
    (get-ffi-obj 'cblas_daxpy *blas*
                 (_fun (m1 m2) ::
                       (_int = (matrix-length m1))
                       (_double = -1.0)
                       (_matrix = m2) (_int = 1)
                       (m-out : _matrix = (matrix-copy m1)) (_int = 1) ->
                       _void ->
                       m-out)))
  
  (define matrix-scale
    (get-ffi-obj 'cblas_dscal *blas*
                 (_fun (m s) ::
                       (_int = (matrix-length m))
                       (_double* = s)
                       (m-out : _matrix = (matrix-copy m)) (_int = 1) ->
                       _void ->
                       m-out)))
  
  (define matrix-mul
    (get-ffi-obj 'cblas_dgemm *blas*
                 (_fun (m1 m2) ::
                       (_cblas-order = 'col-major)
                       (_cblas-transpose = 'no-trans)
                       (_cblas-transpose = 'no-trans)
                       (_int = (matrix-rows m1))
                       (_int = (matrix-cols m2))
                       (_int = (matrix-rows m2))
                       (_double = 1.0)
                       (_matrix = m1)
                       (_int = (matrix-rows m1))
                       (_matrix = m2)
                       (_int = (matrix-rows m2))
                       (_double = 0.0)
                       (m-out : (_matrix o (matrix-rows m1) (matrix-cols m2)))
                       (_int = (matrix-rows m1)) ->
                       _void ->
                       m-out)))
  
  (define matrix-f64vector-mul
    (get-ffi-obj 'cblas_dgemv *blas*
                 (_fun (m v) ::
                       (_cblas-order = 'col-major)
                       (_cblas-transpose = 'no-trans)
                       (_int = (matrix-rows m))
                       (_int = (matrix-cols m))
                       (_double = 1.0)
                       (_matrix = m)
                       (_int = (matrix-rows m))
                       (_f64vector = v)
                       (_int = 1)
                       (_double = 0.0)
                       (v-out : (_f64vector o (matrix-rows m)))
                       (_int = 1) ->
                       _void ->
                       v-out)))
  
  (define f64vector-matrix-mul
    (get-ffi-obj 'cblas_dgemv *blas*
                 (_fun (v m) ::
                       (_cblas-order = 'col-major)
                       (_cblas-transpose = 'trans)
                       (_int = (matrix-rows m))
                       (_int = (matrix-cols m))
                       (_double = 1.0)
                       (_matrix = m)
                       (_int = (matrix-rows m))
                       (_f64vector = v)
                       (_int = 1)
                       (_double = 0.0)
                       (v-out : (_f64vector o (matrix-cols m)))
                       (_int = 1) ->
                       _void ->
                       v-out)))
  
  (define matrix-lu-decomp
    (get-ffi-obj 'dgetrf_ *lapack*
                 (_fun (m) ::
                       ((_ptr i _int) = (matrix-rows m))
                       ((_ptr i _int) = (matrix-cols m))
                       (m-out : _matrix = (matrix-copy m))
                       ((_ptr i _int) = (matrix-rows m))
                       (ipiv : (_u32vector o (matrix-rows m)))
                       (_ptr o _int) ->
                       _void ->
                       (values m-out ipiv))))
  
  (define dgetri-lwork
    (get-ffi-obj 'dgetri_ *lapack*
                 (_fun (m ipiv) ::
                       (n : (_ptr i _int) = (matrix-rows m))
                       (_matrix = m)
                       ((_ptr i _int) = n)
                       (_u32vector = ipiv)
                       (lwork : (_ptr o _double))
                       ((_ptr i _int) = -1)
                       (res : (_ptr o _int)) ->
                       _void ->
                       (values (inexact->exact (round lwork)) res))))
  
  (define dgetri/lwork
    (get-ffi-obj 'dgetri_ *lapack*
                 (_fun (m ipiv lwork) ::
                       (n : (_ptr i _int) = (matrix-rows m))
                       (m-out : _matrix = (matrix-copy m))
                       ((_ptr i _int) = n)
                       (_u32vector = ipiv)
                       (_f64vector o lwork)
                       ((_ptr i _int) = lwork)
                       (_ptr o _int) ->
                       _void ->
                       m-out)))
  
  (define (matrix-inverse m)
    (let-values (((m-lu ipiv)
                  (matrix-lu-decomp m)))
      (let-values (((lwork res)
                    (dgetri-lwork m-lu ipiv)))
        (dgetri/lwork m-lu ipiv lwork))))
  
  (define matrix-norm
    (get-ffi-obj 'cblas_dnrm2 *blas*
                 (_fun (m) ::
                       (_int = (matrix-length m))
                       (_matrix = m)
                       (_int = 1) ->
                       _double)))
  
  (define (matrix-transpose m)
    (let ((r (matrix-rows m))
          (c (matrix-cols m)))
      (matrix-ec c r (:range j r) (:range i c) (matrix-ref m j i))))
  
  (define matrix-solve-many
    (get-ffi-obj 'dgesv_ *lapack*
                 (_fun (m b) ::
                       ((_ptr i _int) = (matrix-rows m))
                       ((_ptr i _int) = (matrix-cols b))
                       (_matrix = (matrix-copy m))
                       ((_ptr i _int) = (matrix-rows m))
                       (_u32vector o (matrix-rows m))
                       (x : _matrix = (matrix-copy b))
                       ((_ptr i _int) = (matrix-rows b))
                       (_ptr o _int) ->
                       _void ->
                       x)))
  
  (define matrix-solve
    (get-ffi-obj 'dgesv_ *lapack*
                 (_fun (m v) ::
                       ((_ptr i _int) = (matrix-rows m))
                       ((_ptr i _int) = 1)
                       (_matrix = (matrix-copy m))
                       ((_ptr i _int) = (matrix-rows m))
                       (_u32vector o (matrix-rows m))
                       (x : _f64vector = (f64vector-of-length-ec (f64vector-length v) (:f64vector x v) x))
                       ((_ptr i _int) = (f64vector-length v))
                       (_ptr o _int) ->
                       _void ->
                       x)))
  
  (define (matrix-identity n)
    (let ((m (my-make-matrix n n 0.0)))
      (do-ec (:range i n) (matrix-set! m i i 1.0))
      m))
  
  (define-syntax matrix-ec
    (syntax-rules ()
      ((matrix-ec rrows ccols etc ...)
       (let ((rows rrows)
             (cols ccols))
         (apply my-matrix rows cols (list-ec etc ...))))))
  
  (define-syntax :matrix
    (syntax-rules (index)
      ((:matrix cc var arg)
       (:matrix cc var (index i j) arg))
      ((:matrix cc var (index i j) arg)
       (:do cc
            (let ((m arg)
                  (rows #f)
                  (cols #f))
              (set! rows (matrix-rows m))
              (set! cols (matrix-cols m)))
            ((i 0) (j 0))
            (< j cols)
            (let ((i+1 (+ i 1))
                  (j+1 (+ j 1))
                  (wrapping? #f)
                  (var (matrix-ref m i j)))
              (set! wrapping? (>= i+1 rows)))
            #t
            ((if wrapping? 0 i+1)
             (if wrapping? j+1 j))))))
  
  (define ptr->matrix make-matrix)
  (provide* (unsafe ptr->matrix))
  (define-unsafer matrix-unsafe!))