apl: reduce f/ and f⌿ (last+first axis); 110/110 tests
Some checks failed
Test, Build, and Deploy / test-build-deploy (push) Has been cancelled
Some checks failed
Test, Build, and Deploy / test-build-deploy (push) Has been cancelled
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -795,3 +795,81 @@
|
||||
(let
|
||||
((result (filter (fn (x) (not (index-of b-ravel x))) a-ravel)))
|
||||
(make-array (list (len result)) result)))))
|
||||
|
||||
(define
|
||||
apl-reduce
|
||||
(fn
|
||||
(f arr)
|
||||
(let
|
||||
((shape (get arr :shape)) (ravel (get arr :ravel)))
|
||||
(if
|
||||
(= (len shape) 0)
|
||||
arr
|
||||
(if
|
||||
(= (len shape) 1)
|
||||
(let
|
||||
((n (first shape)))
|
||||
(if
|
||||
(= n 0)
|
||||
(apl-scalar 0)
|
||||
(apl-scalar
|
||||
(reduce
|
||||
(fn (a b) (disclose (f (apl-scalar a) (apl-scalar b))))
|
||||
(first ravel)
|
||||
(rest ravel)))))
|
||||
(let
|
||||
((last-dim (last shape))
|
||||
(pre-shape (take shape (- (len shape) 1)))
|
||||
(pre-size (reduce * 1 (take shape (- (len shape) 1)))))
|
||||
(make-array
|
||||
pre-shape
|
||||
(map
|
||||
(fn
|
||||
(i)
|
||||
(let
|
||||
((start (* i last-dim))
|
||||
(elems
|
||||
(map
|
||||
(fn (j) (nth ravel (+ start j)))
|
||||
(range 0 last-dim))))
|
||||
(if
|
||||
(= last-dim 0)
|
||||
0
|
||||
(reduce
|
||||
(fn
|
||||
(a b)
|
||||
(disclose (f (apl-scalar a) (apl-scalar b))))
|
||||
(first elems)
|
||||
(rest elems)))))
|
||||
(range 0 pre-size)))))))))
|
||||
|
||||
(define
|
||||
apl-reduce-first
|
||||
(fn
|
||||
(f arr)
|
||||
(let
|
||||
((shape (get arr :shape)) (ravel (get arr :ravel)))
|
||||
(if
|
||||
(< (len shape) 2)
|
||||
(apl-reduce f arr)
|
||||
(let
|
||||
((first-dim (first shape))
|
||||
(inner-shape (rest shape))
|
||||
(inner-size (reduce * 1 (rest shape))))
|
||||
(if
|
||||
(= first-dim 0)
|
||||
(make-array inner-shape (map (fn (i) 0) (range 0 inner-size)))
|
||||
(make-array
|
||||
inner-shape
|
||||
(map
|
||||
(fn
|
||||
(j)
|
||||
(let
|
||||
((col (map (fn (i) (nth ravel (+ j (* i inner-size)))) (range 0 first-dim))))
|
||||
(reduce
|
||||
(fn
|
||||
(a b)
|
||||
(disclose (f (apl-scalar a) (apl-scalar b))))
|
||||
(first col)
|
||||
(rest col))))
|
||||
(range 0 inner-size)))))))))
|
||||
|
||||
@@ -26,6 +26,7 @@ cat > "$TMPFILE" << 'EPOCHS'
|
||||
(eval "(define apl-test (fn (name got expected) (if (= got expected) (set! apl-test-pass (+ apl-test-pass 1)) (begin (set! apl-test-fail (+ apl-test-fail 1)) (set! apl-test-fails (append apl-test-fails (list {:name name :got got :expected expected})))))))")
|
||||
(epoch 3)
|
||||
(load "lib/apl/tests/structural.sx")
|
||||
(load "lib/apl/tests/operators.sx")
|
||||
(epoch 4)
|
||||
(eval "(list apl-test-pass apl-test-fail)")
|
||||
EPOCHS
|
||||
|
||||
85
lib/apl/tests/operators.sx
Normal file
85
lib/apl/tests/operators.sx
Normal file
@@ -0,0 +1,85 @@
|
||||
(define rv (fn (arr) (get arr :ravel)))
|
||||
(define sh (fn (arr) (get arr :shape)))
|
||||
|
||||
(apl-test
|
||||
"reduce +/ vector"
|
||||
(rv (apl-reduce apl-add (make-array (list 5) (list 1 2 3 4 5))))
|
||||
(list 15))
|
||||
|
||||
(apl-test
|
||||
"reduce x/ vector"
|
||||
(rv (apl-reduce apl-mul (make-array (list 4) (list 1 2 3 4))))
|
||||
(list 24))
|
||||
|
||||
(apl-test
|
||||
"reduce max/ vector"
|
||||
(rv (apl-reduce apl-max (make-array (list 5) (list 3 1 4 1 5))))
|
||||
(list 5))
|
||||
|
||||
(apl-test
|
||||
"reduce min/ vector"
|
||||
(rv (apl-reduce apl-min (make-array (list 3) (list 3 1 4))))
|
||||
(list 1))
|
||||
|
||||
(apl-test
|
||||
"reduce and/ all true"
|
||||
(rv (apl-reduce apl-and (make-array (list 3) (list 1 1 1))))
|
||||
(list 1))
|
||||
|
||||
(apl-test
|
||||
"reduce or/ with true"
|
||||
(rv (apl-reduce apl-or (make-array (list 3) (list 0 0 1))))
|
||||
(list 1))
|
||||
|
||||
(apl-test
|
||||
"reduce +/ single element"
|
||||
(rv (apl-reduce apl-add (make-array (list 1) (list 42))))
|
||||
(list 42))
|
||||
|
||||
(apl-test
|
||||
"reduce +/ scalar no-op"
|
||||
(rv (apl-reduce apl-add (apl-scalar 7)))
|
||||
(list 7))
|
||||
|
||||
(apl-test
|
||||
"reduce +/ shape is scalar"
|
||||
(sh (apl-reduce apl-add (make-array (list 4) (list 1 2 3 4))))
|
||||
(list))
|
||||
|
||||
(apl-test
|
||||
"reduce +/ matrix row sums shape"
|
||||
(sh (apl-reduce apl-add (make-array (list 2 3) (list 1 2 3 4 5 6))))
|
||||
(list 2))
|
||||
|
||||
(apl-test
|
||||
"reduce +/ matrix row sums values"
|
||||
(rv (apl-reduce apl-add (make-array (list 2 3) (list 1 2 3 4 5 6))))
|
||||
(list 6 15))
|
||||
|
||||
(apl-test
|
||||
"reduce max/ matrix row maxima"
|
||||
(rv (apl-reduce apl-max (make-array (list 2 3) (list 3 1 4 1 5 9))))
|
||||
(list 4 9))
|
||||
|
||||
(apl-test
|
||||
"reduce-first +/ vector same as reduce"
|
||||
(rv (apl-reduce-first apl-add (make-array (list 5) (list 1 2 3 4 5))))
|
||||
(list 15))
|
||||
|
||||
(apl-test
|
||||
"reduce-first +/ matrix col sums shape"
|
||||
(sh
|
||||
(apl-reduce-first apl-add (make-array (list 2 3) (list 1 2 3 4 5 6))))
|
||||
(list 3))
|
||||
|
||||
(apl-test
|
||||
"reduce-first +/ matrix col sums values"
|
||||
(rv
|
||||
(apl-reduce-first apl-add (make-array (list 2 3) (list 1 2 3 4 5 6))))
|
||||
(list 5 7 9))
|
||||
|
||||
(apl-test
|
||||
"reduce-first max/ matrix col maxima"
|
||||
(rv
|
||||
(apl-reduce-first apl-max (make-array (list 3 2) (list 1 9 2 8 3 7))))
|
||||
(list 3 9))
|
||||
Reference in New Issue
Block a user