taco_smt_encoder/expression_encoding/
config_ctx.rs1use core::fmt;
5use std::{collections::HashMap, rc::Rc};
6
7use taco_display_utils::join_iterator;
8use taco_threshold_automaton::{
9 ThresholdAutomaton,
10 expressions::{Location, Parameter, Variable},
11 path::Configuration,
12};
13
14use crate::{
15 SMTExpr, SMTSolution, SMTSolver,
16 expression_encoding::{DeclaresVariable, SMTSolverError, SMTVariableContext},
17};
18
19pub trait ConfigFromSMT:
21 SMTVariableContext<Parameter> + SMTVariableContext<Location> + SMTVariableContext<Variable>
22{
23 fn get_assigned_configuration(
25 &self,
26 solver: &mut SMTSolver,
27 res: SMTSolution,
28 ) -> Result<Configuration, SMTSolverError> {
29 if !res.is_sat() {
30 return Err(SMTSolverError::ExtractionFromUnsat);
31 }
32
33 let var_assignment = self.get_solution(solver, res)?;
34 let loc_assignment = self.get_solution(solver, res)?;
35
36 Ok(Configuration::new(var_assignment, loc_assignment))
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct ConfigCtx {
43 index: usize,
45 params: Rc<HashMap<Parameter, SMTExpr>>,
47 loc_vars: HashMap<Location, SMTExpr>,
49 variable_vars: HashMap<Variable, SMTExpr>,
51}
52
53impl ConfigCtx {
54 pub fn new(
57 solver: &mut SMTSolver,
58 ta: &impl ThresholdAutomaton,
59 params: Rc<HashMap<Parameter, SMTExpr>>,
60 index: usize,
61 ) -> ConfigCtx {
62 let loc_vars = ta
63 .locations()
64 .map(|l| {
65 let loc = l
66 .declare_variable(solver, index as u32)
67 .expect("Failed to declare locations");
68
69 solver
70 .assert(solver.gte(loc, solver.numeral(0)))
71 .expect("Failed to assume loc >= 0");
72
73 (l.clone(), loc)
74 })
75 .collect();
76
77 let variable_vars = ta
78 .variables()
79 .map(|v| {
80 let var = v
81 .declare_variable(solver, index as u32)
82 .expect("Failed to declare variables");
83
84 solver
85 .assert(solver.gte(var, solver.numeral(0)))
86 .expect("Failed to assume var >= 0");
87
88 (v.clone(), var)
89 })
90 .collect();
91
92 Self {
93 index,
94 params,
95 loc_vars,
96 variable_vars,
97 }
98 }
99}
100
101impl ConfigFromSMT for ConfigCtx {}
102
103impl SMTVariableContext<Parameter> for ConfigCtx {
104 fn get_expr_for(&self, expr: &Parameter) -> Result<SMTExpr, SMTSolverError> {
105 self.params
106 .get(expr)
107 .cloned()
108 .ok_or_else(|| SMTSolverError::UndeclaredParameter(expr.clone()))
109 }
110
111 fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Parameter>
112 where
113 Parameter: 'a,
114 {
115 self.params.keys()
116 }
117}
118
119impl SMTVariableContext<Location> for ConfigCtx {
120 fn get_expr_for(&self, expr: &Location) -> Result<SMTExpr, SMTSolverError> {
121 self.loc_vars
122 .get(expr)
123 .cloned()
124 .ok_or_else(|| SMTSolverError::UndeclaredLocation(expr.clone()))
125 }
126
127 fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Location>
128 where
129 Location: 'a,
130 {
131 self.loc_vars.keys()
132 }
133}
134
135impl SMTVariableContext<Variable> for ConfigCtx {
136 fn get_expr_for(&self, expr: &Variable) -> Result<SMTExpr, SMTSolverError> {
137 self.variable_vars
138 .get(expr)
139 .ok_or_else(|| SMTSolverError::UndeclaredVariable(expr.clone()))
140 .cloned()
141 }
142
143 fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Variable>
144 where
145 Variable: 'a,
146 {
147 self.variable_vars.keys()
148 }
149}
150
151impl fmt::Display for ConfigCtx {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.write_fmt(core::format_args!(
154 "ConfigCtx {}: Locations: [ {} ], Variables: [ {} ]",
155 self.index,
156 join_iterator(self.loc_vars.keys(), ", "),
157 join_iterator(self.variable_vars.keys(), ", "),
158 ))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164
165 use std::{collections::HashMap, rc::Rc};
166
167 use taco_threshold_automaton::{
168 expressions::{Location, Parameter, Variable},
169 general_threshold_automaton::builder::GeneralThresholdAutomatonBuilder,
170 path::Configuration,
171 };
172
173 use crate::{
174 SMTSolverBuilder, SMTSolverContext,
175 expression_encoding::{
176 DeclaresVariable, SMTSolverError, SMTVariableContext, config_ctx::ConfigFromSMT,
177 },
178 };
179
180 use super::ConfigCtx;
181
182 #[test]
183 fn test_all_variables_declared() {
184 let ta = GeneralThresholdAutomatonBuilder::new("test_ta")
185 .with_locations([
186 Location::new("loc1"),
187 Location::new("loc2"),
188 Location::new("loc3"),
189 ])
190 .unwrap()
191 .with_parameters([Parameter::new("n"), Parameter::new("f")])
192 .unwrap()
193 .with_variables([
194 Variable::new("var1"),
195 Variable::new("var2"),
196 Variable::new("var3"),
197 ])
198 .unwrap()
199 .initialize()
200 .build();
201
202 let mut solver = SMTSolverBuilder::default().new_solver();
203
204 let n = solver
205 .declare_const(Parameter::new("n").get_name(0), solver.int_sort())
206 .unwrap();
207 let f = solver
208 .declare_const(Parameter::new("f").get_name(0), solver.int_sort())
209 .unwrap();
210
211 let params = Rc::new(HashMap::from([
212 (Parameter::new("n"), n),
213 (Parameter::new("f"), f),
214 ]));
215
216 let cfg = ConfigCtx::new(&mut solver, &ta, params, 0);
217
218 assert!(cfg.get_expr_for(&Location::new("loc1")).is_ok());
219 assert!(cfg.get_expr_for(&Location::new("loc2")).is_ok());
220 assert!(cfg.get_expr_for(&Location::new("loc3")).is_ok());
221
222 assert!(cfg.get_expr_for(&Variable::new("var1")).is_ok());
223 assert!(cfg.get_expr_for(&Variable::new("var2")).is_ok());
224 assert!(cfg.get_expr_for(&Variable::new("var3")).is_ok());
225
226 assert!(cfg.get_expr_for(&Parameter::new("n")).is_ok());
227 assert!(cfg.get_expr_for(&Parameter::new("f")).is_ok());
228 }
229
230 #[test]
231 fn test_get_assigned_config() {
232 let ta = GeneralThresholdAutomatonBuilder::new("test_ta")
233 .with_locations([
234 Location::new("loc1"),
235 Location::new("loc2"),
236 Location::new("loc3"),
237 ])
238 .unwrap()
239 .with_parameters([Parameter::new("n"), Parameter::new("f")])
240 .unwrap()
241 .with_variables([
242 Variable::new("var1"),
243 Variable::new("var2"),
244 Variable::new("var3"),
245 ])
246 .unwrap()
247 .initialize()
248 .build();
249
250 let mut solver = SMTSolverBuilder::default().new_solver();
251
252 let params = Rc::new(HashMap::new());
253
254 let cfg = ConfigCtx::new(&mut solver, &ta, params, 0);
255
256 let loc1_constr = solver.eq(
257 cfg.get_expr_for(&Location::new("loc1")).unwrap(),
258 solver.numeral(1),
259 );
260 let loc2_constr = solver.eq(
261 cfg.get_expr_for(&Location::new("loc2")).unwrap(),
262 solver.numeral(2),
263 );
264 let loc3_constr = solver.eq(
265 cfg.get_expr_for(&Location::new("loc3")).unwrap(),
266 solver.numeral(3),
267 );
268
269 let var1_constr = solver.eq(
270 cfg.get_expr_for(&Variable::new("var1")).unwrap(),
271 solver.numeral(1),
272 );
273 let var2_constr = solver.eq(
274 cfg.get_expr_for(&Variable::new("var2")).unwrap(),
275 solver.numeral(2),
276 );
277 let var3_constr = solver.eq(
278 cfg.get_expr_for(&Variable::new("var3")).unwrap(),
279 solver.numeral(3),
280 );
281
282 let smt_expr = solver.and_many([
283 loc1_constr,
284 loc2_constr,
285 loc3_constr,
286 var1_constr,
287 var2_constr,
288 var3_constr,
289 ]);
290
291 let res = solver.assert_and_check_expr(smt_expr).unwrap();
292
293 let got_cfg = cfg.get_assigned_configuration(&mut solver, res).unwrap();
294
295 let expected_cfg = Configuration::new(
296 HashMap::from([
297 (Variable::new("var1"), 1),
298 (Variable::new("var2"), 2),
299 (Variable::new("var3"), 3),
300 ]),
301 HashMap::from([
302 (Location::new("loc1"), 1),
303 (Location::new("loc2"), 2),
304 (Location::new("loc3"), 3),
305 ]),
306 );
307
308 assert_eq!(got_cfg, expected_cfg)
309 }
310
311 #[test]
312 fn test_get_assigned_config_unsat() {
313 let ta = GeneralThresholdAutomatonBuilder::new("test_ta")
314 .with_locations([
315 Location::new("loc1"),
316 Location::new("loc2"),
317 Location::new("loc3"),
318 ])
319 .unwrap()
320 .with_parameters([Parameter::new("n"), Parameter::new("f")])
321 .unwrap()
322 .with_variables([
323 Variable::new("var1"),
324 Variable::new("var2"),
325 Variable::new("var3"),
326 ])
327 .unwrap()
328 .initialize()
329 .build();
330
331 let mut solver = SMTSolverBuilder::default().new_solver();
332
333 let params = Rc::new(HashMap::new());
334
335 let cfg = ConfigCtx::new(&mut solver, &ta, params, 0);
336
337 let smt_expr = solver.false_();
338
339 let res = solver.assert_and_check_expr(smt_expr).unwrap();
340
341 let got_cfg = cfg.get_assigned_configuration(&mut solver, res);
342
343 assert!(got_cfg.is_err());
344 assert!(matches!(
345 got_cfg.unwrap_err(),
346 SMTSolverError::ExtractionFromUnsat
347 ))
348 }
349
350 #[test]
351 fn test_display_config_ctx() {
352 let cfg = ConfigCtx {
353 index: 42,
354 params: Rc::new(HashMap::new()),
355 loc_vars: HashMap::new(),
356 variable_vars: HashMap::new(),
357 };
358
359 assert!(format!("{cfg}").contains("ConfigCtx"));
360 assert!(format!("{cfg}").contains("42"));
361 }
362}