Skip to content
This repository was archived by the owner on Feb 11, 2023. It is now read-only.

Commit 6a74859

Browse files
committed
fix mul_dim
1 parent 4efbf9f commit 6a74859

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleTensorNetworks"
22
uuid = "4456351a-5be3-4067-ade9-541926a41e04"
33
authors = ["GiggleLiu <[email protected]> and contributors"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"

src/tensors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function mul_dim(t::LabeledTensor, m::AbstractMatrix; dim::Int)
4343
iA = ntuple(i->i, ndims(data))
4444
iB = (dim, -dim)
4545
iC = ntuple(i->i==dim ? -dim : i, ndims(data))
46-
LabeledTensor(tensorcontract(iA, data, iB, m, iC; compress=false), t.labels)
46+
LabeledTensor(tensorcontract(iA, data, iB, m, iC), t.labels)
4747
end
4848

4949
struct PlotMeta

test/tensorcontract.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,15 @@ end
4040
@test tnet.tensors[] tOut2
4141
@test contracted_labels == [1, 2]
4242
end
43+
44+
@testset "mul_dim" begin
45+
A = zeros(Float64, 10, 32, 21);
46+
B = zeros(Float64, 32, 11);
47+
tA = LabeledTensor(A, [1,2,3])
48+
tB = LabeledTensor(B, [2,4])
49+
tA1 = SimpleTensorNetworks.mul_dim(tA, B; dim=2)
50+
tA2 = tA * tB
51+
@test tA1.array permutedims(tA2.array, (1,3,2))
52+
@test tA1.labels == [1, 2, 3]
53+
@test tA2.labels == [1, 3, 4]
54+
end

0 commit comments

Comments
 (0)