(module bh-tree mzscheme
(require (planet "nbody-ics.scm" ("wmfarr" "nbody-ics.plt"))
(lib "42.ss" "srfi")
(lib "math.ss")
(lib "contract.ss"))
(provide (struct cell (m q bounds sub-trees))
(struct bounds (low high)))
(define-struct cell
(m q bounds sub-trees) #f)
(define-struct bounds
(low high) #f)
(define empty-tree? null?)
(define (tree? obj)
(or (empty-tree? obj)
(body? obj)
(cell? obj)))
(provide/contract (empty-tree? (-> tree? boolean?))
(make-empty (-> empty-tree?))
(tree? (-> any/c boolean?))
(tree-m (-> tree? (>=/c 0.0)))
(tree-q (-> tree? 3vector/c))
(tree-size-squared (-> tree? (>=/c 0.0)))
(in-bounds? (-> bounds? 3vector/c boolean?))
(bodies->tree (-> nbody-system/c tree?))
(tree-fold/sc (-> (-> tree? any) (-> tree? any/c any) any/c tree? any))
(tree-fold (-> (-> tree? any/c any) any/c tree? any)))
(define (make-empty) '())
(define (tree-m t)
(cond
((empty-tree? t) 0.0)
((body? t) (body-m t))
((cell? t) (cell-m t))
(else (error 'tree-m "argument not a tree: ~a" t))))
(define (tree-q t)
(cond
((empty-tree? t) (make-vector 3 0.0))
((body? t) (body-q t))
((cell? t) (cell-q t))
(else (error 'tree-q "argument not a tree: ~a" t))))
(define (in-bounds? bds v)
(every?-ec (:parallel (:vector vx v)
(:vector lx (bounds-low bds))
(:vector ux (bounds-high bds)))
(and (<= lx vx)
(< vx ux))))
(define *epsilon* 1e-6)
(define expand-bounds
(let ((efactor (+ 1.0 *epsilon*)))
(lambda (bds)
(let ((low (bounds-low bds))
(high (bounds-high bds)))
(make-bounds
low
(vector-of-length-ec (vector-length high)
(:parallel (:vector lx low)
(:vector ux high))
(let ((delta (- ux lx)))
(+ lx (* delta efactor)))))))))
(define (bodies->bounds bs)
(let ((min (make-vector 3 +inf.0))
(max (make-vector 3 -inf.0)))
(do-ec (:vector b bs)
(let ((q (body-q b)))
(do-ec (:parallel (:vector qx (index i) q)
(:vector minx min)
(:vector maxx max))
(begin (if (< qx minx) (vector-set! min i qx))
(if (> qx maxx) (vector-set! max i qx))))))
(expand-bounds (make-bounds min max))))
(define (high? i j)
(> (bitwise-and (arithmetic-shift 1 j) i) 0))
(define (sub-bounds bds)
(let ((low (bounds-low bds))
(high (bounds-high bds)))
(let ((mids (vector-of-length-ec 3 (:parallel (:vector lx low)
(:vector hx high))
(/ (+ lx hx) 2))))
(vector-of-length-ec 8
(:range i 8)
(let ((sub-low (vector-of-length-ec 3 (:parallel (:vector lx (index j) low)
(:vector mx mids))
(if (high? i j) mx lx)))
(sub-high (vector-of-length-ec 3 (:parallel (:vector mx (index j) mids)
(:vector hx high))
(if (high? i j) hx mx))))
(make-bounds sub-low sub-high))))))
(define (split-bodies bs bds)
(let ((sub-bds (sub-bounds bds)))
(vector-of-length-ec 8 (:vector sb sub-bds)
(vector-ec (:vector b bs)
(if (in-bounds? sb (body-q b)))
b))))
(define (trees-total-mass ts)
(sum-ec (:vector t ts) (tree-m t)))
(define (trees-center-of-mass ts)
(let ((M (trees-total-mass ts))
(com (make-vector 3 0.0)))
(if (= M 0.0)
com
(begin
(do-ec (:vector t ts)
(let ((m (tree-m t)))
(when (> m 0.0)
(let ((factor (/ m M))
(q (tree-q t)))
(do-ec (:parallel (:vector comx (index i) com)
(:vector qx q))
(vector-set! com i (+ comx (* qx factor))))))))
com))))
(define (bodies->tree bs)
(cond
((= (vector-length bs) 0)
(make-empty))
((= (vector-length bs) 1)
(vector-ref bs 0))
(else (let ((bds (bodies->bounds bs)))
(let ((sub-bs (split-bodies bs bds)))
(let ((sub-trees (vector-of-length-ec 8 (:vector bs sub-bs)
(bodies->tree bs))))
(make-cell (trees-total-mass sub-trees)
(trees-center-of-mass sub-trees)
bds
sub-trees)))))))
(define (tree-fold/sc cut? f start t)
(cond
((empty-tree? t) start)
((body? t) (f t start))
((cut? t) (f t start))
(else (let ((sub-ts (cell-sub-trees t))
(tf (lambda (t acc) (tree-fold/sc cut? f acc t))))
(fold-ec start (:vector st sub-ts) st tf)))))
(define (tree-fold f start t)
(tree-fold/sc (lambda (t) #f) f start t))
(define (tree-size-squared t)
(cond
((empty-tree? t) 0.0)
((body? t) 0.0)
(else (let ((bds (cell-bounds t)))
(let ((low (bounds-low bds))
(high (bounds-high bds)))
(sum-ec (:parallel (:vector lx low)
(:vector hx high))
(sqr (- hx lx)))))))))