taco_threshold_automaton/lia_threshold_automaton/general_to_lia/
split_pair.rs1use log::debug;
8use std::fmt::{Debug, Display};
9
10use crate::expressions::{Atomic, Parameter, fraction::Fraction};
11
12use super::{
13 ConstraintRewriteError,
14 remove_div::{NoDivIntegerExpr, NoDivIntegerOp},
15};
16
17#[derive(Debug, Clone, PartialEq)]
19pub enum AtomFactorPair<T: Display + Debug + Clone> {
20 Atom(T, Fraction),
21 Param(Parameter, Fraction),
22 Const(Fraction),
23}
24
25#[derive(Debug, Clone, PartialEq)]
26pub struct LinearIntegerExpr<T: Display + Debug + Clone> {
28 pairs: Vec<AtomFactorPair<T>>,
30}
31
32impl<T: Atomic> NoDivIntegerExpr<T> {
33 pub fn split_into_factor_pairs(self) -> Result<LinearIntegerExpr<T>, ConstraintRewriteError> {
35 let pairs = self.collect_pairs()?;
36
37 Ok(LinearIntegerExpr { pairs })
38 }
39
40 fn collect_pairs(self) -> Result<Vec<AtomFactorPair<T>>, ConstraintRewriteError> {
41 match self {
42 NoDivIntegerExpr::Atom(a) => Ok(vec![AtomFactorPair::Atom(a, 1.into())]),
43 NoDivIntegerExpr::Frac(f) => Ok(vec![AtomFactorPair::Const(f)]),
44 NoDivIntegerExpr::Param(p) => Ok(vec![AtomFactorPair::Param(p, 1.into())]),
45 NoDivIntegerExpr::BinaryExpr(lhs, op, rhs) => match op {
46 NoDivIntegerOp::Add => {
47 let mut lhs = lhs.collect_pairs()?;
48 let rhs = rhs.collect_pairs()?;
49
50 lhs.extend(rhs);
51 Ok(lhs)
52 }
53 NoDivIntegerOp::Mul => {
54 let lhs_f = lhs.clone().try_to_fraction();
55 let rhs_f = rhs.clone().try_to_fraction();
56
57 if lhs_f.is_none() && rhs_f.is_none() {
60 debug!("Failed to split expression ({lhs} * {rhs}) into factor pairs");
61 return Err(ConstraintRewriteError::NotLinearArithmetic);
62 }
63
64 if let (Some(lhs_f), Some(rhs_f)) = (lhs_f, rhs_f) {
66 return Ok(vec![AtomFactorPair::Const(lhs_f * rhs_f)]);
67 }
68
69 let (const_, non_const_expr) = if let Some(lhs) = lhs_f {
71 (lhs, rhs.collect_pairs()?)
72 } else {
73 (rhs_f.unwrap(), lhs.collect_pairs()?)
74 };
75
76 Ok(non_const_expr
78 .into_iter()
79 .map(|atom| match atom {
80 AtomFactorPair::Atom(a, f) => AtomFactorPair::Atom(a, f * const_),
81 AtomFactorPair::Param(p, f) => AtomFactorPair::Param(p, f * const_),
82 AtomFactorPair::Const(f) => AtomFactorPair::Const(f * const_), })
84 .collect())
85 }
86 },
87 }
88 }
89
90 pub fn try_to_fraction(self) -> Option<Fraction> {
92 match self {
93 NoDivIntegerExpr::Atom(_) | NoDivIntegerExpr::Param(_) => None,
94 NoDivIntegerExpr::Frac(f) => Some(f),
95 NoDivIntegerExpr::BinaryExpr(lhs, op, rhs) => {
96 let lhs = lhs.try_to_fraction()?;
97 let rhs = rhs.try_to_fraction()?;
98
99 match op {
100 NoDivIntegerOp::Add => Some(lhs + rhs),
101 NoDivIntegerOp::Mul => Some(lhs * rhs),
102 }
103 }
104 }
105 }
106}
107
108impl<T: Atomic> TryFrom<NoDivIntegerExpr<T>> for LinearIntegerExpr<T> {
109 type Error = ConstraintRewriteError;
110
111 fn try_from(value: NoDivIntegerExpr<T>) -> Result<Self, Self::Error> {
112 value.split_into_factor_pairs()
113 }
114}
115
116impl<T: Display + Debug + Clone> LinearIntegerExpr<T> {
117 pub fn get_const_factor(&self) -> Fraction {
123 self.pairs
124 .iter()
125 .filter_map(|pair| match pair {
126 AtomFactorPair::Const(f) => Some(*f),
127 _ => None,
128 })
129 .fold(Fraction::from(0), |acc, f| acc + f)
130 }
131
132 pub fn get_atom_factor_pairs(&self) -> impl Iterator<Item = (&T, &Fraction)> {
138 self.pairs
139 .iter()
140 .filter(|pair| matches!(pair, AtomFactorPair::Atom(_, _)))
141 .map(|pair| match pair {
142 AtomFactorPair::Atom(a, f) => (a, f),
143 _ => unreachable!(),
144 })
145 }
146
147 pub fn get_param_factor_pairs(&self) -> impl Iterator<Item = (&Parameter, &Fraction)> {
153 self.pairs
154 .iter()
155 .filter(|pair| matches!(pair, AtomFactorPair::Param(_, _)))
156 .map(|pair| match pair {
157 AtomFactorPair::Param(p, f) => (p, f),
158 _ => unreachable!(),
159 })
160 }
161}
162
163impl<T: Display + Debug + Clone> Display for AtomFactorPair<T> {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match self {
166 AtomFactorPair::Atom(a, frac) => write!(f, "{a} * {frac}"),
167 AtomFactorPair::Param(p, frac) => write!(f, "{p} * {frac}"),
168 AtomFactorPair::Const(c) => write!(f, "{c}"),
169 }
170 }
171}
172
173impl<T: Display + Debug + Clone> Display for LinearIntegerExpr<T> {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 let mut it = self.pairs.iter();
176 if let Some(pair) = it.next() {
177 write!(f, "{pair}")?;
178 }
179
180 it.try_for_each(|pair| write!(f, " + {pair}"))
181 }
182}
183
184#[cfg(test)]
185mod test {
186 use crate::expressions::Variable;
187
188 use super::*;
189
190 #[test]
191 fn test_simple_frac() {
192 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::Frac(2.into());
194
195 let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
196 pairs: vec![AtomFactorPair::Const(2.into())],
197 };
198
199 let got = LinearIntegerExpr::try_from(expr).unwrap();
200
201 assert_eq!(got, expected);
202 }
203
204 #[test]
205 fn simple_addition_no_mul() {
206 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
208 Box::new(NoDivIntegerExpr::Frac(1.into())),
209 NoDivIntegerOp::Add,
210 Box::new(NoDivIntegerExpr::Param(Parameter::new("n"))),
211 );
212
213 let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
214 pairs: vec![
215 AtomFactorPair::Const(1.into()),
216 AtomFactorPair::Param(Parameter::new("n"), 1.into()),
217 ],
218 };
219
220 let got = LinearIntegerExpr::try_from(expr).unwrap();
221 assert_eq!(got, expected);
222 }
223
224 #[test]
225 fn test_simple_mul_lhs_const() {
226 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
228 Box::new(NoDivIntegerExpr::Frac(1.into())),
229 NoDivIntegerOp::Mul,
230 Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
231 );
232
233 let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
234 pairs: vec![AtomFactorPair::Atom(Variable::new("x"), 1.into())],
235 };
236
237 let got = LinearIntegerExpr::try_from(expr).unwrap();
238 assert_eq!(got, expected);
239 }
240
241 #[test]
242 fn test_simple_mul_rhs_const() {
243 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
245 Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
246 NoDivIntegerOp::Mul,
247 Box::new(NoDivIntegerExpr::Frac(1.into())),
248 );
249
250 let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
251 pairs: vec![AtomFactorPair::Atom(Variable::new("x"), 1.into())],
252 };
253
254 let got = LinearIntegerExpr::try_from(expr).unwrap();
255 assert_eq!(got, expected);
256 }
257
258 #[test]
259 fn test_simple_mul_const_param() {
260 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
262 Box::new(NoDivIntegerExpr::Frac(1.into())),
263 NoDivIntegerOp::Mul,
264 Box::new(NoDivIntegerExpr::Param(Parameter::new("n"))),
265 );
266
267 let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
268 pairs: vec![AtomFactorPair::Param(Parameter::new("n"), 1.into())],
269 };
270
271 let got = LinearIntegerExpr::try_from(expr).unwrap();
272 assert_eq!(got, expected);
273 }
274
275 #[test]
276 fn test_simple_mul_both_const() {
277 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
279 Box::new(NoDivIntegerExpr::Frac(1.into())),
280 NoDivIntegerOp::Mul,
281 Box::new(NoDivIntegerExpr::Frac(5.into())),
282 );
283
284 let expected: LinearIntegerExpr<Variable> = LinearIntegerExpr {
285 pairs: vec![AtomFactorPair::Const(5.into())],
286 };
287
288 let got = LinearIntegerExpr::try_from(expr).unwrap();
289 assert_eq!(got, expected);
290 }
291
292 #[test]
293 fn test_simple_mul_both_non_const() {
294 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
296 Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
297 NoDivIntegerOp::Mul,
298 Box::new(NoDivIntegerExpr::Param(Parameter::new("n"))),
299 );
300
301 let got = LinearIntegerExpr::try_from(expr);
302 assert!(got.is_err());
303 assert!(matches!(
304 got.unwrap_err(),
305 ConstraintRewriteError::NotLinearArithmetic
306 ));
307 }
308
309 #[test]
310 fn test_try_to_fraction() {
311 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::Frac(5.into());
312 assert_eq!(expr.try_to_fraction(), Some(5.into()));
313
314 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::Atom(Variable::new("x"));
315 assert_eq!(expr.try_to_fraction(), None);
316
317 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
318 Box::new(NoDivIntegerExpr::Frac(5.into())),
319 NoDivIntegerOp::Mul,
320 Box::new(NoDivIntegerExpr::Frac(3.into())),
321 );
322 assert_eq!(expr.try_to_fraction(), Some(15.into()));
323
324 let expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
325 Box::new(NoDivIntegerExpr::Frac(2.into())),
326 NoDivIntegerOp::Add,
327 Box::new(NoDivIntegerExpr::Frac(5.into())),
328 );
329 assert_eq!(expr.try_to_fraction(), Some(7.into()))
330 }
331
332 #[test]
333 fn test_get_atom_factor_pairs() {
334 let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
335 pairs: vec![
336 AtomFactorPair::Atom(Variable::new("x"), 5.into()),
337 AtomFactorPair::Param(Parameter::new("p"), 3.into()),
338 AtomFactorPair::Const(7.into()),
339 ],
340 };
341
342 let pairs: Vec<(&Variable, &Fraction)> = expr.get_atom_factor_pairs().collect();
343 assert_eq!(pairs.len(), 1);
344 assert_eq!(pairs[0], (&Variable::new("x"), &5.into()));
345 }
346
347 #[test]
348 fn test_get_param_factor_pairs() {
349 let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
350 pairs: vec![
351 AtomFactorPair::Atom(Variable::new("x"), 5.into()),
352 AtomFactorPair::Param(Parameter::new("p"), 3.into()),
353 AtomFactorPair::Const(7.into()),
354 ],
355 };
356
357 let pairs: Vec<(&Parameter, &Fraction)> = expr.get_param_factor_pairs().collect();
358 assert_eq!(pairs.len(), 1);
359 assert_eq!(pairs[0], (&Parameter::new("p"), &3.into()));
360 }
361
362 #[test]
363 fn test_get_const_factor() {
364 let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
365 pairs: vec![
366 AtomFactorPair::Atom(Variable::new("x"), 5.into()),
367 AtomFactorPair::Param(Parameter::new("p"), 3.into()),
368 AtomFactorPair::Const(7.into()),
369 AtomFactorPair::Const(5.into()),
370 ],
371 };
372
373 assert_eq!(expr.get_const_factor(), 12.into());
374 }
375
376 #[test]
377 fn test_display_linear_integer_arithmetic_expr() {
378 let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
379 pairs: vec![
380 AtomFactorPair::Atom(Variable::new("x"), 5.into()),
381 AtomFactorPair::Param(Parameter::new("p"), 3.into()),
382 AtomFactorPair::Const(7.into()),
383 ],
384 };
385 assert_eq!(expr.to_string(), "x * 5 + p * 3 + 7");
386
387 let expr: LinearIntegerExpr<Variable> = LinearIntegerExpr {
388 pairs: vec![AtomFactorPair::Const(7.into())],
389 };
390 assert_eq!(expr.to_string(), "7");
391 }
392}