fix partition, also fixes select

This commit is contained in:
jacekpoz 2024-05-06 11:22:57 +02:00
parent ee93f577e1
commit 6e52722bdb
Signed by: poz
SSH key fingerprint: SHA256:JyLeVWE4bF3tDnFeUpUaJsPsNlJyBldDGV/dIKSLyN8
2 changed files with 61 additions and 76 deletions

View file

@ -1,85 +1,53 @@
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
fn partition_diff(arr: &mut [u64], lo: usize, hi: usize, pivot: usize) -> usize { pub fn insertion_sort_mut(list: &mut [u64]) {
let mut i = lo as i64 - 1; for i in 1..list.len() {
for j in lo..hi { let mut j = i as usize;
if arr[j] == arr[pivot] {
i = j as i64; while j > 0 && list[j - 1] > list[j] {
list.swap(j - 1, j);
j -= 1;
}
}
}
pub fn insertion_sort(list: &[u64]) -> Vec<u64> {
let mut clone = Vec::from(list);
insertion_sort_mut(&mut clone);
clone
}
fn partition(arr: &mut [u64], lo: usize, hi: usize, pivot: u64) -> usize {
let mut pivot_index = lo;
for i in lo..=hi {
if arr[i] == pivot {
pivot_index = i;
break; break;
} }
} }
if i != lo as i64 - 1 { if pivot_index != lo {
arr.swap(hi - 1, i as usize); arr.swap(pivot_index, lo);
pivot_index = lo;
} }
let mut j = lo;
let mut k = lo; for i in (lo + 1)..=hi {
if arr[pivot_index] > arr[i] {
for j in lo..(hi - 1) { j += 1;
if arr[j] < arr[pivot] { arr.swap(j, i);
arr.swap(j, k);
k += 1;
} }
} }
arr.swap(j, lo);
arr.swap(hi - 1, k); j
k
} }
// pub fn lomuto_partition(arr: &mut [u64], lo: usize, hi: usize, pivot: usize) -> usize {
// println!("called lomuto_partition({:?}, {}, {}, {})", arr, lo, hi, pivot);
// let mut swap = 0;
//
// for i in lo..pivot {
// if arr[i] < arr[pivot] {
// if swap != i {
// arr.swap(swap, i);
// }
// swap += 1;
// }
// }
//
// if swap != pivot {
// arr.swap(swap, pivot);
// }
//
// swap
// }
// pub fn partition(arr: &mut [u64], lo: usize, hi: usize, pivot: usize) -> usize {
// println!("called partition({:?}, {}, {}, {})", arr, lo, hi, pivot);
// let mut i = lo;
// let mut j = hi - 1;
//
// loop {
// while arr[i] < arr[pivot] {
// println!("partition i: {i}");
// i += 1;
// }
// while j > 0 && arr[j] > arr[pivot] {
// j -= 1;
// }
// if j == 0 || i >= j {
// break;
// } else if arr[i] == arr[j] {
// i += 1;
// j -= 1;
// } else {
// arr.swap(i, j);
// }
// }
// arr.swap(i, pivot);
// i
// }
pub fn rand_partition(arr: &mut [u64], lo: usize, hi: usize) -> usize { pub fn rand_partition(arr: &mut [u64], lo: usize, hi: usize) -> usize {
let pivot = thread_rng().gen_range(lo..hi); let pivot = thread_rng().gen_range(lo..hi);
partition_diff(arr, lo, hi, pivot) partition(arr, lo, hi, arr[pivot])
} }
fn _rand_select(arr: &mut [u64], lo: usize, hi: usize, k: usize) -> (usize, u64) { fn _rand_select(arr: &mut [u64], lo: usize, hi: usize, k: usize) -> (usize, u64) {
// println!("called _rand_select({:?}, {}, {}, {})", arr, lo, hi, k); if lo == hi {
if lo + 1 == hi {
return (lo, arr[lo]); return (lo, arr[lo]);
} }
let r = rand_partition(arr, lo, hi); let r = rand_partition(arr, lo, hi);
@ -98,22 +66,25 @@ pub fn rand_select(arr: &mut [u64], k: usize) -> (usize, u64) {
} }
fn _select(arr: &mut [u64], lo: usize, hi: usize, k: usize) -> (usize, u64) { fn _select(arr: &mut [u64], lo: usize, hi: usize, k: usize) -> (usize, u64) {
// println!("called _select({:?}, {}, {}, {})", arr, lo, hi, k); if lo == hi {
if lo + 1 == hi {
return (lo, arr[lo]); return (lo, arr[lo]);
} }
let mut medians = arr.chunks(5) let mut medians: Vec<u64> = vec![];
.map(|chunk| chunk[chunk.len() / 2]) let mut chunks = arr.chunks(5)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for chunk in chunks.iter_mut() {
let chunk = insertion_sort(chunk);
medians.push(chunk[chunk.len() / 2]);
}
let n = hi - lo; let n = hi - lo;
let (pivot, _) = _select(&mut medians, 0, n.div_ceil(5), (n.div_ceil(5) - 1).div_ceil(2)); let (_, pivot) = _select(&mut medians, 0, n.div_ceil(5) - 1, (n.div_ceil(5) - 1).div_ceil(2));
// println!("calling partition from _select"); let r = partition(arr, lo, hi, pivot);
let r = partition_diff(arr, lo, hi, pivot);
// println!("_select r: {r}");
let i = r - lo + 1; let i = r - lo + 1;
if i == k { if i == k {
return (r, arr[r]); return (r, arr[r]);
} else if i < k { } else if i < k {

View file

@ -1,6 +1,20 @@
use libsort::{select, rand_select}; use libsort::{select, rand_select};
fn main() { fn main() {
println!("select: {:?}", select(&mut [3, 1, 9, 20, 3, 5, 31, 8, 29, 34, 2], 5)); let arr = [3, 28, 9, 2, 4, 42, 23, 123, 24, 12, 43, 287, 723, 61];
println!("rand_select: {:?}", rand_select(&mut [3, 1, 9, 20, 3, 5, 31, 8, 29, 34, 2], 5)); print!("select: ");
for pos in 1..=arr.len() {
// println!("input: {:?}", arr);
// println!("positional statistic: {}", pos);
// println!("select: {:?}", select(&mut arr.clone(), pos).1);
// println!("rand_select: {:?}", rand_select(&mut arr.clone(), pos).1);
print!("{} ", select(&mut arr.clone(), pos).1);
}
print!("\n");
print!("rand_select: ");
for pos in 1..=arr.len() {
print!("{} ", rand_select(&mut arr.clone(), pos).1);
}
print!("\n");
} }