Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Batteries/Data/Array.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ public import Batteries.Data.Array.Match
public import Batteries.Data.Array.Merge
public import Batteries.Data.Array.Monadic
public import Batteries.Data.Array.Pairwise
public import Batteries.Data.Array.Scan
250 changes: 250 additions & 0 deletions Batteries/Data/Array/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,261 @@ This will perform the update destructively provided that `a` has a reference cou
abbrev setN (a : Array α) (i : Nat) (x : α) (h : i < a.size := by get_elem_tactic) : Array α :=
a.set i x



/--
This is guaranteed by the Array docs but it is unprovable.
May be asserted to be true in an unsafe context via `Array.unsafe_size_fits_usize
-/
private abbrev size_fits_usize {a : Array α}: Prop := a.size < USize.size

@[grind .]
private theorem nat_index_eq_usize_index {n : Nat} {a : Array α}
{h : a.size_fits_usize} {hn : n ≤ a.size}
: (USize.ofNat n).toNat = n
:= USize.toNat_ofNat_of_lt' (Nat.lt_of_le_of_lt ‹_› ‹_›)


/--
This is guaranteed by the Array docs but it is unprovable.
Can be used in unsafe functions to write more efficient implementations
that avoid boxed integer arithmetic.
-/
private unsafe def unsafe_size_fits_usize {a: Array α} : Array.size_fits_usize (a := a) :=
lcProof


@[inline]
private def scanlMFast [Monad m] (f : β → α → m β) (init : β) (as : Array α)
(start := 0) (stop := as.size) : m (Array β) :=
let stop := min stop as.size
let start := min start as.size

loop f init as
(start := USize.ofNat start) (stop := USize.ofNat stop)
(h_stop := by grind only [USize.size_eq, USize.ofNat_eq_iff_mod_eq_toNat, = Nat.min_def])
(acc := Array.mkEmpty <| stop - start + 1)
where
@[specialize]
loop (f : β → α → m β) (init: β) (as: Array α) (start stop : USize)
(h_stop : stop.toNat ≤ as.size) (acc : Array β) : m (Array β) := do
if h_lt: start < stop then
let next ← f init (as.uget start <| Nat.lt_of_lt_of_le h_lt h_stop)
loop f next as (start + 1) stop h_stop (acc.push init)
else
pure <| acc.push init
termination_by stop.toNat - min start.toNat stop.toNat
decreasing_by
have : start < (start + 1) := by grind only [USize.size_eq]
grind only [Nat.min_def, USize.lt_iff_toNat_lt]


/--
Fold an effectful function `f` over the array from the left, returning the list of partial results.
-/
@[implemented_by scanlMFast]
def scanlM [Monad m] (f : β → α → m β) (init : β) (as : Array α) (start := 0)
(stop := as.size) : m (Array β) :=
loop f init as (min start as.size) (min stop as.size) (Nat.min_le_right _ _) #[]
where
/-- auxiliary tail-recursive function for scanlM -/
loop (f : β → α → m β) (init : β ) (as : Array α) (start stop : Nat)
(h_stop : stop ≤ as.size) (acc : Array β) : m (Array β) := do
if h_lt : start < stop then
loop f (← f init as[start]) as (start + 1) stop h_stop (acc.push init)
else
pure <| acc.push init

private theorem scanlM_loop_eq_scanlMFast_loop [Monad m]
{f : β → α → m β} {init : β} {as : Array α} {h_size : as.size_fits_usize}
{start stop : Nat} {h_start : start ≤ as.size}
{h_stop : stop ≤ as.size} {acc : Array β} :
scanlM.loop f init as start stop h_stop acc
= scanlMFast.loop f init as (USize.ofNat start) (USize.ofNat stop)
(by rw [USize.toNat_ofNat_of_lt' (Nat.lt_of_le_of_lt h_stop h_size)]; exact h_stop) acc := by

generalize h_n : stop - start = n
induction n using Nat.strongRecOn generalizing start acc init
rename_i n ih
rw [scanlM.loop, scanlMFast.loop]
have h_stop_usize := nat_index_eq_usize_index (h := h_size) (hn := h_stop)
have h_start_usize := nat_index_eq_usize_index (h := h_size) (hn := h_start)
split
case isTrue h_lt =>
simp_all only [USize.toNat_ofNat', ↓reduceDIte, uget,
show USize.ofNat start < USize.ofNat stop by simp_all [USize.lt_iff_toNat_lt]]
apply bind_congr
intro next
have h_start_succ : USize.ofNat start + 1 = USize.ofNat (start + 1) := by
simp_all only [← USize.toNat_inj, USize.toNat_add]
grind only [USize.size_eq, nat_index_eq_usize_index]
rw [h_start_succ]
apply ih (stop - (start + 1)) <;> omega
case isFalse h_nlt => grind [USize.lt_iff_toNat_lt]

-- this theorem establishes that given the (unprovable) assumption that as.size < USize.size,
-- the scanlMFast and scanlM are equivalent
private theorem scanlM_eq_scanlMFast [Monad m]
{f : β → α → m β} {init : β} {as : Array α}
{h_size : as.size_fits_usize}
{start stop : Nat}
: scanlM f init as start stop = scanlMFast f init as start stop
:= by
unfold scanlM scanlMFast
apply scanlM_loop_eq_scanlMFast_loop
simp_all only [gt_iff_lt]
apply Nat.min_le_right


@[inline]
private def scanrMFast [Monad m] (f : α → β → m β) (init : β) (as : Array α)
(h_size : as.size_fits_usize) (start := as.size) (stop := 0) : m (Array β) :=
let start := min start as.size
let stop := min stop start

loop f init as
(start := USize.ofNat start) (stop := USize.ofNat stop)
(h_start := by grind only [USize.size_eq, USize.ofNat_eq_iff_mod_eq_toNat, = Nat.min_def])
(acc := Array.replicate (start - stop + 1) init)
(by grind only [!Array.size_replicate, = Nat.min_def, Array.nat_index_eq_usize_index])
where
@[specialize]
loop (f : α → β → m β) (init : β) (as : Array α)
(start stop : USize)
(h_start : start.toNat ≤ as.size)
(acc : Array β)
(h_bound : start.toNat - stop.toNat < acc.size)
: m (Array β) := do
if h_gt : stop < start then
let startM1 := start - 1
have : startM1 < start := by grind only [!USize.sub_add_cancel, USize.lt_iff_le_and_ne,
USize.lt_add_one, USize.le_zero_iff]
have : startM1.toNat < as.size := Nat.lt_of_lt_of_le ‹_› ‹_›
have : (startM1 - stop) < (start - stop) := by grind only
[!USize.sub_add_cancel, USize.sub_right_inj, USize.add_comm, USize.lt_add_one,
USize.add_assoc, USize.add_right_inj]

let next ← f (as.uget startM1 ‹_›) init
loop f next as
(start := startM1)
(stop := stop)
(h_start := Nat.le_of_succ_le_succ (Nat.le_succ_of_le ‹_›))
(acc := acc.uset (startM1 - stop) next
(by grind only [USize.toNat_sub_of_le, USize.le_of_lt, USize.lt_iff_toNat_lt]))
(h_bound :=
(by grind only [USize.toNat_sub_of_le, = uset_eq_set, = size_set, USize.size_eq]))
else
pure acc
termination_by start.toNat - stop.toNat
decreasing_by
grind only [USize.lt_iff_toNat_lt, USize.toNat_sub,
USize.toNat_sub_of_le, USize.le_iff_toNat_le]


@[inline]
private unsafe def scanrMUnsafe [Monad m] (f : α → β → m β) (init : β) (as : Array α)
(start := as.size) (stop := 0) : m (Array β) :=
scanrMFast (h_size := Array.unsafe_size_fits_usize) f init as (start := start) (stop := stop)

/--
Folds a monadic function over a list from the left, accumulating the partial results starting with
`init`. The accumulated value is combined with the each element of the list in order, using `f`.

The optional parameters `start` and `stop` control the region of the array to be folded. Folding
proceeds from `start` (inclusive) to `stop` (exclusive), so no folding occurs unless `start < stop`.
By default, the entire array is folded.

Examples:
```lean example
example [Monad m] (f : α → β → m α) :
Array.scanlM (m := m) f x₀ #[a, b, c] = (do
let x₁ ← f x₀ a
let x₂ ← f x₁ b
let x₃ ← f x₂ c
pure #[x₀, x₁, x₂, x₃])
:= by rfl
```

```lean example
example [Monad m] (f : α → β → m α) :
Array.scanlM (m := m) f x₀ #[a, b, c] (start := 1) = (do
let x₁ ← f x₀ b
let x₂ ← f x₁ c
pure #[x₀, x₁, x₂])
:= by rfl
-/
@[implemented_by scanrMUnsafe]
def scanrM [Monad m]
(f : α → β → m β) (init : β) (as : Array α) (start := as.size) (stop := 0) : m (Array β) :=
let start := min start as.size
loop f init as start stop (Nat.min_le_right _ _) #[]
where
/-- auxiliary tail-recursive function for scanrM -/
loop (f : α → β → m β) (init : β) (as : Array α)
(start stop : Nat)
(h_start : start ≤ as.size)
(acc : Array β)
: m (Array β) := do
if h_gt : stop < start then
let i := start - 1
let next ← f as[i] init
loop f next as i stop (by omega) (acc.push init)
else
pure <| acc.push init |>.reverse
/--

