Source code for nnabla.numpy_compat_functions

# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


[docs]def dot(a, b, out=None): ''' A compatible operation with ``numpy.dot``. Note: Any operation between nnabla's Variable/NdArray and numpy array is not supported. If both arguments are 1-D, it is inner product of vectors. If both arguments are 2-D, it is matrix multiplication. If either a or b is 0-D(scalar), it is equivalent to multiply. If b is a 1-D array, it is a sum product over the last axis of a and b. If b is an M-D array (M>=2), it is a sum product over the last axis of a and the second-to-last axis of b. Args: a (Variable, NdArray or scalar): Left input array. b (Variable, NdArray or scalar): Right input array. out: Output argument. This must have the same shape, dtype, and type as the result that would be returned for F.dot(a,b). Returns: ~nnabla.Variable or ~nnabla.NdArray Examples: .. code-block:: python import numpy as np import nnabla as nn import nnabla.functions as F # 2-D matrix * 2-D matrix arr1 = np.arange(5*6).reshape(5, 6) arr2 = np.arange(6*8).reshape(6, 8) nd1 = nn.NdArray.from_numpy_array(arr1) nd2 = nn.NdArray.from_numpy_array(arr2) ans1 = F.dot(nd1, nd2) print(ans1.shape) #(5, 8) var1 = nn.Variable.from_numpy_array(arr1) var2 = nn.Variable.from_numpy_array(arr2) ans2 = F.dot(var1, var2) ans2.forward() print(ans2.shape) #(5, 8) out1 = nn.NdArray((5, 8)) out1.cast(np.float32) F.dot(nd1, nd2, out1) print(out1.shape) #(5, 8) out2 = nn.Variable((5, 8)) out2.data.cast(np.float32) F.dot(var1, var2, out2) out2.forward() print(out2.shape) #(5, 8) # N-D matrix * M-D matrix (M>=2) arr1 = np.arange(5*6*7*8).reshape(5, 6, 7, 8) arr2 = np.arange(2*3*8*6).reshape(2, 3, 8, 6) nd1 = nn.NdArray.from_numpy_array(arr1) nd2 = nn.NdArray.from_numpy_array(arr2) ans1 = F.dot(nd1, nd2) print(ans1.shape) #(5, 6, 7, 2, 3, 6) var1 = nn.Variable.from_numpy_array(arr1) var2 = nn.Variable.from_numpy_array(arr2) ans2 = F.dot(var1, var2) ans2.forward() print(ans2.shape) #(5, 6, 7, 2, 3, 6) out1 = nn.NdArray((5, 6, 7, 2, 3, 6)) out1.cast(np.float32) F.dot(nd1, nd2, out1) print(out1.shape) #(5, 6, 7, 2, 3, 6) out2 = nn.Variable((5, 6, 7, 2, 3, 6)) out2.data.cast(np.float32) F.dot(var1, var2, out2) out2.forward() print(out2.shape) #(5, 6, 7, 2, 3, 6) ''' import nnabla as nn import nnabla.functions as F def _chk(x, mark=0): if isinstance(x, nn.NdArray): return x.data, 1 elif isinstance(x, nn.Variable): return x.d, 1 else: return x, mark m, mark1 = _chk(a) n, mark2 = _chk(b) if mark1 and mark2: if a.ndim == 1 and b.ndim == 1: result = F.sum(a * b) elif a.ndim == 2 and b.ndim == 2: result = F.affine(a, b) elif a.ndim == 0 or b.ndim == 0: if a.ndim == 0: result = F.mul_scalar(b, m) if isinstance(a, nn.NdArray) and isinstance(b, nn.Variable): result.forward() result = result.data else: result = F.mul_scalar(a, n) if isinstance(a, nn.Variable) and isinstance(b, nn.NdArray): result.forward() result = result.data elif b.ndim == 1: h = F.affine(a, F.reshape(b, (-1, 1)), base_axis=a.ndim - 1) result = F.reshape(h, h.shape[:-1]) elif b.ndim >= 2: index = [*range(0, b.ndim)] index.insert(0, index.pop(b.ndim - 2)) b = F.transpose(b, index) h = F.affine(a, b, base_axis=a.ndim - 1) result = h else: result = np.dot(a, b) if out is not None: out_, _ = _chk(out) result_, _ = _chk(result) if type(out) == type(result) and out_.shape == result_.shape and out_.dtype == result_.dtype: if isinstance(out, nn.NdArray): out.cast(result.data.dtype)[...] = result.data elif isinstance(out, nn.Variable): out.rewire_on(result) else: out = result else: raise ValueError(f"Output argument must have the same shape, type and dtype as the result that would be " f"returned for F.dot(a,b).") else: return result