taco_smt_encoder/
expression_encoding.rs

1//! Encoding of expressions into [`SMTExpr`]
2//!
3//! This module provides traits and types to encode boolean or integer
4//! expressions into SMT expressions. The encoding is done using the `easy-smt`
5//! crate, therefore the expressions are encoded into `SExpr` types.
6
7use std::{collections::HashMap, fmt, io};
8
9use std::hash::Hash;
10
11use easy_smt::SExpr;
12use taco_threshold_automaton::{
13    RuleDefinition,
14    expressions::{
15        Atomic, BooleanConnective, BooleanExpression, ComparisonOp, IntegerExpression, IntegerOp,
16        Location, Parameter, Variable,
17    },
18    general_threshold_automaton::Rule,
19};
20
21use crate::{SMTExpr, SMTSolution, SMTSolver, SMTSolverBuilder, SMTSolverContext};
22
23pub mod config_ctx;
24pub mod ctx_mgr;
25pub mod step_ctx;
26pub mod ta_encoding;
27
28/// Trait defining the encoding of type `T` into an S[`SMTExpr`] given a context
29///
30/// If this trait is implemented for a type `T`, expressions of type `T` can be
31/// encoded into an SMT expression, using the given SMT solver and context of
32/// type `C`. This trait does not restrict the type of the context.
33pub trait EncodeToSMT<T, C> {
34    /// Encode the type into an SMT expression
35    ///
36    /// Encode the expression of type `T` into an SMT expression using the
37    /// given SMT solver and context of type `C`.
38    fn encode_to_smt_with_ctx(
39        &self,
40        solver: &SMTSolver,
41        ctx: &C,
42    ) -> Result<SMTExpr, SMTSolverError>;
43}
44
45/// Trait of a context that provides mapping from type declaring SMT variables
46/// `T` to  associated SMT expression
47///
48/// This trait is implemented by contexts that can provide the corresponding
49/// [`SMTExpr`] for a given expression stored in the context. Implementing
50/// this trait allows to encode integer and boolean expressions into
51/// [`SMTExpr`]s.
52pub trait SMTVariableContext<T> {
53    /// Get the corresponding [`SMTExpr`] for the given expression of type
54    /// `T`
55    ///
56    /// This function returns an error if the expression is not stored in the
57    /// context, or it tried to declare a variable and the connection to the
58    /// SMT solver broke.
59    fn get_expr_for(&self, expr: &T) -> Result<SMTExpr, SMTSolverError>;
60
61    /// Get all expressions for which the current context holds an assignment
62    fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a T>
63    where
64        T: 'a;
65
66    /// Get the solution assignment
67    fn get_solution<'a>(
68        &'a self,
69        solver: &mut SMTSolver,
70        res: SMTSolution,
71    ) -> Result<HashMap<T, u32>, SMTSolverError>
72    where
73        T: 'a + Eq + Hash + Clone,
74    {
75        match res {
76            SMTSolution::UNSAT => Err(SMTSolverError::ExtractionFromUnsat),
77            SMTSolution::SAT => {
78                let expr_vec = self
79                    .get_exprs()
80                    .into_iter()
81                    .map(|t| self.get_expr_for(t))
82                    .collect::<Result<Vec<_>, _>>()?;
83
84                let solution = solver.get_value(expr_vec)?;
85
86                solution
87                    .into_iter()
88                    .zip(self.get_exprs())
89                    .map(|((_, sol), t)| {
90                        let solution_value = solver.get_u32(sol).ok_or_else(|| {
91                            let sol = solver.display(sol);
92                            SMTSolverError::SolutionExtractionParseIntError(sol.to_string())
93                        })?;
94
95                        Ok((t.clone(), solution_value))
96                    })
97                    .collect::<Result<HashMap<T, u32>, SMTSolverError>>()
98            }
99        }
100    }
101}
102
103impl<T: Hash + Eq + DeclaresVariable> SMTVariableContext<T> for HashMap<T, SExpr> {
104    fn get_expr_for(&self, expr: &T) -> Result<SMTExpr, SMTSolverError> {
105        self.get(expr)
106            .ok_or_else(|| expr.get_undeclared_error())
107            .copied()
108    }
109
110    fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a T>
111    where
112        T: 'a,
113    {
114        self.keys()
115    }
116}
117
118/// Trait for types that have associated SMT variables
119///
120/// This trait is implemented by types that are usually encoded as variables
121/// when constructing formulas over threshold automata. A type implementing this
122/// trait needs to provide a naming scheme for the smt variables associated with
123/// the type and a method to construct an error if the variable is not declared.
124pub trait DeclaresVariable {
125    /// Get the name of the associated variable for the given index
126    ///
127    /// This function should return the name of the variable associated with
128    /// the type for the given index. The naming scheme **must** result in a
129    /// unique name across all variables used in the SMT solver.
130    ///
131    /// The `index` field is used for variables that are usually indexed. For
132    /// example, along a path, the location variable (representing the number of
133    /// processes in this location) is usually indexed by the configuration.
134    fn get_name(&self, index: u32) -> String;
135
136    /// Declare the variable in the SMT solver
137    ///
138    /// This function declares the variable in the SMT solver and returns the
139    /// SMT expression. Returns an error if the connection to the SMT solver
140    /// broke.
141    fn declare_variable(
142        &self,
143        solver: &mut SMTSolver,
144        index: u32,
145    ) -> Result<SMTExpr, SMTSolverError> {
146        let name = self.get_name(index);
147        Ok(solver.declare_const(&name, solver.int_sort())?)
148    }
149
150    /// Returns an error indicating that `self` was not declared in the current
151    /// SMT context.
152    fn get_undeclared_error(&self) -> SMTSolverError;
153}
154
155impl DeclaresVariable for Location {
156    fn get_name(&self, index: u32) -> String {
157        format!("loc_{}_{}", self.name(), index)
158    }
159
160    fn get_undeclared_error(&self) -> SMTSolverError {
161        SMTSolverError::UndeclaredLocation(self.clone())
162    }
163}
164
165impl DeclaresVariable for Parameter {
166    fn get_name(&self, _index: u32) -> String {
167        format!("param_{}", self.name())
168    }
169
170    fn get_undeclared_error(&self) -> SMTSolverError {
171        SMTSolverError::UndeclaredParameter(self.clone())
172    }
173}
174
175impl DeclaresVariable for Variable {
176    fn get_name(&self, index: u32) -> String {
177        format!("var_{}_{}", self.name(), index)
178    }
179
180    fn get_undeclared_error(&self) -> SMTSolverError {
181        SMTSolverError::UndeclaredVariable(self.clone())
182    }
183}
184
185impl DeclaresVariable for Rule {
186    fn get_name(&self, index: u32) -> String {
187        format!("rule_{}_{}", self.id(), index)
188    }
189
190    fn get_undeclared_error(&self) -> SMTSolverError {
191        SMTSolverError::UndeclaredRule(self.clone())
192    }
193}
194
195/// Trait that allows to extract the assignment of a variable from the solution
196/// found by the SMT solver
197///
198/// This trait is implemented by contexts allowing to extract the solution found
199/// by the SMT solver for a given variable.
200pub trait GetAssignment<T>: SMTVariableContext<T> + SMTSolverContext
201where
202    T: DeclaresVariable,
203{
204    /// Get the assigned solution of the variable
205    ///
206    /// This function returns the assignment of the variable from the SMT
207    /// solver. If the query is unsatisfiable, it returns `None`. If the query
208    /// is satisfiable, it returns the assigned value.
209    ///
210    /// Returns an error if the connection to the SMT solver broke.
211    fn get_assignment(&mut self, res: SMTSolution, var: &T) -> Result<Option<u64>, SMTSolverError> {
212        match res {
213            SMTSolution::SAT => {
214                let expr = self.get_expr_for(var)?;
215                let solver = self.get_smt_solver_mut();
216
217                let solution = solver.get_value(vec![expr])?;
218                debug_assert!(solution.len() == 1);
219                debug_assert!(solution[0].0 == expr);
220
221                let solution_int = solver.get_u64(solution[0].1);
222
223                let sol = solution_int.ok_or_else(|| {
224                    let sol = solver.display(expr);
225                    SMTSolverError::SolutionExtractionParseIntError(sol.to_string())
226                })?;
227
228                Ok(Some(sol))
229            }
230            SMTSolution::UNSAT => Err(SMTSolverError::ExtractionFromUnsat),
231        }
232    }
233}
234
235impl<T, U> EncodeToSMT<IntegerExpression<T>, U> for IntegerExpression<T>
236where
237    U: SMTVariableContext<T> + SMTVariableContext<Parameter>,
238    T: DeclaresVariable + Atomic,
239{
240    fn encode_to_smt_with_ctx(&self, solver: &SMTSolver, ctx: &U) -> Result<SExpr, SMTSolverError> {
241        match self {
242            IntegerExpression::Atom(a) => ctx.get_expr_for(a),
243            IntegerExpression::Const(c) => Ok(solver.numeral(*c)),
244            IntegerExpression::Param(parameter) => ctx.get_expr_for(parameter),
245            IntegerExpression::BinaryExpr(lhs, op, rhs) => {
246                let lhs = lhs.encode_to_smt_with_ctx(solver, ctx)?;
247                let rhs = rhs.encode_to_smt_with_ctx(solver, ctx)?;
248
249                Ok(match op {
250                    IntegerOp::Add => solver.plus(lhs, rhs),
251                    IntegerOp::Sub => solver.sub(lhs, rhs),
252                    IntegerOp::Mul => solver.times(lhs, rhs),
253                    IntegerOp::Div => solver.div(lhs, rhs),
254                })
255            }
256            IntegerExpression::Neg(expr) => {
257                let expr = expr.encode_to_smt_with_ctx(solver, ctx)?;
258                Ok(solver.negate(expr))
259            }
260        }
261    }
262}
263
264impl<T, U> EncodeToSMT<BooleanExpression<T>, U> for BooleanExpression<T>
265where
266    U: SMTVariableContext<T> + SMTVariableContext<Parameter>,
267    T: DeclaresVariable + Atomic,
268{
269    fn encode_to_smt_with_ctx(&self, solver: &SMTSolver, ctx: &U) -> Result<SExpr, SMTSolverError> {
270        match self {
271            BooleanExpression::ComparisonExpression(lhs, op, rhs) => {
272                let lhs = lhs.encode_to_smt_with_ctx(solver, ctx)?;
273                let rhs = rhs.encode_to_smt_with_ctx(solver, ctx)?;
274
275                Ok(match op {
276                    ComparisonOp::Gt => solver.gt(lhs, rhs),
277                    ComparisonOp::Geq => solver.gte(lhs, rhs),
278                    ComparisonOp::Eq => solver.eq(lhs, rhs),
279                    ComparisonOp::Neq => solver.not(solver.eq(lhs, rhs)),
280                    ComparisonOp::Leq => solver.lte(lhs, rhs),
281                    ComparisonOp::Lt => solver.lt(lhs, rhs),
282                })
283            }
284            BooleanExpression::BinaryExpression(lhs, op, rhs) => {
285                let lhs = lhs.encode_to_smt_with_ctx(solver, ctx)?;
286                let rhs = rhs.encode_to_smt_with_ctx(solver, ctx)?;
287
288                Ok(match op {
289                    BooleanConnective::And => solver.and(lhs, rhs),
290                    BooleanConnective::Or => solver.or(lhs, rhs),
291                })
292            }
293            BooleanExpression::Not(expr) => {
294                let expr = expr.encode_to_smt_with_ctx(solver, ctx)?;
295
296                Ok(solver.not(expr))
297            }
298            BooleanExpression::True => Ok(solver.true_()),
299            BooleanExpression::False => Ok(solver.false_()),
300        }
301    }
302}
303
304/// Error occurring in the interaction with the SMT solver
305#[derive(Debug)]
306pub enum SMTSolverError {
307    /// Error from the SMT solver
308    EasySMTErr(io::Error),
309    /// Timeout in the SMT solver
310    SolverTimeout,
311    /// Undeclared Parameter accessed
312    UndeclaredParameter(Parameter),
313    /// Undeclared Location accessed
314    UndeclaredLocation(Location),
315    /// Undeclared Variable accessed
316    UndeclaredVariable(Variable),
317    /// Failed to parse integer from solution
318    SolutionExtractionParseIntError(String),
319    /// Undeclared Rule accessed
320    UndeclaredRule(Rule),
321    /// Attempted to extract solution from an unsatisfiable expression
322    ExtractionFromUnsat,
323    /// Specification is trivially unsatisfiable, reason in spec
324    TriviallyUnsat(String),
325}
326
327impl std::error::Error for SMTSolverError {}
328
329impl fmt::Display for SMTSolverError {
330    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331        match self {
332            SMTSolverError::EasySMTErr(err) => {
333                write!(f, "Error from connection to SMT solver: {err}")
334            }
335            SMTSolverError::UndeclaredParameter(param) => {
336                write!(f, "Undeclared parameter: {param}")
337            }
338            SMTSolverError::UndeclaredLocation(location) => {
339                write!(f, "Undeclared location: {location}")
340            }
341            SMTSolverError::UndeclaredVariable(variable) => {
342                write!(f, "Undeclared variable: {variable}")
343            }
344            SMTSolverError::SolverTimeout => write!(f, "Timeout in SMT solver"),
345            SMTSolverError::SolutionExtractionParseIntError(s) => write!(
346                f,
347                "Failed to parse SMT solver supplied solution into integer: {s} not an integer"
348            ),
349            SMTSolverError::UndeclaredRule(rule) => write!(f, "Undeclared rule: {rule}"),
350            SMTSolverError::ExtractionFromUnsat => write!(
351                f,
352                "Attempted to extract the solution assignment from an unsatisfiable expression"
353            ),
354            SMTSolverError::TriviallyUnsat(r) => {
355                write!(f, "Specification can never be true. Reason: {r}")
356            }
357        }
358    }
359}
360
361impl From<io::Error> for SMTSolverError {
362    fn from(error: io::Error) -> Self {
363        SMTSolverError::EasySMTErr(error)
364    }
365}
366
367/// A simple static context for encoding expressions related to threshold
368/// automata into SMT constraints and checking satisfiability
369///
370/// This context is used to encode expressions into SMT constraints and check
371/// their satisfiability. It should be only used to encode simple expressions,
372/// and is not designed to handle expressions over multiple configurations.
373pub struct StaticSMTContext {
374    solver: SMTSolver,
375    params: HashMap<Parameter, SMTExpr>,
376    locs: HashMap<Location, SMTExpr>,
377    vars: HashMap<Variable, Vec<SMTExpr>>,
378}
379
380impl StaticSMTContext {
381    /// Create a new static SMT context
382    ///
383    /// This function creates a new static SMT context with the given solver
384    /// builder and declares variables for the given parameters, locations, and
385    /// variables.
386    pub fn new(
387        solver_builder: SMTSolverBuilder,
388        params: impl IntoIterator<Item = Parameter>,
389        locs: impl IntoIterator<Item = Location>,
390        vars: impl IntoIterator<Item = Variable>,
391    ) -> Result<Self, SMTSolverError> {
392        let mut solver = solver_builder.new_solver();
393
394        let params = params
395            .into_iter()
396            .map(|p| {
397                let expr = p.declare_variable(&mut solver, 0)?;
398                Ok((p, expr))
399            })
400            .collect::<Result<HashMap<_, _>, SMTSolverError>>()?;
401
402        let locs = locs
403            .into_iter()
404            .map(|l| {
405                let expr = l.declare_variable(&mut solver, 0)?;
406                Ok((l, expr))
407            })
408            .collect::<Result<HashMap<_, _>, SMTSolverError>>()?;
409
410        let vars = vars
411            .into_iter()
412            .map(|v| {
413                let expr = v.declare_variable(&mut solver, 0)?;
414                Ok((v, vec![expr]))
415            })
416            .collect::<Result<HashMap<_, _>, SMTSolverError>>()?;
417
418        Ok(StaticSMTContext {
419            solver,
420            params,
421            locs,
422            vars,
423        })
424    }
425
426    /// Get the SMT expression for the `true` value
427    pub fn get_true(&self) -> SMTExpr {
428        self.solver.true_()
429    }
430
431    /// Encode the given expression into an SMT expression and return the
432    /// expression
433    pub fn encode_to_smt<T>(&self, expr: &T) -> Result<SMTExpr, SMTSolverError>
434    where
435        T: EncodeToSMT<T, Self>,
436    {
437        expr.encode_to_smt_with_ctx(&self.solver, self)
438    }
439}
440
441impl SMTSolverContext for StaticSMTContext {
442    fn get_smt_solver_mut(&mut self) -> &mut SMTSolver {
443        &mut self.solver
444    }
445
446    fn get_smt_solver(&self) -> &SMTSolver {
447        &self.solver
448    }
449}
450
451impl SMTVariableContext<Parameter> for StaticSMTContext {
452    fn get_expr_for(&self, param: &Parameter) -> Result<SMTExpr, SMTSolverError> {
453        self.params
454            .get(param)
455            .cloned()
456            .ok_or_else(|| SMTSolverError::UndeclaredParameter(param.clone()))
457    }
458
459    fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Parameter>
460    where
461        Parameter: 'a,
462    {
463        self.params.keys()
464    }
465}
466
467impl SMTVariableContext<Location> for StaticSMTContext {
468    fn get_expr_for(&self, loc: &Location) -> Result<SMTExpr, SMTSolverError> {
469        self.locs
470            .get(loc)
471            .cloned()
472            .ok_or_else(|| SMTSolverError::UndeclaredLocation(loc.clone()))
473    }
474
475    fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Location>
476    where
477        Location: 'a,
478    {
479        self.locs.keys()
480    }
481}
482
483impl SMTVariableContext<Variable> for StaticSMTContext {
484    fn get_expr_for(&self, var: &Variable) -> Result<SMTExpr, SMTSolverError> {
485        self.vars
486            .get(var)
487            .and_then(|v| v.first().cloned())
488            .ok_or_else(|| SMTSolverError::UndeclaredVariable(var.clone()))
489    }
490
491    fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Variable>
492    where
493        Variable: 'a,
494    {
495        self.vars.keys()
496    }
497}
498
499impl GetAssignment<Parameter> for StaticSMTContext {}
500impl GetAssignment<Variable> for StaticSMTContext {}
501impl GetAssignment<Location> for StaticSMTContext {}
502
503#[cfg(test)]
504mod tests {
505    use std::vec;
506
507    use easy_smt::Response;
508    use taco_threshold_automaton::general_threshold_automaton::builder::RuleBuilder;
509
510    use super::*;
511
512    #[test]
513    fn test_get_true() {
514        let builder = SMTSolverBuilder::default();
515
516        let mut ctx = StaticSMTContext::new(
517            builder,
518            vec![Parameter::new("p")],
519            vec![Location::new("loc")],
520            vec![Variable::new("x")],
521        )
522        .unwrap();
523
524        let top = ctx.get_true();
525
526        assert_eq!(top, ctx.get_smt_solver_mut().true_())
527    }
528
529    #[test]
530    fn test_rule_variable_name() {
531        let r = RuleBuilder::new(42, Location::new("src"), Location::new("tgt")).build();
532
533        assert!(r.get_name(0).contains("rule"));
534        assert!(r.get_name(0).contains("42"));
535        assert!(r.get_name(0).contains("0"));
536    }
537
538    #[test]
539    fn test_boolean_expr_encoding_true_false() {
540        let builder = SMTSolverBuilder::default();
541
542        let true_expr: BooleanExpression<Variable> = BooleanExpression::True;
543        let false_expr: BooleanExpression<Variable> = BooleanExpression::False;
544
545        let mut ctx = StaticSMTContext::new(
546            builder,
547            vec![Parameter::new("p")],
548            vec![Location::new("loc")],
549            vec![Variable::new("x")],
550        )
551        .unwrap();
552
553        let true_encoded = ctx.encode_to_smt(&true_expr).unwrap();
554        let false_encoded = ctx.encode_to_smt(&false_expr).unwrap();
555
556        let solver = ctx.get_smt_solver_mut();
557        let true_str = solver.display(true_encoded).to_string();
558        let false_str = solver.display(false_encoded).to_string();
559        assert_eq!(true_str, "true");
560        assert_eq!(false_str, "false");
561
562        solver.assert(true_encoded).unwrap();
563        assert_eq!(solver.check().unwrap(), Response::Sat);
564
565        solver.assert(false_encoded).unwrap();
566        assert_eq!(solver.check().unwrap(), Response::Unsat);
567    }
568
569    #[test]
570    fn test_boolean_expr_encoding_two_var_gt() {
571        let builder = SMTSolverBuilder::default();
572
573        let expr = BooleanExpression::ComparisonExpression(
574            Box::new(IntegerExpression::Atom(Variable::new("x"))),
575            ComparisonOp::Gt,
576            Box::new(IntegerExpression::Atom(Variable::new("y"))),
577        );
578
579        let mut ctx = StaticSMTContext::new(
580            builder,
581            vec![Parameter::new("p")],
582            vec![Location::new("loc")],
583            vec![Variable::new("x"), Variable::new("y")],
584        )
585        .unwrap();
586
587        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
588
589        let solver = ctx.get_smt_solver_mut();
590        let encoded_str = solver.display(encoded_expr).to_string();
591        assert_eq!(encoded_str, "(> var_x_0 var_y_0)");
592
593        solver.assert(encoded_expr).unwrap();
594        assert_eq!(solver.check().unwrap(), Response::Sat);
595    }
596
597    #[test]
598    fn test_boolean_expr_encoding_var_eq_const() {
599        let builder = SMTSolverBuilder::default();
600
601        let expr = BooleanExpression::ComparisonExpression(
602            Box::new(IntegerExpression::Atom(Variable::new("x"))),
603            ComparisonOp::Eq,
604            Box::new(IntegerExpression::Const(5)),
605        );
606
607        let mut ctx = StaticSMTContext::new(
608            builder,
609            vec![Parameter::new("p")],
610            vec![Location::new("loc")],
611            vec![Variable::new("x")],
612        )
613        .unwrap();
614
615        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
616
617        let solver = ctx.get_smt_solver_mut();
618        let encoded_str = solver.display(encoded_expr).to_string();
619        assert_eq!(encoded_str, "(= var_x_0 5)");
620
621        solver.assert(encoded_expr).unwrap();
622        assert_eq!(solver.check().unwrap(), Response::Sat);
623    }
624
625    #[test]
626    fn test_boolean_expr_encoding_var_and_const() {
627        let builder = SMTSolverBuilder::default();
628
629        let expr = BooleanExpression::BinaryExpression(
630            Box::new(BooleanExpression::ComparisonExpression(
631                Box::new(IntegerExpression::Atom(Variable::new("x"))),
632                ComparisonOp::Geq,
633                Box::new(IntegerExpression::Const(5)),
634            )),
635            BooleanConnective::And,
636            Box::new(BooleanExpression::ComparisonExpression(
637                Box::new(IntegerExpression::Atom(Variable::new("y"))),
638                ComparisonOp::Lt,
639                Box::new(IntegerExpression::Const(10)),
640            )),
641        );
642
643        let mut ctx = StaticSMTContext::new(
644            builder,
645            vec![Parameter::new("p")],
646            vec![Location::new("loc")],
647            vec![Variable::new("x"), Variable::new("y")],
648        )
649        .unwrap();
650
651        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
652
653        let solver = ctx.get_smt_solver_mut();
654        let encoded_str = solver.display(encoded_expr).to_string();
655        assert_eq!(encoded_str, "(and (>= var_x_0 5) (< var_y_0 10))");
656
657        solver.assert(encoded_expr).unwrap();
658        assert_eq!(solver.check().unwrap(), Response::Sat);
659    }
660
661    #[test]
662    fn test_boolean_expr_encoding_var_or_const() {
663        let builder = SMTSolverBuilder::default();
664
665        let expr = BooleanExpression::BinaryExpression(
666            Box::new(BooleanExpression::ComparisonExpression(
667                Box::new(IntegerExpression::Atom(Variable::new("x"))),
668                ComparisonOp::Leq,
669                Box::new(IntegerExpression::Const(5)),
670            )),
671            BooleanConnective::Or,
672            Box::new(BooleanExpression::ComparisonExpression(
673                Box::new(IntegerExpression::Param(Parameter::new("p"))),
674                ComparisonOp::Neq,
675                Box::new(IntegerExpression::Const(10)),
676            )),
677        );
678
679        let mut ctx = StaticSMTContext::new(
680            builder,
681            vec![Parameter::new("p")],
682            vec![Location::new("loc")],
683            vec![Variable::new("x"), Variable::new("y")],
684        )
685        .unwrap();
686
687        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
688
689        let solver = ctx.get_smt_solver_mut();
690        let encoded_str = solver.display(encoded_expr).to_string();
691        assert_eq!(encoded_str, "(or (<= var_x_0 5) (not (= param_p 10)))");
692
693        solver.assert(encoded_expr).unwrap();
694        assert_eq!(solver.check().unwrap(), Response::Sat);
695    }
696
697    #[test]
698    fn test_boolean_expr_encoding_not_var() {
699        let builder = SMTSolverBuilder::default();
700
701        let expr = BooleanExpression::Not(Box::new(BooleanExpression::ComparisonExpression(
702            Box::new(IntegerExpression::Atom(Location::new("loc"))),
703            ComparisonOp::Eq,
704            Box::new(IntegerExpression::Const(5)),
705        )));
706
707        let mut ctx = StaticSMTContext::new(
708            builder,
709            vec![Parameter::new("p")],
710            vec![Location::new("loc")],
711            vec![Variable::new("x")],
712        )
713        .unwrap();
714
715        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
716
717        let solver = ctx.get_smt_solver_mut();
718        let encoded_str = solver.display(encoded_expr).to_string();
719        assert_eq!(encoded_str, "(not (= loc_loc_0 5))");
720
721        solver.assert(encoded_expr).unwrap();
722        assert_eq!(solver.check().unwrap(), Response::Sat);
723    }
724
725    #[test]
726    fn test_integer_expr_encoding_add() {
727        let builder = SMTSolverBuilder::default();
728
729        let expr = IntegerExpression::BinaryExpr(
730            Box::new(IntegerExpression::Atom(Parameter::new("p"))),
731            IntegerOp::Add,
732            Box::new(IntegerExpression::Const(5)),
733        );
734
735        let mut ctx = StaticSMTContext::new(
736            builder,
737            vec![Parameter::new("p")],
738            vec![Location::new("loc")],
739            vec![Variable::new("x")],
740        )
741        .unwrap();
742
743        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
744
745        let solver = ctx.get_smt_solver_mut();
746        let encoded_str = solver.display(encoded_expr).to_string();
747        assert_eq!(encoded_str, "(+ param_p 5)");
748    }
749
750    #[test]
751    fn test_integer_expr_encoding_sub() {
752        let builder = SMTSolverBuilder::default();
753
754        let expr = IntegerExpression::BinaryExpr(
755            Box::new(IntegerExpression::Atom(Parameter::new("p"))),
756            IntegerOp::Sub,
757            Box::new(IntegerExpression::Const(3)),
758        );
759
760        let mut ctx = StaticSMTContext::new(
761            builder,
762            vec![Parameter::new("p")],
763            vec![Location::new("loc")],
764            vec![Variable::new("x")],
765        )
766        .unwrap();
767
768        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
769
770        let solver = ctx.get_smt_solver_mut();
771        let encoded_str = solver.display(encoded_expr).to_string();
772        assert_eq!(encoded_str, "(- param_p 3)");
773    }
774
775    #[test]
776    fn test_integer_expr_encoding_mul() {
777        let builder = SMTSolverBuilder::default();
778
779        let expr = IntegerExpression::BinaryExpr(
780            Box::new(IntegerExpression::Atom(Parameter::new("p"))),
781            IntegerOp::Mul,
782            Box::new(IntegerExpression::Const(2)),
783        );
784
785        let mut ctx = StaticSMTContext::new(
786            builder,
787            vec![Parameter::new("p")],
788            vec![Location::new("loc")],
789            vec![Variable::new("x")],
790        )
791        .unwrap();
792
793        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
794
795        let solver = ctx.get_smt_solver_mut();
796        let encoded_str = solver.display(encoded_expr).to_string();
797        assert_eq!(encoded_str, "(* param_p 2)");
798    }
799
800    #[test]
801    fn test_integer_expr_encoding_div() {
802        let builder = SMTSolverBuilder::default();
803
804        let expr = IntegerExpression::BinaryExpr(
805            Box::new(IntegerExpression::Atom(Variable::new("x"))),
806            IntegerOp::Div,
807            Box::new(IntegerExpression::Const(4)),
808        );
809
810        let mut ctx = StaticSMTContext::new(
811            builder,
812            vec![Parameter::new("p")],
813            vec![Location::new("loc")],
814            vec![Variable::new("x")],
815        )
816        .unwrap();
817
818        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
819
820        let solver = ctx.get_smt_solver_mut();
821        let encoded_str = solver.display(encoded_expr).to_string();
822        assert_eq!(encoded_str, "(div var_x_0 4)");
823    }
824
825    #[test]
826    fn test_integer_expr_encoding_neg() {
827        let builder = SMTSolverBuilder::default();
828
829        let expr = IntegerExpression::Neg(Box::new(IntegerExpression::Atom(Parameter::new("p"))));
830
831        let mut ctx = StaticSMTContext::new(
832            builder,
833            vec![Parameter::new("p")],
834            vec![Location::new("loc")],
835            vec![Variable::new("x")],
836        )
837        .unwrap();
838
839        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
840
841        let solver = ctx.get_smt_solver_mut();
842        let encoded_str = solver.display(encoded_expr).to_string();
843        assert_eq!(encoded_str, "(- param_p)");
844    }
845
846    #[test]
847    fn test_get_assignment_sat() {
848        let builder = SMTSolverBuilder::default();
849
850        let expr = BooleanExpression::BinaryExpression(
851            Box::new(BooleanExpression::ComparisonExpression(
852                Box::new(IntegerExpression::Atom(Variable::new("x"))),
853                ComparisonOp::Leq,
854                Box::new(IntegerExpression::Const(5)),
855            )),
856            BooleanConnective::And,
857            Box::new(BooleanExpression::ComparisonExpression(
858                Box::new(IntegerExpression::Atom(Variable::new("x"))),
859                ComparisonOp::Gt,
860                Box::new(IntegerExpression::Const(4)),
861            )),
862        );
863
864        let mut ctx = StaticSMTContext::new(
865            builder,
866            vec![Parameter::new("p")],
867            vec![Location::new("loc")],
868            vec![Variable::new("x"), Variable::new("y")],
869        )
870        .unwrap();
871
872        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
873        let res = ctx.assert_and_check_expr(encoded_expr).unwrap();
874
875        let assignment = ctx.get_assignment(res, &Variable::new("x")).unwrap();
876
877        assert_eq!(assignment, Some(5));
878    }
879
880    #[test]
881    fn test_get_assignment_unsat() {
882        let builder = SMTSolverBuilder::default();
883
884        let expr = BooleanExpression::BinaryExpression(
885            Box::new(BooleanExpression::ComparisonExpression(
886                Box::new(IntegerExpression::Atom(Variable::new("x"))),
887                ComparisonOp::Leq,
888                Box::new(IntegerExpression::Const(5)),
889            )),
890            BooleanConnective::And,
891            Box::new(BooleanExpression::ComparisonExpression(
892                Box::new(IntegerExpression::Atom(Variable::new("x"))),
893                ComparisonOp::Gt,
894                Box::new(IntegerExpression::Const(5)),
895            )),
896        );
897
898        let mut ctx = StaticSMTContext::new(
899            builder,
900            vec![Parameter::new("p")],
901            vec![Location::new("loc")],
902            vec![Variable::new("x"), Variable::new("y")],
903        )
904        .unwrap();
905
906        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
907        let res = ctx.assert_and_check_expr(encoded_expr).unwrap();
908
909        let assignment = ctx.get_assignment(res, &Variable::new("x"));
910
911        assert!(assignment.is_err());
912        assert!(matches!(
913            assignment.unwrap_err(),
914            SMTSolverError::ExtractionFromUnsat
915        ));
916    }
917
918    #[test]
919    fn test_get_expr_smt_context() {
920        let builder = SMTSolverBuilder::default();
921
922        let ctx = StaticSMTContext::new(
923            builder,
924            vec![Parameter::new("p")],
925            vec![Location::new("loc")],
926            vec![Variable::new("x")],
927        )
928        .unwrap();
929
930        let exprs = SMTVariableContext::<Parameter>::get_exprs(&ctx)
931            .into_iter()
932            .cloned()
933            .collect::<Vec<Parameter>>();
934        assert_eq!(exprs.len(), 1);
935        assert!(exprs.contains(&Parameter::new("p")));
936
937        let exprs = SMTVariableContext::<Location>::get_exprs(&ctx)
938            .into_iter()
939            .cloned()
940            .collect::<Vec<Location>>();
941        assert_eq!(exprs.len(), 1);
942        assert!(exprs.contains(&Location::new("loc")));
943
944        let exprs = SMTVariableContext::<Variable>::get_exprs(&ctx)
945            .into_iter()
946            .cloned()
947            .collect::<Vec<Variable>>();
948        assert_eq!(exprs.len(), 1);
949        assert!(exprs.contains(&Variable::new("x")));
950    }
951
952    #[test]
953    fn test_get_assignment_no_solution() {
954        let builder = SMTSolverBuilder::default();
955
956        let mut solver = builder.new_solver();
957
958        let vars = vec![Variable::new("x")];
959        let vars_smt = vars
960            .into_iter()
961            .map(|v| (v.clone(), v.declare_variable(&mut solver, 0).unwrap()))
962            .collect::<HashMap<_, _>>();
963
964        let res = solver.assert_and_check_expr(solver.false_());
965        let sol = vars_smt.get_solution(&mut solver, res.unwrap());
966
967        assert!(sol.is_err());
968        assert!(matches!(
969            sol.unwrap_err(),
970            SMTSolverError::ExtractionFromUnsat
971        ));
972    }
973
974    #[test]
975    fn test_parse_int_err_get_sol() {
976        let builder = SMTSolverBuilder::default();
977
978        let mut solver = builder.new_solver();
979
980        let vars = vec![Variable::new("x")];
981        let vars_smt = vars
982            .into_iter()
983            .map(|v| (v.clone(), v.declare_variable(&mut solver, 0).unwrap()))
984            .collect::<HashMap<_, _>>();
985
986        let expr = solver.lt(vars_smt[&Variable::new("x")], solver.numeral(0));
987
988        let res = solver.assert_and_check_expr(expr);
989        let sol = vars_smt.get_solution(&mut solver, res.unwrap());
990
991        assert!(sol.is_err());
992        assert!(matches!(
993            sol.unwrap_err(),
994            SMTSolverError::SolutionExtractionParseIntError(_)
995        ));
996    }
997
998    #[test]
999    fn test_parse_int_err_get_assignment() {
1000        let builder = SMTSolverBuilder::default();
1001
1002        let expr = BooleanExpression::ComparisonExpression(
1003            Box::new(IntegerExpression::Atom(Variable::new("x"))),
1004            ComparisonOp::Lt,
1005            Box::new(IntegerExpression::Const(0)),
1006        );
1007
1008        let mut ctx = StaticSMTContext::new(
1009            builder,
1010            vec![Parameter::new("p")],
1011            vec![Location::new("loc")],
1012            vec![Variable::new("x"), Variable::new("y")],
1013        )
1014        .unwrap();
1015
1016        let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
1017        let res = ctx.assert_and_check_expr(encoded_expr).unwrap();
1018
1019        let sol = ctx.get_assignment(res, &Variable::new("x"));
1020
1021        assert!(sol.is_err());
1022        assert!(matches!(
1023            sol.unwrap_err(),
1024            SMTSolverError::SolutionExtractionParseIntError(_)
1025        ));
1026    }
1027
1028    #[test]
1029    fn test_get_undeclared_error() {
1030        let loc = Location::new("loc");
1031        let param = Parameter::new("param");
1032        let var = Variable::new("var");
1033        let rule = RuleBuilder::new(42, loc.clone(), loc.clone()).build();
1034
1035        assert!(matches!(
1036            loc.get_undeclared_error(),
1037            SMTSolverError::UndeclaredLocation(_)
1038        ));
1039        assert!(matches!(
1040            param.get_undeclared_error(),
1041            SMTSolverError::UndeclaredParameter(_)
1042        ));
1043        assert!(matches!(
1044            var.get_undeclared_error(),
1045            SMTSolverError::UndeclaredVariable(_)
1046        ));
1047        assert!(matches!(
1048            rule.get_undeclared_error(),
1049            SMTSolverError::UndeclaredRule(_)
1050        ));
1051    }
1052
1053    #[test]
1054    fn test_from_io_error() {
1055        let io_error = io::Error::other("Some error");
1056        let err = SMTSolverError::from(io_error);
1057
1058        assert!(matches!(err, SMTSolverError::EasySMTErr(_)));
1059    }
1060
1061    #[test]
1062    fn test_display_smt_err() {
1063        let err = SMTSolverError::EasySMTErr(io::Error::other("Some error"));
1064        let display = err.to_string();
1065        assert!(display.contains("Error from connection to SMT solver"));
1066
1067        let err = SMTSolverError::UndeclaredParameter(Parameter::new("p"));
1068        let display = err.to_string();
1069        assert!(display.contains("Undeclared parameter"));
1070        assert!(display.contains(Parameter::new("p").name()));
1071
1072        let err = SMTSolverError::UndeclaredLocation(Location::new("loc"));
1073        let display = err.to_string();
1074        assert!(display.contains("Undeclared location"));
1075        assert!(display.contains(Location::new("loc").name()));
1076
1077        let err = SMTSolverError::UndeclaredVariable(Variable::new("x"));
1078        let display = err.to_string();
1079        assert!(display.contains("Undeclared variable"));
1080        assert!(display.contains(Variable::new("x").name()));
1081
1082        let err = SMTSolverError::SolverTimeout;
1083        let display = err.to_string();
1084        assert!(display.contains("Timeout in SMT solver"));
1085
1086        let err = SMTSolverError::SolutionExtractionParseIntError("not_an_int".to_string());
1087        let display = err.to_string();
1088        assert!(display.contains("Failed to parse SMT solver supplied solution into integer"));
1089        assert!(display.contains("not_an_int"));
1090
1091        let err = SMTSolverError::UndeclaredRule(
1092            RuleBuilder::new(42, Location::new("src"), Location::new("tgt")).build(),
1093        );
1094        let display = err.to_string();
1095        assert!(display.contains("Undeclared rule"));
1096        assert!(display.contains("42"));
1097
1098        let err = SMTSolverError::ExtractionFromUnsat;
1099        let display = err.to_string();
1100        assert!(display.contains(
1101            "Attempted to extract the solution assignment from an unsatisfiable expression"
1102        ));
1103
1104        let err = SMTSolverError::TriviallyUnsat("Some reason".to_string());
1105        let display = err.to_string();
1106        assert!(display.contains("Specification can never be true"));
1107        assert!(display.contains("Some reason"));
1108    }
1109}