taco_threshold_automaton/lia_threshold_automaton/general_to_lia/
split_pair.rs

1//! This module contains the logic to split an integer expression into pairs of
2//! factors and an atom / parameter or constant.
3//!
4//! The general outline of the conversion is described in more detail in
5//! [`../general_to_lia.rs`](general_to_lia.rs).
6
7use log::debug;
8use std::fmt::{Debug, Display};
9
10use crate::expressions::{Atomic, Parameter, fraction::Fraction};
11
12use super::{
13    ConstraintRewriteError,
14    remove_div::{NoDivIntegerExpr, NoDivIntegerOp},
15};
16
17/// Pair of an atom and a factor
18#[derive(Debug, Clone, PartialEq)]
19pub enum AtomFactorPair<T: Display + Debug + Clone> {
20    Atom(T, Fraction),
21    Param(Parameter, Fraction),
22    Const(Fraction),
23}
24
25#[derive(Debug, Clone, PartialEq)]
26/// Type to represent a linear arithmetic expression
27pub struct LinearIntegerExpr<T: Display + Debug + Clone> {
28    /// Pairs part of a sum
29    pairs: Vec<AtomFactorPair<T>>,
30}
31
32impl<T: Atomic> NoDivIntegerExpr<T> {
33    /// Split the expression into pairs of factors and an atom / parameter or constant
34    pub fn split_into_factor_pairs(self) -> Result<LinearIntegerExpr<T>, ConstraintRewriteError> {
35        let pairs = self.collect_pairs()?;
36
37        Ok(LinearIntegerExpr { pairs })
38    }
39
40    fn collect_pairs(self) -> Result<Vec<AtomFactorPair<T>>, ConstraintRewriteError> {
41        match self {
42            NoDivIntegerExpr::Atom(a) => Ok(vec![AtomFactorPair::Atom(a, 1.into())]),
43            NoDivIntegerExpr::Frac(f) => Ok(vec![AtomFactorPair::Const(f)]),
44            NoDivIntegerExpr::Param(p) => Ok(vec![AtomFactorPair::Param(p, 1.into())]),
45            NoDivIntegerExpr::BinaryExpr(lhs, op, rhs) => match op {
46                NoDivIntegerOp::Add => {
47                    let mut lhs = lhs.collect_pairs()?;
48                    let rhs = rhs.collect_pairs()?;
49
50                    lhs.extend(rhs);
51                    Ok(lhs)
52                }
53                NoDivIntegerOp::Mul => {
54                    let lhs_f = lhs.clone().try_to_fraction();
55                    let rhs_f = rhs.clone().try_to_fraction();
56
57                    // both sides of the multiplication contain parameters or
58                    // atoms, the expression is not linear arithmetic
59                    if lhs_f.is_none() && rhs_f.is_none() {
60                        debug!("Failed to split expression ({lhs} * {rhs}) into factor pairs");
61                        return Err(ConstraintRewriteError::NotLinearArithmetic);
62                    }
63
64                    // if both sides are constants, we can simplify the expression
65                    if let (Some(lhs_f), Some(rhs_f)) = (lhs_f, rhs_f) {
66                        return Ok(vec![AtomFactorPair::Const(lhs_f * rhs_f)]);
67                    }
68
69                    // one is constant, the other contains parameters or atoms
70                    let (const_, non_const_expr) = if let Some(lhs) = lhs_f {
71                        (lhs, rhs.collect_pairs()?)
72                    } else {
73                        (rhs_f.unwrap(), lhs.collect_pairs()?)
74                    };
75
76                    // add constant factor to all of the pairs
77                    Ok(non_const_expr
78                        .into_iter()
79                        .map(|atom| match atom {
80                            AtomFactorPair::Atom(a, f) => AtomFactorPair::Atom(a, f * const_),
81                            AtomFactorPair::Param(p, f) => AtomFactorPair::Param(p, f * const_),
82                            AtomFactorPair::Const(f) => AtomFactorPair::Const(f * const_), // unreachable
83                        })
84                        .collect())
85                }
86            },
87        }
88    }
89
90    /// Try to convert the expression to a fraction
91    pub fn try_to_fraction(self) -> Option<Fraction> {
92        match self {
93            NoDivIntegerExpr::Atom(_) | NoDivIntegerExpr::Param(_) => None,
94            NoDivIntegerExpr::Frac(f) => Some(f),
95            NoDivIntegerExpr::BinaryExpr(lhs, op, rhs) => {
96                let lhs = lhs.try_to_fraction()?;
97                let rhs = rhs.try_to_fraction()?;
98
99                match op {
100                    NoDivIntegerOp::Add => Some(lhs + rhs),
101                    NoDivIntegerOp::Mul => Some(lhs * rhs),
102                }
103            }
104        }
105    }
106}
107
108impl<T: Atomic> TryFrom<NoDivIntegerExpr<T>> for LinearIntegerExpr<T> {
109    type Error = ConstraintRewriteError;
110
111    fn try_from(value: NoDivIntegerExpr<T>) -> Result<Self, Self::Error> {
112        value.split_into_factor_pairs()
113    }
114}
115
116impl<T: Display + Debug + Clone> LinearIntegerExpr<T> {
117    /// Get the constant factor of the expression
118    ///
119    /// For an expression of the form `a * x + b * y + p * z + c`, where `a, b`
120    /// and `p` are coefficients and `x,y,z` are parameters and c a constant,
121    /// this function returns `c`.
122    pub fn get_const_factor(&self) -> Fraction {
123        self.pairs
124            .iter()
125            .filter_map(|pair| match pair {
126                AtomFactorPair::Const(f) => Some(*f),
127                _ => None,
128            })
129            .fold(Fraction::from(0), |acc, f| acc + f)
130    }
131
132    /// Get all atom factor pairs of the expression
133    ///
134    /// For an expression of the form `a * x + b * y + p * z + c`, where `x` and `y` are atoms,
135    /// `a` and `b` are their coefficients, `p` is a parameter, and `z` is its coefficient,
136    /// this function returns the pairs `(x, a), (y, b)`.
137    pub fn get_atom_factor_pairs(&self) -> impl Iterator<Item = (&T, &Fraction)> {
138        self.pairs
139            .iter()
140            .filter(|pair| matches!(pair, AtomFactorPair::Atom(_, _)))
141            .map(|pair| match pair {
142                AtomFactorPair::Atom(a, f) => (a, f),
143                _ => unreachable!(),
144            })
145    }
146
147    /// Get all parameter factor pairs of the expression
148    ///
149    /// For an expression of the form `a * x + b * y + p * z + c`, where `x, y, z` are parameters
150    /// and `a, b, p` are their coefficients, this function returns the pairs `(p, z)`, where
151    /// `p` is the parameter and `z` is its coefficient.
152    pub fn get_param_factor_pairs(&self) -> impl Iterator<Item = (&Parameter, &Fraction)> {
153        self.pairs
154            .iter()
155            .filter(|pair| matches!(pair, AtomFactorPair::Param(_, _)))
156            .map(|pair| match pair {
157                AtomFactorPair::Param(p, f) => (p, f),
158                _ => unreachable!(),
159            })
160    }
161}
162
163impl<T: Display + Debug + Clone> Display for AtomFactorPair<T> {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        match self {
166            AtomFactorPair::Atom(a, frac) => write!(f, "{a} * {frac}"),
167            AtomFactorPair::Param(p, frac) => write!(f, "{p} * {frac}"),
168            AtomFactorPair::Const(c) => write!(f, "{c}"),
169        }
170    }
171}
172
173impl<T: Display + Debug + Clone> Display for LinearIntegerExpr<T> {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        let mut it = self.pairs.iter();
176        if let Some(pair) = it.next() {
177            write!(f, "{pair}")?;
178        }
179
180        it.try_for_each(|pair| write!(f, " + {pair}"))
181    }
182}
183
184#[cfg(test)]
185mod test {
186    use crate::expressions::Variable;
187
188    use super::*;
189
190    #[test]
191    fn test_simple_frac() {
192        // 1
193        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::Frac(2.into());
194
195        let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
196            pairs: vec![AtomFactorPair::Const(2.into())],
197        };
198
199        let got = LinearIntegerExpr::try_from(expr).unwrap();
200
201        assert_eq!(got, expected);
202    }
203
204    #[test]
205    fn simple_addition_no_mul() {
206        // 1 + n
207        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
208            Box::new(NoDivIntegerExpr::Frac(1.into())),
209            NoDivIntegerOp::Add,
210            Box::new(NoDivIntegerExpr::Param(Parameter::new("n"))),
211        );
212
213        let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
214            pairs: vec![
215                AtomFactorPair::Const(1.into()),
216                AtomFactorPair::Param(Parameter::new("n"), 1.into()),
217            ],
218        };
219
220        let got = LinearIntegerExpr::try_from(expr).unwrap();
221        assert_eq!(got, expected);
222    }
223
224    #[test]
225    fn test_simple_mul_lhs_const() {
226        // 1 * x
227        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
228            Box::new(NoDivIntegerExpr::Frac(1.into())),
229            NoDivIntegerOp::Mul,
230            Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
231        );
232
233        let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
234            pairs: vec![AtomFactorPair::Atom(Variable::new("x"), 1.into())],
235        };
236
237        let got = LinearIntegerExpr::try_from(expr).unwrap();
238        assert_eq!(got, expected);
239    }
240
241    #[test]
242    fn test_simple_mul_rhs_const() {
243        // x * 1
244        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
245            Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
246            NoDivIntegerOp::Mul,
247            Box::new(NoDivIntegerExpr::Frac(1.into())),
248        );
249
250        let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
251            pairs: vec![AtomFactorPair::Atom(Variable::new("x"), 1.into())],
252        };
253
254        let got = LinearIntegerExpr::try_from(expr).unwrap();
255        assert_eq!(got, expected);
256    }
257
258    #[test]
259    fn test_simple_mul_const_param() {
260        // 1 * n
261        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
262            Box::new(NoDivIntegerExpr::Frac(1.into())),
263            NoDivIntegerOp::Mul,
264            Box::new(NoDivIntegerExpr::Param(Parameter::new("n"))),
265        );
266
267        let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
268            pairs: vec![AtomFactorPair::Param(Parameter::new("n"), 1.into())],
269        };
270
271        let got = LinearIntegerExpr::try_from(expr).unwrap();
272        assert_eq!(got, expected);
273    }
274
275    #[test]
276    fn test_simple_mul_both_const() {
277        // 1 * 5
278        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
279            Box::new(NoDivIntegerExpr::Frac(1.into())),
280            NoDivIntegerOp::Mul,
281            Box::new(NoDivIntegerExpr::Frac(5.into())),
282        );
283
284        let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
285            pairs: vec![AtomFactorPair::Const(5.into())],
286        };
287
288        let got = LinearIntegerExpr::try_from(expr).unwrap();
289        assert_eq!(got, expected);
290    }
291
292    #[test]
293    fn test_simple_mul_both_non_const() {
294        // x * n
295        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
296            Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
297            NoDivIntegerOp::Mul,
298            Box::new(NoDivIntegerExpr::Param(Parameter::new("n"))),
299        );
300
301        let got = LinearIntegerExpr::try_from(expr);
302        assert!(got.is_err());
303        assert!(matches!(
304            got.unwrap_err(),
305            ConstraintRewriteError::NotLinearArithmetic
306        ));
307    }
308
309    #[test]
310    fn test_try_to_fraction() {
311        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::Frac(5.into());
312        assert_eq!(expr.try_to_fraction(), Some(5.into()));
313
314        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::Atom(Variable::new("x"));
315        assert_eq!(expr.try_to_fraction(), None);
316
317        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
318            Box::new(NoDivIntegerExpr::Frac(5.into())),
319            NoDivIntegerOp::Mul,
320            Box::new(NoDivIntegerExpr::Frac(3.into())),
321        );
322        assert_eq!(expr.try_to_fraction(), Some(15.into()));
323
324        let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
325            Box::new(NoDivIntegerExpr::Frac(2.into())),
326            NoDivIntegerOp::Add,
327            Box::new(NoDivIntegerExpr::Frac(5.into())),
328        );
329        assert_eq!(expr.try_to_fraction(), Some(7.into()))
330    }
331
332    #[test]
333    fn test_get_atom_factor_pairs() {
334        let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
335            pairs: vec![
336                AtomFactorPair::Atom(Variable::new("x"), 5.into()),
337                AtomFactorPair::Param(Parameter::new("p"), 3.into()),
338                AtomFactorPair::Const(7.into()),
339            ],
340        };
341
342        let pairs: Vec<(&Variable, &Fraction)> = expr.get_atom_factor_pairs().collect();
343        assert_eq!(pairs.len(), 1);
344        assert_eq!(pairs[0], (&Variable::new("x"), &5.into()));
345    }
346
347    #[test]
348    fn test_get_param_factor_pairs() {
349        let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
350            pairs: vec![
351                AtomFactorPair::Atom(Variable::new("x"), 5.into()),
352                AtomFactorPair::Param(Parameter::new("p"), 3.into()),
353                AtomFactorPair::Const(7.into()),
354            ],
355        };
356
357        let pairs: Vec<(&Parameter, &Fraction)> = expr.get_param_factor_pairs().collect();
358        assert_eq!(pairs.len(), 1);
359        assert_eq!(pairs[0], (&Parameter::new("p"), &3.into()));
360    }
361
362    #[test]
363    fn test_get_const_factor() {
364        let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
365            pairs: vec![
366                AtomFactorPair::Atom(Variable::new("x"), 5.into()),
367                AtomFactorPair::Param(Parameter::new("p"), 3.into()),
368                AtomFactorPair::Const(7.into()),
369                AtomFactorPair::Const(5.into()),
370            ],
371        };
372
373        assert_eq!(expr.get_const_factor(), 12.into());
374    }
375
376    #[test]
377    fn test_display_linear_integer_arithmetic_expr() {
378        let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
379            pairs: vec![
380                AtomFactorPair::Atom(Variable::new("x"), 5.into()),
381                AtomFactorPair::Param(Parameter::new("p"), 3.into()),
382                AtomFactorPair::Const(7.into()),
383            ],
384        };
385        assert_eq!(expr.to_string(), "x * 5 + p * 3 + 7");
386
387        let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
388            pairs: vec![AtomFactorPair::Const(7.into())],
389        };
390        assert_eq!(expr.to_string(), "7");
391    }
392}