@@ -313,7 +313,10 @@ def var(self) -> ArrayType:
313313 """
314314 if self .__var is None :
315315 try :
316- var = np .diag (self .cov ).reshape (self .__shape ).copy ()
316+ var = backend .reshape (
317+ backend .diag (self .cov ),
318+ self .__shape ,
319+ ).copy ()
317320 except NotImplementedError as exc :
318321 raise NotImplementedError from exc
319322 else :
@@ -514,8 +517,12 @@ def quantile(self, p: ArrayType) -> ArrayType:
514517 return quantile
515518
516519 def __getitem__ (self , key : ArrayLikeGetitemArgType ) -> "RandomVariable" :
520+ # Shape inference
521+ # For simplicity, this should not be computed using backend, but rather in numpy
522+ shape = np .broadcast_to (np .empty (()), self .shape )[key ].shape
523+
517524 return RandomVariable (
518- shape = np . empty ( shape = self . shape )[ key ]. shape ,
525+ shape = shape ,
519526 dtype = self .dtype ,
520527 sample = lambda rng , size : self .sample (rng , size )[key ],
521528 mode = lambda : self .mode [key ],
@@ -557,8 +564,13 @@ def transpose(self, *axes: int) -> "RandomVariable":
557564 axes :
558565 See documentation of :meth:`numpy.ndarray.transpose`.
559566 """
567+
568+ # Shape inference
569+ # For simplicity, this should not be computed using backend, but rather in numpy
570+ shape = np .broadcast_to (np .empty (()), self .shape ).transpose (* axes ).shape
571+
560572 return RandomVariable (
561- shape = np . empty ( shape = self . shape ). transpose ( * axes ). shape ,
573+ shape = shape ,
562574 dtype = self .dtype ,
563575 sample = lambda rng , size : self .sample (rng , size ).transpose (* axes ),
564576 mode = lambda : self .mode .transpose (* axes ),
0 commit comments