taco_threshold_automaton/lia_threshold_automaton/general_to_lia/
remove_div.rs

1//! Module for removing divisions by replacing them with multiplications by
2//! fractions.
3//!
4//! This module contains the logic to remove all division operators from an
5//! integer expression. Most importantly, it contains the conversion from
6//! `NoMinusIntegerExpr` to [`NoDivIntegerExpr`].
7//!
8//! Integer expressions without divisions (and subtraction) are represented by
9//! the type [`NoDivIntegerExpr`]. This type is a subset of the
10//! [`NonMinusIntegerExpr`] type, which is a subset of the
11//! [`crate::expressions::IntegerExpression`] type.
12
13use crate::expressions::{Atomic, Parameter, fraction::Fraction};
14use std::fmt::{Debug, Display};
15
16use super::{
17    ConstraintRewriteError,
18    remove_minus::{NonMinusIntegerExpr, NonMinusIntegerOp},
19};
20
21#[derive(Debug, PartialEq, Clone, Copy)]
22/// Integer operator that only allows for addition and multiplication
23pub enum NoDivIntegerOp {
24    /// +
25    Add,
26    /// *
27    Mul,
28}
29
30impl Display for NoDivIntegerOp {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            NoDivIntegerOp::Add => write!(f, "+"),
34            NoDivIntegerOp::Mul => write!(f, "*"),
35        }
36    }
37}
38
39impl From<NonMinusIntegerOp> for NoDivIntegerOp {
40    /// Convert a [`NonMinusIntegerOp`] to a [`NoDivIntegerOp`]
41    ///
42    /// Panics if the [`NonMinusIntegerOp`] is a [`NonMinusIntegerOp::Div`]
43    fn from(op: NonMinusIntegerOp) -> Self {
44        match op {
45            NonMinusIntegerOp::Add => NoDivIntegerOp::Add,
46            NonMinusIntegerOp::Mul => NoDivIntegerOp::Mul,
47            NonMinusIntegerOp::Div => panic!("Division operator not allowed"),
48        }
49    }
50}
51
52#[derive(Debug, PartialEq, Clone)]
53/// Integer expression without division operators
54pub enum NoDivIntegerExpr<T: Atomic> {
55    /// Atom of type T
56    Atom(T),
57    /// Constant fraction
58    Frac(Fraction),
59    /// Parameter
60    Param(Parameter),
61    /// Integer expression combining two integer expressions through an
62    /// arithmetic operator
63    BinaryExpr(
64        Box<NoDivIntegerExpr<T>>,
65        NoDivIntegerOp,
66        Box<NoDivIntegerExpr<T>>,
67    ),
68}
69
70impl<T> NonMinusIntegerExpr<T>
71where
72    T: Atomic,
73{
74    /// Remove all division operators from the expression
75    ///
76    /// This function recursively evaluates the expression and tries to simplify
77    /// it by removing all division operators. If the formula is not linear
78    /// arithmetic, an error is returned. This is the case if a parameter or
79    /// atom appears in the denominator of a division.
80    ///
81    /// Note that we do not implement full symbolic math, therefore it could be
82    /// the case that the expression is linear arithmetic but requires
83    /// simplification (e.g. that violating atom also appears in the denominator
84    /// and could be removed by a simplification step).
85    pub fn remove_div(self) -> Result<NoDivIntegerExpr<T>, ConstraintRewriteError> {
86        match self {
87            NonMinusIntegerExpr::Atom(a) => Ok(NoDivIntegerExpr::Atom(a)),
88            NonMinusIntegerExpr::Const(c) => Ok(NoDivIntegerExpr::Frac(c.into())),
89            NonMinusIntegerExpr::NegConst(c) => {
90                Ok(NoDivIntegerExpr::Frac(-Into::<Fraction>::into(c)))
91            }
92            NonMinusIntegerExpr::Param(parameter) => Ok(NoDivIntegerExpr::Param(parameter)),
93            NonMinusIntegerExpr::BinaryExpr(numerator, NonMinusIntegerOp::Div, denominator) => {
94                // try to parse the denominator into a fraction
95                // This must be possible for a linear arithmetic formula,
96                // otherwise we have an expression of the form 1/param or 1/var
97                let denominator = denominator
98                    .try_to_fraction()
99                    .ok_or(ConstraintRewriteError::NotLinearArithmetic)?;
100
101                // check if the numerator is also a constant and simplify
102                if let Some(numerator) = numerator.try_to_fraction() {
103                    return Ok(NoDivIntegerExpr::Frac(numerator / denominator));
104                }
105
106                // if the numerator is not a constant, we need to recursively
107                // evaluate the expression and add the fraction as a factor
108                Ok(NoDivIntegerExpr::BinaryExpr(
109                    Box::new(NoDivIntegerExpr::Frac(Fraction::from(1) / denominator)),
110                    NoDivIntegerOp::Mul,
111                    Box::new(numerator.remove_div()?),
112                ))
113            }
114            NonMinusIntegerExpr::BinaryExpr(lhs, op, rhs) => {
115                let lhs = lhs.remove_div()?;
116                let rhs = rhs.remove_div()?;
117                Ok(NoDivIntegerExpr::BinaryExpr(
118                    Box::new(lhs),
119                    op.into(),
120                    Box::new(rhs),
121                ))
122            }
123        }
124    }
125
126    /// Attempt to parse the expression into a fraction by recursively
127    /// evaluating the expression
128    ///
129    /// This function returns `None` if the expression contains any atoms or
130    /// parameters. Otherwise, it returns a fraction equivalent to the original
131    /// expression
132    pub fn try_to_fraction(&self) -> Option<Fraction> {
133        match self {
134            NonMinusIntegerExpr::Atom(_) | NonMinusIntegerExpr::Param(_) => None,
135            NonMinusIntegerExpr::Const(c) => Some(Fraction::new(*c, 1, false)),
136            NonMinusIntegerExpr::NegConst(c) => Some(Fraction::new(*c, 1, true)),
137            NonMinusIntegerExpr::BinaryExpr(lhs, op, rhs) => {
138                let lhs = lhs.try_to_fraction()?;
139                let rhs = rhs.try_to_fraction()?;
140
141                match op {
142                    NonMinusIntegerOp::Add => Some(lhs + rhs),
143                    NonMinusIntegerOp::Mul => Some(lhs * rhs),
144                    NonMinusIntegerOp::Div => Some(lhs / rhs),
145                }
146            }
147        }
148    }
149}
150
151impl<T: Atomic> TryFrom<NonMinusIntegerExpr<T>> for NoDivIntegerExpr<T> {
152    type Error = ConstraintRewriteError;
153
154    fn try_from(value: NonMinusIntegerExpr<T>) -> Result<Self, Self::Error> {
155        value.remove_div()
156    }
157}
158
159impl<T> Display for NoDivIntegerExpr<T>
160where
161    T: Atomic,
162{
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        match self {
165            NoDivIntegerExpr::Atom(a) => write!(f, "{a}"),
166            NoDivIntegerExpr::Frac(fraction) => write!(f, "{fraction}"),
167            NoDivIntegerExpr::Param(parameter) => write!(f, "{parameter}"),
168            NoDivIntegerExpr::BinaryExpr(lhs, op, rhs) => write!(f, "({lhs} {op} {rhs})"),
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::expressions::Variable;
177
178    #[test]
179    fn test_keep_no_div_add() {
180        // 3 + -5
181        let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
182            Box::new(NonMinusIntegerExpr::Const(3)),
183            NonMinusIntegerOp::Add,
184            Box::new(NonMinusIntegerExpr::NegConst(5)),
185        );
186
187        // (3/1) + (5/1)
188        let expected_expr = NoDivIntegerExpr::BinaryExpr(
189            Box::new(NoDivIntegerExpr::Frac(Fraction::new(3, 1, false))),
190            NoDivIntegerOp::Add,
191            Box::new(NoDivIntegerExpr::Frac(Fraction::new(5, 1, true))),
192        );
193
194        let got_expr = NoDivIntegerExpr::try_from(expr).unwrap();
195        assert_eq!(got_expr, expected_expr);
196    }
197
198    #[test]
199    fn test_keep_no_div_mul() {
200        // 4 + (-7 * -2)
201        let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
202            Box::new(NonMinusIntegerExpr::Const(4)),
203            NonMinusIntegerOp::Add,
204            Box::new(NonMinusIntegerExpr::BinaryExpr(
205                Box::new(NonMinusIntegerExpr::NegConst(7)),
206                NonMinusIntegerOp::Mul,
207                Box::new(NonMinusIntegerExpr::NegConst(2)),
208            )),
209        );
210
211        // (4/1) + (7/1 * 2/1)
212        let expected_expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
213            Box::new(NoDivIntegerExpr::Frac(Fraction::new(4, 1, false))),
214            NoDivIntegerOp::Add,
215            Box::new(NoDivIntegerExpr::BinaryExpr(
216                Box::new(NoDivIntegerExpr::Frac(Fraction::new(7, 1, true))),
217                NoDivIntegerOp::Mul,
218                Box::new(NoDivIntegerExpr::Frac(Fraction::new(2, 1, true))),
219            )),
220        );
221
222        let got_expr = NoDivIntegerExpr::try_from(expr).unwrap();
223        assert_eq!(got_expr, expected_expr);
224    }
225
226    #[test]
227    fn test_simple_division() {
228        // x / 2
229        let expr = NonMinusIntegerExpr::BinaryExpr(
230            Box::new(NonMinusIntegerExpr::Atom(Variable::new("x"))),
231            NonMinusIntegerOp::Div,
232            Box::new(NonMinusIntegerExpr::Const(2)),
233        );
234
235        // 1/2 * x
236        let expected_expr = NoDivIntegerExpr::BinaryExpr(
237            Box::new(NoDivIntegerExpr::Frac(Fraction::new(1, 2, false))),
238            NoDivIntegerOp::Mul,
239            Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
240        );
241
242        let got_expr = expr.remove_div().unwrap();
243        assert_eq!(got_expr, expected_expr);
244    }
245
246    #[test]
247    fn test_division_both_const() {
248        // (3 + 2) / -4
249        let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
250            Box::new(NonMinusIntegerExpr::BinaryExpr(
251                Box::new(NonMinusIntegerExpr::Const(3)),
252                NonMinusIntegerOp::Add,
253                Box::new(NonMinusIntegerExpr::Const(2)),
254            )),
255            NonMinusIntegerOp::Div,
256            Box::new(NonMinusIntegerExpr::NegConst(4)),
257        );
258
259        // (5/24)
260        let expected_expr = NoDivIntegerExpr::Frac(Fraction::new(5, 4, true));
261
262        let got_expr = expr.remove_div().unwrap();
263        assert_eq!(got_expr, expected_expr);
264    }
265
266    #[test]
267    fn test_double_division() {
268        // x / (1 / 2)
269        let expr = NonMinusIntegerExpr::BinaryExpr(
270            Box::new(NonMinusIntegerExpr::Atom(Variable::new("x"))),
271            NonMinusIntegerOp::Div,
272            Box::new(NonMinusIntegerExpr::BinaryExpr(
273                Box::new(NonMinusIntegerExpr::Const(1)),
274                NonMinusIntegerOp::Div,
275                Box::new(NonMinusIntegerExpr::Const(2)),
276            )),
277        );
278
279        // 2 * x
280        let expected_expr = NoDivIntegerExpr::BinaryExpr(
281            Box::new(NoDivIntegerExpr::Frac(Fraction::new(2, 1, false))),
282            NoDivIntegerOp::Mul,
283            Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
284        );
285
286        let got_expr = expr.remove_div().unwrap();
287        assert_eq!(got_expr, expected_expr);
288    }
289
290    #[test]
291    fn test_simple_from_var() {
292        // x / (3 + 2)
293        let expr = NonMinusIntegerExpr::BinaryExpr(
294            Box::new(NonMinusIntegerExpr::Atom(Variable::new("x"))),
295            NonMinusIntegerOp::Div,
296            Box::new(NonMinusIntegerExpr::BinaryExpr(
297                Box::new(NonMinusIntegerExpr::Const(3)),
298                NonMinusIntegerOp::Add,
299                Box::new(NonMinusIntegerExpr::Const(2)),
300            )),
301        );
302
303        // 1/5 * x
304        let expected_expr = NoDivIntegerExpr::BinaryExpr(
305            Box::new(NoDivIntegerExpr::Frac(Fraction::new(1, 5, false))),
306            NoDivIntegerOp::Mul,
307            Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
308        );
309
310        let got_expr = NoDivIntegerExpr::try_from(expr).unwrap();
311        assert_eq!(got_expr, expected_expr);
312    }
313
314    #[test]
315    fn try_simple_from_param() {
316        // n / (5 * 2)
317        let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
318            Box::new(NonMinusIntegerExpr::Param(Parameter::new("n"))),
319            NonMinusIntegerOp::Div,
320            Box::new(NonMinusIntegerExpr::BinaryExpr(
321                Box::new(NonMinusIntegerExpr::Const(5)),
322                NonMinusIntegerOp::Mul,
323                Box::new(NonMinusIntegerExpr::Const(2)),
324            )),
325        );
326
327        // 1/10 * n
328        let expected_expr = NoDivIntegerExpr::BinaryExpr(
329            Box::new(NoDivIntegerExpr::Frac(Fraction::new(1, 10, false))),
330            NoDivIntegerOp::Mul,
331            Box::new(NoDivIntegerExpr::Param(Parameter::new("n"))),
332        );
333        let got_expr = NoDivIntegerExpr::try_from(expr).unwrap();
334        assert_eq!(got_expr, expected_expr);
335    }
336
337    #[test]
338    fn test_error_on_div_by_var() {
339        // 1 / x
340        let expr = NonMinusIntegerExpr::BinaryExpr(
341            Box::new(NonMinusIntegerExpr::Const(1)),
342            NonMinusIntegerOp::Div,
343            Box::new(NonMinusIntegerExpr::Atom(Variable::new("x"))),
344        );
345        let e = NoDivIntegerExpr::try_from(expr);
346        assert!(e.is_err());
347        assert!(matches!(
348            e.unwrap_err(),
349            ConstraintRewriteError::NotLinearArithmetic
350        ));
351    }
352
353    #[test]
354    fn test_error_on_div_by_param() {
355        // 1 / (n + 5)
356        let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
357            Box::new(NonMinusIntegerExpr::Const(1)),
358            NonMinusIntegerOp::Div,
359            Box::new(NonMinusIntegerExpr::BinaryExpr(
360                Box::new(NonMinusIntegerExpr::Param(Parameter::new("n"))),
361                NonMinusIntegerOp::Add,
362                Box::new(NonMinusIntegerExpr::Const(5)),
363            )),
364        );
365
366        let e = NoDivIntegerExpr::try_from(expr);
367        assert!(e.is_err());
368        assert!(matches!(
369            e.unwrap_err(),
370            ConstraintRewriteError::NotLinearArithmetic
371        ));
372    }
373
374    #[test]
375    fn test_display_no_div_integer_expr() {
376        let expr = NoDivIntegerExpr::BinaryExpr(
377            Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
378            NoDivIntegerOp::Mul,
379            Box::new(NoDivIntegerExpr::Frac(Fraction::new(1, 2, false))),
380        );
381        assert_eq!(format!("{expr}"), "(x * 1/2)");
382
383        let expr = NoDivIntegerExpr::BinaryExpr(
384            Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
385            NoDivIntegerOp::Add,
386            Box::new(NoDivIntegerExpr::Param(Parameter::new("p"))),
387        );
388
389        assert_eq!(format!("{expr}"), "(x + p)");
390    }
391
392    #[test]
393    #[should_panic]
394    fn test_no_div_op_div() {
395        let _ = NoDivIntegerOp::from(NonMinusIntegerOp::Div);
396    }
397}