#lang scheme
(define-struct matrix
(rows cols elts)
#:transparent
#:property prop:sequence (lambda (m)
(make-do-sequence
(lambda ()
(values
(lambda (k)
(vector-ref (matrix-elts m) k))
add1
0
(lambda (k) (< k (* (matrix-rows m) (matrix-cols m))))
(lambda (elt) #t)
(lambda (k elt) #t))))))
(define (matrix-of-dimensions/c r c)
(flat-named-contract
(format "<~a by ~a matrix>" r c)
(lambda (obj)
(and (matrix? obj)
(= (matrix-rows obj) r)
(= (matrix-cols obj) c)))))
(define (matrix-of-rows/c r)
(flat-named-contract
(format "<~a by <any> matrix>" r)
(lambda (obj)
(and (matrix? obj)
(= (matrix-rows obj) r)))))
(define (matrix-of-cols/c c)
(flat-named-contract
(format "<<any> by ~a matrix>" c)
(lambda (obj)
(and (matrix? obj)
(= (matrix-cols obj) c)))))
(define (vector-of-length/c l)
(flat-named-contract
(format "<vector of length ~a>" l)
(lambda (obj)
(and (vector? obj)
(= (vector-length obj) l)))))
(define (matrix-same-dimensions/c m)
(matrix-of-dimensions/c (matrix-rows m) (matrix-cols m)))
(define (matrix-mul-compatible/c m)
(matrix-of-rows/c (matrix-cols m)))
(define (matrix-mul-result/c m1 m2)
(matrix-of-dimensions/c (matrix-rows m1) (matrix-cols m2)))
(define (list-of-length/c l)
(flat-named-contract
(format "<list-of-length ~a>" l)
(lambda (obj)
(and (list? obj)
(= (length obj) l)))))
(provide in-matrix for/vector for*/vector for/matrix for*/matrix
matrix-of-dimensions/c matrix-of-rows/c matrix-of-cols/c
vector-of-length/c
matrix-same-dimensions/c
matrix-mul-compatible/c
matrix-mul-result/c
list-of-length/c)
(provide/contract
(struct matrix
((rows natural-number/c)
(cols natural-number/c)
(elts (vectorof number?))))
(matrix* (->d ((rows natural-number/c)
(cols natural-number/c))
()
#:rest nums (and/c (listof number?) (list-of-length/c (* rows cols)))
(_ matrix?)))
(matrix-ref (->d ((m matrix?)
(i (and/c natural-number/c
(</c (matrix-rows m))))
(j (and/c natural-number/c
(</c (matrix-cols m)))))
()
(_ number?)))
(matrix-set! (->d ((m matrix?)
(i (and/c natural-number/c
(</c (matrix-rows m))))
(j (and/c natural-number/c
(</c (matrix-cols m))))
(x number?))
()
any))
(matrix-add (->d ((m1 matrix?)
(m2 (matrix-same-dimensions/c m1)))
()
(_ (matrix-same-dimensions/c m1))))
(matrix-sub (->d ((m1 matrix?)
(m2 (matrix-same-dimensions/c m1)))
()
(_ (matrix-same-dimensions/c m1))))
(matrix-scale (->d ((m matrix?)
(s number?))
()
(_ (matrix-same-dimensions/c m))))
(matrix-mul (->d ((m1 matrix?)
(m2 (matrix-mul-compatible/c m1)))
()
(_ (matrix-mul-result/c m1 m2))))
(matrix-vector-mul (->d ((m matrix?)
(v (and/c (vectorof number?)
(vector-of-length/c (matrix-cols m)))))
()
(_ (and/c (vectorof number?)
(vector-of-length/c (matrix-rows m))))))
(vector-matrix-mul (->d ((v (vectorof number?))
(m (matrix-of-rows/c (vector-length v))))
()
(_ (and/c (vectorof number?)
(vector-of-length/c (matrix-cols m))))))
(vector-add (->d ((v1 (vectorof number?))
(v2 (and/c (vectorof number?)
(vector-of-length/c (vector-length v1)))))
()
(_ (vector-of-length/c (vector-length v1)))))
(vector-sub (->d ((v1 (vectorof number?))
(v2 (and/c (vectorof number?)
(vector-of-length/c (vector-length v1)))))
()
(_ (vector-of-length/c (vector-length v1)))))
(vector-scale (->d ((v (vectorof number?))
(s number?))
()
(_ (vector-of-length/c (vector-length v)))))
(vector-dot (->d ((v1 (vectorof number?))
(v2 (and/c (vectorof number?)
(vector-of-length/c (vector-length v1)))))
()
(_ number?)))
(matrix-transpose (->d ((m matrix?))
()
(_ (matrix-of-dimensions/c (matrix-cols m) (matrix-rows m)))))
(matrix-identity (->d ((n natural-number/c))
()
(_ (matrix-of-dimensions/c n n)))))
(define (matrix* m n . nums)
(make-matrix m n (list->vector nums)))
(define (matrix-ref m i j)
(vector-ref (matrix-elts m) (+ (* (matrix-cols m) i) j)))
(define (matrix-set! m i j x)
(vector-set! (matrix-elts m) (+ (* (matrix-cols m) i) j) x))
(define (*in-matrix m)
(make-do-sequence
(lambda ()
(values
(lambda (k)
(vector-ref (matrix-elts m) k))
add1
0
(lambda (k) (< k (* (matrix-rows m) (matrix-cols m))))
(lambda (elt) #t)
(lambda (k elt) #t)))))
(define-sequence-syntax in-matrix
(lambda () (syntax *in-matrix))
(lambda (stx)
(syntax-case stx ()
(((id) (_ matrix-expr))
(syntax/loc stx
((id) (in-vector (matrix-elts matrix-expr)))))
(((i-id j-id elt-id) (_ matrix-expr))
(syntax/loc stx
((i-id j-id elt-id)
(:do-in
(((m) matrix-expr)) #t ((i-id 0) (j-id 0) (rows (matrix-rows m)) (cols (matrix-cols m))) (< i-id rows) (((ip1) (add1 i-id)) ((jp1) (add1 j-id)) ((elt-id) (matrix-ref m i-id j-id))) #t #t ((if (>= jp1 cols) ip1 i-id) (if (>= jp1 cols) 0 jp1) rows cols))))))))
(define-syntax (for/vector stx)
(syntax-case stx ()
((for/vector length-expr (for-clause ...) body)
(syntax/loc stx
(let ((length length-expr))
(for/fold/derived stx ((result (make-vector length)))
((i (in-naturals))
for-clause ...)
(let ()
(vector-set! result i body)
result)))))))
(define-syntax (for*/vector stx)
(syntax-case stx ()
((for*/vector length-expr (for-clause ...) body)
(syntax/loc stx
(let ((length length-expr)
(i 0))
(for*/fold/derived stx ((result (make-vector length)))
(for-clause ...)
(let ()
(vector-set! result i body)
(set! i (add1 i))
result)))))))
(define-syntax (for/matrix stx)
(syntax-case stx ()
((for/matrix rows-expr cols-expr (for-clause ...) body)
(syntax/loc stx
(let ((rows rows-expr)
(cols cols-expr))
(make-matrix rows cols (for/vector (* rows cols) (for-clause ...) body)))))))
(define-syntax (for*/matrix stx)
(syntax-case stx ()
((for*/matrix rows-expr cols-expr (for-clause ...) body)
(syntax/loc stx
(let ((rows rows-expr)
(cols cols-expr))
(make-matrix rows cols (for*/vector (* rows cols) (for-clause ...) body)))))))
(define (matrix-add m1 m2)
(for/matrix (matrix-rows m1) (matrix-cols m1)
((x (in-matrix m1))
(y (in-matrix m2)))
(+ x y)))
(define (matrix-sub m1 m2)
(for/matrix (matrix-rows m1) (matrix-cols m1)
((x (in-matrix m1))
(y (in-matrix m2)))
(- x y)))
(define (matrix-scale m s)
(for/matrix (matrix-rows m) (matrix-cols m)
((x (in-matrix m)))
(* s x)))
(define (matrix-mul m1 m2)
(let ((m (matrix-rows m1))
(n (matrix-cols m2)))
(for*/matrix m n
((i (in-range m))
(j (in-range n)))
(for/fold ((sum 0))
((k (in-range (matrix-cols m1))))
(+ sum (* (matrix-ref m1 i k) (matrix-ref m2 k j)))))))
(define (matrix-vector-mul m v)
(let ((r (matrix-rows m)))
(for/vector r
((i (in-range r)))
(for/fold ((sum 0))
((j (in-range (vector-length v))))
(+ sum (* (matrix-ref m i j) (vector-ref v j)))))))
(define (vector-matrix-mul v m)
(let ((c (matrix-cols m)))
(for/vector c
((j (in-range c)))
(for/fold ((sum 0))
((i (in-range (vector-length v))))
(+ sum (* (matrix-ref m i j) (vector-ref v i)))))))
(define (vector-add v1 v2)
(for/vector (vector-length v1)
((x (in-vector v1))
(y (in-vector v2)))
(+ x y)))
(define (vector-sub v1 v2)
(for/vector (vector-length v1)
((x (in-vector v1))
(y (in-vector v2)))
(- x y)))
(define (vector-scale v s)
(for/vector (vector-length v)
((x (in-vector v)))
(* x s)))
(define (vector-dot v1 v2)
(for/fold ((sum 0))
((x (in-vector v1))
(y (in-vector v2)))
(+ sum (* x y))))
(define (matrix-transpose m)
(let ((r (matrix-rows m))
(c (matrix-cols m)))
(for*/matrix c r
((i (in-range c))
(j (in-range r)))
(matrix-ref m j i))))
(define (matrix-identity n)
(for*/matrix n n
((i (in-range n))
(j (in-range n)))
(if (= i j) 1 0)))