taco_smt_encoder/expression_encoding/
config_ctx.rs

1//! This module defines the  `ConfigContext` type which declares all SMT
2//! variables required to represent a configuration inside an SMT encoding
3
4use core::fmt;
5use std::{collections::HashMap, rc::Rc};
6
7use taco_display_utils::join_iterator;
8use taco_threshold_automaton::{
9    ThresholdAutomaton,
10    expressions::{Location, Parameter, Variable},
11    path::Configuration,
12};
13
14use crate::{
15    SMTExpr, SMTSolution, SMTSolver,
16    expression_encoding::{DeclaresVariable, SMTSolverError, SMTVariableContext},
17};
18
19/// Trait for Context
20pub trait ConfigFromSMT:
21    SMTVariableContext<Parameter> + SMTVariableContext<Location> + SMTVariableContext<Variable>
22{
23    /// Extract the assignment found by the SMT solver
24    fn get_assigned_configuration(
25        &self,
26        solver: &mut SMTSolver,
27        res: SMTSolution,
28    ) -> Result<Configuration, SMTSolverError> {
29        if !res.is_sat() {
30            return Err(SMTSolverError::ExtractionFromUnsat);
31        }
32
33        let var_assignment = self.get_solution(solver, res)?;
34        let loc_assignment = self.get_solution(solver, res)?;
35
36        Ok(Configuration::new(var_assignment, loc_assignment))
37    }
38}
39
40/// SMT context for a configuration of a threshold automaton
41#[derive(Debug, Clone)]
42pub struct ConfigCtx {
43    /// index of the configuration
44    index: usize,
45    /// reference to parameter variables
46    params: Rc<HashMap<Parameter, SMTExpr>>,
47    /// location variables
48    loc_vars: HashMap<Location, SMTExpr>,
49    /// variable variables
50    variable_vars: HashMap<Variable, SMTExpr>,
51}
52
53impl ConfigCtx {
54    /// Create a new set of variables for a configuration of the given threshold
55    /// automaton
56    pub fn new(
57        solver: &mut SMTSolver,
58        ta: &impl ThresholdAutomaton,
59        params: Rc<HashMap<Parameter, SMTExpr>>,
60        index: usize,
61    ) -> ConfigCtx {
62        let loc_vars = ta
63            .locations()
64            .map(|l| {
65                let loc = l
66                    .declare_variable(solver, index as u32)
67                    .expect("Failed to declare locations");
68
69                solver
70                    .assert(solver.gte(loc, solver.numeral(0)))
71                    .expect("Failed to assume loc >= 0");
72
73                (l.clone(), loc)
74            })
75            .collect();
76
77        let variable_vars = ta
78            .variables()
79            .map(|v| {
80                let var = v
81                    .declare_variable(solver, index as u32)
82                    .expect("Failed to declare variables");
83
84                solver
85                    .assert(solver.gte(var, solver.numeral(0)))
86                    .expect("Failed to assume var >= 0");
87
88                (v.clone(), var)
89            })
90            .collect();
91
92        Self {
93            index,
94            params,
95            loc_vars,
96            variable_vars,
97        }
98    }
99}
100
101impl ConfigFromSMT for ConfigCtx {}
102
103impl SMTVariableContext<Parameter> for ConfigCtx {
104    fn get_expr_for(&self, expr: &Parameter) -> Result<SMTExpr, SMTSolverError> {
105        self.params
106            .get(expr)
107            .cloned()
108            .ok_or_else(|| SMTSolverError::UndeclaredParameter(expr.clone()))
109    }
110
111    fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Parameter>
112    where
113        Parameter: 'a,
114    {
115        self.params.keys()
116    }
117}
118
119impl SMTVariableContext<Location> for ConfigCtx {
120    fn get_expr_for(&self, expr: &Location) -> Result<SMTExpr, SMTSolverError> {
121        self.loc_vars
122            .get(expr)
123            .cloned()
124            .ok_or_else(|| SMTSolverError::UndeclaredLocation(expr.clone()))
125    }
126
127    fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Location>
128    where
129        Location: 'a,
130    {
131        self.loc_vars.keys()
132    }
133}
134
135impl SMTVariableContext<Variable> for ConfigCtx {
136    fn get_expr_for(&self, expr: &Variable) -> Result<SMTExpr, SMTSolverError> {
137        self.variable_vars
138            .get(expr)
139            .ok_or_else(|| SMTSolverError::UndeclaredVariable(expr.clone()))
140            .cloned()
141    }
142
143    fn get_exprs<'a>(&'a self) -> impl IntoIterator<Item = &'a Variable>
144    where
145        Variable: 'a,
146    {
147        self.variable_vars.keys()
148    }
149}
150
151impl fmt::Display for ConfigCtx {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        f.write_fmt(core::format_args!(
154            "ConfigCtx {}: Locations: [ {} ], Variables: [ {} ]",
155            self.index,
156            join_iterator(self.loc_vars.keys(), ", "),
157            join_iterator(self.variable_vars.keys(), ", "),
158        ))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164
165    use std::{collections::HashMap, rc::Rc};
166
167    use taco_threshold_automaton::{
168        expressions::{Location, Parameter, Variable},
169        general_threshold_automaton::builder::GeneralThresholdAutomatonBuilder,
170        path::Configuration,
171    };
172
173    use crate::{
174        SMTSolverBuilder, SMTSolverContext,
175        expression_encoding::{
176            DeclaresVariable, SMTSolverError, SMTVariableContext, config_ctx::ConfigFromSMT,
177        },
178    };
179
180    use super::ConfigCtx;
181
182    #[test]
183    fn test_all_variables_declared() {
184        let ta = GeneralThresholdAutomatonBuilder::new("test_ta")
185            .with_locations([
186                Location::new("loc1"),
187                Location::new("loc2"),
188                Location::new("loc3"),
189            ])
190            .unwrap()
191            .with_parameters([Parameter::new("n"), Parameter::new("f")])
192            .unwrap()
193            .with_variables([
194                Variable::new("var1"),
195                Variable::new("var2"),
196                Variable::new("var3"),
197            ])
198            .unwrap()
199            .initialize()
200            .build();
201
202        let mut solver = SMTSolverBuilder::default().new_solver();
203
204        let n = solver
205            .declare_const(Parameter::new("n").get_name(0), solver.int_sort())
206            .unwrap();
207        let f = solver
208            .declare_const(Parameter::new("f").get_name(0), solver.int_sort())
209            .unwrap();
210
211        let params = Rc::new(HashMap::from([
212            (Parameter::new("n"), n),
213            (Parameter::new("f"), f),
214        ]));
215
216        let cfg = ConfigCtx::new(&mut solver, &ta, params, 0);
217
218        assert!(cfg.get_expr_for(&Location::new("loc1")).is_ok());
219        assert!(cfg.get_expr_for(&Location::new("loc2")).is_ok());
220        assert!(cfg.get_expr_for(&Location::new("loc3")).is_ok());
221
222        assert!(cfg.get_expr_for(&Variable::new("var1")).is_ok());
223        assert!(cfg.get_expr_for(&Variable::new("var2")).is_ok());
224        assert!(cfg.get_expr_for(&Variable::new("var3")).is_ok());
225
226        assert!(cfg.get_expr_for(&Parameter::new("n")).is_ok());
227        assert!(cfg.get_expr_for(&Parameter::new("f")).is_ok());
228    }
229
230    #[test]
231    fn test_get_assigned_config() {
232        let ta = GeneralThresholdAutomatonBuilder::new("test_ta")
233            .with_locations([
234                Location::new("loc1"),
235                Location::new("loc2"),
236                Location::new("loc3"),
237            ])
238            .unwrap()
239            .with_parameters([Parameter::new("n"), Parameter::new("f")])
240            .unwrap()
241            .with_variables([
242                Variable::new("var1"),
243                Variable::new("var2"),
244                Variable::new("var3"),
245            ])
246            .unwrap()
247            .initialize()
248            .build();
249
250        let mut solver = SMTSolverBuilder::default().new_solver();
251
252        let params = Rc::new(HashMap::new());
253
254        let cfg = ConfigCtx::new(&mut solver, &ta, params, 0);
255
256        let loc1_constr = solver.eq(
257            cfg.get_expr_for(&Location::new("loc1")).unwrap(),
258            solver.numeral(1),
259        );
260        let loc2_constr = solver.eq(
261            cfg.get_expr_for(&Location::new("loc2")).unwrap(),
262            solver.numeral(2),
263        );
264        let loc3_constr = solver.eq(
265            cfg.get_expr_for(&Location::new("loc3")).unwrap(),
266            solver.numeral(3),
267        );
268
269        let var1_constr = solver.eq(
270            cfg.get_expr_for(&Variable::new("var1")).unwrap(),
271            solver.numeral(1),
272        );
273        let var2_constr = solver.eq(
274            cfg.get_expr_for(&Variable::new("var2")).unwrap(),
275            solver.numeral(2),
276        );
277        let var3_constr = solver.eq(
278            cfg.get_expr_for(&Variable::new("var3")).unwrap(),
279            solver.numeral(3),
280        );
281
282        let smt_expr = solver.and_many([
283            loc1_constr,
284            loc2_constr,
285            loc3_constr,
286            var1_constr,
287            var2_constr,
288            var3_constr,
289        ]);
290
291        let res = solver.assert_and_check_expr(smt_expr).unwrap();
292
293        let got_cfg = cfg.get_assigned_configuration(&mut solver, res).unwrap();
294
295        let expected_cfg = Configuration::new(
296            HashMap::from([
297                (Variable::new("var1"), 1),
298                (Variable::new("var2"), 2),
299                (Variable::new("var3"), 3),
300            ]),
301            HashMap::from([
302                (Location::new("loc1"), 1),
303                (Location::new("loc2"), 2),
304                (Location::new("loc3"), 3),
305            ]),
306        );
307
308        assert_eq!(got_cfg, expected_cfg)
309    }
310
311    #[test]
312    fn test_get_assigned_config_unsat() {
313        let ta = GeneralThresholdAutomatonBuilder::new("test_ta")
314            .with_locations([
315                Location::new("loc1"),
316                Location::new("loc2"),
317                Location::new("loc3"),
318            ])
319            .unwrap()
320            .with_parameters([Parameter::new("n"), Parameter::new("f")])
321            .unwrap()
322            .with_variables([
323                Variable::new("var1"),
324                Variable::new("var2"),
325                Variable::new("var3"),
326            ])
327            .unwrap()
328            .initialize()
329            .build();
330
331        let mut solver = SMTSolverBuilder::default().new_solver();
332
333        let params = Rc::new(HashMap::new());
334
335        let cfg = ConfigCtx::new(&mut solver, &ta, params, 0);
336
337        let smt_expr = solver.false_();
338
339        let res = solver.assert_and_check_expr(smt_expr).unwrap();
340
341        let got_cfg = cfg.get_assigned_configuration(&mut solver, res);
342
343        assert!(got_cfg.is_err());
344        assert!(matches!(
345            got_cfg.unwrap_err(),
346            SMTSolverError::ExtractionFromUnsat
347        ))
348    }
349
350    #[test]
351    fn test_display_config_ctx() {
352        let cfg = ConfigCtx {
353            index: 42,
354            params: Rc::new(HashMap::new()),
355            loc_vars: HashMap::new(),
356            variable_vars: HashMap::new(),
357        };
358
359        assert!(format!("{cfg}").contains("ConfigCtx"));
360        assert!(format!("{cfg}").contains("42"));
361    }
362}