1use crate::expressions::{Atomic, Parameter, fraction::Fraction};
14use std::fmt::{Debug, Display};
15
16use super::{
17 ConstraintRewriteError,
18 remove_minus::{NonMinusIntegerExpr, NonMinusIntegerOp},
19};
20
21#[derive(Debug, PartialEq, Clone, Copy)]
22pub enum NoDivIntegerOp {
24 Add,
26 Mul,
28}
29
30impl Display for NoDivIntegerOp {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 NoDivIntegerOp::Add => write!(f, "+"),
34 NoDivIntegerOp::Mul => write!(f, "*"),
35 }
36 }
37}
38
39impl From<NonMinusIntegerOp> for NoDivIntegerOp {
40 fn from(op: NonMinusIntegerOp) -> Self {
44 match op {
45 NonMinusIntegerOp::Add => NoDivIntegerOp::Add,
46 NonMinusIntegerOp::Mul => NoDivIntegerOp::Mul,
47 NonMinusIntegerOp::Div => panic!("Division operator not allowed"),
48 }
49 }
50}
51
52#[derive(Debug, PartialEq, Clone)]
53pub enum NoDivIntegerExpr<T: Atomic> {
55 Atom(T),
57 Frac(Fraction),
59 Param(Parameter),
61 BinaryExpr(
64 Box<NoDivIntegerExpr<T>>,
65 NoDivIntegerOp,
66 Box<NoDivIntegerExpr<T>>,
67 ),
68}
69
70impl<T> NonMinusIntegerExpr<T>
71where
72 T: Atomic,
73{
74 pub fn remove_div(self) -> Result<NoDivIntegerExpr<T>, ConstraintRewriteError> {
86 match self {
87 NonMinusIntegerExpr::Atom(a) => Ok(NoDivIntegerExpr::Atom(a)),
88 NonMinusIntegerExpr::Const(c) => Ok(NoDivIntegerExpr::Frac(c.into())),
89 NonMinusIntegerExpr::NegConst(c) => {
90 Ok(NoDivIntegerExpr::Frac(-Into::<Fraction>::into(c)))
91 }
92 NonMinusIntegerExpr::Param(parameter) => Ok(NoDivIntegerExpr::Param(parameter)),
93 NonMinusIntegerExpr::BinaryExpr(numerator, NonMinusIntegerOp::Div, denominator) => {
94 let denominator = denominator
98 .try_to_fraction()
99 .ok_or(ConstraintRewriteError::NotLinearArithmetic)?;
100
101 if let Some(numerator) = numerator.try_to_fraction() {
103 return Ok(NoDivIntegerExpr::Frac(numerator / denominator));
104 }
105
106 Ok(NoDivIntegerExpr::BinaryExpr(
109 Box::new(NoDivIntegerExpr::Frac(Fraction::from(1) / denominator)),
110 NoDivIntegerOp::Mul,
111 Box::new(numerator.remove_div()?),
112 ))
113 }
114 NonMinusIntegerExpr::BinaryExpr(lhs, op, rhs) => {
115 let lhs = lhs.remove_div()?;
116 let rhs = rhs.remove_div()?;
117 Ok(NoDivIntegerExpr::BinaryExpr(
118 Box::new(lhs),
119 op.into(),
120 Box::new(rhs),
121 ))
122 }
123 }
124 }
125
126 pub fn try_to_fraction(&self) -> Option<Fraction> {
133 match self {
134 NonMinusIntegerExpr::Atom(_) | NonMinusIntegerExpr::Param(_) => None,
135 NonMinusIntegerExpr::Const(c) => Some(Fraction::new(*c, 1, false)),
136 NonMinusIntegerExpr::NegConst(c) => Some(Fraction::new(*c, 1, true)),
137 NonMinusIntegerExpr::BinaryExpr(lhs, op, rhs) => {
138 let lhs = lhs.try_to_fraction()?;
139 let rhs = rhs.try_to_fraction()?;
140
141 match op {
142 NonMinusIntegerOp::Add => Some(lhs + rhs),
143 NonMinusIntegerOp::Mul => Some(lhs * rhs),
144 NonMinusIntegerOp::Div => Some(lhs / rhs),
145 }
146 }
147 }
148 }
149}
150
151impl<T: Atomic> TryFrom<NonMinusIntegerExpr<T>> for NoDivIntegerExpr<T> {
152 type Error = ConstraintRewriteError;
153
154 fn try_from(value: NonMinusIntegerExpr<T>) -> Result<Self, Self::Error> {
155 value.remove_div()
156 }
157}
158
159impl<T> Display for NoDivIntegerExpr<T>
160where
161 T: Atomic,
162{
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 match self {
165 NoDivIntegerExpr::Atom(a) => write!(f, "{a}"),
166 NoDivIntegerExpr::Frac(fraction) => write!(f, "{fraction}"),
167 NoDivIntegerExpr::Param(parameter) => write!(f, "{parameter}"),
168 NoDivIntegerExpr::BinaryExpr(lhs, op, rhs) => write!(f, "({lhs} {op} {rhs})"),
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::expressions::Variable;
177
178 #[test]
179 fn test_keep_no_div_add() {
180 let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
182 Box::new(NonMinusIntegerExpr::Const(3)),
183 NonMinusIntegerOp::Add,
184 Box::new(NonMinusIntegerExpr::NegConst(5)),
185 );
186
187 let expected_expr = NoDivIntegerExpr::BinaryExpr(
189 Box::new(NoDivIntegerExpr::Frac(Fraction::new(3, 1, false))),
190 NoDivIntegerOp::Add,
191 Box::new(NoDivIntegerExpr::Frac(Fraction::new(5, 1, true))),
192 );
193
194 let got_expr = NoDivIntegerExpr::try_from(expr).unwrap();
195 assert_eq!(got_expr, expected_expr);
196 }
197
198 #[test]
199 fn test_keep_no_div_mul() {
200 let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
202 Box::new(NonMinusIntegerExpr::Const(4)),
203 NonMinusIntegerOp::Add,
204 Box::new(NonMinusIntegerExpr::BinaryExpr(
205 Box::new(NonMinusIntegerExpr::NegConst(7)),
206 NonMinusIntegerOp::Mul,
207 Box::new(NonMinusIntegerExpr::NegConst(2)),
208 )),
209 );
210
211 let expected_expr: NoDivIntegerExpr<Variable> = NoDivIntegerExpr::BinaryExpr(
213 Box::new(NoDivIntegerExpr::Frac(Fraction::new(4, 1, false))),
214 NoDivIntegerOp::Add,
215 Box::new(NoDivIntegerExpr::BinaryExpr(
216 Box::new(NoDivIntegerExpr::Frac(Fraction::new(7, 1, true))),
217 NoDivIntegerOp::Mul,
218 Box::new(NoDivIntegerExpr::Frac(Fraction::new(2, 1, true))),
219 )),
220 );
221
222 let got_expr = NoDivIntegerExpr::try_from(expr).unwrap();
223 assert_eq!(got_expr, expected_expr);
224 }
225
226 #[test]
227 fn test_simple_division() {
228 let expr = NonMinusIntegerExpr::BinaryExpr(
230 Box::new(NonMinusIntegerExpr::Atom(Variable::new("x"))),
231 NonMinusIntegerOp::Div,
232 Box::new(NonMinusIntegerExpr::Const(2)),
233 );
234
235 let expected_expr = NoDivIntegerExpr::BinaryExpr(
237 Box::new(NoDivIntegerExpr::Frac(Fraction::new(1, 2, false))),
238 NoDivIntegerOp::Mul,
239 Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
240 );
241
242 let got_expr = expr.remove_div().unwrap();
243 assert_eq!(got_expr, expected_expr);
244 }
245
246 #[test]
247 fn test_division_both_const() {
248 let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
250 Box::new(NonMinusIntegerExpr::BinaryExpr(
251 Box::new(NonMinusIntegerExpr::Const(3)),
252 NonMinusIntegerOp::Add,
253 Box::new(NonMinusIntegerExpr::Const(2)),
254 )),
255 NonMinusIntegerOp::Div,
256 Box::new(NonMinusIntegerExpr::NegConst(4)),
257 );
258
259 let expected_expr = NoDivIntegerExpr::Frac(Fraction::new(5, 4, true));
261
262 let got_expr = expr.remove_div().unwrap();
263 assert_eq!(got_expr, expected_expr);
264 }
265
266 #[test]
267 fn test_double_division() {
268 let expr = NonMinusIntegerExpr::BinaryExpr(
270 Box::new(NonMinusIntegerExpr::Atom(Variable::new("x"))),
271 NonMinusIntegerOp::Div,
272 Box::new(NonMinusIntegerExpr::BinaryExpr(
273 Box::new(NonMinusIntegerExpr::Const(1)),
274 NonMinusIntegerOp::Div,
275 Box::new(NonMinusIntegerExpr::Const(2)),
276 )),
277 );
278
279 let expected_expr = NoDivIntegerExpr::BinaryExpr(
281 Box::new(NoDivIntegerExpr::Frac(Fraction::new(2, 1, false))),
282 NoDivIntegerOp::Mul,
283 Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
284 );
285
286 let got_expr = expr.remove_div().unwrap();
287 assert_eq!(got_expr, expected_expr);
288 }
289
290 #[test]
291 fn test_simple_from_var() {
292 let expr = NonMinusIntegerExpr::BinaryExpr(
294 Box::new(NonMinusIntegerExpr::Atom(Variable::new("x"))),
295 NonMinusIntegerOp::Div,
296 Box::new(NonMinusIntegerExpr::BinaryExpr(
297 Box::new(NonMinusIntegerExpr::Const(3)),
298 NonMinusIntegerOp::Add,
299 Box::new(NonMinusIntegerExpr::Const(2)),
300 )),
301 );
302
303 let expected_expr = NoDivIntegerExpr::BinaryExpr(
305 Box::new(NoDivIntegerExpr::Frac(Fraction::new(1, 5, false))),
306 NoDivIntegerOp::Mul,
307 Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
308 );
309
310 let got_expr = NoDivIntegerExpr::try_from(expr).unwrap();
311 assert_eq!(got_expr, expected_expr);
312 }
313
314 #[test]
315 fn try_simple_from_param() {
316 let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
318 Box::new(NonMinusIntegerExpr::Param(Parameter::new("n"))),
319 NonMinusIntegerOp::Div,
320 Box::new(NonMinusIntegerExpr::BinaryExpr(
321 Box::new(NonMinusIntegerExpr::Const(5)),
322 NonMinusIntegerOp::Mul,
323 Box::new(NonMinusIntegerExpr::Const(2)),
324 )),
325 );
326
327 let expected_expr = NoDivIntegerExpr::BinaryExpr(
329 Box::new(NoDivIntegerExpr::Frac(Fraction::new(1, 10, false))),
330 NoDivIntegerOp::Mul,
331 Box::new(NoDivIntegerExpr::Param(Parameter::new("n"))),
332 );
333 let got_expr = NoDivIntegerExpr::try_from(expr).unwrap();
334 assert_eq!(got_expr, expected_expr);
335 }
336
337 #[test]
338 fn test_error_on_div_by_var() {
339 let expr = NonMinusIntegerExpr::BinaryExpr(
341 Box::new(NonMinusIntegerExpr::Const(1)),
342 NonMinusIntegerOp::Div,
343 Box::new(NonMinusIntegerExpr::Atom(Variable::new("x"))),
344 );
345 let e = NoDivIntegerExpr::try_from(expr);
346 assert!(e.is_err());
347 assert!(matches!(
348 e.unwrap_err(),
349 ConstraintRewriteError::NotLinearArithmetic
350 ));
351 }
352
353 #[test]
354 fn test_error_on_div_by_param() {
355 let expr: NonMinusIntegerExpr<Variable> = NonMinusIntegerExpr::BinaryExpr(
357 Box::new(NonMinusIntegerExpr::Const(1)),
358 NonMinusIntegerOp::Div,
359 Box::new(NonMinusIntegerExpr::BinaryExpr(
360 Box::new(NonMinusIntegerExpr::Param(Parameter::new("n"))),
361 NonMinusIntegerOp::Add,
362 Box::new(NonMinusIntegerExpr::Const(5)),
363 )),
364 );
365
366 let e = NoDivIntegerExpr::try_from(expr);
367 assert!(e.is_err());
368 assert!(matches!(
369 e.unwrap_err(),
370 ConstraintRewriteError::NotLinearArithmetic
371 ));
372 }
373
374 #[test]
375 fn test_display_no_div_integer_expr() {
376 let expr = NoDivIntegerExpr::BinaryExpr(
377 Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
378 NoDivIntegerOp::Mul,
379 Box::new(NoDivIntegerExpr::Frac(Fraction::new(1, 2, false))),
380 );
381 assert_eq!(format!("{expr}"), "(x * 1/2)");
382
383 let expr = NoDivIntegerExpr::BinaryExpr(
384 Box::new(NoDivIntegerExpr::Atom(Variable::new("x"))),
385 NoDivIntegerOp::Add,
386 Box::new(NoDivIntegerExpr::Param(Parameter::new("p"))),
387 );
388
389 assert_eq!(format!("{expr}"), "(x + p)");
390 }
391
392 #[test]
393 #[should_panic]
394 fn test_no_div_op_div() {
395 let _ = NoDivIntegerOp::from(NonMinusIntegerOp::Div);
396 }
397}