(module union-find mzscheme
(require (lib "contract.ss")
(lib "etc.ss"))
(define-struct forest (ht))
(define-struct node (elt p rank))
(define -make-forest
(case-lambda
[()
(make-forest (make-hash-table))]
[flags
(make-forest (apply make-hash-table flags))]))
(define (lookup-node a-forest an-elt)
(hash-table-get (forest-ht a-forest) an-elt))
(define (make-set a-forest an-elt)
(cond
[(hash-table-get (forest-ht a-forest) an-elt #f)
(void)]
[else
(local ((define a-node (make-node an-elt #f 0)))
(set-node-p! a-node a-node)
(hash-table-put! (forest-ht a-forest) an-elt a-node))]))
(define (find-set a-forest an-elt)
(local ((define a-node (lookup-node a-forest an-elt)))
(node-elt (get-representative-node a-node))))
(define (get-representative-node a-node)
(local ((define p (node-p a-node)))
(cond [(eq? a-node p)
a-node]
[else
(let ([rep (get-representative-node p)])
(set-node-p! a-node rep)
rep)])))
(define (union-set a-forest elt1 elt2)
(local ((define rep1 (get-representative-node
(lookup-node a-forest elt1)))
(define rep2 (get-representative-node
(lookup-node a-forest elt2))))
(cond
[(< (node-rank rep1) (node-rank rep2))
(set-node-p! rep1 rep2)]
[(> (node-rank rep1) (node-rank rep2))
(set-node-p! rep2 rep1)]
[else
(set-node-p! rep1 rep2)
(set-node-rank! rep1 (add1 (node-rank rep1)))])))
(provide/contract
[forest? (any/c . -> . boolean?)]
[rename -make-forest make-forest
(case->
(-> forest?)
(() (listof (symbols 'equal 'weak)) . ->* . (forest?)))]
[make-set (forest? any/c . -> . any)]
[find-set (forest? any/c . -> . any)]
[union-set (forest? any/c any/c . -> . void)]))