From 8c3d30f6f6027b34518c8216f7f18092d05ce387 Mon Sep 17 00:00:00 2001 From: jacekpoz Date: Wed, 8 May 2024 22:00:25 +0200 Subject: [PATCH] fix zad1 --- libselect/src/normal_select.rs | 47 +++++++++++++++++++++--------- libselect/src/randomized_select.rs | 6 ++-- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/libselect/src/normal_select.rs b/libselect/src/normal_select.rs index 11fbfd4..51025c0 100644 --- a/libselect/src/normal_select.rs +++ b/libselect/src/normal_select.rs @@ -25,15 +25,26 @@ impl NormalSelect { } } - pub fn partition(&mut self, list: &mut Vec, lo: usize, hi: usize, mut pivot_index: usize) -> usize { + pub fn partition(&mut self, list: &mut Vec, lo: usize, hi: usize, pivot: u64) -> usize { + use CompareResult::*; + + let mut pivot_index = 0; + + for i in lo..=hi { + if self.compare(pivot, list[i]) == EQUAL { + pivot_index = i; + break; + } + } + if pivot_index != lo { self.swap(list, pivot_index, lo); - pivot_index = lo; } + pivot_index = lo; + let mut j = lo; for i in (lo + 1)..=hi { - use CompareResult::*; - if self.compare(list[pivot_index],list[i]) == GREATER { + if self.compare(list[pivot_index], list[i]) == GREATER { j += 1; self.swap(list, j, i); } @@ -47,15 +58,15 @@ impl NormalSelect { return (lo, list[lo]); } - if self.should_print { println!("list: {:?}; lo: {lo}; hi: {hi}", list); } + if self.should_print { println!("list: {:?}; lo: {lo}; hi: {hi}; ord_stat: {ord_stat}; k: {k}", list); } let mut medians: Vec = vec![]; - let mut chunks = list.chunks(k) + let mut chunks = list[lo..=hi].chunks(k) .collect::>(); for mut chunk in chunks.iter_mut().map(|c| c.to_vec()) { self.insertion.sort_mut(&mut chunk); - medians.push(chunk[chunk.len() / 2]); + medians.push(chunk[(chunk.len() - 1) / 2]); } if self.should_print { println!("medians: {:?}", medians); } @@ -66,17 +77,25 @@ impl NormalSelect { let n = hi - lo; - let (pivot, _) = self._select(&mut medians, 0, n.div_ceil(k) - 1, (n.div_ceil(k) - 1).div_ceil(2), k); - if self.should_print { println!("pivot: {pivot}"); } + let (pivot_index, pivot) = self._select(&mut medians, 0, n.div_ceil(k) - 1, (n.div_ceil(k) - 1).div_ceil(2), k); + if self.should_print { + println!("pivot_index: {pivot_index}; pivot: {pivot}"); + } let r = self.partition(list, lo, hi, pivot); - if self.should_print { println!("r: {r}"); } - let i = r - lo + 1; + let k = r - lo + 1; + if self.should_print { + println!("r: {r}"); + println!("k: {k}"); + } - if i == ord_stat { + if k == ord_stat { + if self.should_print { println!("k == ord_stat"); } return (r, list[r]); - } else if i < ord_stat { - return self._select(list, r + 1, hi, ord_stat - i, k); + } else if k < ord_stat { + if self.should_print { println!("k < ord_stat"); } + return self._select(list, r + 1, hi, ord_stat - k, k); } else { + if self.should_print { println!("k > ord_stat"); } return self._select(list, lo, r - 1, ord_stat, k); } } diff --git a/libselect/src/randomized_select.rs b/libselect/src/randomized_select.rs index 3c032c8..0a6c17c 100644 --- a/libselect/src/randomized_select.rs +++ b/libselect/src/randomized_select.rs @@ -12,9 +12,9 @@ pub struct RandomizedSelect { impl RandomizedSelect { fn rand_partition(&mut self, list: &mut Vec, lo: usize, hi: usize) -> usize { - let pivot = thread_rng().gen_range(lo..hi); - if self.should_print { println!("pivot: {pivot}"); } - self.normal.partition(list, lo, hi, pivot) + let pivot_index = thread_rng().gen_range(lo..hi); + if self.should_print { println!("pivot_index: {pivot_index}"); } + self.normal.partition(list, lo, hi, list[pivot_index]) } fn _select(&mut self, list: &mut Vec, lo: usize, hi: usize, k: usize) -> (usize, u64) {