#lang racket/base
(require racket/vector)
(provide (all-defined-out))
(define(make-matrix m n [initialize null])
(let [(vov (make-vector m))] (if (null? initialize)
(for [(i (in-range m))]
(vector-set! vov i (make-vector n)))
(for [(i (in-range m))]
(vector-set! vov i (make-vector n initialize))))
vov))
(define (build-matrix m n initialize)
(let [(vov (make-vector m))]
(for [(i (in-range m))]
(vector-set! vov i (build-vector n (lambda (j) (initialize i j)))))
vov))
(define-syntax-rule (matrix row ...)
(vector row ...))
(define-syntax-rule (row item ...)
(vector item ...))
(define (matrix-set! m i j v)
(vector-set! (vector-ref m i) j v))
(define (matrix-ref m i j)
(vector-ref (vector-ref m i) j))
(define (matrix-size m)
(values (vector-length m)
(vector-length (vector-ref m 0))))
(define (matrix? m)
(and (vector? m)
(vector? (vector-ref m 0))))
(define (matrix-map f . ms)
(apply vector-map
(lambda vs
(apply vector-map f vs))
ms))
(define (matrix*vector m v [multiply *][add +])
(vector-map
(lambda (row)
(for/fold
[(sum 0)]
[(product (vector-map multiply row v))]
(add sum product)))
m))
(define (matrix*matrix a1 a2 [multiply *][add +])
(let-values [((m1 n1)(matrix-size a1))
((m2 n2)(matrix-size a2))]
(if (= n1 m2)
(build-matrix m1 n2 (lambda (r c)
(for/fold
[(sum 0)]
[(i (in-range n1))]
(add sum (multiply (matrix-ref a1 r i)
(matrix-ref a2 i c))))))
(error "Incompatibly sized matrices"))))
(define (transpose m)
(let-values [((columns rows) (matrix-size m))]
(build-matrix rows columns (lambda (i j)
(matrix-ref m j i)))))