diff --git a/lib/minikanren/clpfd.sx b/lib/minikanren/clpfd.sx index fedc87c7..ce2c36d8 100644 --- a/lib/minikanren/clpfd.sx +++ b/lib/minikanren/clpfd.sx @@ -484,3 +484,53 @@ (mk-conj (fd-distinct-from-head (first vars) (rest vars)) (fd-distinct (rest vars))))))) + +;; --- fd-plus (x + y = z, ground-cases propagator) --- + +(define + fd-bind-or-narrow + (fn + (w target s) + (cond + ((number? w) (cond ((= w target) s) (:else nil))) + ((is-var? w) + (let + ((wd (fd-domain-of s (var-name w)))) + (cond + ((and (not (= wd nil)) (not (fd-dom-member? target wd))) nil) + (:else + (let + ((s2 (mk-unify w target s))) + (cond ((= s2 nil) nil) (:else s2))))))) + (:else nil)))) + +(define + fd-plus-prop + (fn + (x y z s) + (let + ((wx (mk-walk x s)) (wy (mk-walk y s)) (wz (mk-walk z s))) + (cond + ((and (number? wx) (number? wy) (number? wz)) + (cond ((= (+ wx wy) wz) s) (:else nil))) + ((and (number? wx) (number? wy)) + (fd-bind-or-narrow wz (+ wx wy) s)) + ((and (number? wx) (number? wz)) + (fd-bind-or-narrow wy (- wz wx) s)) + ((and (number? wy) (number? wz)) + (fd-bind-or-narrow wx (- wz wy) s)) + (:else s))))) + +(define + fd-plus + (fn + (x y z) + (fn + (s) + (let + ((c (fn (sp) (fd-plus-prop x y z sp)))) + (let + ((s2 (fd-add-constraint s c))) + (let + ((s3 (c s2))) + (cond ((= s3 nil) mzero) (:else (unit s3))))))))) diff --git a/lib/minikanren/tests/clpfd-plus.sx b/lib/minikanren/tests/clpfd-plus.sx new file mode 100644 index 00000000..81b01d18 --- /dev/null +++ b/lib/minikanren/tests/clpfd-plus.sx @@ -0,0 +1,62 @@ +;; lib/minikanren/tests/clpfd-plus.sx — fd-plus (x + y = z). + +(mk-test + "fd-plus-all-ground" + (run* q (fresh (z) (fd-plus 2 3 z) (== q z))) + (list 5)) + +(mk-test + "fd-plus-recover-x" + (run* q (fresh (x) (fd-plus x 3 5) (== q x))) + (list 2)) + +(mk-test + "fd-plus-recover-y" + (run* q (fresh (y) (fd-plus 2 y 5) (== q y))) + (list 3)) + +(mk-test + "fd-plus-impossible-fails" + (run* + q + (fresh + (z) + (fd-plus 2 3 z) + (== z 99) + (== q z))) + (list)) + +(mk-test + "fd-plus-domain-check" + (run* + q + (fresh + (x) + (fd-in x (list 3 4 5)) + (fd-plus x 3 5) + (== q x))) + (list)) + +(mk-test + "fd-plus-pairs-summing-to-5" + (run* + q + (fresh + (x y) + (fd-in x (list 1 2 3 4)) + (fd-in y (list 1 2 3 4)) + (fd-plus x y 5) + (fd-label (list x y)) + (== q (list x y)))) + (list + (list 1 4) + (list 2 3) + (list 3 2) + (list 4 1))) + +(mk-test + "fd-plus-z-derived" + (run* q (fresh (z) (fd-plus 7 8 z) (== q z))) + (list 15)) + +(mk-tests-run!)