diff --git a/libselect/src/normal_select.rs b/libselect/src/normal_select.rs index 3378844..11fbfd4 100644 --- a/libselect/src/normal_select.rs +++ b/libselect/src/normal_select.rs @@ -42,7 +42,7 @@ impl NormalSelect { j } - fn _select(&mut self, list: &mut Vec, lo: usize, hi: usize, k: usize) -> (usize, u64) { + fn _select(&mut self, list: &mut Vec, lo: usize, hi: usize, ord_stat: usize, k: usize) -> (usize, u64) { if lo == hi { return (lo, list[lo]); } @@ -50,7 +50,7 @@ impl NormalSelect { if self.should_print { println!("list: {:?}; lo: {lo}; hi: {hi}", list); } let mut medians: Vec = vec![]; - let mut chunks = list.chunks(5) + let mut chunks = list.chunks(k) .collect::>(); for mut chunk in chunks.iter_mut().map(|c| c.to_vec()) { @@ -66,20 +66,24 @@ impl NormalSelect { let n = hi - lo; - let (pivot, _) = self._select(&mut medians, 0, n.div_ceil(5) - 1, (n.div_ceil(5) - 1).div_ceil(2)); + let (pivot, _) = self._select(&mut medians, 0, n.div_ceil(k) - 1, (n.div_ceil(k) - 1).div_ceil(2), k); if self.should_print { println!("pivot: {pivot}"); } let r = self.partition(list, lo, hi, pivot); if self.should_print { println!("r: {r}"); } let i = r - lo + 1; - if i == k { + if i == ord_stat { return (r, list[r]); - } else if i < k { - return self._select(list, r + 1, hi, k - i); + } else if i < ord_stat { + return self._select(list, r + 1, hi, ord_stat - i, k); } else { - return self._select(list, lo, r - 1, k); + return self._select(list, lo, r - 1, ord_stat, k); } } + + pub fn select_k(&mut self, list: &mut Vec, order_statistic: usize, k: usize) -> u64 { + self._select(list, 0, list.len() - 1, order_statistic, k).1 + } } impl Select for NormalSelect { @@ -93,7 +97,7 @@ impl Select for NormalSelect { } fn select_mut(&mut self, list: &mut Vec, order_statistic: usize) -> u64 { - self._select(list, 0, list.len() - 1, order_statistic).1 + self.select_k(list, order_statistic, 5) } fn num_comp(&self) -> u64 {