1use std::{collections::HashMap, fmt, io};
8
9use std::hash::Hash;
10
11use easy_smt::SExpr;
12use taco_threshold_automaton::{
13 RuleDefinition,
14 expressions::{
15 Atomic, BooleanConnective, BooleanExpression, ComparisonOp, IntegerExpression, IntegerOp,
16 Location, Parameter, Variable,
17 },
18 general_threshold_automaton::Rule,
19};
20
21use crate::{SMTExpr, SMTSolution, SMTSolver, SMTSolverBuilder, SMTSolverContext};
22
23pub mod config_ctx;
24pub mod ctx_mgr;
25pub mod step_ctx;
26pub mod ta_encoding;
27
28pub trait EncodeToSMT<T, C> {
34 fn encode_to_smt_with_ctx(
39 &self,
40 solver: &SMTSolver,
41 ctx: &C,
42 ) -> Result<SMTExpr, SMTSolverError>;
43}
44
45pub trait SMTVariableContext<T> {
53 fn get_expr_for(&self, expr: &T) -> Result<SMTExpr, SMTSolverError>;
60
61 fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a T>
63 where
64 T: 'a;
65
66 fn get_solution<'a>(
68 &'a self,
69 solver: &mut SMTSolver,
70 res: SMTSolution,
71 ) -> Result<HashMap<T, u32>, SMTSolverError>
72 where
73 T: 'a + Eq + Hash + Clone,
74 {
75 match res {
76 SMTSolution::UNSAT => Err(SMTSolverError::ExtractionFromUnsat),
77 SMTSolution::SAT => {
78 let expr_vec = self
79 .get_exprs()
80 .into_iter()
81 .map(|t| self.get_expr_for(t))
82 .collect::<Result<Vec<_>, _>>()?;
83
84 let solution = solver.get_value(expr_vec)?;
85
86 solution
87 .into_iter()
88 .zip(self.get_exprs())
89 .map(|((_, sol), t)| {
90 let solution_value = solver.get_u32(sol).ok_or_else(|| {
91 let sol = solver.display(sol);
92 SMTSolverError::SolutionExtractionParseIntError(sol.to_string())
93 })?;
94
95 Ok((t.clone(), solution_value))
96 })
97 .collect::<Result<HashMap<T, u32>, SMTSolverError>>()
98 }
99 }
100 }
101}
102
103impl<T: Hash + Eq + DeclaresVariable> SMTVariableContext<T> for HashMap<T, SExpr> {
104 fn get_expr_for(&self, expr: &T) -> Result<SMTExpr, SMTSolverError> {
105 self.get(expr)
106 .ok_or_else(|| expr.get_undeclared_error())
107 .copied()
108 }
109
110 fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a T>
111 where
112 T: 'a,
113 {
114 self.keys()
115 }
116}
117
118pub trait DeclaresVariable {
125 fn get_name(&self, index: u32) -> String;
135
136 fn declare_variable(
142 &self,
143 solver: &mut SMTSolver,
144 index: u32,
145 ) -> Result<SMTExpr, SMTSolverError> {
146 let name = self.get_name(index);
147 Ok(solver.declare_const(&name, solver.int_sort())?)
148 }
149
150 fn get_undeclared_error(&self) -> SMTSolverError;
153}
154
155impl DeclaresVariable for Location {
156 fn get_name(&self, index: u32) -> String {
157 format!("loc_{}_{}", self.name(), index)
158 }
159
160 fn get_undeclared_error(&self) -> SMTSolverError {
161 SMTSolverError::UndeclaredLocation(self.clone())
162 }
163}
164
165impl DeclaresVariable for Parameter {
166 fn get_name(&self, _index: u32) -> String {
167 format!("param_{}", self.name())
168 }
169
170 fn get_undeclared_error(&self) -> SMTSolverError {
171 SMTSolverError::UndeclaredParameter(self.clone())
172 }
173}
174
175impl DeclaresVariable for Variable {
176 fn get_name(&self, index: u32) -> String {
177 format!("var_{}_{}", self.name(), index)
178 }
179
180 fn get_undeclared_error(&self) -> SMTSolverError {
181 SMTSolverError::UndeclaredVariable(self.clone())
182 }
183}
184
185impl DeclaresVariable for Rule {
186 fn get_name(&self, index: u32) -> String {
187 format!("rule_{}_{}", self.id(), index)
188 }
189
190 fn get_undeclared_error(&self) -> SMTSolverError {
191 SMTSolverError::UndeclaredRule(self.clone())
192 }
193}
194
195pub trait GetAssignment<T>: SMTVariableContext<T> + SMTSolverContext
201where
202 T: DeclaresVariable,
203{
204 fn get_assignment(&mut self, res: SMTSolution, var: &T) -> Result<Option<u64>, SMTSolverError> {
212 match res {
213 SMTSolution::SAT => {
214 let expr = self.get_expr_for(var)?;
215 let solver = self.get_smt_solver_mut();
216
217 let solution = solver.get_value(vec![expr])?;
218 debug_assert!(solution.len() == 1);
219 debug_assert!(solution[0].0 == expr);
220
221 let solution_int = solver.get_u64(solution[0].1);
222
223 let sol = solution_int.ok_or_else(|| {
224 let sol = solver.display(expr);
225 SMTSolverError::SolutionExtractionParseIntError(sol.to_string())
226 })?;
227
228 Ok(Some(sol))
229 }
230 SMTSolution::UNSAT => Err(SMTSolverError::ExtractionFromUnsat),
231 }
232 }
233}
234
235impl<T, U> EncodeToSMT<IntegerExpression<T>, U> for IntegerExpression<T>
236where
237 U: SMTVariableContext<T> + SMTVariableContext<Parameter>,
238 T: DeclaresVariable + Atomic,
239{
240 fn encode_to_smt_with_ctx(&self, solver: &SMTSolver, ctx: &U) -> Result<SExpr, SMTSolverError> {
241 match self {
242 IntegerExpression::Atom(a) => ctx.get_expr_for(a),
243 IntegerExpression::Const(c) => Ok(solver.numeral(*c)),
244 IntegerExpression::Param(parameter) => ctx.get_expr_for(parameter),
245 IntegerExpression::BinaryExpr(lhs, op, rhs) => {
246 let lhs = lhs.encode_to_smt_with_ctx(solver, ctx)?;
247 let rhs = rhs.encode_to_smt_with_ctx(solver, ctx)?;
248
249 Ok(match op {
250 IntegerOp::Add => solver.plus(lhs, rhs),
251 IntegerOp::Sub => solver.sub(lhs, rhs),
252 IntegerOp::Mul => solver.times(lhs, rhs),
253 IntegerOp::Div => solver.div(lhs, rhs),
254 })
255 }
256 IntegerExpression::Neg(expr) => {
257 let expr = expr.encode_to_smt_with_ctx(solver, ctx)?;
258 Ok(solver.negate(expr))
259 }
260 }
261 }
262}
263
264impl<T, U> EncodeToSMT<BooleanExpression<T>, U> for BooleanExpression<T>
265where
266 U: SMTVariableContext<T> + SMTVariableContext<Parameter>,
267 T: DeclaresVariable + Atomic,
268{
269 fn encode_to_smt_with_ctx(&self, solver: &SMTSolver, ctx: &U) -> Result<SExpr, SMTSolverError> {
270 match self {
271 BooleanExpression::ComparisonExpression(lhs, op, rhs) => {
272 let lhs = lhs.encode_to_smt_with_ctx(solver, ctx)?;
273 let rhs = rhs.encode_to_smt_with_ctx(solver, ctx)?;
274
275 Ok(match op {
276 ComparisonOp::Gt => solver.gt(lhs, rhs),
277 ComparisonOp::Geq => solver.gte(lhs, rhs),
278 ComparisonOp::Eq => solver.eq(lhs, rhs),
279 ComparisonOp::Neq => solver.not(solver.eq(lhs, rhs)),
280 ComparisonOp::Leq => solver.lte(lhs, rhs),
281 ComparisonOp::Lt => solver.lt(lhs, rhs),
282 })
283 }
284 BooleanExpression::BinaryExpression(lhs, op, rhs) => {
285 let lhs = lhs.encode_to_smt_with_ctx(solver, ctx)?;
286 let rhs = rhs.encode_to_smt_with_ctx(solver, ctx)?;
287
288 Ok(match op {
289 BooleanConnective::And => solver.and(lhs, rhs),
290 BooleanConnective::Or => solver.or(lhs, rhs),
291 })
292 }
293 BooleanExpression::Not(expr) => {
294 let expr = expr.encode_to_smt_with_ctx(solver, ctx)?;
295
296 Ok(solver.not(expr))
297 }
298 BooleanExpression::True => Ok(solver.true_()),
299 BooleanExpression::False => Ok(solver.false_()),
300 }
301 }
302}
303
304#[derive(Debug)]
306pub enum SMTSolverError {
307 EasySMTErr(io::Error),
309 SolverTimeout,
311 UndeclaredParameter(Parameter),
313 UndeclaredLocation(Location),
315 UndeclaredVariable(Variable),
317 SolutionExtractionParseIntError(String),
319 UndeclaredRule(Rule),
321 ExtractionFromUnsat,
323 TriviallyUnsat(String),
325}
326
327impl std::error::Error for SMTSolverError {}
328
329impl fmt::Display for SMTSolverError {
330 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331 match self {
332 SMTSolverError::EasySMTErr(err) => {
333 write!(f, "Error from connection to SMT solver: {err}")
334 }
335 SMTSolverError::UndeclaredParameter(param) => {
336 write!(f, "Undeclared parameter: {param}")
337 }
338 SMTSolverError::UndeclaredLocation(location) => {
339 write!(f, "Undeclared location: {location}")
340 }
341 SMTSolverError::UndeclaredVariable(variable) => {
342 write!(f, "Undeclared variable: {variable}")
343 }
344 SMTSolverError::SolverTimeout => write!(f, "Timeout in SMT solver"),
345 SMTSolverError::SolutionExtractionParseIntError(s) => write!(
346 f,
347 "Failed to parse SMT solver supplied solution into integer: {s} not an integer"
348 ),
349 SMTSolverError::UndeclaredRule(rule) => write!(f, "Undeclared rule: {rule}"),
350 SMTSolverError::ExtractionFromUnsat => write!(
351 f,
352 "Attempted to extract the solution assignment from an unsatisfiable expression"
353 ),
354 SMTSolverError::TriviallyUnsat(r) => {
355 write!(f, "Specification can never be true. Reason: {r}")
356 }
357 }
358 }
359}
360
361impl From<io::Error> for SMTSolverError {
362 fn from(error: io::Error) -> Self {
363 SMTSolverError::EasySMTErr(error)
364 }
365}
366
367pub struct StaticSMTContext {
374 solver: SMTSolver,
375 params: HashMap<Parameter, SMTExpr>,
376 locs: HashMap<Location, SMTExpr>,
377 vars: HashMap<Variable, Vec<SMTExpr>>,
378}
379
380impl StaticSMTContext {
381 pub fn new(
387 solver_builder: SMTSolverBuilder,
388 params: impl IntoIterator<Item = Parameter>,
389 locs: impl IntoIterator<Item = Location>,
390 vars: impl IntoIterator<Item = Variable>,
391 ) -> Result<Self, SMTSolverError> {
392 let mut solver = solver_builder.new_solver();
393
394 let params = params
395 .into_iter()
396 .map(|p| {
397 let expr = p.declare_variable(&mut solver, 0)?;
398 Ok((p, expr))
399 })
400 .collect::<Result<HashMap<_, _>, SMTSolverError>>()?;
401
402 let locs = locs
403 .into_iter()
404 .map(|l| {
405 let expr = l.declare_variable(&mut solver, 0)?;
406 Ok((l, expr))
407 })
408 .collect::<Result<HashMap<_, _>, SMTSolverError>>()?;
409
410 let vars = vars
411 .into_iter()
412 .map(|v| {
413 let expr = v.declare_variable(&mut solver, 0)?;
414 Ok((v, vec![expr]))
415 })
416 .collect::<Result<HashMap<_, _>, SMTSolverError>>()?;
417
418 Ok(StaticSMTContext {
419 solver,
420 params,
421 locs,
422 vars,
423 })
424 }
425
426 pub fn get_true(&self) -> SMTExpr {
428 self.solver.true_()
429 }
430
431 pub fn encode_to_smt<T>(&self, expr: &T) -> Result<SMTExpr, SMTSolverError>
434 where
435 T: EncodeToSMT<T, Self>,
436 {
437 expr.encode_to_smt_with_ctx(&self.solver, self)
438 }
439}
440
441impl SMTSolverContext for StaticSMTContext {
442 fn get_smt_solver_mut(&mut self) -> &mut SMTSolver {
443 &mut self.solver
444 }
445
446 fn get_smt_solver(&self) -> &SMTSolver {
447 &self.solver
448 }
449}
450
451impl SMTVariableContext<Parameter> for StaticSMTContext {
452 fn get_expr_for(&self, param: &Parameter) -> Result<SMTExpr, SMTSolverError> {
453 self.params
454 .get(param)
455 .cloned()
456 .ok_or_else(|| SMTSolverError::UndeclaredParameter(param.clone()))
457 }
458
459 fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Parameter>
460 where
461 Parameter: 'a,
462 {
463 self.params.keys()
464 }
465}
466
467impl SMTVariableContext<Location> for StaticSMTContext {
468 fn get_expr_for(&self, loc: &Location) -> Result<SMTExpr, SMTSolverError> {
469 self.locs
470 .get(loc)
471 .cloned()
472 .ok_or_else(|| SMTSolverError::UndeclaredLocation(loc.clone()))
473 }
474
475 fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Location>
476 where
477 Location: 'a,
478 {
479 self.locs.keys()
480 }
481}
482
483impl SMTVariableContext<Variable> for StaticSMTContext {
484 fn get_expr_for(&self, var: &Variable) -> Result<SMTExpr, SMTSolverError> {
485 self.vars
486 .get(var)
487 .and_then(|v| v.first().cloned())
488 .ok_or_else(|| SMTSolverError::UndeclaredVariable(var.clone()))
489 }
490
491 fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Variable>
492 where
493 Variable: 'a,
494 {
495 self.vars.keys()
496 }
497}
498
499impl GetAssignment<Parameter> for StaticSMTContext {}
500impl GetAssignment<Variable> for StaticSMTContext {}
501impl GetAssignment<Location> for StaticSMTContext {}
502
503#[cfg(test)]
504mod tests {
505 use std::vec;
506
507 use easy_smt::Response;
508 use taco_threshold_automaton::general_threshold_automaton::builder::RuleBuilder;
509
510 use super::*;
511
512 #[test]
513 fn test_get_true() {
514 let builder = SMTSolverBuilder::default();
515
516 let mut ctx = StaticSMTContext::new(
517 builder,
518 vec![Parameter::new("p")],
519 vec![Location::new("loc")],
520 vec![Variable::new("x")],
521 )
522 .unwrap();
523
524 let top = ctx.get_true();
525
526 assert_eq!(top, ctx.get_smt_solver_mut().true_())
527 }
528
529 #[test]
530 fn test_rule_variable_name() {
531 let r = RuleBuilder::new(42, Location::new("src"), Location::new("tgt")).build();
532
533 assert!(r.get_name(0).contains("rule"));
534 assert!(r.get_name(0).contains("42"));
535 assert!(r.get_name(0).contains("0"));
536 }
537
538 #[test]
539 fn test_boolean_expr_encoding_true_false() {
540 let builder = SMTSolverBuilder::default();
541
542 let true_expr: BooleanExpression<Variable> = BooleanExpression::True;
543 let false_expr: BooleanExpression<Variable> = BooleanExpression::False;
544
545 let mut ctx = StaticSMTContext::new(
546 builder,
547 vec![Parameter::new("p")],
548 vec![Location::new("loc")],
549 vec![Variable::new("x")],
550 )
551 .unwrap();
552
553 let true_encoded = ctx.encode_to_smt(&true_expr).unwrap();
554 let false_encoded = ctx.encode_to_smt(&false_expr).unwrap();
555
556 let solver = ctx.get_smt_solver_mut();
557 let true_str = solver.display(true_encoded).to_string();
558 let false_str = solver.display(false_encoded).to_string();
559 assert_eq!(true_str, "true");
560 assert_eq!(false_str, "false");
561
562 solver.assert(true_encoded).unwrap();
563 assert_eq!(solver.check().unwrap(), Response::Sat);
564
565 solver.assert(false_encoded).unwrap();
566 assert_eq!(solver.check().unwrap(), Response::Unsat);
567 }
568
569 #[test]
570 fn test_boolean_expr_encoding_two_var_gt() {
571 let builder = SMTSolverBuilder::default();
572
573 let expr = BooleanExpression::ComparisonExpression(
574 Box::new(IntegerExpression::Atom(Variable::new("x"))),
575 ComparisonOp::Gt,
576 Box::new(IntegerExpression::Atom(Variable::new("y"))),
577 );
578
579 let mut ctx = StaticSMTContext::new(
580 builder,
581 vec![Parameter::new("p")],
582 vec![Location::new("loc")],
583 vec![Variable::new("x"), Variable::new("y")],
584 )
585 .unwrap();
586
587 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
588
589 let solver = ctx.get_smt_solver_mut();
590 let encoded_str = solver.display(encoded_expr).to_string();
591 assert_eq!(encoded_str, "(> var_x_0 var_y_0)");
592
593 solver.assert(encoded_expr).unwrap();
594 assert_eq!(solver.check().unwrap(), Response::Sat);
595 }
596
597 #[test]
598 fn test_boolean_expr_encoding_var_eq_const() {
599 let builder = SMTSolverBuilder::default();
600
601 let expr = BooleanExpression::ComparisonExpression(
602 Box::new(IntegerExpression::Atom(Variable::new("x"))),
603 ComparisonOp::Eq,
604 Box::new(IntegerExpression::Const(5)),
605 );
606
607 let mut ctx = StaticSMTContext::new(
608 builder,
609 vec![Parameter::new("p")],
610 vec![Location::new("loc")],
611 vec![Variable::new("x")],
612 )
613 .unwrap();
614
615 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
616
617 let solver = ctx.get_smt_solver_mut();
618 let encoded_str = solver.display(encoded_expr).to_string();
619 assert_eq!(encoded_str, "(= var_x_0 5)");
620
621 solver.assert(encoded_expr).unwrap();
622 assert_eq!(solver.check().unwrap(), Response::Sat);
623 }
624
625 #[test]
626 fn test_boolean_expr_encoding_var_and_const() {
627 let builder = SMTSolverBuilder::default();
628
629 let expr = BooleanExpression::BinaryExpression(
630 Box::new(BooleanExpression::ComparisonExpression(
631 Box::new(IntegerExpression::Atom(Variable::new("x"))),
632 ComparisonOp::Geq,
633 Box::new(IntegerExpression::Const(5)),
634 )),
635 BooleanConnective::And,
636 Box::new(BooleanExpression::ComparisonExpression(
637 Box::new(IntegerExpression::Atom(Variable::new("y"))),
638 ComparisonOp::Lt,
639 Box::new(IntegerExpression::Const(10)),
640 )),
641 );
642
643 let mut ctx = StaticSMTContext::new(
644 builder,
645 vec![Parameter::new("p")],
646 vec![Location::new("loc")],
647 vec![Variable::new("x"), Variable::new("y")],
648 )
649 .unwrap();
650
651 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
652
653 let solver = ctx.get_smt_solver_mut();
654 let encoded_str = solver.display(encoded_expr).to_string();
655 assert_eq!(encoded_str, "(and (>= var_x_0 5) (< var_y_0 10))");
656
657 solver.assert(encoded_expr).unwrap();
658 assert_eq!(solver.check().unwrap(), Response::Sat);
659 }
660
661 #[test]
662 fn test_boolean_expr_encoding_var_or_const() {
663 let builder = SMTSolverBuilder::default();
664
665 let expr = BooleanExpression::BinaryExpression(
666 Box::new(BooleanExpression::ComparisonExpression(
667 Box::new(IntegerExpression::Atom(Variable::new("x"))),
668 ComparisonOp::Leq,
669 Box::new(IntegerExpression::Const(5)),
670 )),
671 BooleanConnective::Or,
672 Box::new(BooleanExpression::ComparisonExpression(
673 Box::new(IntegerExpression::Param(Parameter::new("p"))),
674 ComparisonOp::Neq,
675 Box::new(IntegerExpression::Const(10)),
676 )),
677 );
678
679 let mut ctx = StaticSMTContext::new(
680 builder,
681 vec![Parameter::new("p")],
682 vec![Location::new("loc")],
683 vec![Variable::new("x"), Variable::new("y")],
684 )
685 .unwrap();
686
687 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
688
689 let solver = ctx.get_smt_solver_mut();
690 let encoded_str = solver.display(encoded_expr).to_string();
691 assert_eq!(encoded_str, "(or (<= var_x_0 5) (not (= param_p 10)))");
692
693 solver.assert(encoded_expr).unwrap();
694 assert_eq!(solver.check().unwrap(), Response::Sat);
695 }
696
697 #[test]
698 fn test_boolean_expr_encoding_not_var() {
699 let builder = SMTSolverBuilder::default();
700
701 let expr = BooleanExpression::Not(Box::new(BooleanExpression::ComparisonExpression(
702 Box::new(IntegerExpression::Atom(Location::new("loc"))),
703 ComparisonOp::Eq,
704 Box::new(IntegerExpression::Const(5)),
705 )));
706
707 let mut ctx = StaticSMTContext::new(
708 builder,
709 vec![Parameter::new("p")],
710 vec![Location::new("loc")],
711 vec![Variable::new("x")],
712 )
713 .unwrap();
714
715 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
716
717 let solver = ctx.get_smt_solver_mut();
718 let encoded_str = solver.display(encoded_expr).to_string();
719 assert_eq!(encoded_str, "(not (= loc_loc_0 5))");
720
721 solver.assert(encoded_expr).unwrap();
722 assert_eq!(solver.check().unwrap(), Response::Sat);
723 }
724
725 #[test]
726 fn test_integer_expr_encoding_add() {
727 let builder = SMTSolverBuilder::default();
728
729 let expr = IntegerExpression::BinaryExpr(
730 Box::new(IntegerExpression::Atom(Parameter::new("p"))),
731 IntegerOp::Add,
732 Box::new(IntegerExpression::Const(5)),
733 );
734
735 let mut ctx = StaticSMTContext::new(
736 builder,
737 vec![Parameter::new("p")],
738 vec![Location::new("loc")],
739 vec![Variable::new("x")],
740 )
741 .unwrap();
742
743 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
744
745 let solver = ctx.get_smt_solver_mut();
746 let encoded_str = solver.display(encoded_expr).to_string();
747 assert_eq!(encoded_str, "(+ param_p 5)");
748 }
749
750 #[test]
751 fn test_integer_expr_encoding_sub() {
752 let builder = SMTSolverBuilder::default();
753
754 let expr = IntegerExpression::BinaryExpr(
755 Box::new(IntegerExpression::Atom(Parameter::new("p"))),
756 IntegerOp::Sub,
757 Box::new(IntegerExpression::Const(3)),
758 );
759
760 let mut ctx = StaticSMTContext::new(
761 builder,
762 vec![Parameter::new("p")],
763 vec![Location::new("loc")],
764 vec![Variable::new("x")],
765 )
766 .unwrap();
767
768 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
769
770 let solver = ctx.get_smt_solver_mut();
771 let encoded_str = solver.display(encoded_expr).to_string();
772 assert_eq!(encoded_str, "(- param_p 3)");
773 }
774
775 #[test]
776 fn test_integer_expr_encoding_mul() {
777 let builder = SMTSolverBuilder::default();
778
779 let expr = IntegerExpression::BinaryExpr(
780 Box::new(IntegerExpression::Atom(Parameter::new("p"))),
781 IntegerOp::Mul,
782 Box::new(IntegerExpression::Const(2)),
783 );
784
785 let mut ctx = StaticSMTContext::new(
786 builder,
787 vec![Parameter::new("p")],
788 vec![Location::new("loc")],
789 vec![Variable::new("x")],
790 )
791 .unwrap();
792
793 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
794
795 let solver = ctx.get_smt_solver_mut();
796 let encoded_str = solver.display(encoded_expr).to_string();
797 assert_eq!(encoded_str, "(* param_p 2)");
798 }
799
800 #[test]
801 fn test_integer_expr_encoding_div() {
802 let builder = SMTSolverBuilder::default();
803
804 let expr = IntegerExpression::BinaryExpr(
805 Box::new(IntegerExpression::Atom(Variable::new("x"))),
806 IntegerOp::Div,
807 Box::new(IntegerExpression::Const(4)),
808 );
809
810 let mut ctx = StaticSMTContext::new(
811 builder,
812 vec![Parameter::new("p")],
813 vec![Location::new("loc")],
814 vec![Variable::new("x")],
815 )
816 .unwrap();
817
818 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
819
820 let solver = ctx.get_smt_solver_mut();
821 let encoded_str = solver.display(encoded_expr).to_string();
822 assert_eq!(encoded_str, "(div var_x_0 4)");
823 }
824
825 #[test]
826 fn test_integer_expr_encoding_neg() {
827 let builder = SMTSolverBuilder::default();
828
829 let expr = IntegerExpression::Neg(Box::new(IntegerExpression::Atom(Parameter::new("p"))));
830
831 let mut ctx = StaticSMTContext::new(
832 builder,
833 vec![Parameter::new("p")],
834 vec![Location::new("loc")],
835 vec![Variable::new("x")],
836 )
837 .unwrap();
838
839 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
840
841 let solver = ctx.get_smt_solver_mut();
842 let encoded_str = solver.display(encoded_expr).to_string();
843 assert_eq!(encoded_str, "(- param_p)");
844 }
845
846 #[test]
847 fn test_get_assignment_sat() {
848 let builder = SMTSolverBuilder::default();
849
850 let expr = BooleanExpression::BinaryExpression(
851 Box::new(BooleanExpression::ComparisonExpression(
852 Box::new(IntegerExpression::Atom(Variable::new("x"))),
853 ComparisonOp::Leq,
854 Box::new(IntegerExpression::Const(5)),
855 )),
856 BooleanConnective::And,
857 Box::new(BooleanExpression::ComparisonExpression(
858 Box::new(IntegerExpression::Atom(Variable::new("x"))),
859 ComparisonOp::Gt,
860 Box::new(IntegerExpression::Const(4)),
861 )),
862 );
863
864 let mut ctx = StaticSMTContext::new(
865 builder,
866 vec![Parameter::new("p")],
867 vec![Location::new("loc")],
868 vec![Variable::new("x"), Variable::new("y")],
869 )
870 .unwrap();
871
872 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
873 let res = ctx.assert_and_check_expr(encoded_expr).unwrap();
874
875 let assignment = ctx.get_assignment(res, &Variable::new("x")).unwrap();
876
877 assert_eq!(assignment, Some(5));
878 }
879
880 #[test]
881 fn test_get_assignment_unsat() {
882 let builder = SMTSolverBuilder::default();
883
884 let expr = BooleanExpression::BinaryExpression(
885 Box::new(BooleanExpression::ComparisonExpression(
886 Box::new(IntegerExpression::Atom(Variable::new("x"))),
887 ComparisonOp::Leq,
888 Box::new(IntegerExpression::Const(5)),
889 )),
890 BooleanConnective::And,
891 Box::new(BooleanExpression::ComparisonExpression(
892 Box::new(IntegerExpression::Atom(Variable::new("x"))),
893 ComparisonOp::Gt,
894 Box::new(IntegerExpression::Const(5)),
895 )),
896 );
897
898 let mut ctx = StaticSMTContext::new(
899 builder,
900 vec![Parameter::new("p")],
901 vec![Location::new("loc")],
902 vec![Variable::new("x"), Variable::new("y")],
903 )
904 .unwrap();
905
906 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
907 let res = ctx.assert_and_check_expr(encoded_expr).unwrap();
908
909 let assignment = ctx.get_assignment(res, &Variable::new("x"));
910
911 assert!(assignment.is_err());
912 assert!(matches!(
913 assignment.unwrap_err(),
914 SMTSolverError::ExtractionFromUnsat
915 ));
916 }
917
918 #[test]
919 fn test_get_expr_smt_context() {
920 let builder = SMTSolverBuilder::default();
921
922 let ctx = StaticSMTContext::new(
923 builder,
924 vec![Parameter::new("p")],
925 vec![Location::new("loc")],
926 vec![Variable::new("x")],
927 )
928 .unwrap();
929
930 let exprs = SMTVariableContext::<Parameter>::get_exprs(&ctx)
931 .into_iter()
932 .cloned()
933 .collect::<Vec<Parameter>>();
934 assert_eq!(exprs.len(), 1);
935 assert!(exprs.contains(&Parameter::new("p")));
936
937 let exprs = SMTVariableContext::<Location>::get_exprs(&ctx)
938 .into_iter()
939 .cloned()
940 .collect::<Vec<Location>>();
941 assert_eq!(exprs.len(), 1);
942 assert!(exprs.contains(&Location::new("loc")));
943
944 let exprs = SMTVariableContext::<Variable>::get_exprs(&ctx)
945 .into_iter()
946 .cloned()
947 .collect::<Vec<Variable>>();
948 assert_eq!(exprs.len(), 1);
949 assert!(exprs.contains(&Variable::new("x")));
950 }
951
952 #[test]
953 fn test_get_assignment_no_solution() {
954 let builder = SMTSolverBuilder::default();
955
956 let mut solver = builder.new_solver();
957
958 let vars = vec![Variable::new("x")];
959 let vars_smt = vars
960 .into_iter()
961 .map(|v| (v.clone(), v.declare_variable(&mut solver, 0).unwrap()))
962 .collect::<HashMap<_, _>>();
963
964 let res = solver.assert_and_check_expr(solver.false_());
965 let sol = vars_smt.get_solution(&mut solver, res.unwrap());
966
967 assert!(sol.is_err());
968 assert!(matches!(
969 sol.unwrap_err(),
970 SMTSolverError::ExtractionFromUnsat
971 ));
972 }
973
974 #[test]
975 fn test_parse_int_err_get_sol() {
976 let builder = SMTSolverBuilder::default();
977
978 let mut solver = builder.new_solver();
979
980 let vars = vec![Variable::new("x")];
981 let vars_smt = vars
982 .into_iter()
983 .map(|v| (v.clone(), v.declare_variable(&mut solver, 0).unwrap()))
984 .collect::<HashMap<_, _>>();
985
986 let expr = solver.lt(vars_smt[&Variable::new("x")], solver.numeral(0));
987
988 let res = solver.assert_and_check_expr(expr);
989 let sol = vars_smt.get_solution(&mut solver, res.unwrap());
990
991 assert!(sol.is_err());
992 assert!(matches!(
993 sol.unwrap_err(),
994 SMTSolverError::SolutionExtractionParseIntError(_)
995 ));
996 }
997
998 #[test]
999 fn test_parse_int_err_get_assignment() {
1000 let builder = SMTSolverBuilder::default();
1001
1002 let expr = BooleanExpression::ComparisonExpression(
1003 Box::new(IntegerExpression::Atom(Variable::new("x"))),
1004 ComparisonOp::Lt,
1005 Box::new(IntegerExpression::Const(0)),
1006 );
1007
1008 let mut ctx = StaticSMTContext::new(
1009 builder,
1010 vec![Parameter::new("p")],
1011 vec![Location::new("loc")],
1012 vec![Variable::new("x"), Variable::new("y")],
1013 )
1014 .unwrap();
1015
1016 let encoded_expr = ctx.encode_to_smt(&expr).unwrap();
1017 let res = ctx.assert_and_check_expr(encoded_expr).unwrap();
1018
1019 let sol = ctx.get_assignment(res, &Variable::new("x"));
1020
1021 assert!(sol.is_err());
1022 assert!(matches!(
1023 sol.unwrap_err(),
1024 SMTSolverError::SolutionExtractionParseIntError(_)
1025 ));
1026 }
1027
1028 #[test]
1029 fn test_get_undeclared_error() {
1030 let loc = Location::new("loc");
1031 let param = Parameter::new("param");
1032 let var = Variable::new("var");
1033 let rule = RuleBuilder::new(42, loc.clone(), loc.clone()).build();
1034
1035 assert!(matches!(
1036 loc.get_undeclared_error(),
1037 SMTSolverError::UndeclaredLocation(_)
1038 ));
1039 assert!(matches!(
1040 param.get_undeclared_error(),
1041 SMTSolverError::UndeclaredParameter(_)
1042 ));
1043 assert!(matches!(
1044 var.get_undeclared_error(),
1045 SMTSolverError::UndeclaredVariable(_)
1046 ));
1047 assert!(matches!(
1048 rule.get_undeclared_error(),
1049 SMTSolverError::UndeclaredRule(_)
1050 ));
1051 }
1052
1053 #[test]
1054 fn test_from_io_error() {
1055 let io_error = io::Error::other("Some error");
1056 let err = SMTSolverError::from(io_error);
1057
1058 assert!(matches!(err, SMTSolverError::EasySMTErr(_)));
1059 }
1060
1061 #[test]
1062 fn test_display_smt_err() {
1063 let err = SMTSolverError::EasySMTErr(io::Error::other("Some error"));
1064 let display = err.to_string();
1065 assert!(display.contains("Error from connection to SMT solver"));
1066
1067 let err = SMTSolverError::UndeclaredParameter(Parameter::new("p"));
1068 let display = err.to_string();
1069 assert!(display.contains("Undeclared parameter"));
1070 assert!(display.contains(Parameter::new("p").name()));
1071
1072 let err = SMTSolverError::UndeclaredLocation(Location::new("loc"));
1073 let display = err.to_string();
1074 assert!(display.contains("Undeclared location"));
1075 assert!(display.contains(Location::new("loc").name()));
1076
1077 let err = SMTSolverError::UndeclaredVariable(Variable::new("x"));
1078 let display = err.to_string();
1079 assert!(display.contains("Undeclared variable"));
1080 assert!(display.contains(Variable::new("x").name()));
1081
1082 let err = SMTSolverError::SolverTimeout;
1083 let display = err.to_string();
1084 assert!(display.contains("Timeout in SMT solver"));
1085
1086 let err = SMTSolverError::SolutionExtractionParseIntError("not_an_int".to_string());
1087 let display = err.to_string();
1088 assert!(display.contains("Failed to parse SMT solver supplied solution into integer"));
1089 assert!(display.contains("not_an_int"));
1090
1091 let err = SMTSolverError::UndeclaredRule(
1092 RuleBuilder::new(42, Location::new("src"), Location::new("tgt")).build(),
1093 );
1094 let display = err.to_string();
1095 assert!(display.contains("Undeclared rule"));
1096 assert!(display.contains("42"));
1097
1098 let err = SMTSolverError::ExtractionFromUnsat;
1099 let display = err.to_string();
1100 assert!(display.contains(
1101 "Attempted to extract the solution assignment from an unsatisfiable expression"
1102 ));
1103
1104 let err = SMTSolverError::TriviallyUnsat("Some reason".to_string());
1105 let display = err.to_string();
1106 assert!(display.contains("Specification can never be true"));
1107 assert!(display.contains("Some reason"));
1108 }
1109}