1use core::fmt;
9use std::{collections::HashMap, hash::Hash};
10
11#[derive(Clone, Copy, Debug, PartialEq)]
13#[repr(i8)]
14pub enum PartialOrdCompResult {
15 Incomparable,
17 Smaller,
19 Equal,
21 Greater,
23}
24
25impl PartialOrdCompResult {
26 #[inline(always)]
27 pub fn combine(self, other: Self) -> Self {
28 match (self, other) {
29 (PartialOrdCompResult::Incomparable, _)
30 | (_, PartialOrdCompResult::Incomparable)
31 | (PartialOrdCompResult::Smaller, PartialOrdCompResult::Greater)
32 | (PartialOrdCompResult::Greater, PartialOrdCompResult::Smaller) => {
33 PartialOrdCompResult::Incomparable
34 }
35
36 (PartialOrdCompResult::Smaller, PartialOrdCompResult::Smaller)
37 | (PartialOrdCompResult::Smaller, PartialOrdCompResult::Equal)
38 | (PartialOrdCompResult::Equal, PartialOrdCompResult::Smaller) => {
39 PartialOrdCompResult::Smaller
40 }
41 (PartialOrdCompResult::Equal, PartialOrdCompResult::Equal) => {
42 PartialOrdCompResult::Equal
43 }
44 (PartialOrdCompResult::Greater, PartialOrdCompResult::Greater)
45 | (PartialOrdCompResult::Equal, PartialOrdCompResult::Greater)
46 | (PartialOrdCompResult::Greater, PartialOrdCompResult::Equal) => {
47 PartialOrdCompResult::Greater
48 }
49 }
50 }
51}
52
53impl fmt::Display for PartialOrdCompResult {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 PartialOrdCompResult::Incomparable => write!(f, "≠"),
57 PartialOrdCompResult::Smaller => write!(f, "<"),
58 PartialOrdCompResult::Equal => write!(f, "=="),
59 PartialOrdCompResult::Greater => write!(f, ">"),
60 }
61 }
62}
63
64pub trait PartialOrder {
69 fn part_cmp(&self, other: &Self) -> PartialOrdCompResult;
70
71 fn is_greater_or_equal(&self, other: &Self) -> bool {
73 matches!(
74 self.part_cmp(other),
75 PartialOrdCompResult::Equal | PartialOrdCompResult::Greater
76 )
77 }
78
79 fn is_smaller_or_equal(&self, other: &Self) -> bool {
81 matches!(
82 self.part_cmp(other),
83 PartialOrdCompResult::Equal | PartialOrdCompResult::Smaller
84 )
85 }
86}
87
88macro_rules! impl_partial_order {
89 ( $( $ty:ty )* ) => {
90 $(
91 impl PartialOrder for $ty {
92 fn part_cmp(&self, other: &Self) -> PartialOrdCompResult {
93 match self.partial_cmp(other) {
94 Some(res) => match res {
95 std::cmp::Ordering::Less => PartialOrdCompResult::Smaller,
96 std::cmp::Ordering::Equal => PartialOrdCompResult::Equal,
97 std::cmp::Ordering::Greater => PartialOrdCompResult::Greater,
98 },
99 None => PartialOrdCompResult::Incomparable,
100 }
101 }
102 }
103 )*
104 };
105}
106
107impl_partial_order!(usize i8 u8 i16 u16 i32 u32 i64 u64 i128 u128);
108
109impl<T: PartialOrder> PartialOrder for Vec<T> {
110 fn part_cmp(&self, other: &Self) -> PartialOrdCompResult {
111 if self.len() != other.len() {
113 return PartialOrdCompResult::Incomparable;
114 }
115
116 let mut res = PartialOrdCompResult::Equal;
117 for (self_item, other_item) in self.iter().zip(other.iter()) {
118 let cmp = self_item.part_cmp(other_item);
119 res = res.combine(cmp);
120
121 if res == PartialOrdCompResult::Incomparable {
122 break;
123 }
124 }
125
126 res
127 }
128}
129
130impl<V: PartialOrder, K: Hash + Eq> PartialOrder for HashMap<K, V> {
131 fn part_cmp(&self, other: &Self) -> PartialOrdCompResult {
132 if self.len() != other.len() {
134 return PartialOrdCompResult::Incomparable;
135 }
136
137 let mut res = PartialOrdCompResult::Equal;
138 for (self_key, self_val) in self.iter() {
139 let other_val = other.get(self_key);
140 if other_val.is_none() {
141 return PartialOrdCompResult::Incomparable;
142 }
143
144 let cmp = self_val.part_cmp(other_val.unwrap());
145 res = res.combine(cmp);
146
147 if res == PartialOrdCompResult::Incomparable {
148 break;
149 }
150 }
151
152 res
153 }
154}
155
156#[derive(Debug, Clone, PartialEq)]
163pub struct SetMinimalBasis<T: PartialOrder + PartialEq + Hash + Eq> {
164 configs: Vec<T>,
165}
166
167impl<T: PartialOrder + PartialEq + Hash + Eq> SetMinimalBasis<T> {
168 pub fn new<V: Into<Vec<T>>>(configs: V) -> Self {
173 let configs = configs.into();
174 if configs.len() <= 1 {
175 return Self { configs };
176 }
177
178 let mut reduced_set = Vec::new();
179 for cfg in configs.into_iter() {
180 if reduced_set
182 .iter()
183 .any(|existing_cfg: &T| cfg.is_greater_or_equal(existing_cfg))
184 {
185 continue;
186 }
187
188 reduced_set.retain(|existing_cf| !existing_cf.is_greater_or_equal(&cfg));
190
191 reduced_set.push(cfg);
192 }
193
194 Self {
195 configs: reduced_set,
196 }
197 }
198
199 pub fn retain<F>(&mut self, f: F)
201 where
202 F: FnMut(&T) -> bool,
203 {
204 self.configs.retain(f);
205 }
206}
207
208impl<T: PartialOrder + PartialEq + Hash + Eq, I: Iterator<Item = T>> From<I>
209 for SetMinimalBasis<T>
210{
211 fn from(value: I) -> Self {
212 SetMinimalBasis::new(value.into_iter().collect::<Vec<_>>())
213 }
214}
215
216impl<T: PartialOrder + PartialEq + Hash + Eq> IntoIterator for SetMinimalBasis<T> {
217 type Item = T;
218
219 type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
220
221 fn into_iter(self) -> Self::IntoIter {
222 self.configs.into_iter()
223 }
224}
225
226impl<'a, T: PartialOrder + PartialEq + Hash + Eq> IntoIterator for &'a SetMinimalBasis<T> {
227 type Item = &'a T;
228
229 type IntoIter = std::slice::Iter<'a, T>;
230
231 fn into_iter(self) -> Self::IntoIter {
232 self.configs.iter()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use std::collections::HashMap;
239
240 use crate::partial_ord::{PartialOrdCompResult, PartialOrder, SetMinimalBasis};
241
242 #[test]
243 fn test_combine() {
244 let a = PartialOrdCompResult::Incomparable;
245 let b = PartialOrdCompResult::Incomparable;
246 assert_eq!(a.combine(b), PartialOrdCompResult::Incomparable);
247
248 let a = PartialOrdCompResult::Incomparable;
249 let b = PartialOrdCompResult::Smaller;
250 assert_eq!(a.combine(b), PartialOrdCompResult::Incomparable);
251 assert_eq!(b.combine(a), PartialOrdCompResult::Incomparable);
252
253 let a = PartialOrdCompResult::Incomparable;
254 let b = PartialOrdCompResult::Equal;
255 assert_eq!(a.combine(b), PartialOrdCompResult::Incomparable);
256 assert_eq!(b.combine(a), PartialOrdCompResult::Incomparable);
257
258 let a = PartialOrdCompResult::Incomparable;
259 let b = PartialOrdCompResult::Greater;
260 assert_eq!(a.combine(b), PartialOrdCompResult::Incomparable);
261 assert_eq!(b.combine(a), PartialOrdCompResult::Incomparable);
262
263 let a = PartialOrdCompResult::Smaller;
264 let b = PartialOrdCompResult::Smaller;
265 assert_eq!(a.combine(b), PartialOrdCompResult::Smaller);
266 assert_eq!(b.combine(a), PartialOrdCompResult::Smaller);
267
268 let a = PartialOrdCompResult::Smaller;
269 let b = PartialOrdCompResult::Equal;
270 assert_eq!(a.combine(b), PartialOrdCompResult::Smaller);
271 assert_eq!(b.combine(a), PartialOrdCompResult::Smaller);
272
273 let a = PartialOrdCompResult::Smaller;
274 let b = PartialOrdCompResult::Greater;
275 assert_eq!(a.combine(b), PartialOrdCompResult::Incomparable);
276 assert_eq!(b.combine(a), PartialOrdCompResult::Incomparable);
277
278 let a = PartialOrdCompResult::Equal;
279 let b = PartialOrdCompResult::Equal;
280 assert_eq!(a.combine(b), PartialOrdCompResult::Equal);
281 assert_eq!(b.combine(a), PartialOrdCompResult::Equal);
282
283 let a = PartialOrdCompResult::Equal;
284 let b = PartialOrdCompResult::Greater;
285 assert_eq!(a.combine(b), PartialOrdCompResult::Greater);
286 assert_eq!(b.combine(a), PartialOrdCompResult::Greater);
287
288 let a = PartialOrdCompResult::Greater;
289 let b = PartialOrdCompResult::Greater;
290 assert_eq!(a.combine(b), PartialOrdCompResult::Greater);
291 assert_eq!(b.combine(a), PartialOrdCompResult::Greater);
292 }
293
294 #[test]
295 fn test_vecs_equal() {
296 let a = vec![1, 2, 3, 4, 5];
297 let b = vec![1, 2, 3, 4, 5];
298
299 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Equal);
300 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Equal);
301
302 assert!(a.is_greater_or_equal(&b));
303 assert!(b.is_greater_or_equal(&a));
304
305 assert!(a.is_smaller_or_equal(&b));
306 assert!(b.is_smaller_or_equal(&a));
307 }
308
309 #[test]
310 fn test_vecs_different_size() {
311 let a = vec![1, 2];
312 let b = vec![1];
313
314 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Incomparable);
315 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Incomparable);
316
317 assert!(!a.is_greater_or_equal(&b));
318 assert!(!b.is_greater_or_equal(&a));
319
320 assert!(!a.is_smaller_or_equal(&b));
321 assert!(!b.is_smaller_or_equal(&a));
322 }
323
324 #[test]
325 fn test_vecs_gt_one() {
326 let a = vec![1, 4, 3, 4, 5];
327 let b = vec![1, 2, 3, 4, 5];
328
329 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Greater);
330 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Smaller);
331
332 assert!(a.is_greater_or_equal(&b));
333 assert!(!b.is_greater_or_equal(&a));
334
335 assert!(!a.is_smaller_or_equal(&b));
336 assert!(b.is_smaller_or_equal(&a));
337 }
338
339 #[test]
340 fn test_vecs_gt_all() {
341 let a = vec![2, 4, 6, 8, 10];
342 let b = vec![1, 2, 3, 4, 5];
343
344 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Greater);
345 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Smaller);
346
347 assert!(a.is_greater_or_equal(&b));
348 assert!(!b.is_greater_or_equal(&a));
349
350 assert!(!a.is_smaller_or_equal(&b));
351 assert!(b.is_smaller_or_equal(&a));
352 }
353
354 #[test]
355 fn test_vecs_one_gt_on_lt() {
356 let a = vec![2, 1, 3, 4, 5];
357 let b = vec![1, 2, 3, 4, 5];
358
359 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Incomparable);
360 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Incomparable);
361
362 assert!(!a.is_greater_or_equal(&b));
363 assert!(!b.is_greater_or_equal(&a));
364
365 assert!(!a.is_smaller_or_equal(&b));
366 assert!(!b.is_smaller_or_equal(&a));
367 }
368
369 #[test]
370 fn test_hm_equal() {
371 let a = HashMap::from([("a", 1), ("b", 2), ("c", 3)]);
372 let b = HashMap::from([("a", 1), ("b", 2), ("c", 3)]);
373
374 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Equal);
375 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Equal);
376
377 assert!(a.is_greater_or_equal(&b));
378 assert!(b.is_greater_or_equal(&a));
379
380 assert!(a.is_smaller_or_equal(&b));
381 assert!(b.is_smaller_or_equal(&a));
382 }
383
384 #[test]
385 fn test_hm_different_size() {
386 let a = HashMap::from([("a", 1), ("b", 2), ("c", 3)]);
387 let b = HashMap::from([("a", 1), ("b", 2)]);
388
389 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Incomparable);
390 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Incomparable);
391
392 assert!(!a.is_greater_or_equal(&b));
393 assert!(!b.is_greater_or_equal(&a));
394
395 assert!(!a.is_smaller_or_equal(&b));
396 assert!(!b.is_smaller_or_equal(&a));
397 }
398
399 #[test]
400 fn test_hm_gt_one() {
401 let a = HashMap::from([("a", 1), ("b", 3), ("c", 3)]);
402 let b = HashMap::from([("a", 1), ("b", 2), ("c", 3)]);
403
404 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Greater);
405 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Smaller);
406
407 assert!(a.is_greater_or_equal(&b));
408 assert!(!b.is_greater_or_equal(&a));
409
410 assert!(!a.is_smaller_or_equal(&b));
411 assert!(b.is_smaller_or_equal(&a));
412 }
413
414 #[test]
415 fn test_hm_gt_all() {
416 let a = HashMap::from([("a", 2), ("b", 4), ("c", 6)]);
417 let b = HashMap::from([("a", 1), ("b", 2), ("c", 3)]);
418
419 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Greater);
420 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Smaller);
421
422 assert!(a.is_greater_or_equal(&b));
423 assert!(!b.is_greater_or_equal(&a));
424
425 assert!(!a.is_smaller_or_equal(&b));
426 assert!(b.is_smaller_or_equal(&a));
427 }
428
429 #[test]
430 fn test_hm_gt_one_lt_one() {
431 let a = HashMap::from([("a", 1), ("b", 3), ("c", 3)]);
432 let b = HashMap::from([("a", 1), ("b", 2), ("c", 4)]);
433
434 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Incomparable);
435 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Incomparable);
436
437 assert!(!a.is_greater_or_equal(&b));
438 assert!(!b.is_greater_or_equal(&a));
439
440 assert!(!a.is_smaller_or_equal(&b));
441 assert!(!b.is_smaller_or_equal(&a));
442 }
443
444 #[test]
445 fn test_hm_different_keys() {
446 let a = HashMap::from([("a", 1), ("x", 2), ("c", 3)]);
447 let b = HashMap::from([("a", 1), ("b", 2), ("c", 3)]);
448
449 assert_eq!(a.part_cmp(&b), PartialOrdCompResult::Incomparable);
450 assert_eq!(b.part_cmp(&a), PartialOrdCompResult::Incomparable);
451
452 assert!(!a.is_greater_or_equal(&b));
453 assert!(!b.is_greater_or_equal(&a));
454
455 assert!(!a.is_smaller_or_equal(&b));
456 assert!(!b.is_smaller_or_equal(&a));
457 }
458
459 #[test]
460 fn test_display_partial_ord_comp_result() {
461 let got_str = PartialOrdCompResult::Incomparable.to_string();
462 assert_eq!(&got_str, "≠");
463
464 let got_str = PartialOrdCompResult::Equal.to_string();
465 assert_eq!(&got_str, "==");
466
467 let got_str = PartialOrdCompResult::Greater.to_string();
468 assert_eq!(&got_str, ">");
469
470 let got_str = PartialOrdCompResult::Smaller.to_string();
471 assert_eq!(&got_str, "<");
472 }
473
474 #[test]
475 fn test_new_min_basis() {
476 let config_1 = vec![1, 2, 3];
477 let config_2 = vec![2, 2, 3];
478 let config_3 = vec![4, 2, 1];
479 let config_4 = vec![3, 2, 1];
480
481 let min_basis_set = SetMinimalBasis::from(
482 [
483 config_1.clone(),
484 config_2.clone(),
485 config_3.clone(),
486 config_4.clone(),
487 ]
488 .into_iter(),
489 );
490
491 let expected = SetMinimalBasis {
492 configs: Vec::from([config_1.clone(), config_4.clone()]),
493 };
494
495 assert_eq!(min_basis_set, expected);
496
497 let min_basis_set: Vec<_> = min_basis_set.into_iter().collect();
498 let expected = Vec::from([config_1, config_4]);
499 assert_eq!(min_basis_set, expected);
500
501 let min_basis_set: Vec<_> = min_basis_set.to_vec();
502 assert_eq!(min_basis_set, expected);
503 }
504
505 #[test]
506 fn test_min_basis_retain() {
507 let config_1 = vec![1, 2, 3];
508 let config_2 = vec![1, 2];
509 let config_3 = vec![5, 4];
510 let config_4 = vec![3, 2, 1];
511
512 let mut min_basis_set = SetMinimalBasis::from(
513 [
514 config_1.clone(),
515 config_2.clone(),
516 config_3.clone(),
517 config_4.clone(),
518 ]
519 .into_iter(),
520 );
521 min_basis_set.retain(|cfg| cfg.len() == 3);
522
523 let expected = SetMinimalBasis {
524 configs: Vec::from([config_1, config_4]),
525 };
526
527 assert_eq!(min_basis_set, expected);
528 }
529}