core/slice/sort/shared/pivot.rs
1//! This module contains the logic for pivot selection.
2
3use crate::{hint, intrinsics};
4
5// Recursively select a pseudomedian if above this threshold.
6const PSEUDO_MEDIAN_REC_THRESHOLD: usize = 64;
7
8/// Selects a pivot from `v`. Algorithm taken from glidesort by Orson Peters.
9///
10/// This chooses a pivot by sampling an adaptive amount of points, approximating
11/// the quality of a median of sqrt(n) elements.
12#[inline]
13pub fn choose_pivot<T, F: FnMut(&T, &T) -> bool>(v: &[T], is_less: &mut F) -> usize {
14 // We use unsafe code and raw pointers here because we're dealing with
15 // heavy recursion. Passing safe slices around would involve a lot of
16 // branches and function call overhead.
17
18 let len = v.len();
19 if len < 8 {
20 intrinsics::abort();
21 }
22
23 // SAFETY: a, b, c point to initialized regions of len_div_8 elements,
24 // satisfying median3 and median3_rec's preconditions as v_base points
25 // to an initialized region of n = len elements.
26 let index = unsafe {
27 let v_base = v.as_ptr();
28 let len_div_8 = len / 8;
29
30 let a = v_base; // [0, floor(n/8))
31 let b = v_base.add(len_div_8 * 4); // [4*floor(n/8), 5*floor(n/8))
32 let c = v_base.add(len_div_8 * 7); // [7*floor(n/8), 8*floor(n/8))
33
34 if len < PSEUDO_MEDIAN_REC_THRESHOLD {
35 median3(&*a, &*b, &*c, is_less).offset_from_unsigned(v_base)
36 } else {
37 median3_rec(a, b, c, len_div_8, is_less).offset_from_unsigned(v_base)
38 }
39 };
40 // SAFETY: preconditions must have been met for offset_from_unsigned()
41 unsafe {
42 hint::assert_unchecked(index < v.len());
43 index
44 }
45}
46
47/// Calculates an approximate median of 3 elements from sections a, b, c, or
48/// recursively from an approximation of each, if they're large enough. By
49/// dividing the size of each section by 8 when recursing we have logarithmic
50/// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) =
51/// O(n^(log(3)/log(8))) ~= O(n^0.528) elements.
52///
53/// SAFETY: a, b, c must point to the start of initialized regions of memory of
54/// at least n elements.
55unsafe fn median3_rec<T, F: FnMut(&T, &T) -> bool>(
56 mut a: *const T,
57 mut b: *const T,
58 mut c: *const T,
59 n: usize,
60 is_less: &mut F,
61) -> *const T {
62 // SAFETY: a, b, c still point to initialized regions of n / 8 elements,
63 // by the exact same logic as in choose_pivot.
64 unsafe {
65 if n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD {
66 let n8 = n / 8;
67 a = median3_rec(a, a.add(n8 * 4), a.add(n8 * 7), n8, is_less);
68 b = median3_rec(b, b.add(n8 * 4), b.add(n8 * 7), n8, is_less);
69 c = median3_rec(c, c.add(n8 * 4), c.add(n8 * 7), n8, is_less);
70 }
71 median3(&*a, &*b, &*c, is_less)
72 }
73}
74
75/// Calculates the median of 3 elements.
76///
77/// SAFETY: a, b, c must be valid initialized elements.
78#[inline(always)]
79fn median3<T, F: FnMut(&T, &T) -> bool>(a: &T, b: &T, c: &T, is_less: &mut F) -> *const T {
80 // Compiler tends to make this branchless when sensible, and avoids the
81 // third comparison when not.
82 let x = is_less(a, b);
83 let y = is_less(a, c);
84 if x == y {
85 // If x=y=0 then b, c <= a. In this case we want to return max(b, c).
86 // If x=y=1 then a < b, c. In this case we want to return min(b, c).
87 // By toggling the outcome of b < c using XOR x we get this behavior.
88 let z = is_less(b, c);
89 if z ^ x { c } else { b }
90 } else {
91 // Either c <= a < b or b <= a < c, thus a is our median.
92 a
93 }
94}