argumentation_weighted/
framework.rs1use crate::error::Error;
4use crate::types::{AttackWeight, WeightedAttack};
5use std::collections::HashSet;
6use std::hash::Hash;
7
8#[derive(Debug, Clone)]
19pub struct WeightedFramework<A: Clone + Eq + Hash> {
20 arguments: HashSet<A>,
21 attacks: Vec<WeightedAttack<A>>,
22}
23
24impl<A: Clone + Eq + Hash> WeightedFramework<A> {
25 #[must_use]
27 pub fn new() -> Self {
28 Self {
29 arguments: HashSet::new(),
30 attacks: Vec::new(),
31 }
32 }
33
34 pub fn add_argument(&mut self, a: A) {
36 self.arguments.insert(a);
37 }
38
39 pub fn add_weighted_attack(
44 &mut self,
45 attacker: A,
46 target: A,
47 weight: f64,
48 ) -> Result<(), Error> {
49 let w = AttackWeight::new(weight)?;
50 self.arguments.insert(attacker.clone());
51 self.arguments.insert(target.clone());
52 self.attacks.push(WeightedAttack {
53 attacker,
54 target,
55 weight: w,
56 });
57 Ok(())
58 }
59
60 pub fn collapse_duplicate_attacks(&mut self) -> Result<(), Error> {
70 use std::collections::HashMap;
71 let mut map: HashMap<(A, A), f64> = HashMap::new();
72 for atk in self.attacks.drain(..) {
73 let key = (atk.attacker, atk.target);
74 *map.entry(key).or_insert(0.0) += atk.weight.value();
75 }
76 let mut new_attacks = Vec::with_capacity(map.len());
77 for ((attacker, target), weight) in map {
78 let w = AttackWeight::new(weight)?;
79 new_attacks.push(WeightedAttack { attacker, target, weight: w });
80 }
81 self.attacks = new_attacks;
82 Ok(())
83 }
84
85 pub fn arguments(&self) -> impl Iterator<Item = &A> {
87 self.arguments.iter()
88 }
89
90 pub fn attacks(&self) -> impl Iterator<Item = &WeightedAttack<A>> {
92 self.attacks.iter()
93 }
94
95 #[must_use]
97 pub fn len(&self) -> usize {
98 self.arguments.len()
99 }
100
101 #[must_use]
103 pub fn is_empty(&self) -> bool {
104 self.arguments.is_empty()
105 }
106
107 #[must_use]
109 pub fn attack_count(&self) -> usize {
110 self.attacks.len()
111 }
112
113 #[must_use]
117 pub fn sorted_weights(&self) -> Vec<f64> {
118 let mut ws: Vec<f64> = self.attacks.iter().map(|a| a.weight.value()).collect();
119 ws.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
120 ws
121 }
122}
123
124impl<A: Clone + Eq + Hash> Default for WeightedFramework<A> {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130const _: fn() = || {
132 fn assert_send<T: Send>() {}
133 fn assert_sync<T: Sync>() {}
134 assert_send::<WeightedFramework<String>>();
135 assert_sync::<WeightedFramework<String>>();
136};
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn empty_framework_has_no_arguments() {
144 let wf: WeightedFramework<&str> = WeightedFramework::new();
145 assert!(wf.is_empty());
146 assert_eq!(wf.len(), 0);
147 assert_eq!(wf.attack_count(), 0);
148 }
149
150 #[test]
151 fn add_weighted_attack_registers_both_endpoints() {
152 let mut wf = WeightedFramework::new();
153 wf.add_weighted_attack("a", "b", 0.5).unwrap();
154 assert_eq!(wf.len(), 2);
155 assert_eq!(wf.attack_count(), 1);
156 }
157
158 #[test]
159 fn add_weighted_attack_rejects_invalid_weight() {
160 let mut wf: WeightedFramework<&str> = WeightedFramework::new();
161 assert!(wf.add_weighted_attack("a", "b", -0.1).is_err());
162 assert!(wf.add_weighted_attack("a", "b", f64::NAN).is_err());
163 }
164
165 #[test]
166 fn parallel_edges_are_preserved_before_collapse() {
167 let mut wf = WeightedFramework::new();
168 wf.add_weighted_attack("a", "b", 0.3).unwrap();
169 wf.add_weighted_attack("a", "b", 0.4).unwrap();
170 assert_eq!(wf.attack_count(), 2);
171 }
172
173 #[test]
174 fn collapse_duplicate_attacks_sums_weights() {
175 let mut wf = WeightedFramework::new();
176 wf.add_weighted_attack("a", "b", 0.3).unwrap();
177 wf.add_weighted_attack("a", "b", 0.4).unwrap();
178 wf.add_weighted_attack("a", "c", 0.5).unwrap();
179 wf.collapse_duplicate_attacks().unwrap();
180 assert_eq!(wf.attack_count(), 2);
181 let ab = wf
183 .attacks()
184 .find(|a| a.attacker == "a" && a.target == "b")
185 .unwrap();
186 assert!((ab.weight.value() - 0.7).abs() < 1e-9);
187 }
188
189 #[test]
190 fn collapse_duplicate_attacks_returns_err_on_weight_overflow() {
191 let mut wf = WeightedFramework::new();
192 wf.add_weighted_attack("a", "b", f64::MAX).unwrap();
193 wf.add_weighted_attack("a", "b", f64::MAX).unwrap();
194 let err = wf.collapse_duplicate_attacks().unwrap_err();
195 assert!(matches!(err, Error::InvalidWeight { .. }));
196 }
197
198 #[test]
199 fn sorted_weights_returns_ascending() {
200 let mut wf = WeightedFramework::new();
201 wf.add_weighted_attack("a", "b", 0.5).unwrap();
202 wf.add_weighted_attack("a", "c", 0.2).unwrap();
203 wf.add_weighted_attack("a", "d", 0.8).unwrap();
204 let ws = wf.sorted_weights();
205 assert_eq!(ws, vec![0.2, 0.5, 0.8]);
206 }
207}