Skip to content

Commit b63214d

Browse files
Move null checking to compiler-side using nullable annotations
- Remove redundant runtime null checks from non-nullable Parameter weight setters in Linear, Bilinear, Convolution, PReLU, GroupNorm, LayerNorm, and NormBase (already enforced by #nullable enable) - Add [DisallowNull] to Parameter? properties in RNNCell, LSTMCell, GRUCell, Embedding, and EmbeddingBag where getter can return null but setter must not accept null - Enable #nullable enable in Transforms.cs and remove runtime null check from ComposeTransform constructor - Remove runtime null check from Module.get_buffer() (string parameter is non-nullable under #nullable enable) - Add DisallowNullAttribute polyfill in netstandard.cs for netstandard2.0 compatibility
1 parent bd0de70 commit b63214d

File tree

15 files changed

+55
-67
lines changed

15 files changed

+55
-67
lines changed

src/TorchSharp/Distributions/Transforms.cs

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Collections.Generic;
44
using System.Linq;
55

6+
#nullable enable
67
namespace TorchSharp
78
{
89

@@ -26,11 +27,11 @@ public abstract class Transform
2627
{
2728
protected bool _bijective = false;
2829

29-
protected constraints.Constraint _domain;
30+
protected constraints.Constraint _domain = null!;
3031

31-
protected constraints.Constraint _codomain;
32+
protected constraints.Constraint _codomain = null!;
3233

33-
protected Transform _inv = null;
34+
protected Transform? _inv = null;
3435

3536
public virtual int event_dim {
3637
get {
@@ -42,7 +43,7 @@ public virtual int event_dim {
4243

4344
public virtual Transform inv {
4445
get {
45-
Transform result = null;
46+
Transform? result = null;
4647
if (this._inv != null)
4748
result = _inv;
4849
if (result == null) {
@@ -118,19 +119,19 @@ public _InverseTransform(torch.distributions.transforms.Transform transform)
118119

119120
public override constraints.Constraint domain {
120121
get {
121-
return _inv.domain;
122+
return _inv!.domain;
122123
}
123124
}
124125

125126
public override constraints.Constraint codomain {
126127
get {
127-
return _inv.codomain;
128+
return _inv!.codomain;
128129
}
129130
}
130131

131132
public override bool bijective {
132133
get {
133-
return _inv.bijective;
134+
return _inv!.bijective;
134135
}
135136
}
136137

@@ -140,41 +141,39 @@ public override bool bijective {
140141

141142
protected internal override Tensor log_abs_det_jacobian(Tensor x, Tensor y)
142143
{
143-
return -_inv.log_abs_det_jacobian(y, x);
144+
return -_inv!.log_abs_det_jacobian(y, x);
144145
}
145146

146147
protected internal override Tensor _call(Tensor x)
147148
{
148-
return _inv._inverse(x);
149+
return _inv!._inverse(x);
149150
}
150151

151152
protected internal override Tensor _inverse(Tensor y)
152153
{
153-
return _inv._call(y);
154+
return _inv!._call(y);
154155
}
155156

156157
protected internal override Tensor _sign()
157158
{
158-
return _inv.sign;
159+
return _inv!.sign;
159160
}
160161

161162
public override long[] forward_shape(long[] shape)
162163
{
163-
return _inv.forward_shape(shape);
164+
return _inv!.forward_shape(shape);
164165
}
165166

166167
public override long[] inverse_shape(long[] shape)
167168
{
168-
return _inv.inverse_shape(shape);
169+
return _inv!.inverse_shape(shape);
169170
}
170171
}
171172

172173
public class ComposeTransform : Transform
173174
{
174175
public ComposeTransform(IEnumerable<Transform> parts, int cache_size = 0)
175176
{
176-
if (parts == null) throw new ArgumentNullException("parts cannot be null");
177-
178177
_parts = parts.ToArray();
179178
_reverse_parts = parts.Reverse().ToArray();
180179
}
@@ -186,13 +185,13 @@ public ComposeTransform(IEnumerable<Transform> parts, int cache_size = 0)
186185

187186
public override Transform inv {
188187
get {
189-
Transform i = _inv;
188+
Transform? i = _inv;
190189

191190
if (i == null) {
192191
i = new ComposeTransform(_reverse_parts.Select(p => p.inv));
193192
_inv = i;
194193
}
195-
return _inv;
194+
return _inv!;
196195
}
197196
}
198197

src/TorchSharp/NN/Activation/PReLU.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ public override string GetName()
4040
public Parameter weight {
4141
get => _weight!;
4242
set {
43-
if (value is null) throw new ArgumentNullException(nameof(weight));
4443
if (value.Handle != _weight?.Handle) {
4544
_weight?.Dispose();
4645
_weight = (value.DetachFromDisposeScope() as Parameter)!;

src/TorchSharp/NN/Bilinear.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ public Parameter? bias {
6060
public Parameter weight {
6161
get => _weight!;
6262
set {
63-
if (value is null) throw new ArgumentNullException(nameof(weight));
6463
if (value.Handle != _weight?.Handle) {
6564
_weight?.Dispose();
6665
_weight = (value.DetachFromDisposeScope() as Parameter)!;

src/TorchSharp/NN/Convolution/Convolution.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ public Parameter? bias {
145145
public Parameter weight {
146146
get => _weight!;
147147
set {
148-
if (value is null) throw new ArgumentNullException(nameof(weight));
149148
if (value.Handle != _weight?.Handle) {
150149
_weight?.Dispose();
151150
_weight = (value.DetachFromDisposeScope() as Parameter)!;

src/TorchSharp/NN/Embedding.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
using System;
3+
using System.Diagnostics.CodeAnalysis;
34
using static TorchSharp.torch;
45
using static TorchSharp.PInvoke.NativeMethods;
56

@@ -21,16 +22,15 @@ public override Tensor forward(Tensor input)
2122
return new Tensor(res);
2223
}
2324

25+
[DisallowNull]
2426
public Parameter? weight {
2527
get {
2628
var res = THSNN_Embedding_weight(handle);
2729
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
2830
return (res == IntPtr.Zero) ? null : new Parameter(res);
2931
}
3032
set {
31-
// Please ignore, for now, that the litorch call thinks you *can* set it to null.
32-
if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'");
33-
THSNN_Embedding_set_weight(handle, value is null ? IntPtr.Zero : value.Handle);
33+
THSNN_Embedding_set_weight(handle, value.Handle);
3434
torch.CheckForErrors();
3535
ConditionallyRegisterParameter("weight", value);
3636
}

src/TorchSharp/NN/EmbeddingBag.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
using System;
3+
using System.Diagnostics.CodeAnalysis;
34
using static TorchSharp.torch;
45
using static TorchSharp.PInvoke.NativeMethods;
56

@@ -82,16 +83,15 @@ public Tensor call(Tensor input)
8283
return base.call(input, null, null);
8384
}
8485

86+
[DisallowNull]
8587
public Parameter? weight {
8688
get {
8789
var res = THSNN_EmbeddingBag_weight(handle);
8890
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
8991
return (res == IntPtr.Zero) ? null : new Parameter(res);
9092
}
9193
set {
92-
// Please ignore, for now, that the litorch call thinks you *can* set it to null.
93-
if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'");
94-
THSNN_EmbeddingBag_set_weight(handle, value is null ? IntPtr.Zero : value.Handle);
94+
THSNN_EmbeddingBag_set_weight(handle, value.Handle);
9595
torch.CheckForErrors();
9696
ConditionallyRegisterParameter("weight", value);
9797
}

src/TorchSharp/NN/Linear.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ public Parameter? bias {
7070
public Parameter weight {
7171
get => _weight!;
7272
set {
73-
if (value is null) throw new ArgumentNullException(nameof(weight));
7473
if (value.Handle != _weight?.Handle) {
7574
_weight?.Dispose();
7675
_weight = (value.DetachFromDisposeScope() as Parameter)!;

src/TorchSharp/NN/Module.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,6 @@ public virtual bool has_parameter(string target)
697697
/// <returns>The tensor referenced by target</returns>
698698
public virtual Tensor? get_buffer(string target)
699699
{
700-
if (target is null) throw new ArgumentNullException("target");
701700
if (_internal_buffers.TryGetValue(target, out var buffer)) {
702701
return buffer.Item1;
703702
}

src/TorchSharp/NN/Normalization/GroupNorm.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ public Parameter? bias {
5757
public Parameter weight {
5858
get => _weight!;
5959
set {
60-
if (value is null) throw new ArgumentNullException(nameof(weight));
6160
if (value.Handle != _weight?.Handle) {
6261
_weight?.Dispose();
6362
_weight = (value.DetachFromDisposeScope() as Parameter)!;

src/TorchSharp/NN/Normalization/LayerNorm.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ public Parameter? bias {
7575
public Parameter weight {
7676
get => _weight!;
7777
set {
78-
if (value is null) throw new ArgumentNullException(nameof(weight));
7978
if (value.Handle != _weight?.Handle) {
8079
_weight?.Dispose();
8180
_weight = (value.DetachFromDisposeScope() as Parameter)!;

0 commit comments

Comments
 (0)