Skip to content

Commit fe1e45a

Browse files
more general indexing
1 parent 7d14a44 commit fe1e45a

2 files changed

Lines changed: 36 additions & 3 deletions

File tree

src/karray.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,6 @@ end
494494
# function _getindex(l::LinearIndexing, A::AbstractArray, I::Union{Real, AbstractArray, Colon}...)
495495
# in abstractarray.jl:487,multidimensional.jl:184.
496496

497-
if VERSION < v"0.5.0"
498-
@typealias6 AbstractUnitRange UnitRange
499-
end
500497

501498
function getindex{T}(A::KnetArray{T}, I::AbstractUnitRange)
502499
if !(1 <= first(I) <= last(I) <= length(A)); throw(BoundsError(A,I)); end
@@ -552,6 +549,22 @@ function setindex!{T}(A::KnetArray{T}, v, I::Colon)
552549
unsafe_copy!(A,1,v,1,length(A))
553550
end
554551

552+
## General Indexing Fallback to linear indexing
553+
554+
function getindex(a::KnetArray, I...)
555+
crange = CartesianRange(to_indices(a, I))
556+
linind = [sub2ind(size(a), t.I...) for t in crange]
557+
b = getindex(a, vec(linind))
558+
shape = size(crange) # TODO drop scalar dimension
559+
reshape(b, shape)
560+
end
561+
562+
function setindex!(a::KnetArray, v, I...)
563+
crange = CartesianRange(to_indices(a, I))
564+
linind = [sub2ind(size(a), t.I...) for t in crange]
565+
setindex!(a, v, vec(linind))
566+
end
567+
555568
for F in (32,64); T=Symbol("Float$F"); @eval begin
556569

557570
## Indexing with KnetArray{Int32}: low level, only Int32 supported, no bounds checking

test/karray.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,26 @@ if gpu() >= 0
123123
@test size(reshape(a, :, 4)) == size(reshape(a, (:, 4))) == (2, 4)
124124
@test size(reshape(a, :, 1, 4)) == (2, 1, 4)
125125
end
126+
127+
@testset "general indexing" begin
128+
# TODO
129+
#
130+
#julia> a=KnetArray(rand(3,3,2));
131+
#
132+
#julia> grad(a->sum(a[1:2,:,1]))(a)
133+
#3×3×2 Knet.KnetArray{Float64,3}:
134+
#[:, :, 1] =
135+
# 1.0 1.0 1.0
136+
# 1.0 1.0 1.0
137+
# 0.0 0.0 0.0
138+
#
139+
#[:, :, 2] =
140+
# 0.0 0.0 0.0
141+
# 0.0 0.0 0.0
142+
# 0.0 0.0 0.0
143+
end
144+
145+
126146
end
127147
end
128148

0 commit comments

Comments
 (0)