Fold a function `f` over the list from the left, returning the list of partial results.
```
scanl (· + ·) 0 #[1, 2, 3] = #[0, 1, 3, 6]
```
-/
@[inline]
def scanl (f : β → α → β) (init : β) (as : Array α) (start := 0) (stop := as.size) : Array β :=
Id.run <| as.scanlM (pure <| f · ·) init start stop

/--

Fold a function `f` over the list from the right, returning the list of partial results.
```
scanl (+) 0 #[1, 2, 3] = #[0, 1, 3, 6]
```
-/
@[inline]
def scanr (f : α → β → β) (init : β) (as : Array α) (start := as.size) (stop := 0) : Array β :=
Id.run <| as.scanrM (pure <| f · ·) init start stop

end Array


namespace Subarray

/--
Fold an effectful function `f` over the array from the left, returning the list of partial results.
-/
@[inline]
def scanlM [Monad m] (f : β → α → m β) (init : β) (as : Subarray α) : m (Array β) :=
as.array.scanlM f init (start := as.start) (stop := as.stop)

/--
Fold an effectful function `f` over the array from the right, returning the list of partial results.
-/
@[inline]
def scanrM [Monad m] (f : α → β → m β) (init : β) (as : Subarray α) : m (Array β) :=
as.array.scanrM f init (start := as.start) (stop := as.stop)

/--
Fold a pure function `f` over the array from the left, returning the list of partial results.
-/
@[inline]
def scanl (f : β → α → β) (init : β) (as : Subarray α): Array β :=
as.array.scanl f init (start := as.start) (stop := as.stop)

/--
Fold a pure function `f` over the array from the left, returning the list of partial results.
-/
def scanr (f : α → β → β) (init : β) (as : Subarray α): Array β :=
as.array.scanr f init (start := as.start) (stop := as.stop)

/--
Check whether a subarray is empty.
-/
Expand Down
4 changes: 4 additions & 0 deletions Batteries/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ theorem idxOf?_toList [BEq α] {a : α} {l : Array α} :
(a.eraseIdxIfInBounds i).size = if i < a.size then a.size-1 else a.size := by
grind

theorem toList_drop {as: Array α} {n : Nat}
: (as.drop n).toList = as.toList.drop n
:= by simp only [drop, toList_extract, size_eq_length_toList, List.drop_eq_extract]

/-! ### set -/

theorem size_set! (a : Array α) (i v) : (a.set! i v).size = a.size := by simp
Expand Down
Loading