diff --git a/lib/apl/runtime.sx b/lib/apl/runtime.sx index 4b93ebfb..30081175 100644 --- a/lib/apl/runtime.sx +++ b/lib/apl/runtime.sx @@ -436,3 +436,173 @@ ((old-coords (map (fn (i) (nth new-coords (nth inv-perm i))) (range 0 (len shape))))) (nth ravel (apl-multi->flat old-coords strides))))) (range 0 new-size)))))))) + +(define apl-safe-mod (fn (a m) (mod (+ (mod a m) m) m))) + +(define + apl-take + (fn + (n-arr data-arr) + (let + ((old-shape (get data-arr :shape)) + (old-ravel (get data-arr :ravel)) + (ns + (if (scalar? n-arr) (list (disclose n-arr)) (get n-arr :ravel)))) + (let + ((new-shape (map abs ns)) (old-strides (apl-strides old-shape))) + (let + ((new-size (reduce * 1 new-shape)) + (new-strides (apl-strides new-shape))) + (make-array + new-shape + (map + (fn + (new-flat) + (let + ((new-coords (apl-flat->multi new-flat new-shape new-strides))) + (let + ((old-coords (map (fn (i) (let ((ni (nth ns i)) (nc (nth new-coords i)) (od (nth old-shape i))) (if (>= ni 0) nc (+ (- od (- ni)) nc)))) (range 0 (len ns))))) + (if + (every? + (fn + (i) + (and + (>= (nth old-coords i) 0) + (< (nth old-coords i) (nth old-shape i)))) + (range 0 (len old-coords))) + (nth old-ravel (apl-multi->flat old-coords old-strides)) + 0)))) + (range 0 new-size)))))))) + +(define + apl-drop + (fn + (n-arr data-arr) + (let + ((old-shape (get data-arr :shape)) + (old-ravel (get data-arr :ravel)) + (ns + (if (scalar? n-arr) (list (disclose n-arr)) (get n-arr :ravel)))) + (let + ((new-shape (map (fn (i) (let ((ni (nth ns i)) (od (nth old-shape i))) (let ((d (if (>= ni 0) (- od ni) (+ od ni)))) (if (> d 0) d 0)))) (range 0 (len ns)))) + (offsets + (map + (fn (i) (let ((ni (nth ns i))) (if (>= ni 0) ni 0))) + (range 0 (len ns)))) + (old-strides (apl-strides old-shape))) + (let + ((new-size (reduce * 1 new-shape)) + (new-strides (apl-strides new-shape))) + (make-array + new-shape + (map + (fn + (new-flat) + (let + ((new-coords (apl-flat->multi new-flat new-shape new-strides))) + (let + ((old-coords (map (fn (i) (+ (nth new-coords i) (nth offsets i))) (range 0 (len ns))))) + (nth old-ravel (apl-multi->flat old-coords old-strides))))) + (range 0 new-size)))))))) + +(define + apl-reverse + (fn + (arr) + (let + ((shape (get arr :shape)) (ravel (get arr :ravel))) + (if + (= (len shape) 0) + arr + (let + ((last-dim (last shape)) (n (len ravel))) + (make-array + shape + (map + (fn + (flat) + (let + ((c-last (mod flat last-dim))) + (nth ravel (+ flat (- last-dim 1) (* -2 c-last))))) + (range 0 n)))))))) + +(define + apl-reverse-first + (fn + (arr) + (let + ((shape (get arr :shape)) (ravel (get arr :ravel))) + (if + (= (len shape) 0) + arr + (let + ((first-dim (first shape)) + (first-stride (reduce * 1 (rest shape))) + (n (len ravel))) + (make-array + shape + (map + (fn + (flat) + (let + ((row (floor (/ flat first-stride)))) + (let + ((old-row (- first-dim 1 row))) + (nth + ravel + (+ (* old-row first-stride) (mod flat first-stride)))))) + (range 0 n)))))))) + +(define + apl-rotate-first + (fn + (n-arr data-arr) + (let + ((shape (get data-arr :shape)) + (ravel (get data-arr :ravel)) + (rot (disclose n-arr))) + (if + (= (len shape) 0) + data-arr + (let + ((first-dim (first shape)) + (first-stride (reduce * 1 (rest shape))) + (n (len ravel))) + (make-array + shape + (map + (fn + (flat) + (let + ((row (floor (/ flat first-stride)))) + (let + ((old-row (apl-safe-mod (+ row rot) first-dim))) + (nth + ravel + (+ (* old-row first-stride) (mod flat first-stride)))))) + (range 0 n)))))))) + +(define + apl-rotate + (fn + (n-arr data-arr) + (let + ((shape (get data-arr :shape)) + (ravel (get data-arr :ravel)) + (rot (disclose n-arr))) + (if + (= (len shape) 0) + data-arr + (let + ((last-dim (last shape)) (n (len ravel))) + (make-array + shape + (map + (fn + (flat) + (let + ((c-last (mod flat last-dim))) + (let + ((old-c-last (apl-safe-mod (+ c-last rot) last-dim))) + (nth ravel (+ flat (- old-c-last c-last)))))) + (range 0 n)))))))) diff --git a/lib/apl/tests/structural.sx b/lib/apl/tests/structural.sx index 10cb18e2..72b1e961 100644 --- a/lib/apl/tests/structural.sx +++ b/lib/apl/tests/structural.sx @@ -188,4 +188,138 @@ (apl-transpose-dyadic (make-array (list 3) (list 2 1 3)) (make-array (list 2 3 4) (range 0 24)))) - (list 3 2 4)) \ No newline at end of file + (list 3 2 4)) + +(apl-test + "take 3 from front" + (rv (apl-take (apl-scalar 3) (make-array (list 5) (list 1 2 3 4 5)))) + (list 1 2 3)) + +(apl-test + "take 0" + (rv (apl-take (apl-scalar 0) (make-array (list 5) (list 1 2 3 4 5)))) + (list)) + +(apl-test + "take -2 from back" + (rv (apl-take (apl-scalar -2) (make-array (list 5) (list 1 2 3 4 5)))) + (list 4 5)) + +(apl-test + "take over-take pads with 0" + (rv (apl-take (apl-scalar 7) (make-array (list 5) (list 1 2 3 4 5)))) + (list 1 2 3 4 5 0 0)) + +(apl-test + "take matrix 1 row 2 cols shape" + (sh + (apl-take + (make-array (list 2) (list 1 2)) + (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 1 2)) + +(apl-test + "take matrix 1 row 2 cols ravel" + (rv + (apl-take + (make-array (list 2) (list 1 2)) + (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 1 2)) + +(apl-test + "take matrix negative row" + (rv + (apl-take + (make-array (list 2) (list -1 3)) + (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 4 5 6)) + +(apl-test + "drop 2 from front" + (rv (apl-drop (apl-scalar 2) (make-array (list 5) (list 1 2 3 4 5)))) + (list 3 4 5)) + +(apl-test + "drop -2 from back" + (rv (apl-drop (apl-scalar -2) (make-array (list 5) (list 1 2 3 4 5)))) + (list 1 2 3)) + +(apl-test + "drop all" + (rv (apl-drop (apl-scalar 5) (make-array (list 5) (list 1 2 3 4 5)))) + (list)) + +(apl-test + "drop 0" + (rv (apl-drop (apl-scalar 0) (make-array (list 5) (list 1 2 3 4 5)))) + (list 1 2 3 4 5)) + +(apl-test + "drop matrix 1 row shape" + (sh + (apl-drop + (make-array (list 2) (list 1 0)) + (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 1 3)) + +(apl-test + "drop matrix 1 row ravel" + (rv + (apl-drop + (make-array (list 2) (list 1 0)) + (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 4 5 6)) + +(apl-test + "reverse vector" + (rv (apl-reverse (make-array (list 5) (list 1 2 3 4 5)))) + (list 5 4 3 2 1)) + +(apl-test + "reverse scalar identity" + (rv (apl-reverse (apl-scalar 42))) + (list 42)) + +(apl-test + "reverse matrix last axis" + (rv (apl-reverse (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 3 2 1 6 5 4)) + +(apl-test + "reverse-first matrix" + (rv (apl-reverse-first (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 4 5 6 1 2 3)) + +(apl-test + "reverse-first vector identity" + (rv (apl-reverse-first (make-array (list 4) (list 1 2 3 4)))) + (list 4 3 2 1)) + +(apl-test + "rotate vector left by 2" + (rv (apl-rotate (apl-scalar 2) (make-array (list 5) (list 1 2 3 4 5)))) + (list 3 4 5 1 2)) + +(apl-test + "rotate vector right by 1 (negative)" + (rv (apl-rotate (apl-scalar -1) (make-array (list 5) (list 1 2 3 4 5)))) + (list 5 1 2 3 4)) + +(apl-test + "rotate by 0 is identity" + (rv (apl-rotate (apl-scalar 0) (make-array (list 5) (list 1 2 3 4 5)))) + (list 1 2 3 4 5)) + +(apl-test + "rotate matrix last axis" + (rv + (apl-rotate (apl-scalar 1) (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 2 3 1 5 6 4)) + +(apl-test + "rotate-first matrix" + (rv + (apl-rotate-first + (apl-scalar 1) + (make-array (list 2 3) (list 1 2 3 4 5 6)))) + (list 4 5 6 1 2 3)) \ No newline at end of file