taco_smt_encoder/expression_encoding/
ta_encoding.rs1use taco_threshold_automaton::{
5 ThresholdAutomaton, VariableConstraint,
6 expressions::{Location, Parameter, Variable},
7};
8
9use crate::{
10 SMTExpr, SMTSolver,
11 expression_encoding::{EncodeToSMT, SMTVariableContext},
12};
13
14pub fn encode_resilience_condition<TA, C>(ta: &TA, solver: &SMTSolver, ctx: &C) -> SMTExpr
17where
18 TA: ThresholdAutomaton,
19 C: SMTVariableContext<Parameter>,
20{
21 if ta.resilience_conditions().count() == 0 {
22 return solver.true_();
23 }
24
25 let rc_it = ta
26 .resilience_conditions()
27 .map(|rc| {
28 rc.encode_to_smt_with_ctx(solver, ctx)
29 .expect("Failed to encode resilience condition")
30 })
31 .chain([solver.true_()]);
32 solver.and_many(rc_it)
33}
34
35pub fn encode_initial_constraints<TA, C>(ta: &TA, solver: &SMTSolver, ctx: &C) -> SMTExpr
43where
44 TA: ThresholdAutomaton,
45 C: SMTVariableContext<Parameter> + SMTVariableContext<Location> + SMTVariableContext<Variable>,
46{
47 if ta.initial_location_constraints().count() == 0
48 && ta.initial_variable_constraints().count() == 0
49 {
50 return solver.true_();
51 }
52
53 let loc_init = ta.initial_location_constraints().map(|loc| {
54 loc.encode_to_smt_with_ctx(solver, ctx)
55 .expect("Failed to encode initial location condition")
56 });
57
58 let var_init = ta.initial_variable_constraints().map(|var| {
59 var.as_boolean_expr()
60 .encode_to_smt_with_ctx(solver, ctx)
61 .expect("Failed to encode initial variable condition")
62 });
63
64 solver.and_many(loc_init.chain(var_init))
65}
66
67#[cfg(test)]
68mod tests {
69 use std::{collections::HashMap, rc::Rc};
70
71 use taco_threshold_automaton::{
72 ThresholdAutomaton,
73 expressions::{
74 BooleanExpression, ComparisonOp, IntegerExpression, Location, Parameter, Variable,
75 },
76 general_threshold_automaton::builder::GeneralThresholdAutomatonBuilder,
77 };
78
79 use crate::{
80 SMTSolverBuilder,
81 expression_encoding::{
82 DeclaresVariable, EncodeToSMT,
83 config_ctx::ConfigCtx,
84 ta_encoding::{encode_initial_constraints, encode_resilience_condition},
85 },
86 };
87
88 #[test]
89 fn test_encode_rc_constraints() {
90 let mut solver = SMTSolverBuilder::default().new_solver();
91
92 let test_ta = GeneralThresholdAutomatonBuilder::new("test_ta")
93 .with_parameters([
94 Parameter::new("n"),
95 Parameter::new("t"),
96 Parameter::new("f"),
97 ])
98 .unwrap()
99 .initialize()
100 .with_resilience_conditions([BooleanExpression::ComparisonExpression(
101 Box::new(IntegerExpression::Param(Parameter::new("n"))),
102 ComparisonOp::Gt,
103 Box::new(
104 IntegerExpression::Const(3) * IntegerExpression::Atom(Parameter::new("t")),
105 ),
106 )])
107 .unwrap()
108 .build();
109
110 let ctx: HashMap<Parameter, easy_smt::SExpr> = test_ta
111 .parameters()
112 .map(|p| {
113 let param = p
114 .declare_variable(&mut solver, 0)
115 .expect("Failed to declare parameter");
116
117 solver
118 .assert(solver.gte(param, solver.numeral(0)))
119 .expect("Failed to assert parameter >= 0");
120
121 (p.clone(), param)
122 })
123 .collect::<HashMap<_, _>>();
124
125 let got_expr = encode_resilience_condition(&test_ta, &solver, &ctx);
126
127 let expected_expr = solver.and(
128 BooleanExpression::ComparisonExpression(
129 Box::new(IntegerExpression::Param(Parameter::new("n"))),
130 ComparisonOp::Gt,
131 Box::new(
132 IntegerExpression::Const(3) * IntegerExpression::Atom(Parameter::new("t")),
133 ),
134 )
135 .encode_to_smt_with_ctx(&solver, &ctx)
136 .unwrap(),
137 solver.true_(),
138 );
139
140 assert_eq!(
141 got_expr,
142 expected_expr,
143 "Got:{}\nExpected:{}",
144 solver.display(got_expr),
145 solver.display(expected_expr)
146 )
147 }
148
149 #[test]
150 fn test_initial_constraints() {
151 let mut solver = SMTSolverBuilder::default().new_solver();
152
153 let test_ta = GeneralThresholdAutomatonBuilder::new("test_ta")
154 .with_locations([
155 Location::new("loc1"),
156 Location::new("loc2"),
157 Location::new("loc3"),
158 ])
159 .unwrap()
160 .with_variables([
161 Variable::new("var1"),
162 Variable::new("var2"),
163 Variable::new("var3"),
164 ])
165 .unwrap()
166 .initialize()
167 .with_initial_location_constraints([
168 BooleanExpression::ComparisonExpression(
169 Box::new(IntegerExpression::Atom(Location::new("loc1"))),
170 ComparisonOp::Gt,
171 Box::new(IntegerExpression::Const(0)),
172 ),
173 BooleanExpression::ComparisonExpression(
174 Box::new(IntegerExpression::Atom(Location::new("loc2"))),
175 ComparisonOp::Eq,
176 Box::new(IntegerExpression::Const(0)),
177 ),
178 BooleanExpression::ComparisonExpression(
179 Box::new(IntegerExpression::Atom(Location::new("loc3"))),
180 ComparisonOp::Eq,
181 Box::new(IntegerExpression::Const(0)),
182 ),
183 ])
184 .unwrap()
185 .with_initial_variable_constraints([BooleanExpression::ComparisonExpression(
186 Box::new(IntegerExpression::Atom(Variable::new("var1"))),
187 ComparisonOp::Gt,
188 Box::new(IntegerExpression::Const(0)),
189 )])
190 .unwrap()
191 .build();
192
193 let ctx = ConfigCtx::new(&mut solver, &test_ta, Rc::new(HashMap::new()), 0);
194
195 let got_expr = encode_initial_constraints(&test_ta, &solver, &ctx);
196
197 let expected_expr = solver.and_many([
198 BooleanExpression::ComparisonExpression(
199 Box::new(IntegerExpression::Atom(Location::new("loc1"))),
200 ComparisonOp::Gt,
201 Box::new(IntegerExpression::Const(0)),
202 )
203 .encode_to_smt_with_ctx(&solver, &ctx)
204 .unwrap(),
205 BooleanExpression::ComparisonExpression(
206 Box::new(IntegerExpression::Atom(Location::new("loc2"))),
207 ComparisonOp::Eq,
208 Box::new(IntegerExpression::Const(0)),
209 )
210 .encode_to_smt_with_ctx(&solver, &ctx)
211 .unwrap(),
212 BooleanExpression::ComparisonExpression(
213 Box::new(IntegerExpression::Atom(Location::new("loc3"))),
214 ComparisonOp::Eq,
215 Box::new(IntegerExpression::Const(0)),
216 )
217 .encode_to_smt_with_ctx(&solver, &ctx)
218 .unwrap(),
219 BooleanExpression::ComparisonExpression(
220 Box::new(IntegerExpression::Atom(Variable::new("var1"))),
221 ComparisonOp::Gt,
222 Box::new(IntegerExpression::Const(0)),
223 )
224 .encode_to_smt_with_ctx(&solver, &ctx)
225 .unwrap(),
226 ]);
227
228 assert_eq!(got_expr, expected_expr)
229 }
230}