This commit is contained in:
jacekpoz 2024-05-08 22:00:25 +02:00
parent 49125ee428
commit 8c3d30f6f6
Signed by: poz
SSH key fingerprint: SHA256:JyLeVWE4bF3tDnFeUpUaJsPsNlJyBldDGV/dIKSLyN8
2 changed files with 36 additions and 17 deletions

View file

@ -25,14 +25,25 @@ impl NormalSelect {
} }
} }
pub fn partition(&mut self, list: &mut Vec<u64>, lo: usize, hi: usize, mut pivot_index: usize) -> usize { pub fn partition(&mut self, list: &mut Vec<u64>, lo: usize, hi: usize, pivot: u64) -> usize {
use CompareResult::*;
let mut pivot_index = 0;
for i in lo..=hi {
if self.compare(pivot, list[i]) == EQUAL {
pivot_index = i;
break;
}
}
if pivot_index != lo { if pivot_index != lo {
self.swap(list, pivot_index, lo); self.swap(list, pivot_index, lo);
pivot_index = lo;
} }
pivot_index = lo;
let mut j = lo; let mut j = lo;
for i in (lo + 1)..=hi { for i in (lo + 1)..=hi {
use CompareResult::*;
if self.compare(list[pivot_index], list[i]) == GREATER { if self.compare(list[pivot_index], list[i]) == GREATER {
j += 1; j += 1;
self.swap(list, j, i); self.swap(list, j, i);
@ -47,15 +58,15 @@ impl NormalSelect {
return (lo, list[lo]); return (lo, list[lo]);
} }
if self.should_print { println!("list: {:?}; lo: {lo}; hi: {hi}", list); } if self.should_print { println!("list: {:?}; lo: {lo}; hi: {hi}; ord_stat: {ord_stat}; k: {k}", list); }
let mut medians: Vec<u64> = vec![]; let mut medians: Vec<u64> = vec![];
let mut chunks = list.chunks(k) let mut chunks = list[lo..=hi].chunks(k)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for mut chunk in chunks.iter_mut().map(|c| c.to_vec()) { for mut chunk in chunks.iter_mut().map(|c| c.to_vec()) {
self.insertion.sort_mut(&mut chunk); self.insertion.sort_mut(&mut chunk);
medians.push(chunk[chunk.len() / 2]); medians.push(chunk[(chunk.len() - 1) / 2]);
} }
if self.should_print { println!("medians: {:?}", medians); } if self.should_print { println!("medians: {:?}", medians); }
@ -66,17 +77,25 @@ impl NormalSelect {
let n = hi - lo; let n = hi - lo;
let (pivot, _) = self._select(&mut medians, 0, n.div_ceil(k) - 1, (n.div_ceil(k) - 1).div_ceil(2), k); let (pivot_index, pivot) = self._select(&mut medians, 0, n.div_ceil(k) - 1, (n.div_ceil(k) - 1).div_ceil(2), k);
if self.should_print { println!("pivot: {pivot}"); } if self.should_print {
println!("pivot_index: {pivot_index}; pivot: {pivot}");
}
let r = self.partition(list, lo, hi, pivot); let r = self.partition(list, lo, hi, pivot);
if self.should_print { println!("r: {r}"); } let k = r - lo + 1;
let i = r - lo + 1; if self.should_print {
println!("r: {r}");
println!("k: {k}");
}
if i == ord_stat { if k == ord_stat {
if self.should_print { println!("k == ord_stat"); }
return (r, list[r]); return (r, list[r]);
} else if i < ord_stat { } else if k < ord_stat {
return self._select(list, r + 1, hi, ord_stat - i, k); if self.should_print { println!("k < ord_stat"); }
return self._select(list, r + 1, hi, ord_stat - k, k);
} else { } else {
if self.should_print { println!("k > ord_stat"); }
return self._select(list, lo, r - 1, ord_stat, k); return self._select(list, lo, r - 1, ord_stat, k);
} }
} }

View file

@ -12,9 +12,9 @@ pub struct RandomizedSelect {
impl RandomizedSelect { impl RandomizedSelect {
fn rand_partition(&mut self, list: &mut Vec<u64>, lo: usize, hi: usize) -> usize { fn rand_partition(&mut self, list: &mut Vec<u64>, lo: usize, hi: usize) -> usize {
let pivot = thread_rng().gen_range(lo..hi); let pivot_index = thread_rng().gen_range(lo..hi);
if self.should_print { println!("pivot: {pivot}"); } if self.should_print { println!("pivot_index: {pivot_index}"); }
self.normal.partition(list, lo, hi, pivot) self.normal.partition(list, lo, hi, list[pivot_index])
} }
fn _select(&mut self, list: &mut Vec<u64>, lo: usize, hi: usize, k: usize) -> (usize, u64) { fn _select(&mut self, list: &mut Vec<u64>, lo: usize, hi: usize, k: usize) -> (usize, u64) {