From 673fbfa0bbccbc6e0aa20162e743ab59169b95df Mon Sep 17 00:00:00 2001 From: Mario Carneiro Date: Tue, 23 Apr 2024 00:05:53 -0400 Subject: [PATCH 1/2] feat: merging functions on List + mergeSort --- Std/Data/List/Basic.lean | 129 +++++++++++++++++++++++++++++++++++--- Std/Data/List/Lemmas.lean | 74 +++++++--------------- 2 files changed, 141 insertions(+), 62 deletions(-) diff --git a/Std/Data/List/Basic.lean b/Std/Data/List/Basic.lean index 74eeb5e87d..2e91678191 100644 --- a/Std/Data/List/Basic.lean +++ b/Std/Data/List/Basic.lean @@ -1620,13 +1620,124 @@ See `isSubperm_iff` for a characterization in terms of `List.Subperm`. def isSubperm [BEq α] (l₁ l₂ : List α) : Bool := ∀ x ∈ l₁, count x l₁ ≤ count x l₂ /-- -`O(|l| + |r|)`. Merge two lists using `s` as a switch. --/ -def merge (s : α → α → Bool) (l r : List α) : List α := - loop l r [] +`O(|xs| + |ys|)`. Merge lists `xs` and `ys`. If the lists are sorted according to `lt`, then the +result is sorted as well. If two (or more) elements are equal according to `lt`, they are preserved. +-/ +def merge (lt : α → α → Bool) : (xs ys : List α) → List α + | [], xs + | xs, [] => xs + | x :: xs, y :: ys => + bif lt x y then x :: merge lt xs (y :: ys) else y :: merge lt (x :: xs) ys + +/-- Tail recursive version of `merge`. -/ +@[inline] def mergeTR (lt : α → α → Bool) (xs ys : List α) : List α := go xs ys [] where + /-- Auxiliary for `mergeTR`: `mergeTR.go xs ys acc = acc.toList ++ merge xs ys`. -/ + go : List α → List α → List α → List α + | [], ys, acc => reverseAux acc ys + | xs, [], acc => reverseAux acc xs + | x::xs, y::ys, acc => bif lt x y then go xs (y::ys) (x::acc) else go (x::xs) ys (y::acc) + +@[csimp] theorem merge_eq_mergeTR : @merge = @mergeTR := by + funext α lt xs ys + let rec go (acc) : ∀ xs ys, @mergeTR.go α lt xs ys acc = reverseAux acc (merge lt xs ys) + | [], _ => by simp [mergeTR.go, merge] + | _::_, [] => by simp [mergeTR.go, merge] + | x::xs, y::ys => by + simp [mergeTR.go, merge, cond]; split + · exact go _ xs (y::ys) + · exact go _ (x::xs) ys + simp [mergeTR, go] + +/-- +`O(|xs| + |ys|)`. Merge lists `xs` and `ys`, which must be sorted according to `compare` and must +not contain duplicates. Equal elements are merged using `merge`. If `merge` respects the order +(i.e. for all `x`, `y`, `y'`, `z`, if `x < y < z` and `x < y' < z` then `x < merge y y' < z`) +then the resulting list is again sorted. +-/ +def mergeDedupWith [Ord α] (merge : α → α → α) : (xs ys : List α) → List α + | [], xs + | xs, [] => xs + | x :: xs, y :: ys => + match compare x y with + | .lt => x :: mergeDedupWith merge xs (y :: ys) + | .gt => y :: mergeDedupWith merge (x :: xs) ys + | .eq => merge x y :: mergeDedupWith merge xs ys + +/-- +`O(|xs| + |ys|)`. Merge lists `xs` and `ys`, which must be sorted according to `compare` and must +not contain duplicates. If an element appears in both `xs` and `ys`, only one copy is kept. +-/ +@[inline] def mergeDedup [Ord α] (xs ys : List α) : List α := mergeDedupWith (fun x _ => x) xs ys + +/-- +`O(|xs| * |ys|)`. Merge `xs` and `ys`, which do not need to be sorted. Elements which occur in +both `xs` and `ys` are only added once. If `xs` and `ys` do not contain duplicates, then neither +does the result. +-/ +def mergeUnsortedDedup [BEq α] (xs ys : List α) : List α := + if xs.length < ys.length then go ys xs else go xs ys +where + /-- Auxiliary definition for `mergeUnsortedDedup`. -/ + go (xs ys : List α) := xs ++ ys.filter fun y => xs.any (· == y) + +/-- Replace each run `[x₁, ⋯, xₙ]` of equal elements in `xs` with `f ⋯ (f (f x₁ x₂) x₃) ⋯ xₙ`. -/ +def mergeAdjacentDups [BEq α] (f : α → α → α) : (xs : List α) → List α + | [] => [] + | x :: xs => go x xs where - /-- Inner loop for `List.merge`. Tail recursive. -/ - loop : List α → List α → List α → List α - | [], r, t => reverseAux t r - | l, [], t => reverseAux t l - | a::l, b::r, t => bif s a b then loop l (b::r) (a::t) else loop (a::l) r (b::t) + /-- Auxiliary definition for `mergeAdjacentDups`. -/ + go (hd : α) + | [] => [hd] + | x :: xs => + if x == hd then + go (f hd x) xs + else + hd :: go x xs + +/-- +`O(|xs|)`. Deduplicate a sorted list. The list must be sorted with to an order which agrees with +`==`, i.e. whenever `x == y` then `compare x y == .eq`. +-/ +def dedupSorted [BEq α] (xs : List α) : List α := + xs.mergeAdjacentDups fun x _ => x + +namespace MergeSort + +/-- `O(|l|)`. Split `l` into two lists of approximately equal length. +``` +split [1, 2, 3, 4, 5] = ([1, 3, 5], [2, 4]) +``` +-/ +@[simp] def split : List α → List α × List α + | [] => ([], []) + | a :: l => + let (l₁, l₂) := split l + (a :: l₂, l₁) + +theorem length_split_le : + ∀ l : List α, length (split l).1 ≤ length l ∧ length (split l).2 ≤ length l + | [] => ⟨Nat.le_refl 0, Nat.le_refl 0⟩ + | _ :: l => + let ⟨h₁, h₂⟩ := length_split_le l + ⟨Nat.succ_le_succ h₂, Nat.le_succ_of_le h₁⟩ + +end MergeSort + +/-- `O(|l| log |l|)`. Uses merge sort to sort a list in ascending order by `lt`. -/ +def mergeSort (lt : α → α → Bool) (l : List α) : List α := + match _e : l with + | [] => [] + | [a] => [a] + | _ :: _ :: _ => + let ls := MergeSort.split l + merge lt (mergeSort lt ls.1) (mergeSort lt ls.2) + termination_by length l + decreasing_by + all_goals subst _e + · exact Nat.add_le_add_right (MergeSort.length_split_le _).1 2 + · exact Nat.add_le_add_right (MergeSort.length_split_le _).2 2 + +/-- `O(|xs| log |xs|)`. Sort and deduplicate a list. -/ +def sortDedup [ord : Ord α] (xs : List α) : List α := + have := ord.toBEq + dedupSorted <| xs.mergeSort (compare · · |>.isLT) diff --git a/Std/Data/List/Lemmas.lean b/Std/Data/List/Lemmas.lean index 514ed09dbc..6bb82ae08a 100644 --- a/Std/Data/List/Lemmas.lean +++ b/Std/Data/List/Lemmas.lean @@ -2684,66 +2684,34 @@ theorem indexOf_mem_indexesOf [BEq α] [LawfulBEq α] {xs : List α} (m : x ∈ specialize ih m simpa -theorem merge_loop_nil_left (s : α → α → Bool) (r t) : - merge.loop s [] r t = reverseAux t r := by - rw [merge.loop] - -theorem merge_loop_nil_right (s : α → α → Bool) (l t) : - merge.loop s l [] t = reverseAux t l := by - cases l <;> rw [merge.loop]; intro; contradiction - -theorem merge_loop (s : α → α → Bool) (l r t) : - merge.loop s l r t = reverseAux t (merge s l r) := by - rw [merge]; generalize hn : l.length + r.length = n - induction n using Nat.recAux generalizing l r t with - | zero => - rw [eq_nil_of_length_eq_zero (Nat.eq_zero_of_add_eq_zero_left hn)] - rw [eq_nil_of_length_eq_zero (Nat.eq_zero_of_add_eq_zero_right hn)] - rfl - | succ n ih => - match l, r with - | [], r => simp only [merge_loop_nil_left]; rfl - | l, [] => simp only [merge_loop_nil_right]; rfl - | a::l, b::r => - simp only [merge.loop, cond] - split - · have hn : l.length + (b :: r).length = n := by - apply Nat.add_right_cancel (m:=1) - rw [←hn]; simp only [length_cons, Nat.add_succ, Nat.succ_add] - rw [ih _ _ (a::t) hn, ih _ _ [] hn, ih _ _ [a] hn]; rfl - · have hn : (a::l).length + r.length = n := by - apply Nat.add_right_cancel (m:=1) - rw [←hn]; simp only [length_cons, Nat.add_succ, Nat.succ_add] - rw [ih _ _ (b::t) hn, ih _ _ [] hn, ih _ _ [b] hn]; rfl - -@[simp] theorem merge_nil (s : α → α → Bool) (l) : merge s l [] = l := merge_loop_nil_right .. +@[simp] theorem merge_nil (lt : α → α → Bool) (l) : merge lt l [] = l := by cases l <;> simp [merge] -@[simp] theorem nil_merge (s : α → α → Bool) (r) : merge s [] r = r := merge_loop_nil_left .. +@[simp] theorem nil_merge (lt : α → α → Bool) (r) : merge lt [] r = r := by simp [merge] -theorem cons_merge_cons (s : α → α → Bool) (a b l r) : - merge s (a::l) (b::r) = if s a b then a :: merge s l (b::r) else b :: merge s (a::l) r := by - simp only [merge, merge.loop, cond]; split <;> (next hs => rw [hs, merge_loop]; rfl) +theorem cons_merge_cons (lt : α → α → Bool) (a b l r) : + merge lt (a::l) (b::r) = if lt a b then a :: merge lt l (b::r) else b :: merge lt (a::l) r := by + simp only [merge, cond_eq_if] -@[simp] theorem cons_merge_cons_pos (s : α → α → Bool) (l r) (h : s a b) : - merge s (a::l) (b::r) = a :: merge s l (b::r) := by +@[simp] theorem cons_merge_cons_pos (lt : α → α → Bool) (l r) (h : lt a b) : + merge lt (a::l) (b::r) = a :: merge lt l (b::r) := by rw [cons_merge_cons, if_pos h] -@[simp] theorem cons_merge_cons_neg (s : α → α → Bool) (l r) (h : ¬ s a b) : - merge s (a::l) (b::r) = b :: merge s (a::l) r := by +@[simp] theorem cons_merge_cons_neg (lt : α → α → Bool) (l r) (h : ¬ lt a b) : + merge lt (a::l) (b::r) = b :: merge lt (a::l) r := by rw [cons_merge_cons, if_neg h] -@[simp] theorem length_merge (s : α → α → Bool) (l r) : - (merge s l r).length = l.length + r.length := by +@[simp] theorem length_merge (lt : α → α → Bool) (l r) : + (merge lt l r).length = l.length + r.length := by match l, r with | [], r => simp | l, [] => simp | a::l, b::r => rw [cons_merge_cons] split - · simp_arith [length_merge s l (b::r)] - · simp_arith [length_merge s (a::l) r] + · simp_arith [length_merge lt l (b::r)] + · simp_arith [length_merge lt (a::l) r] -theorem mem_merge_left (s : α → α → Bool) (h : x ∈ l) : x ∈ merge s l r := by +theorem mem_merge_left (lt : α → α → Bool) (h : x ∈ l) : x ∈ merge lt l r := by match l, r with | l, [] => simp [h] | a::l, b::r => @@ -2752,14 +2720,14 @@ theorem mem_merge_left (s : α → α → Bool) (h : x ∈ l) : x ∈ merge s l rw [cons_merge_cons] split · exact mem_cons_self .. - · apply mem_cons_of_mem; exact mem_merge_left s h + · apply mem_cons_of_mem; exact mem_merge_left lt h | .inr h' => rw [cons_merge_cons] split - · apply mem_cons_of_mem; exact mem_merge_left s h' - · apply mem_cons_of_mem; exact mem_merge_left s h + · apply mem_cons_of_mem; exact mem_merge_left lt h' + · apply mem_cons_of_mem; exact mem_merge_left lt h -theorem mem_merge_right (s : α → α → Bool) (h : x ∈ r) : x ∈ merge s l r := by +theorem mem_merge_right (lt : α → α → Bool) (h : x ∈ r) : x ∈ merge lt l r := by match l, r with | [], r => simp [h] | a::l, b::r => @@ -2767,10 +2735,10 @@ theorem mem_merge_right (s : α → α → Bool) (h : x ∈ r) : x ∈ merge s l | .inl rfl => rw [cons_merge_cons] split - · apply mem_cons_of_mem; exact mem_merge_right s h + · apply mem_cons_of_mem; exact mem_merge_right lt h · exact mem_cons_self .. | .inr h' => rw [cons_merge_cons] split - · apply mem_cons_of_mem; exact mem_merge_right s h - · apply mem_cons_of_mem; exact mem_merge_right s h' + · apply mem_cons_of_mem; exact mem_merge_right lt h + · apply mem_cons_of_mem; exact mem_merge_right lt h' From d63bb438f479134b0467f8dbaf77dd480b00e653 Mon Sep 17 00:00:00 2001 From: Mario Carneiro Date: Wed, 24 Apr 2024 09:32:45 -0700 Subject: [PATCH 2/2] Update Std/Data/List/Basic.lean MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: François G. Dorais --- Std/Data/List/Basic.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Std/Data/List/Basic.lean b/Std/Data/List/Basic.lean index 2e91678191..4124bf220b 100644 --- a/Std/Data/List/Basic.lean +++ b/Std/Data/List/Basic.lean @@ -1703,7 +1703,7 @@ def dedupSorted [BEq α] (xs : List α) : List α := namespace MergeSort -/-- `O(|l|)`. Split `l` into two lists of approximately equal length. +/-- `O(|l|)`. Split alternating elements of list `l` into two lists. The two lists will have equal length if `|l|` is even, otherwise the first list will be one element longer than the second. ``` split [1, 2, 3, 4, 5] = ([1, 3, 5], [2, 4]) ```