Source code for nnabla.experimental.tb_graph_writer

# Copyright 2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorboardX.proto.event_pb2 as event_pb2
from nnabla.parameter import get_parameters
from tensorboardX import SummaryWriter
from tensorboardX.proto.attr_value_pb2 import AttrValue
from tensorboardX.proto.graph_pb2 import GraphDef
from tensorboardX.proto.node_def_pb2 import NodeDef
from tensorboardX.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboardX.proto.types_pb2 import DT_FLOAT
from tensorboardX.proto.versions_pb2 import VersionDef


[docs]class TBGraphWriter(SummaryWriter): """This class is a wrapper of tensorboardX summary writer, which enable nn.Variable can be visualized as a graph by tensorboard. Prerequisite: Install tensorflow and tensorboardX, simply by the following commands: .. code-block:: plaintext pip install tensorflow pip install tensorboardX Please refer to the following example to use this class: Example: .. code-block:: python import numpy as np import nnabla as nn import nnabla.functions as F import nnabla.parametric_functions as PF def show_a_graph(): try: from nnabla.experimental.tb_graph_writer import TBGraphWriter except Exception as e: print("please install tensorflow and tensorboardX at first.") raise e nn.clear_parameters() x = nn.Variable((2, 3, 4, 4)) with nn.parameter_scope('c1'): h = PF.convolution(x, 8, (3, 3), pad=(1, 1)) h = F.relu(PF.batch_normalization(h)) with nn.parameter_scope('f1'): y = PF.affine(h, 10) with TBGraphWriter(log_dir='log_out') as tb: tb.from_variable(y, output_name="y") The corresponding graph can be visualized as the following: .. image:: ./tb_graph_writer_files/graph.PNG :height: 561px :width: 233px :scale: 50% Or, you may show scalar value or histogram of values along the increasing of iteration number as the following: .. code-block:: python with TBGraphWriter(log_dir='log_out') as tb: values = [] for i in range(360): s = np.sin(i / 180.0 * np.pi) tb.add_scalar("show_curve/sin", s, i) values.append(s) nd_values = np.array(values) for i in range(10): tb.add_histogram("histogram", nd_values, i) nd_values += 0.05 It looks like: .. image:: ./tb_graph_writer_files/scalar.PNG :height: 418px :width: 522px :scale: 50% d .. image:: ./tb_graph_writer_files/histogram.PNG :height: 328px :width: 387px :scale: 50% This class writes a protobuf file to `log_dir` specified folder, thus, user should launch tensorboard to specify this folder: .. code-block:: plaintext tensorboard --logdir=log_out Then, user may check graph in a web browser, by typing the address: http://localhost:6006 See Also: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html """ def __init__(self, log_dir="log_out", comment='', **kwargs): super(TBGraphWriter, self).__init__(log_dir, comment, **kwargs)
[docs] def from_variable(self, leaf, output_name="output"): """ Args: leaf (`nn.Variable`): Leaf node of graph, normally, the output variable of a network output_name ('str'): A given name of output variable of graph """ def parse_variable(v, var_num): def add_variable(v, v_idx): v_name = parameters.get(v.data, None) exist = False if not v_name: v_name, exist = get_variable_name(v, v_idx) if not exist: shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) if v.parent is None: inputs = [] else: inputs = [get_func_name(v.parent)] # print("Variable: {}:{}".format(v_name, inputs)) nodes.append(NodeDef( name=v_name.encode(encoding='utf-8'), op='Variable', input=inputs, attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } )) return v_name def get_unique_variable_name(v_name_base): v_num = 0 v_name = v_name_base + '_' + str(v_num) while v_name in unique_var_names: v_num += 1 v_name = v_name_base + '_' + str(v_num) unique_var_names.add(v_name) return v_name def get_variable_name(v, v_idx): v_name = variables.get(v, None) if v_name: return v_name, True else: if v.parent is None: v_name_base = "Input" v_name = get_unique_variable_name(v_name_base) elif not nodes: v_name = output_name else: f_name_sections = get_func_name(v.parent).split("/") f_name = f_name_sections[-1] f_scope = f_name_sections[:-1] base_name = "v_{}".format(f_name) v_name_base = "/".join(f_scope + [base_name]) v_name = get_unique_variable_name(v_name_base) variables[v] = v_name return v_name, False def get_func_name(func): func_name = func_names.get(func, None) if func_name: return func_name name_scope = loc_var['name_scope'] for v in func.inputs: v_name = parameters.get(v.data, None) if v_name: name_scope = '/'.join(v_name.split('/')[:-1]) break if name_scope: func_name_base = '/'.join([name_scope, func.name]) else: func_name_base = func.name func_num = 0 func_name = func_name_base + str(func_num) while func_name in unique_func_names: func_num += 1 func_name = func_name_base + '_' + str(func_num) unique_func_names.add(func_name) func_names[func] = func_name return func_name def add_func(v): input_names = [] for index, v_input in enumerate(v.parent.inputs): v_name = add_variable(v_input, index) input_names.append(v_name) # print("Function: {}:{}".format(get_func_name(v.parent), input_names)) f_name = get_func_name(v.parent) if f_name in func_set: return False attrs = [] for k, a in v.parent.info.args.items(): attr = "{}={}".format(k, a) attrs.append(attr) attr_str = ','.join(attrs).encode(encoding='utf-8') nodes.append(NodeDef( name=f_name, op=v.parent.info.type_name, input=input_names, attr={"parameters": AttrValue(s=attr_str)} )) func_set.add(f_name) return True name_scope = loc_var['name_scope'] if not nodes: add_variable(v, var_num) if v.parent is None: add_variable(v, var_num) else: if not add_func(v): return for idx, in_var in enumerate(v.parent.inputs): name_scope_stack.append(name_scope) parse_variable(in_var, idx) name_scope = name_scope_stack.pop() nodes = [] variables = {} loc_var = {} loc_var['name_scope'] = '' name_scope_stack = [] func_names = {} func_set = set() unique_func_names = set() unique_var_names = set() parameters = {v.data: k for k, v in get_parameters(grad_only=False).items()} parse_variable(leaf, 0) nodes = nodes[::-1] current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22)) event = event_pb2.Event( graph_def=current_graph.SerializeToString()) self.file_writer.add_event(event)
def from_graph_def(self, graph_def): variables = graph_def.variables parameters = graph_def.parameters functions = graph_def.functions inputs = graph_def.inputs nodes = [] scope = {} for n, v in parameters.items(): shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) node = NodeDef( name=n.encode(encoding='utf-8'), op='Parameter', input=[], attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } ) nodes.append(node) scope[n] = node for n, v in inputs.items(): shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) nodes.append(NodeDef( name=n.encode(encoding='utf-8'), op='Variable', input=[], attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } )) for func_name, func in functions.items(): for o in func['outputs']: if o in scope: node = scope[o] node.input.extend([func_name]) else: if o in variables: v = variables[o] shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) node = NodeDef( name=o.encode(encoding='utf-8'), op='Variable', input=[func_name], attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } ) nodes.append(node) for i in func['inputs']: if i in variables: v = variables[i] shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) node = NodeDef( name=o.encode(encoding='utf-8'), op='Variable', input=[], attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } ) nodes.append(node) scope[o] = node nodes.append(NodeDef( name=func_name, op=func['type'], input=func['inputs'], attr={"arguments": AttrValue(s='a=1'.encode(encoding='utf-8'))} )) current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22)) event = event_pb2.Event( graph_def=current_graph.SerializeToString()) self.file_writer.add_event(event)