1use std::collections::{HashMap, HashSet};
14
15use taco_display_utils::join_iterator;
16
17use crate::{
18 ActionDefinition,
19 expressions::{Atomic, IsDeclared, Variable},
20 general_threshold_automaton::{Action, UpdateExpression},
21 lia_threshold_automaton::{
22 LIARule, LIAThresholdAutomaton, LIAVariableConstraint, SingleAtomConstraint,
23 integer_thresholds::{ThresholdCompOp, ThresholdConstraint, WeightedSum},
24 },
25};
26
27impl LIAThresholdAutomaton {
28 pub fn into_ta_without_sum_vars(mut self) -> Self {
38 let sv_to_v: HashMap<_, _> = self
39 .get_sum_var_constraints()
40 .into_iter()
41 .map(|v| {
42 let v = v.get_atoms().clone();
43 let new_var = Self::create_new_variable_for_sumvar(&v, &self);
44 (v, new_var)
45 })
46 .collect();
47
48 self.rules = self
49 .rules
50 .into_iter()
51 .map(|(l, r)| {
52 (
53 l,
54 r.into_iter().map(|r| r.remove_sumvar(&sv_to_v)).collect(),
55 )
56 })
57 .collect();
58 self.init_variable_constr.extend(sv_to_v.values().map(|v| {
60 let sa = SingleAtomConstraint::new(
61 v.clone(),
62 ThresholdConstraint::new(ThresholdCompOp::Lt, WeightedSum::new_empty(), 1),
63 );
64
65 LIAVariableConstraint::SingleVarConstraint(sa)
66 }));
67 self.additional_vars_for_sums
68 .extend(sv_to_v.into_iter().map(|(sv, v)| (v, sv)));
69
70 self
71 }
72
73 pub fn get_replacement_vars_for_sumvars(&self) -> &HashMap<Variable, WeightedSum<Variable>> {
75 &self.additional_vars_for_sums
76 }
77
78 fn create_new_variable_for_sumvar(
80 s: &WeightedSum<Variable>,
81 lta: &LIAThresholdAutomaton,
82 ) -> Variable {
83 let name: String =
84 "sv_".to_string() + &join_iterator(s.get_atoms_appearing().map(|v| v.name()), "_");
85
86 let mut new_var = Variable::new(name.clone());
87 let mut i = 0;
88 while lta.is_declared(&new_var) {
89 new_var = Variable::new(name.clone() + &format!("_{i}"));
90 i += 1;
91 }
92
93 new_var
94 }
95}
96
97impl LIARule {
98 fn remove_sumvar(self, sv_to_v: &HashMap<WeightedSum<Variable>, Variable>) -> Self {
100 let guard = self.guard.remove_sumvar(sv_to_v);
101 let mut acts = self.actions.clone();
102 for sc in sv_to_v.keys() {
103 acts.extend(Self::compute_acts_to_add(&self.actions, sc, sv_to_v));
104 }
105
106 Self {
107 id: self.id,
108 source: self.source,
109 target: self.target,
110 guard,
111 actions: acts,
112 }
113 }
114
115 fn compute_acts_to_add(
118 existing_acts: &[Action],
119 sc: &WeightedSum<Variable>,
120 sv_to_v: &HashMap<WeightedSum<Variable>, Variable>,
121 ) -> HashSet<Action> {
122 let mut new_actions = HashSet::new();
123
124 let mut effect = HashMap::new();
125 for act in existing_acts {
126 if sc.contains(act.variable()) {
127 let var = sv_to_v.get(sc).expect("Failed to get sumvar").clone();
128
129 let accumulated_effect = effect.entry(var.clone()).or_insert(0);
130
131 let scale_to_effect = *sc.get_factor(act.variable()).unwrap();
132
133 if !scale_to_effect.is_integer() {
134 unimplemented!(
135 "Failed to scale boundary properly, currently such constraints are unsupported at the moment!"
136 );
137 }
138 let scale_to_effect: i64 = scale_to_effect.try_into().unwrap();
139
140 match act.update() {
141 UpdateExpression::Inc(i) => {
142 *accumulated_effect += (*i as i64) * scale_to_effect;
143 }
144 UpdateExpression::Dec(i) => *accumulated_effect -= *i as i64 * scale_to_effect,
145 UpdateExpression::Reset => unimplemented!(),
146 UpdateExpression::Unchanged => {}
147 }
148 }
149 }
150
151 for (var, acc_effect) in effect.into_iter().filter(|(_, eff)| eff != &0) {
152 if acc_effect < 0 {
153 new_actions.insert(Action::new_with_update(
154 var,
155 UpdateExpression::Dec(-acc_effect as u32),
156 ));
157 continue;
158 }
159
160 new_actions.insert(Action::new_with_update(
161 var,
162 UpdateExpression::Inc((acc_effect) as u32),
163 ));
164 }
165
166 new_actions
167 }
168}
169
170impl LIAVariableConstraint {
171 fn remove_sumvar(self, sv_to_v: &HashMap<WeightedSum<Variable>, Variable>) -> Self {
175 match self {
176 LIAVariableConstraint::SumVarConstraint(sv) => {
177 let var = sv_to_v
178 .get(sv.get_atoms())
179 .expect("Missing var in translation")
180 .clone();
181 LIAVariableConstraint::SingleVarConstraint(SingleAtomConstraint::new(
182 var,
183 sv.get_threshold_constraint().clone(),
184 ))
185 }
186 LIAVariableConstraint::BinaryGuard(lhs, bc, rhs) => {
187 let lhs = lhs.remove_sumvar(sv_to_v);
188 let rhs = rhs.remove_sumvar(sv_to_v);
189 LIAVariableConstraint::BinaryGuard(Box::new(lhs), bc, Box::new(rhs))
190 }
191 s => s,
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use std::collections::{BTreeMap, HashMap, HashSet};
199
200 use crate::{
201 expressions::{BooleanConnective, Location, Parameter, Variable},
202 general_threshold_automaton::{
203 Action, UpdateExpression, builder::GeneralThresholdAutomatonBuilder,
204 },
205 lia_threshold_automaton::{
206 LIARule, LIAThresholdAutomaton, LIAVariableConstraint, SingleAtomConstraint,
207 SumAtomConstraint,
208 integer_thresholds::{
209 ThresholdCompOp, ThresholdConstraint, ThresholdConstraintOver, WeightedSum,
210 },
211 },
212 };
213
214 #[test]
215 fn test_remove_sumvar_liata() {
216 let ta = GeneralThresholdAutomatonBuilder::new("test_ta")
217 .initialize()
218 .build();
219
220 let lta = LIAThresholdAutomaton {
221 ta: ta.clone(),
222 rules: HashMap::from([(
223 Location::new("loc1"),
224 vec![LIARule {
225 id: 0,
226 source: Location::new("l1"),
227 target: Location::new("l2"),
228 guard: LIAVariableConstraint::SumVarConstraint(SumAtomConstraint(
229 ThresholdConstraintOver::new(
230 WeightedSum::new(BTreeMap::from([
231 (Variable::new("var1"), 1),
232 (Variable::new("var2"), 2),
233 ])),
234 ThresholdConstraint::new(
235 ThresholdCompOp::Lt,
236 BTreeMap::from([(Parameter::new("n"), 3)]),
237 5,
238 ),
239 ),
240 )),
241 actions: vec![Action::new_with_update(
242 Variable::new("var1"),
243 UpdateExpression::Inc(1),
244 )],
245 }],
246 )]),
247 init_variable_constr: vec![],
248 additional_vars_for_sums: HashMap::new(),
249 };
250 let got_lta = lta.into_ta_without_sum_vars();
251
252 let expected_ta = LIAThresholdAutomaton {
253 ta: ta.clone(),
254 rules: HashMap::from([(
255 Location::new("loc1"),
256 vec![LIARule {
257 id: 0,
258 source: Location::new("l1"),
259 target: Location::new("l2"),
260 guard: LIAVariableConstraint::SingleVarConstraint(SingleAtomConstraint(
261 ThresholdConstraintOver::new(
262 Variable::new("sv_var1_var2"),
263 ThresholdConstraint::new(
264 ThresholdCompOp::Lt,
265 BTreeMap::from([(Parameter::new("n"), 3)]),
266 5,
267 ),
268 ),
269 )),
270 actions: vec![
271 Action::new_with_update(Variable::new("var1"), UpdateExpression::Inc(1)),
272 Action::new_with_update(
273 Variable::new("sv_var1_var2"),
274 UpdateExpression::Inc(1),
275 ),
276 ],
277 }],
278 )]),
279 init_variable_constr: vec![LIAVariableConstraint::SingleVarConstraint(
280 SingleAtomConstraint::new(
281 Variable::new("sv_var1_var2"),
282 ThresholdConstraint::new(ThresholdCompOp::Lt, WeightedSum::new_empty(), 1),
283 ),
284 )],
285 additional_vars_for_sums: HashMap::from([(
286 Variable::new("sv_var1_var2"),
287 WeightedSum::new(BTreeMap::from([
288 (Variable::new("var1"), 1),
289 (Variable::new("var2"), 2),
290 ])),
291 )]),
292 };
293
294 assert_eq!(
295 got_lta, expected_ta,
296 "got: {got_lta}\n expected:{expected_ta}"
297 )
298 }
299
300 #[test]
301 fn test_remove_sumvar_liarule() {
302 let lia_rule = LIARule {
303 id: 0,
304 source: Location::new("l1"),
305 target: Location::new("l2"),
306 guard: LIAVariableConstraint::SumVarConstraint(SumAtomConstraint(
307 ThresholdConstraintOver::new(
308 WeightedSum::new(BTreeMap::from([
309 (Variable::new("var1"), 1),
310 (Variable::new("var2"), 2),
311 ])),
312 ThresholdConstraint::new(
313 ThresholdCompOp::Lt,
314 BTreeMap::from([(Parameter::new("n"), 3)]),
315 5,
316 ),
317 ),
318 )),
319 actions: vec![Action::new_with_update(
320 Variable::new("var1"),
321 UpdateExpression::Inc(1),
322 )],
323 };
324
325 let sv_to_v = HashMap::from([(
326 WeightedSum::new(BTreeMap::from([
327 (Variable::new("var1"), 1),
328 (Variable::new("var2"), 2),
329 ])),
330 Variable::new("sv_var1_var2"),
331 )]);
332
333 let got_rule = lia_rule.remove_sumvar(&sv_to_v);
334
335 let expected_rule = LIARule {
336 id: 0,
337 source: Location::new("l1"),
338 target: Location::new("l2"),
339 guard: LIAVariableConstraint::SingleVarConstraint(SingleAtomConstraint(
340 ThresholdConstraintOver::new(
341 Variable::new("sv_var1_var2"),
342 ThresholdConstraint::new(
343 ThresholdCompOp::Lt,
344 BTreeMap::from([(Parameter::new("n"), 3)]),
345 5,
346 ),
347 ),
348 )),
349 actions: vec![
350 Action::new_with_update(Variable::new("var1"), UpdateExpression::Inc(1)),
351 Action::new_with_update(Variable::new("sv_var1_var2"), UpdateExpression::Inc(1)),
352 ],
353 };
354
355 assert_eq!(got_rule, expected_rule)
356 }
357
358 #[test]
359 fn test_remove_sumvar_liarule_bug_c1cs() {
360 let lia_rule = LIARule {
361 id: 0,
362 source: Location::new("l1"),
363 target: Location::new("l2"),
364 guard: LIAVariableConstraint::True,
365 actions: vec![
366 Action::new_with_update(Variable::new("nfaulty"), UpdateExpression::Unchanged),
367 Action::new_with_update(Variable::new("nsnt1F"), UpdateExpression::Unchanged),
368 Action::new_with_update(Variable::new("nsnt1"), UpdateExpression::Unchanged),
369 Action::new_with_update(Variable::new("nsnt0"), UpdateExpression::Inc(1)),
370 Action::new_with_update(Variable::new("nsnt0F"), UpdateExpression::Unchanged),
371 ],
372 };
373
374 let sv_to_v = HashMap::from([(
375 WeightedSum::new(BTreeMap::from([
376 (Variable::new("nsnt0F"), 1),
377 (Variable::new("nsnt0"), 1),
378 ])),
379 Variable::new("sv_nsnt0_nsnt0F"),
380 )]);
381
382 let got_rule = lia_rule.remove_sumvar(&sv_to_v);
383
384 let expected_rule = LIARule {
385 id: 0,
386 source: Location::new("l1"),
387 target: Location::new("l2"),
388 guard: LIAVariableConstraint::True,
389 actions: vec![
390 Action::new_with_update(Variable::new("nfaulty"), UpdateExpression::Unchanged),
391 Action::new_with_update(Variable::new("nsnt1F"), UpdateExpression::Unchanged),
392 Action::new_with_update(Variable::new("nsnt1"), UpdateExpression::Unchanged),
393 Action::new_with_update(Variable::new("nsnt0"), UpdateExpression::Inc(1)),
394 Action::new_with_update(Variable::new("nsnt0F"), UpdateExpression::Unchanged),
395 Action::new_with_update(Variable::new("sv_nsnt0_nsnt0F"), UpdateExpression::Inc(1)),
396 ],
397 };
398
399 assert_eq!(got_rule, expected_rule)
400 }
401
402 #[test]
403 fn test_compute_acts_to_add_inc() {
404 let acts = vec![Action::new_with_update(
405 Variable::new("var1"),
406 UpdateExpression::Inc(1),
407 )];
408
409 let sc = WeightedSum::new(BTreeMap::from([
410 (Variable::new("var1"), 1),
411 (Variable::new("var2"), 2),
412 ]));
413
414 let sv_to_v = HashMap::from([(
415 WeightedSum::new(BTreeMap::from([
416 (Variable::new("var1"), 1),
417 (Variable::new("var2"), 2),
418 ])),
419 Variable::new("sv_var1_var2"),
420 )]);
421
422 let got_acts = LIARule::compute_acts_to_add(&acts, &sc, &sv_to_v);
423
424 let expected_acts = HashSet::from([Action::new_with_update(
425 Variable::new("sv_var1_var2"),
426 UpdateExpression::Inc(1),
427 )]);
428
429 assert_eq!(got_acts, expected_acts)
430 }
431
432 #[test]
433 fn test_compute_acts_to_add_dec() {
434 let acts = vec![Action::new_with_update(
435 Variable::new("var1"),
436 UpdateExpression::Dec(2),
437 )];
438
439 let sc = WeightedSum::new(BTreeMap::from([
440 (Variable::new("var1"), 1),
441 (Variable::new("var2"), 2),
442 ]));
443
444 let sv_to_v = HashMap::from([(
445 WeightedSum::new(BTreeMap::from([
446 (Variable::new("var1"), 1),
447 (Variable::new("var2"), 2),
448 ])),
449 Variable::new("sv_var1_var2"),
450 )]);
451
452 let got_acts = LIARule::compute_acts_to_add(&acts, &sc, &sv_to_v);
453
454 let expected_acts = HashSet::from([Action::new_with_update(
455 Variable::new("sv_var1_var2"),
456 UpdateExpression::Dec(2),
457 )]);
458
459 assert_eq!(got_acts, expected_acts)
460 }
461
462 #[test]
463 fn test_compute_acts_to_add_none() {
464 let acts = vec![Action::new_with_update(
465 Variable::new("var1"),
466 UpdateExpression::Unchanged,
467 )];
468
469 let sc = WeightedSum::new(BTreeMap::from([
470 (Variable::new("var1"), 1),
471 (Variable::new("var2"), 2),
472 ]));
473
474 let sv_to_v = HashMap::from([(
475 WeightedSum::new(BTreeMap::from([
476 (Variable::new("var1"), 1),
477 (Variable::new("var2"), 2),
478 ])),
479 Variable::new("sv_var1_var2"),
480 )]);
481
482 let got_acts = LIARule::compute_acts_to_add(&acts, &sc, &sv_to_v);
483
484 let expected_acts = HashSet::from([]);
485
486 assert_eq!(got_acts, expected_acts)
487 }
488
489 #[test]
490 fn test_compute_acts_to_add_diff_var() {
491 let acts = vec![Action::new_with_update(
492 Variable::new("var3"),
493 UpdateExpression::Inc(1),
494 )];
495
496 let sc = WeightedSum::new(BTreeMap::from([
497 (Variable::new("var1"), 1),
498 (Variable::new("var2"), 2),
499 ]));
500
501 let sv_to_v = HashMap::from([(
502 WeightedSum::new(BTreeMap::from([
503 (Variable::new("var1"), 1),
504 (Variable::new("var2"), 2),
505 ])),
506 Variable::new("sv_var1_var2"),
507 )]);
508
509 let got_acts = LIARule::compute_acts_to_add(&acts, &sc, &sv_to_v);
510
511 let expected_acts = HashSet::from([]);
512
513 assert_eq!(got_acts, expected_acts)
514 }
515
516 #[test]
517 fn test_compute_acts_to_add_two_add() {
518 let acts = vec![
519 Action::new_with_update(Variable::new("var1"), UpdateExpression::Inc(1)),
520 Action::new_with_update(Variable::new("var2"), UpdateExpression::Inc(1)),
521 ];
522
523 let sc = WeightedSum::new(BTreeMap::from([
524 (Variable::new("var1"), 1),
525 (Variable::new("var2"), 2),
526 ]));
527
528 let sv_to_v = HashMap::from([(
529 WeightedSum::new(BTreeMap::from([
530 (Variable::new("var1"), 1),
531 (Variable::new("var2"), 2),
532 ])),
533 Variable::new("sv_var1_var2"),
534 )]);
535
536 let got_acts = LIARule::compute_acts_to_add(&acts, &sc, &sv_to_v);
537
538 let expected_acts = HashSet::from([Action::new_with_update(
539 Variable::new("sv_var1_var2"),
540 UpdateExpression::Inc(3),
541 )]);
542
543 assert_eq!(got_acts, expected_acts)
544 }
545
546 #[test]
547 fn test_remove_sumvar_var_constr() {
548 let thr = LIAVariableConstraint::False;
549 let got_thr = thr.remove_sumvar(&HashMap::new());
550 let expected = LIAVariableConstraint::False;
551 assert_eq!(got_thr, expected);
552
553 let thr = LIAVariableConstraint::SumVarConstraint(SumAtomConstraint(
554 ThresholdConstraintOver::new(
555 WeightedSum::new(BTreeMap::from([
556 (Variable::new("var1"), 1),
557 (Variable::new("var2"), 2),
558 ])),
559 ThresholdConstraint::new(
560 ThresholdCompOp::Lt,
561 BTreeMap::from([(Parameter::new("n"), 3)]),
562 5,
563 ),
564 ),
565 ));
566 let got_thr = thr.remove_sumvar(&HashMap::from([(
567 WeightedSum::new(BTreeMap::from([
568 (Variable::new("var1"), 1),
569 (Variable::new("var2"), 2),
570 ])),
571 Variable::new("sv_var1_var2"),
572 )]));
573 let expected = LIAVariableConstraint::SingleVarConstraint(SingleAtomConstraint(
574 ThresholdConstraintOver::new(
575 Variable::new("sv_var1_var2"),
576 ThresholdConstraint::new(
577 ThresholdCompOp::Lt,
578 BTreeMap::from([(Parameter::new("n"), 3)]),
579 5,
580 ),
581 ),
582 ));
583 assert_eq!(got_thr, expected);
584
585 let thr = LIAVariableConstraint::BinaryGuard(
586 Box::new(LIAVariableConstraint::SumVarConstraint(SumAtomConstraint(
587 ThresholdConstraintOver::new(
588 WeightedSum::new(BTreeMap::from([
589 (Variable::new("var1"), 1),
590 (Variable::new("var2"), 2),
591 ])),
592 ThresholdConstraint::new(
593 ThresholdCompOp::Lt,
594 BTreeMap::from([(Parameter::new("n"), 3)]),
595 5,
596 ),
597 ),
598 ))),
599 BooleanConnective::And,
600 Box::new(LIAVariableConstraint::True),
601 );
602 let got_thr = thr.remove_sumvar(&HashMap::from([(
603 WeightedSum::new(BTreeMap::from([
604 (Variable::new("var1"), 1),
605 (Variable::new("var2"), 2),
606 ])),
607 Variable::new("sv_var1_var2"),
608 )]));
609 let expected = LIAVariableConstraint::BinaryGuard(
610 Box::new(LIAVariableConstraint::SingleVarConstraint(
611 SingleAtomConstraint(ThresholdConstraintOver::new(
612 Variable::new("sv_var1_var2"),
613 ThresholdConstraint::new(
614 ThresholdCompOp::Lt,
615 BTreeMap::from([(Parameter::new("n"), 3)]),
616 5,
617 ),
618 )),
619 )),
620 BooleanConnective::And,
621 Box::new(LIAVariableConstraint::True),
622 );
623 assert_eq!(got_thr, expected);
624 }
625}