argumentation_weighted/
weight_source.rs1use crate::error::Error;
12use crate::framework::WeightedFramework;
13use std::hash::Hash;
14
15pub trait WeightSource<A> {
24 fn weight_for(&self, attacker: &A, target: &A) -> Option<f64>;
28}
29
30pub struct ClosureWeightSource<F>(pub F);
32
33impl<A, F> WeightSource<A> for ClosureWeightSource<F>
34where
35 F: Fn(&A, &A) -> Option<f64>,
36{
37 fn weight_for(&self, attacker: &A, target: &A) -> Option<f64> {
38 (self.0)(attacker, target)
39 }
40}
41
42pub fn populate_from_source<A, W, I>(
51 framework: &mut WeightedFramework<A>,
52 pairs: I,
53 source: &W,
54) -> Result<(), Error>
55where
56 A: Clone + Eq + Hash,
57 W: WeightSource<A>,
58 I: IntoIterator<Item = (A, A)>,
59{
60 for (attacker, target) in pairs {
61 if let Some(weight) = source.weight_for(&attacker, &target) {
62 framework.add_weighted_attack(attacker, target, weight)?;
63 }
64 }
65 Ok(())
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71
72 struct FixedSource(f64);
73
74 impl WeightSource<&'static str> for FixedSource {
75 fn weight_for(&self, _attacker: &&'static str, _target: &&'static str) -> Option<f64> {
76 Some(self.0)
77 }
78 }
79
80 #[test]
81 fn closure_weight_source_returns_closure_output() {
82 let src = ClosureWeightSource(|_a: &&str, _b: &&str| Some(0.42));
83 assert_eq!(src.weight_for(&"x", &"y"), Some(0.42));
84 }
85
86 #[test]
87 fn populate_from_source_adds_all_attacks() {
88 let mut wf: WeightedFramework<&str> = WeightedFramework::new();
89 let src = FixedSource(0.5);
90 populate_from_source(&mut wf, vec![("a", "b"), ("c", "d")], &src).unwrap();
91 assert_eq!(wf.attack_count(), 2);
92 }
93
94 #[test]
95 fn populate_skips_none_weights() {
96 let src = ClosureWeightSource(
97 |_a: &&str, target: &&str| if *target == "b" { Some(0.5) } else { None },
98 );
99 let mut wf: WeightedFramework<&str> = WeightedFramework::new();
100 populate_from_source(&mut wf, vec![("x", "b"), ("x", "c")], &src).unwrap();
101 assert_eq!(wf.attack_count(), 1);
102 }
103
104 #[test]
105 fn populate_propagates_invalid_weights() {
106 let src = ClosureWeightSource(|_a: &&str, _b: &&str| Some(-1.0));
107 let mut wf: WeightedFramework<&str> = WeightedFramework::new();
108 let err = populate_from_source(&mut wf, vec![("x", "y")], &src).unwrap_err();
109 assert!(matches!(err, Error::InvalidWeight { .. }));
110 }
111}