taco_threshold_automaton/lia_threshold_automaton/
integer_thresholds.rs

1//! This module defines abstract types for linear arithmetic constraints over
2//! parameters, variables and other [`Atomic`] values, that can be safely
3//! encoded into integer arithmetic SMT constraints.
4//!
5//! A type that can be safely encoded into integer constraints will implement
6//! the [`IntoNoDivBooleanExpr`] trait. Expressions implementing this trait will
7//! always be expanded to the least common multiple of the denominators
8//! appearing in the expression, thus removing the rational constants in the
9//! encoded expression while preserving the satisfying assignments-
10//!
11//! This is important as SMTLIB-2 defines integer division analog to how integer
12//! division is usually implemented on a computer
13//! (see [Theory of `Ints`](https://smt-lib.org/theories-Ints.shtml)).
14//!
15//! This means for example that the expressions `1 / 2 == 0` is `True`. In our
16//! case, this is not the intended definition. Therefore, if fractions appear as
17//! part of boolean expressions, these expressions are expanded until all
18//! fractions can be represented as an integer.
19//!
20//! For example, an expression of the form `x = 1/3 * n` will be encoded as
21//! `3 * x = n`.
22
23use std::{
24    collections::{BTreeMap, HashMap},
25    fmt::{self},
26};
27
28use crate::{
29    expressions::{
30        Atomic, BooleanExpression, ComparisonOp, IntegerExpression, Location, Parameter, Variable,
31        fraction::Fraction,
32    },
33    lia_threshold_automaton::{
34        ConstraintRewriteError,
35        general_to_lia::classify_into_lia::split_pairs_into_atom_and_threshold,
36    },
37};
38
39/// Weighted sum of [`Atomic`] values
40///
41/// A weighted sum is an expression of the form `c_1 * v_1 + ... + c_n * v_n`
42/// where `v_1, ..., v_n` are atomic values and `c_1, ..., c_n` are real
43/// valued coefficients to these variables.
44#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
45pub struct WeightedSum<T>
46where
47    T: Atomic,
48{
49    /// Map of atomic values and their weights
50    weight_map: BTreeMap<T, Fraction>,
51}
52
53impl<T: Atomic> WeightedSum<T> {
54    /// Check if the weighted sum is empty / equal to 0
55    pub fn is_zero(&self) -> bool {
56        self.weight_map.is_empty()
57    }
58
59    /// Check whether all coefficients are already integers
60    ///
61    /// This checks whether all coefficients in the weighted sum are integers,
62    /// i.e., they can be converted to integer values without a loss in
63    /// precision.
64    ///
65    /// # Example
66    ///
67    ///
68    /// ```rust
69    /// use taco_threshold_automaton::{
70    ///     expressions::{Variable, fraction::Fraction},
71    ///     lia_threshold_automaton::integer_thresholds::WeightedSum
72    /// };
73    ///
74    /// let ws : WeightedSum<Variable> = WeightedSum::new([
75    ///     (Variable::new("var2"), 1),
76    /// ]);
77    /// assert!(ws.is_integer_form());
78    ///
79    /// let ws : WeightedSum<Variable> = WeightedSum::new([
80    ///     (Variable::new("var2"), Fraction::new(1, 2, false)),
81    /// ]);
82    /// assert!(!ws.is_integer_form());
83    /// ```
84    pub fn is_integer_form(&self) -> bool {
85        self.weight_map.values().all(|coef| coef.is_integer())
86    }
87
88    /// Create a new [`WeightedSum`]
89    ///
90    /// Creates a new [`WeightedSum`], filtering all components with a factor of `0`
91    pub fn new<F, I, V>(weight_map: I) -> Self
92    where
93        F: Into<Fraction>,
94        V: Into<T>,
95        I: IntoIterator<Item = (V, F)>,
96    {
97        // remove all entries with a coefficient of zero
98        let weight_map = weight_map
99            .into_iter()
100            .map(|(v, c)| (v.into(), c.into()))
101            .filter(|(_, coef)| *coef != Fraction::from(0))
102            .collect();
103
104        Self { weight_map }
105    }
106
107    /// Create a new empty [`WeightedSum`]
108    pub fn new_empty() -> Self {
109        Self {
110            weight_map: BTreeMap::new(),
111        }
112    }
113
114    /// Get the factor for `v` if it exists
115    ///
116    /// # Example
117    ///
118    /// ```
119    /// use taco_threshold_automaton::{
120    ///     expressions::{Variable, fraction::Fraction},
121    ///     lia_threshold_automaton::integer_thresholds::WeightedSum
122    /// };
123    ///
124    /// let ws : WeightedSum<Variable> = WeightedSum::new([
125    ///     (Variable::new("var1"), 1),
126    /// ]);
127    /// assert_eq!(ws.get_factor(&Variable::new("var1")), Some(&Fraction::from(1)));
128    /// ```
129    pub fn get_factor(&self, v: &T) -> Option<&Fraction> {
130        self.weight_map.get(v)
131    }
132
133    /// Get the least common multiple across all denominators of the
134    /// coefficients
135    ///
136    /// The least common multiple is computed across all denominators of the
137    /// coefficients in the [`WeightedSum`].
138    ///
139    /// This can be useful for scaling such that all coefficients are integers.
140    ///
141    /// # Example
142    ///
143    /// ```
144    /// use taco_threshold_automaton::{
145    ///     expressions::{Variable, fraction::Fraction},
146    ///     lia_threshold_automaton::integer_thresholds::WeightedSum
147    /// };
148    ///
149    /// let ws : WeightedSum<Variable> = WeightedSum::new([
150    ///     (Variable::new("var1"), 1),
151    /// ]);
152    /// assert_eq!(ws.get_lcm_of_denominators(), 1);
153    ///
154    /// let ws : WeightedSum<Variable> = WeightedSum::new([
155    ///     (Variable::new("var1"), Fraction::from(1)),
156    ///     (Variable::new("var2"), Fraction::new(1, 2, false)),
157    /// ]);
158    /// assert_eq!(ws.get_lcm_of_denominators(), 2);
159    /// ```
160    pub fn get_lcm_of_denominators(&self) -> u32 {
161        self.weight_map
162            .values()
163            .fold(1, |acc, coef| num::Integer::lcm(&acc, &coef.denominator()))
164    }
165
166    /// Scale the weighted sum by a factor
167    ///
168    /// # Example
169    ///
170    /// ```rust
171    /// use taco_threshold_automaton::{
172    ///     expressions::Variable,
173    ///     lia_threshold_automaton::integer_thresholds::WeightedSum
174    /// };
175    ///
176    /// let mut ws : WeightedSum<Variable> = WeightedSum::new([
177    ///     (Variable::new("var2"), 1),
178    ///     (Variable::new("var2"), 3),
179    /// ]);
180    /// let scaled_ws = WeightedSum::new([
181    ///     (Variable::new("var2"), 3),
182    ///     (Variable::new("var2"), 9),
183    /// ]);
184    ///
185    /// ws.scale(3.into());
186    ///
187    /// assert_eq!(ws, scaled_ws);
188    /// ```
189    pub fn scale(&mut self, factor: Fraction) {
190        if factor == Fraction::from(0) {
191            self.weight_map.clear();
192            return;
193        }
194
195        for coef in self.weight_map.values_mut() {
196            *coef *= factor;
197        }
198    }
199
200    /// Check whether `t` is part of the [`WeightedSum`] (and the factor is not 0)
201    ///
202    /// # Example
203    ///
204    /// ```rust
205    /// use taco_threshold_automaton::{
206    ///     expressions::Variable,
207    ///     lia_threshold_automaton::integer_thresholds::WeightedSum
208    /// };
209    ///
210    /// let ws = WeightedSum::new([
211    ///     (Variable::new("var"), 1),
212    ///     (Variable::new("zvar"), 0),
213    /// ]);
214    ///
215    /// assert!(ws.contains(&Variable::new("var")));
216    /// assert!(!ws.contains(&Variable::new("unknown")));
217    /// assert!(!ws.contains(&Variable::new("zvar")));
218    /// ```
219    pub fn contains(&self, t: &T) -> bool {
220        self.weight_map.contains_key(t)
221    }
222
223    /// Encode the weighted sum into an [`IntegerExpression`]
224    ///
225    /// This method will return the [`WeightedSum`] encoded as an
226    /// [`IntegerExpression`]. In case the weighted sum contains a fraction that
227    /// is not in an integer form this function **will encode the fraction**
228    /// without panicking or returning an error
229    fn get_integer_expression<S>(&self) -> IntegerExpression<S>
230    where
231        T: Into<IntegerExpression<S>>,
232        S: Atomic,
233    {
234        self.weight_map
235            .iter()
236            .fold(IntegerExpression::Const(0), |acc, (v, c)| {
237                debug_assert!(c != &Fraction::from(0));
238                debug_assert!(c.is_integer());
239
240                let mut expr = v.clone().into();
241                if *c != Fraction::from(1) {
242                    expr = IntegerExpression::from(*c) * expr;
243                }
244
245                // remove initial accumulator value
246                if acc == IntegerExpression::Const(0) {
247                    return expr;
248                }
249                acc + expr
250            })
251    }
252
253    /// Returns an iterator over all atoms in the weighted sum
254    pub fn get_atoms_appearing(&self) -> impl Iterator<Item = &T> {
255        self.weight_map.keys()
256    }
257}
258
259impl<T, F, I> From<I> for WeightedSum<T>
260where
261    T: Atomic,
262    F: Into<Fraction> + Clone,
263    I: IntoIterator<Item = (T, F)>,
264{
265    fn from(value: I) -> Self {
266        Self::new(value)
267    }
268}
269
270impl<T: Atomic> fmt::Display for WeightedSum<T> {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        //? This works but is not very nice: we make the string representation
273        //? deterministic by sorting by their string representation.
274        let mut it = self.weight_map.iter().collect::<Vec<_>>();
275        it.sort_by_key(|(v, _)| v.to_string());
276        let mut it = it.into_iter();
277
278        if let Some((v, coef)) = it.next() {
279            display_factor_pair_omit_one(f, coef, v)?;
280        }
281
282        it.try_for_each(|(v, coef)| {
283            write!(f, " + ")?;
284            display_factor_pair_omit_one(f, coef, v)
285        })?;
286
287        Ok(())
288    }
289}
290
291impl<'a, T: Atomic> IntoIterator for &'a WeightedSum<T> {
292    type Item = (&'a T, &'a Fraction);
293
294    type IntoIter = std::collections::btree_map::Iter<'a, T, Fraction>;
295
296    fn into_iter(self) -> Self::IntoIter {
297        self.weight_map.iter()
298    }
299}
300
301/// Display a scaled pair but omit factor if it is one
302///
303/// This function converts the pair `factor` * `x` to the string `factor * x`
304/// but will omit `factor *` if factor is equal to one
305fn display_factor_pair_omit_one<T: fmt::Display>(
306    f: &mut std::fmt::Formatter<'_>,
307    factor: &Fraction,
308    x: &T,
309) -> std::fmt::Result {
310    if factor == &1.into() {
311        write!(f, "{x}")
312    } else {
313        write!(f, "{factor} * {x}")
314    }
315}
316
317/// [`WeightedSum`] of [`Parameter`]s with additional constant summand
318///
319/// This struct represents a sum over [`Parameter`] of the form
320/// `c_1 * p_1 + ... + c_n * p_n + c`
321/// where p_1, ..., p_n are [`Parameter`] and c_1,..,c_n, c are rational numbers
322/// represented by [`Fraction`]s.
323#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
324pub struct Threshold {
325    weighted_parameters: WeightedSum<Parameter>,
326    constant: Fraction,
327}
328
329impl Threshold {
330    /// Create a new [`Threshold`]
331    pub fn new<S, T>(weighted_parameters: S, constant: T) -> Self
332    where
333        S: Into<WeightedSum<Parameter>>,
334        T: Into<Fraction>,
335    {
336        Self {
337            weighted_parameters: weighted_parameters.into(),
338            constant: constant.into(),
339        }
340    }
341
342    /// This function is designed to rewrite an comparison expression into a form
343    /// where the returned `HashMap<T, Fraction>` forms the new lhs of the equation
344    /// and the returned threshold is the right hand side of the equation
345    pub fn from_integer_comp_expr<T: Atomic>(
346        lhs: IntegerExpression<T>,
347        rhs: IntegerExpression<T>,
348    ) -> Result<(HashMap<T, Fraction>, Threshold), ConstraintRewriteError> {
349        split_pairs_into_atom_and_threshold(lhs, rhs)
350    }
351
352    /// Create a new [`Threshold`] from a constant without any parameters
353    pub fn from_const<T: Into<Fraction>>(c: T) -> Self {
354        Self {
355            weighted_parameters: WeightedSum::new_empty(),
356            constant: c.into(),
357        }
358    }
359
360    /// Scale the [`Threshold`] by a rational factor
361    pub fn scale<T: Into<Fraction>>(&mut self, factor: T) {
362        let factor = factor.into();
363        self.weighted_parameters.scale(factor);
364        self.constant *= factor;
365    }
366
367    /// Check whether the [`Threshold`] only contains a constant
368    pub fn is_constant(&self) -> bool {
369        self.weighted_parameters.is_zero()
370    }
371
372    /// Check whether the [`Threshold`] always evaluates to 0, i.e. there are no
373    /// scaled [`Parameter`] and the constant is 0
374    pub fn is_zero(&self) -> bool {
375        self.weighted_parameters.is_zero() && self.constant == Fraction::from(0)
376    }
377
378    /// If the threshold is constant, return the constant
379    ///
380    /// Returns `None` otherwise
381    pub fn get_const(&self) -> Option<Fraction> {
382        if self.weighted_parameters.is_zero() {
383            return Some(self.constant);
384        }
385
386        None
387    }
388
389    /// Add a constant to the threshold
390    pub fn add_const<F: Into<Fraction>>(&mut self, c: F) {
391        self.constant += c.into()
392    }
393
394    /// Subtract a constant from the threshold
395    pub fn sub_const<F: Into<Fraction>>(&mut self, c: F) {
396        self.constant -= c.into()
397    }
398}
399
400impl fmt::Display for Threshold {
401    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402        if self.weighted_parameters.is_zero() {
403            return write!(f, "{}", self.constant);
404        }
405
406        write!(f, "{}", self.weighted_parameters)?;
407        if self.constant != Fraction::from(0) {
408            write!(f, " + {}", self.constant)?;
409        }
410
411        Ok(())
412    }
413}
414
415/// Restricted form of [`ComparisonOp`]
416#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
417pub enum ThresholdCompOp {
418    /// >=
419    Geq,
420    /// <
421    Lt,
422}
423
424impl From<ThresholdCompOp> for ComparisonOp {
425    fn from(value: ThresholdCompOp) -> Self {
426        match value {
427            ThresholdCompOp::Geq => ComparisonOp::Geq,
428            ThresholdCompOp::Lt => ComparisonOp::Lt,
429        }
430    }
431}
432
433impl fmt::Display for ThresholdCompOp {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        match self {
436            ThresholdCompOp::Geq => write!(f, ">="),
437            ThresholdCompOp::Lt => write!(f, "<"),
438        }
439    }
440}
441
442/// [`Threshold`] in combination with a comparison operator
443///
444/// A [`ThresholdConstraint`] represents a [`Threshold`] in combination with a
445/// comparison operator. Thus it can be used to represent equations of the form
446///  `<COMPOP> c_1 * p_1 + ... + c_n * p_n + c`
447/// where `<COMPOP>` is one of the comparison operators `==`, `!=`, `<`, `>`,
448/// `<=`, `>=` (represented by [`ComparisonOp`]).
449///
450/// Most importantly, this struct implements operations like scaling by a
451/// negative number faithfully. Thus it can be useful when transforming
452/// threshold guards.
453#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
454pub struct ThresholdConstraint(ThresholdCompOp, Threshold);
455
456impl ThresholdConstraint {
457    /// Create a new threshold constraint
458    pub fn new<S, T>(op: ThresholdCompOp, weighted_parameters: S, constant: T) -> Self
459    where
460        S: Into<WeightedSum<Parameter>>,
461        T: Into<Fraction>,
462    {
463        let thr = Threshold::new(weighted_parameters, constant);
464        Self(op, thr)
465    }
466
467    /// Create a new threshold constraint
468    pub fn new_from_thr(op: ThresholdCompOp, thr: Threshold) -> Self {
469        Self(op, thr)
470    }
471
472    /// Scale the threshold constraint by a fraction
473    pub fn scale<T: Into<Fraction>>(&mut self, factor: T) {
474        let factor = factor.into();
475        // first execute * -1, then scale
476        if factor.is_negative() {
477            match self.0 {
478                ThresholdCompOp::Geq => {
479                    self.0 = ThresholdCompOp::Lt;
480                    self.1.add_const(1);
481                }
482                ThresholdCompOp::Lt => {
483                    self.0 = ThresholdCompOp::Geq;
484                    self.1.sub_const(1);
485                }
486            }
487        }
488
489        self.1.scale(factor);
490    }
491
492    /// Get the [`ComparisonOp`] used in the threshold
493    pub fn get_op(&self) -> ThresholdCompOp {
494        self.0
495    }
496
497    /// Get the [`Threshold`] of the threshold constraint
498    pub fn get_threshold(&self) -> &Threshold {
499        &self.1
500    }
501}
502
503impl fmt::Display for ThresholdConstraint {
504    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505        write!(f, "{} {}", self.0, self.1)
506    }
507}
508
509/// This struct represents a [`Threshold`] constraint over an object of type T
510///
511/// It is used to, for example, represent a threshold guard over a variable
512#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
513pub(crate) struct ThresholdConstraintOver<T> {
514    variable: T,
515    thr_constr: ThresholdConstraint,
516}
517
518impl<T> ThresholdConstraintOver<T> {
519    /// Create a new symbolic constraint over an object of type `T`
520    pub fn new(variable: T, thr_constr: ThresholdConstraint) -> Self {
521        Self {
522            variable,
523            thr_constr,
524        }
525    }
526
527    /// Check whether the constraint is an upper guard
528    ///
529    /// An upper guard is a guard of the form `< t` or `<= t` or `!= t` or
530    /// `== t`, as it can become disabled when the threshold is reached.
531    pub fn is_upper_guard(&self) -> bool {
532        matches!(self.thr_constr.get_op(), ThresholdCompOp::Lt)
533    }
534
535    /// Check whether the constraint is a lower guard
536    ///
537    /// A lower guard is a guard of the form `> t` or `>= t` or `!= t` or
538    /// `== t`, as it can become enabled when the threshold is reached.
539    pub fn is_lower_guard(&self) -> bool {
540        matches!(self.thr_constr.get_op(), ThresholdCompOp::Geq)
541    }
542
543    /// Get the subject of the constraint
544    pub fn get_variable(&self) -> &T {
545        &self.variable
546    }
547
548    /// Get the threshold of the constraint
549    pub fn get_threshold(&self) -> &Threshold {
550        self.thr_constr.get_threshold()
551    }
552
553    /// Get the threshold constraint of this constraint
554    pub fn get_threshold_constraint(&self) -> &ThresholdConstraint {
555        &self.thr_constr
556    }
557
558    /// Encode the constraint into a BooleanExpression
559    ///
560    /// This encoding will be guaranteed to not include rational constants
561    pub fn encode_to_boolean_expr<S>(&self) -> BooleanExpression<S>
562    where
563        S: Atomic,
564        T: IntoNoDivBooleanExpr<S>,
565        Threshold: IntoNoDivBooleanExpr<S>,
566    {
567        self.variable.encode_comparison_to_boolean_expression(
568            self.thr_constr.get_op().into(),
569            self.thr_constr.get_threshold(),
570        )
571    }
572}
573
574impl<S, T> From<ThresholdConstraintOver<T>> for BooleanExpression<S>
575where
576    S: Atomic,
577    T: IntoNoDivBooleanExpr<S>,
578    Threshold: IntoNoDivBooleanExpr<S>,
579{
580    fn from(val: ThresholdConstraintOver<T>) -> Self {
581        val.encode_to_boolean_expr()
582    }
583}
584
585impl<T> fmt::Display for ThresholdConstraintOver<T>
586where
587    T: fmt::Display,
588{
589    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
590        write!(f, "{} {}", self.variable, self.thr_constr)
591    }
592}
593
594/// Trait for objects that can be encoded into boolean expressions that
595/// do not contain any integer divisions or real numbers.
596///
597/// When using this trait all expressions will be scaled such that there are no
598/// longer real coefficients appearing in the expression, but the set of
599/// satisfying (integer) solutions will stay the same.
600///
601/// This is done by computing the least common multiple (LCM) of all
602/// denominators of rational coefficients and then scaling both sides of a
603/// `ComparisonExpr` to obtain an equivalent expression only containing integers.
604///
605/// Note that using this trait might result in a lot of LCM computations.
606pub trait IntoNoDivBooleanExpr<T>
607where
608    T: Atomic,
609{
610    /// Encode the object into an `IntegerExpression` without divisions
611    /// appearing
612    ///
613    /// **Important:** The scaling factor must be a multiple of the least common
614    /// multiple (LCM) of the expression. That is the relation
615    /// `scaling_factor % self.get_lcm_of_denominators() == 0` must hold.
616    ///
617    /// This function converts the object into an integer expression without any
618    /// divisions or real numbers. To eliminate rational coefficients, the
619    /// expression is therefore scaled by the given factor.    
620    fn get_scaled_integer_expression(&self, scaling_factor: u32) -> IntegerExpression<T>;
621
622    /// Get the lcm across all denominators in the object
623    ///
624    /// This is useful to determine the (least) scaling factor needed to scale
625    /// this expressions such that it does not contain any real numbers.
626    fn get_lcm_of_denominators(&self) -> u32;
627
628    /// Encode the comparison of two expressions into a boolean expression with a
629    /// that does not contain any real numbers.
630    ///
631    /// This functions scales the given thresholds such that in both thresholds
632    /// all factors are integers. The scaling is done by computing the least
633    /// common multiple (LCM) of the denominators.
634    ///
635    /// Resulting comparison: `self <OP> other`
636    fn encode_comparison_to_boolean_expression<Q>(
637        &self,
638        comparison_op: ComparisonOp,
639        other: &Q,
640    ) -> BooleanExpression<T>
641    where
642        Q: IntoNoDivBooleanExpr<T>,
643    {
644        let lcm_self = self.get_lcm_of_denominators();
645        let lcm_other = other.get_lcm_of_denominators();
646        let lcm = num::integer::lcm(lcm_self, lcm_other);
647
648        let self_s = self.get_scaled_integer_expression(lcm);
649        let other_s = other.get_scaled_integer_expression(lcm);
650
651        BooleanExpression::ComparisonExpression(Box::new(self_s), comparison_op, Box::new(other_s))
652    }
653}
654
655impl<T> IntoNoDivBooleanExpr<T> for Parameter
656where
657    Self: Into<IntegerExpression<T>>,
658    T: Atomic,
659{
660    fn get_scaled_integer_expression(&self, scaling_factor: u32) -> IntegerExpression<T> {
661        if scaling_factor == 0 {
662            return IntegerExpression::Const(0);
663        }
664        if scaling_factor == 1 {
665            return IntegerExpression::from(self.clone());
666        }
667
668        IntegerExpression::from(scaling_factor) * IntegerExpression::from(self.clone())
669    }
670
671    fn get_lcm_of_denominators(&self) -> u32 {
672        1 // Parameters do not have denominators
673    }
674}
675
676impl<T> IntoNoDivBooleanExpr<T> for Location
677where
678    T: Atomic,
679    IntegerExpression<T>: From<Self>,
680    IntegerExpression<T>: From<u32>,
681{
682    fn get_scaled_integer_expression(&self, scaling_factor: u32) -> IntegerExpression<T> {
683        if scaling_factor == 0 {
684            return IntegerExpression::Const(0);
685        }
686        if scaling_factor == 1 {
687            return IntegerExpression::from(self.clone());
688        }
689
690        IntegerExpression::<T>::from(scaling_factor) * IntegerExpression::<T>::from(self.clone())
691    }
692
693    fn get_lcm_of_denominators(&self) -> u32 {
694        1 // Locations do not have denominators
695    }
696}
697
698impl<T> IntoNoDivBooleanExpr<T> for Variable
699where
700    T: Atomic,
701    IntegerExpression<T>: From<Self>,
702    IntegerExpression<T>: From<u32>,
703{
704    fn get_scaled_integer_expression(&self, scaling_factor: u32) -> IntegerExpression<T> {
705        if scaling_factor == 0 {
706            return IntegerExpression::Const(0);
707        }
708        if scaling_factor == 1 {
709            return IntegerExpression::from(self.clone());
710        }
711
712        IntegerExpression::<T>::from(scaling_factor) * IntegerExpression::<T>::from(self.clone())
713    }
714
715    fn get_lcm_of_denominators(&self) -> u32 {
716        1 // Variables do not have denominators
717    }
718}
719
720impl<T, S> IntoNoDivBooleanExpr<S> for WeightedSum<T>
721where
722    T: Atomic + IntoNoDivBooleanExpr<S>,
723    S: Atomic,
724    IntegerExpression<S>: From<T>,
725{
726    fn get_scaled_integer_expression(&self, scaling_factor: u32) -> IntegerExpression<S> {
727        let mut scaled = self.clone();
728        scaled.scale(Fraction::from(scaling_factor));
729        debug_assert!(
730            scaled.is_integer_form(),
731            "Scaled weighted sum is not in integer form"
732        );
733
734        scaled.get_integer_expression()
735    }
736
737    fn get_lcm_of_denominators(&self) -> u32 {
738        self.get_lcm_of_denominators()
739    }
740}
741
742impl<T> IntoNoDivBooleanExpr<T> for Threshold
743where
744    WeightedSum<Parameter>: IntoNoDivBooleanExpr<T>,
745    T: Atomic,
746{
747    fn get_scaled_integer_expression(&self, scaling_factor: u32) -> IntegerExpression<T> {
748        let scaled_constant = self.constant * Fraction::from(scaling_factor);
749        debug_assert!(
750            scaled_constant.is_integer(),
751            "Scaled constant is not an integer"
752        );
753
754        if self.weighted_parameters.is_zero() {
755            return IntegerExpression::from(scaled_constant);
756        }
757
758        let intexpr_ws = self
759            .weighted_parameters
760            .get_scaled_integer_expression(scaling_factor);
761
762        if scaled_constant == Fraction::from(0) {
763            return intexpr_ws;
764        }
765
766        intexpr_ws + IntegerExpression::from(scaled_constant)
767    }
768
769    fn get_lcm_of_denominators(&self) -> u32 {
770        let lcm_ws = self.weighted_parameters.get_lcm_of_denominators();
771        let lcm_const = self.constant.denominator();
772        num::integer::lcm(lcm_ws, lcm_const)
773    }
774}
775
776#[cfg(test)]
777mod tests {
778    use std::collections::{BTreeMap, HashSet};
779
780    use crate::{
781        expressions::{
782            BooleanExpression, ComparisonOp, IntegerExpression, Location, Parameter, Variable,
783            fraction::Fraction,
784        },
785        lia_threshold_automaton::integer_thresholds::{
786            IntoNoDivBooleanExpr, Threshold, ThresholdCompOp, ThresholdConstraint,
787            ThresholdConstraintOver, WeightedSum,
788        },
789    };
790
791    #[test]
792    fn test_threshold_is_zero() {
793        let thr = Threshold::new(Vec::<(Parameter, Fraction)>::new(), 0);
794        assert!(thr.is_zero());
795
796        let thr = Threshold::new(BTreeMap::from([(Parameter::new("n"), 1)]), 0);
797        assert!(!thr.is_zero());
798
799        let thr = Threshold::new(Vec::<(Parameter, Fraction)>::new(), 1);
800        assert!(!thr.is_zero());
801
802        let thr = Threshold::new(
803            BTreeMap::from([(Parameter::new("n"), 1)]),
804            -Fraction::from(1),
805        );
806        assert!(!thr.is_zero());
807    }
808
809    #[test]
810    fn test_threshold_is_constant() {
811        let thr = Threshold::new(Vec::<(Parameter, Fraction)>::new(), 0);
812        assert!(thr.is_constant());
813
814        let thr = Threshold::new(BTreeMap::from([(Parameter::new("n"), 1)]), 0);
815        assert!(!thr.is_constant());
816
817        let thr = Threshold::new(Vec::<(Parameter, Fraction)>::new(), 1);
818        assert!(thr.is_constant());
819
820        let thr = Threshold::new(
821            BTreeMap::from([(Parameter::new("n"), 1)]),
822            -Fraction::from(1),
823        );
824        assert!(!thr.is_constant());
825    }
826
827    #[test]
828    fn test_threshold_scale() {
829        let mut thr = Threshold::new(BTreeMap::from([(Parameter::new("n"), 1)]), 0);
830        thr.scale(2);
831
832        let expected = Threshold::new(BTreeMap::from([(Parameter::new("n"), 2)]), 0);
833        assert_eq!(thr, expected);
834
835        let mut thr = Threshold::new(BTreeMap::from([(Parameter::new("n"), 0)]), 1);
836        thr.scale(-Fraction::from(2));
837
838        let expected = Threshold::new(
839            BTreeMap::from([(Parameter::new("n"), Fraction::from(0))]),
840            -Fraction::from(2),
841        );
842        assert_eq!(thr, expected);
843
844        let mut thr = Threshold::new(BTreeMap::from([(Parameter::new("n"), 1)]), 0);
845        thr.scale(0);
846
847        let expected = Threshold::new(BTreeMap::from([(Parameter::new("n"), 0)]), 0);
848        assert_eq!(thr, expected);
849    }
850
851    #[test]
852    fn test_threshold_add_cons() {
853        let mut thr = Threshold::new(Vec::<(Parameter, Fraction)>::new(), 0);
854        assert!(thr.is_constant());
855
856        thr.add_const(1);
857        assert_eq!(thr.get_const().unwrap(), Fraction::from(1));
858
859        thr.sub_const(1);
860        assert_eq!(thr.get_const().unwrap(), Fraction::from(0));
861
862        let mut thr = Threshold::new(BTreeMap::from([(Parameter::new("n"), 1)]), 0);
863        thr.add_const(1);
864        assert_eq!(thr.get_const(), None);
865
866        thr.sub_const(1);
867        assert_eq!(thr.get_const(), None);
868    }
869
870    #[test]
871    fn test_threshold_constr_getters() {
872        let thrc = ThresholdConstraint::new(
873            ThresholdCompOp::Geq,
874            BTreeMap::from([
875                (Parameter::new("n"), Fraction::from(1)),
876                (Parameter::new("m"), -Fraction::from(2)),
877            ]),
878            1,
879        );
880
881        let thr = Threshold::new(
882            [
883                (Parameter::new("n"), Fraction::from(1)),
884                (Parameter::new("m"), -Fraction::from(2)),
885            ],
886            1,
887        );
888
889        assert_eq!(thrc.get_threshold(), &thr)
890    }
891
892    #[test]
893    fn test_threshold_display() {
894        let thr = ThresholdConstraint::new(
895            ThresholdCompOp::Geq,
896            BTreeMap::from([
897                (Parameter::new("n"), Fraction::from(1)),
898                (Parameter::new("m"), -Fraction::from(2)),
899            ]),
900            1,
901        );
902
903        let expected = ">= -2 * m + n + 1";
904        assert_eq!(thr.to_string(), expected);
905
906        let thr =
907            ThresholdConstraint::new(ThresholdCompOp::Geq, Vec::<(Parameter, Fraction)>::new(), 0);
908        assert_eq!(thr.to_string(), ">= 0");
909    }
910
911    #[test]
912    fn test_contains_weighted_sum() {
913        let ws = WeightedSum::new(BTreeMap::from([
914            (Variable::new("var1"), 1),
915            (Variable::new("var2"), 2),
916        ]));
917
918        assert!(ws.contains(&Variable::new("var1")));
919        assert!(!ws.contains(&Variable::new("var3")));
920    }
921
922    #[test]
923    fn test_get_atoms_appearing_ws() {
924        let ws = WeightedSum::new(BTreeMap::from([
925            (Variable::new("var1"), 1),
926            (Variable::new("var2"), 2),
927        ]));
928
929        let got_atoms: HashSet<_> = ws.get_atoms_appearing().cloned().collect();
930        let expected_atoms = HashSet::from([Variable::new("var1"), Variable::new("var2")]);
931        assert_eq!(got_atoms, expected_atoms)
932    }
933
934    #[test]
935    fn test_into_scaled_integer_expression_weighted_sum() {
936        let ws: WeightedSum<Parameter> = WeightedSum::new(BTreeMap::from([
937            (Parameter::new("var1"), 1),
938            (Parameter::new("var2"), 2),
939            (Parameter::new("var3"), 0),
940        ]));
941
942        let expected: IntegerExpression<Parameter> =
943            IntegerExpression::Param(Parameter::new("var1"))
944                + (IntegerExpression::Const(2) * IntegerExpression::Param(Parameter::new("var2")));
945
946        let ws = ws.get_scaled_integer_expression(1);
947
948        assert_eq!(ws, expected);
949
950        let ws: WeightedSum<Parameter> = WeightedSum::new(BTreeMap::from([
951            (Parameter::new("var1"), 1),
952            (Parameter::new("var2"), 2),
953            (Parameter::new("var3"), 0),
954        ]));
955
956        let expected: IntegerExpression<Parameter> = (IntegerExpression::Const(5)
957            * IntegerExpression::Param(Parameter::new("var1")))
958            + (IntegerExpression::Const(10) * IntegerExpression::Param(Parameter::new("var2")));
959
960        let ws = ws.get_scaled_integer_expression(5);
961
962        assert_eq!(ws, expected);
963    }
964
965    #[test]
966    fn test_weighted_sum_into_iter() {
967        let ws: WeightedSum<Variable> = WeightedSum::new([
968            (Variable::new("var1"), 1),
969            (Variable::new("var2"), 2),
970            (Variable::new("var3"), 0),
971        ]);
972
973        let expected: Vec<(Variable, Fraction)> = vec![
974            (Variable::new("var1"), 1.into()),
975            (Variable::new("var2"), 2.into()),
976        ];
977
978        let result: Vec<(Variable, Fraction)> =
979            (&ws).into_iter().map(|(a, b)| (a.clone(), *b)).collect();
980        assert_eq!(result, expected);
981    }
982
983    #[test]
984    fn test_to_boolean_expr_thr_constr_over_var() {
985        let constr = ThresholdConstraintOver::new(
986            Variable::new("v"),
987            ThresholdConstraint::new(ThresholdCompOp::Geq, [(Parameter::new("n"), 1)], 1),
988        );
989
990        let b_expr = BooleanExpression::from(constr);
991
992        let expected_b_expr = BooleanExpression::ComparisonExpression(
993            Box::new(IntegerExpression::Atom(Variable::new("v"))),
994            ComparisonOp::Geq,
995            Box::new(IntegerExpression::Param(Parameter::new("n")) + IntegerExpression::Const(1)),
996        );
997
998        assert_eq!(b_expr, expected_b_expr)
999    }
1000
1001    #[test]
1002    fn test_into_no_div_for_param() {
1003        let param = Parameter::new("n");
1004
1005        assert_eq!(
1006            <Parameter as IntoNoDivBooleanExpr<Parameter>>::get_lcm_of_denominators(&param),
1007            1
1008        );
1009
1010        let expected_int_expr: IntegerExpression<Parameter> =
1011            IntegerExpression::Const(42) * IntegerExpression::Param(param.clone());
1012        assert_eq!(param.get_scaled_integer_expression(42), expected_int_expr)
1013    }
1014
1015    #[test]
1016    fn test_into_no_div_for_loc() {
1017        let loc = Location::new("n");
1018
1019        assert_eq!(
1020            <Location as IntoNoDivBooleanExpr<Location>>::get_lcm_of_denominators(&loc),
1021            1
1022        );
1023
1024        let expected_int_expr: IntegerExpression<Location> =
1025            IntegerExpression::Const(42) * IntegerExpression::Atom(loc.clone());
1026        assert_eq!(loc.get_scaled_integer_expression(42), expected_int_expr)
1027    }
1028
1029    #[test]
1030    fn test_into_no_div_for_var() {
1031        let var = Variable::new("n");
1032
1033        assert_eq!(
1034            <Variable as IntoNoDivBooleanExpr<Variable>>::get_lcm_of_denominators(&var),
1035            1
1036        );
1037
1038        let expected_int_expr: IntegerExpression<Variable> =
1039            IntegerExpression::Const(42) * IntegerExpression::Atom(var.clone());
1040        assert_eq!(var.get_scaled_integer_expression(42), expected_int_expr)
1041    }
1042}