diff --git a/lib/apl/runtime.sx b/lib/apl/runtime.sx index 8f53df38..165234db 100644 --- a/lib/apl/runtime.sx +++ b/lib/apl/runtime.sx @@ -661,3 +661,30 @@ (make-array (cons (+ (first a-s) (first b-s)) (rest a-s)) (append a-r b-r))))) + +(define + apl-squad + (fn + (idx-arr data-arr) + (let + ((shape (get data-arr :shape)) + (ravel (get data-arr :ravel)) + (strides (apl-strides (get data-arr :shape)))) + (let + ((idxs (if (scalar? idx-arr) (list (disclose idx-arr)) (get idx-arr :ravel)))) + (let + ((k (len idxs)) (rank (len shape))) + (let + ((adj (map (fn (i) (- i apl-io)) idxs))) + (if + (= k rank) + (apl-scalar (nth ravel (apl-multi->flat adj strides))) + (let + ((remaining-shape (drop shape k)) + (start (apl-multi->flat adj (take strides k))) + (slice-size (reduce * 1 (drop shape k)))) + (make-array + remaining-shape + (map + (fn (j) (nth ravel (+ start j))) + (range 0 slice-size))))))))))) diff --git a/lib/apl/tests/structural.sx b/lib/apl/tests/structural.sx index 4a6dddd8..204905dc 100644 --- a/lib/apl/tests/structural.sx +++ b/lib/apl/tests/structural.sx @@ -388,4 +388,50 @@ (apl-catenate-first (make-array (list 2 3) (list 1 2 3 4 5 6)) (make-array (list 3 3) (list 11 12 13 14 15 16 17 18 19)))) - (list 1 2 3 4 5 6 11 12 13 14 15 16 17 18 19)) \ No newline at end of file + (list 1 2 3 4 5 6 11 12 13 14 15 16 17 18 19)) + +(apl-test + "squad scalar into vector" + (rv + (apl-squad (apl-scalar 2) (make-array (list 5) (list 10 20 30 40 50)))) + (list 20)) + +(apl-test + "squad first element" + (rv (apl-squad (apl-scalar 1) (make-array (list 3) (list 10 20 30)))) + (list 10)) + +(apl-test + "squad last element" + (rv + (apl-squad (apl-scalar 5) (make-array (list 5) (list 10 20 30 40 50)))) + (list 50)) + +(apl-test + "squad fully specified matrix element" + (rv + (apl-squad + (make-array (list 2) (list 2 3)) + (make-array (list 3 4) (list 1 2 3 4 5 6 7 8 9 10 11 12)))) + (list 7)) + +(apl-test + "squad partial row of matrix shape" + (sh + (apl-squad + (apl-scalar 2) + (make-array (list 3 4) (list 1 2 3 4 5 6 7 8 9 10 11 12)))) + (list 4)) + +(apl-test + "squad partial row of matrix ravel" + (rv + (apl-squad + (apl-scalar 2) + (make-array (list 3 4) (list 1 2 3 4 5 6 7 8 9 10 11 12)))) + (list 5 6 7 8)) + +(apl-test + "squad partial 3d slice shape" + (sh (apl-squad (apl-scalar 1) (make-array (list 2 3 4) (range 1 25)))) + (list 3 4)) \ No newline at end of file