taco_smt_encoder/expression_encoding/
ta_encoding.rs

1//! Helper functions to encode expression of a threshold automaton into SMT
2//! constraints
3
4use taco_threshold_automaton::{
5    ThresholdAutomaton, VariableConstraint,
6    expressions::{Location, Parameter, Variable},
7};
8
9use crate::{
10    SMTExpr, SMTSolver,
11    expression_encoding::{EncodeToSMT, SMTVariableContext},
12};
13
14/// Encode the resilience conditions of a threshold automaton into an SMT
15/// expression and return it
16pub fn encode_resilience_condition<TA, C>(ta: &TA, solver: &SMTSolver, ctx: &C) -> SMTExpr
17where
18    TA: ThresholdAutomaton,
19    C: SMTVariableContext<Parameter>,
20{
21    if ta.resilience_conditions().count() == 0 {
22        return solver.true_();
23    }
24
25    let rc_it = ta
26        .resilience_conditions()
27        .map(|rc| {
28            rc.encode_to_smt_with_ctx(solver, ctx)
29                .expect("Failed to encode resilience condition")
30        })
31        .chain([solver.true_()]);
32    solver.and_many(rc_it)
33}
34
35/// Encodes the initial variable and location constraints of the threshold
36/// automaton in to an SMT expression
37///
38/// This function encodes the initial location and variable constraints of
39/// the threshold automaton in to an SMT expression. To do so, it needs the
40/// initial configuration of the threshold automaton, which is used as an
41/// SMT variable context.
42pub fn encode_initial_constraints<TA, C>(ta: &TA, solver: &SMTSolver, ctx: &C) -> SMTExpr
43where
44    TA: ThresholdAutomaton,
45    C: SMTVariableContext<Parameter> + SMTVariableContext<Location> + SMTVariableContext<Variable>,
46{
47    if ta.initial_location_constraints().count() == 0
48        && ta.initial_variable_constraints().count() == 0
49    {
50        return solver.true_();
51    }
52
53    let loc_init = ta.initial_location_constraints().map(|loc| {
54        loc.encode_to_smt_with_ctx(solver, ctx)
55            .expect("Failed to encode initial location condition")
56    });
57
58    let var_init = ta.initial_variable_constraints().map(|var| {
59        var.as_boolean_expr()
60            .encode_to_smt_with_ctx(solver, ctx)
61            .expect("Failed to encode initial variable condition")
62    });
63
64    solver.and_many(loc_init.chain(var_init))
65}
66
67#[cfg(test)]
68mod tests {
69    use std::{collections::HashMap, rc::Rc};
70
71    use taco_threshold_automaton::{
72        ThresholdAutomaton,
73        expressions::{
74            BooleanExpression, ComparisonOp, IntegerExpression, Location, Parameter, Variable,
75        },
76        general_threshold_automaton::builder::GeneralThresholdAutomatonBuilder,
77    };
78
79    use crate::{
80        SMTSolverBuilder,
81        expression_encoding::{
82            DeclaresVariable, EncodeToSMT,
83            config_ctx::ConfigCtx,
84            ta_encoding::{encode_initial_constraints, encode_resilience_condition},
85        },
86    };
87
88    #[test]
89    fn test_encode_rc_constraints() {
90        let mut solver = SMTSolverBuilder::default().new_solver();
91
92        let test_ta = GeneralThresholdAutomatonBuilder::new("test_ta")
93            .with_parameters([
94                Parameter::new("n"),
95                Parameter::new("t"),
96                Parameter::new("f"),
97            ])
98            .unwrap()
99            .initialize()
100            .with_resilience_conditions([BooleanExpression::ComparisonExpression(
101                Box::new(IntegerExpression::Param(Parameter::new("n"))),
102                ComparisonOp::Gt,
103                Box::new(
104                    IntegerExpression::Const(3) * IntegerExpression::Atom(Parameter::new("t")),
105                ),
106            )])
107            .unwrap()
108            .build();
109
110        let ctx: HashMap<Parameter, easy_smt::SExpr> = test_ta
111            .parameters()
112            .map(|p| {
113                let param = p
114                    .declare_variable(&mut solver, 0)
115                    .expect("Failed to declare parameter");
116
117                solver
118                    .assert(solver.gte(param, solver.numeral(0)))
119                    .expect("Failed to assert parameter >= 0");
120
121                (p.clone(), param)
122            })
123            .collect::<HashMap<_, _>>();
124
125        let got_expr = encode_resilience_condition(&test_ta, &solver, &ctx);
126
127        let expected_expr = solver.and(
128            BooleanExpression::ComparisonExpression(
129                Box::new(IntegerExpression::Param(Parameter::new("n"))),
130                ComparisonOp::Gt,
131                Box::new(
132                    IntegerExpression::Const(3) * IntegerExpression::Atom(Parameter::new("t")),
133                ),
134            )
135            .encode_to_smt_with_ctx(&solver, &ctx)
136            .unwrap(),
137            solver.true_(),
138        );
139
140        assert_eq!(
141            got_expr,
142            expected_expr,
143            "Got:{}\nExpected:{}",
144            solver.display(got_expr),
145            solver.display(expected_expr)
146        )
147    }
148
149    #[test]
150    fn test_initial_constraints() {
151        let mut solver = SMTSolverBuilder::default().new_solver();
152
153        let test_ta = GeneralThresholdAutomatonBuilder::new("test_ta")
154            .with_locations([
155                Location::new("loc1"),
156                Location::new("loc2"),
157                Location::new("loc3"),
158            ])
159            .unwrap()
160            .with_variables([
161                Variable::new("var1"),
162                Variable::new("var2"),
163                Variable::new("var3"),
164            ])
165            .unwrap()
166            .initialize()
167            .with_initial_location_constraints([
168                BooleanExpression::ComparisonExpression(
169                    Box::new(IntegerExpression::Atom(Location::new("loc1"))),
170                    ComparisonOp::Gt,
171                    Box::new(IntegerExpression::Const(0)),
172                ),
173                BooleanExpression::ComparisonExpression(
174                    Box::new(IntegerExpression::Atom(Location::new("loc2"))),
175                    ComparisonOp::Eq,
176                    Box::new(IntegerExpression::Const(0)),
177                ),
178                BooleanExpression::ComparisonExpression(
179                    Box::new(IntegerExpression::Atom(Location::new("loc3"))),
180                    ComparisonOp::Eq,
181                    Box::new(IntegerExpression::Const(0)),
182                ),
183            ])
184            .unwrap()
185            .with_initial_variable_constraints([BooleanExpression::ComparisonExpression(
186                Box::new(IntegerExpression::Atom(Variable::new("var1"))),
187                ComparisonOp::Gt,
188                Box::new(IntegerExpression::Const(0)),
189            )])
190            .unwrap()
191            .build();
192
193        let ctx = ConfigCtx::new(&mut solver, &test_ta, Rc::new(HashMap::new()), 0);
194
195        let got_expr = encode_initial_constraints(&test_ta, &solver, &ctx);
196
197        let expected_expr = solver.and_many([
198            BooleanExpression::ComparisonExpression(
199                Box::new(IntegerExpression::Atom(Location::new("loc1"))),
200                ComparisonOp::Gt,
201                Box::new(IntegerExpression::Const(0)),
202            )
203            .encode_to_smt_with_ctx(&solver, &ctx)
204            .unwrap(),
205            BooleanExpression::ComparisonExpression(
206                Box::new(IntegerExpression::Atom(Location::new("loc2"))),
207                ComparisonOp::Eq,
208                Box::new(IntegerExpression::Const(0)),
209            )
210            .encode_to_smt_with_ctx(&solver, &ctx)
211            .unwrap(),
212            BooleanExpression::ComparisonExpression(
213                Box::new(IntegerExpression::Atom(Location::new("loc3"))),
214                ComparisonOp::Eq,
215                Box::new(IntegerExpression::Const(0)),
216            )
217            .encode_to_smt_with_ctx(&solver, &ctx)
218            .unwrap(),
219            BooleanExpression::ComparisonExpression(
220                Box::new(IntegerExpression::Atom(Variable::new("var1"))),
221                ComparisonOp::Gt,
222                Box::new(IntegerExpression::Const(0)),
223            )
224            .encode_to_smt_with_ctx(&solver, &ctx)
225            .unwrap(),
226        ]);
227
228        assert_eq!(got_expr, expected_expr)
229    }
230}