1 2
from __future__ import absolute_import, division, unicode_literals
2

3 2
import os
4 2
import sys
5 2
import traceback as tb
6

7 2
from collections import OrderedDict, defaultdict
8

9 2
import param
10

11 2
from .layout import Row, Column, HSpacer, VSpacer
12 2
from .pane import HoloViews, Pane, Markdown
13 2
from .widgets import Button, Select
14 2
from .param import Param
15 2
from .util import param_reprs
16

17

18 2
class PipelineError(RuntimeError):
19
    """
20
    Custom error type which can be raised to display custom error
21
    message in a Pipeline.
22
    """
23

24

25 2
def traverse(graph, v, visited):
26
    """
27
    Traverse the graph from a node and mark visited vertices.
28
    """
29 2
    visited[v] = True
30
    # Recur for all the vertices adjacent to this vertex
31 2
    for i in graph.get(v, []):
32 2
        if visited[i] == False:
33 2
            traverse(graph, i, visited)
34

35

36 2
def find_route(graph, current, target):
37
    """
38
    Find a route to the target node from the current node.
39
    """
40 2
    next_nodes = graph.get(current)
41 2
    if next_nodes is None:
42 2
        return None
43 2
    elif target in next_nodes:
44 2
        return [target]
45
    else:
46 2
        for n in next_nodes:
47 2
            route = find_route(graph, n, target)
48 2
            if route is None:
49 2
                continue
50 2
            return [n]+route
51 2
        return None
52

53

54 2
def get_root(graph):
55
    """
56
    Search for the root not by finding nodes without inputs.
57
    """
58
    # Find root node
59 2
    roots = []
60 2
    targets = [t for ts in graph.values() for t in ts]
61 2
    for src in graph:
62 2
        if src not in targets:
63 2
            roots.append(src)
64

65 2
    if len(roots) > 1:
66 0
        raise ValueError("Graph has more than one node with no "
67
                         "incoming edges. Ensure that the graph "
68
                         "only has a single source node.")
69 2
    elif len(roots) == 0:
70 0
        raise ValueError("Graph has no source node. Ensure that the "
71
                         "graph is not cyclic and has a single "
72
                         "starting point.")
73 2
    return roots[0]
74

75

76 2
def is_traversable(root, graph, stages):
77
    """
78
    Check if the graph is fully traversable from the root node.
79
    """
80
    # Ensure graph is traverable from root
81 2
    int_graph = {stages.index(s): tuple(stages.index(t) for t in tgts)
82
                 for s, tgts in graph.items()}
83 2
    visited = [False]*len(stages)
84 2
    traverse(int_graph, stages.index(root), visited)
85 2
    return all(visited)
86

87

88 2
def get_depth(node, graph, depth=0):
89 2
    depths = []
90 2
    for sub in graph.get(node, []):
91 2
        depths.append(get_depth(sub, graph, depth+1))
92 2
    return max(depths) if depths else depth+1
93

94

95 2
def get_breadths(node, graph, depth=0, breadths=None):
96 2
    if breadths is None:
97 2
        breadths = defaultdict(list)
98 2
        breadths[depth].append(node)
99 2
    for sub in graph.get(node, []):
100 2
        if sub not in breadths[depth+1]:
101 2
            breadths[depth+1].append(sub)
102 2
        get_breadths(sub, graph, depth+1, breadths)
103 2
    return breadths
104

105

106

107 2
class Pipeline(param.Parameterized):
108
    """
109
    A Pipeline represents a directed graph of stages, which each
110
    returns a panel object to render. A pipeline therefore represents
111
    a UI workflow of multiple linear or branching stages.
112

113
    The Pipeline layout consists of a number of sub-components:
114

115
    * header:
116

117
      * title: The name of the current stage.
118
      * error: A field to display the error state.
119
      * network: A network diagram representing the pipeline.
120
      * buttons: All navigation buttons and selectors.
121
      * prev_button: The button to go to the previous stage.
122
      * prev_selector: The selector widget to select between
123
        previous branching stages.
124
      * next_button: The button to go to the previous stage
125
      * next_selector: The selector widget to select the next
126
        branching stages.
127

128
    * stage: The contents of the current pipeline stage.
129

130
    By default any outputs of one stage annotated with the
131
    param.output decorator are fed into the next stage. Additionally,
132
    if the inherit_params parameter is set any parameters which are
133
    declared on both the previous and next stage are also inherited.
134

135
    The stages are declared using the add_stage method and must each
136
    be given a unique name. By default any stages will simply be
137
    connected linearly, however an explicit graph can be declared using
138
    the define_graph method.
139
    """
140

141 2
    auto_advance = param.Boolean(default=False, doc="""
142
        Whether to automatically advance if the ready parameter is True.""")
143

144 2
    debug = param.Boolean(default=False, doc="""
145
        Whether to raise errors, useful for debugging while building
146
        an application.""")
147

148 2
    inherit_params = param.Boolean(default=True, doc="""
149
        Whether parameters should be inherited between pipeline
150
        stages.""")
151

152 2
    next_parameter = param.String(default=None, doc="""
153
        Parameter name to watch to switch between different branching
154
        stages""")
155

156 2
    ready_parameter = param.String(default=None, doc="""
157
        Parameter name to watch to check whether a stage is ready.""")
158

159 2
    show_header = param.Boolean(default=True, doc="""
160
        Whether to show the header with the title, network diagram,
161
        and buttons.""")
162

163 2
    next = param.Action(default=lambda x: x.param.trigger('next'))
164

165 2
    previous = param.Action(default=lambda x: x.param.trigger('previous'))
166

167 2
    def __init__(self, stages=[], graph={}, **params):
168 2
        try:
169 2
            import holoviews as hv
170 0
        except Exception:
171 0
            raise ImportError('Pipeline requires holoviews to be installed')
172

173 2
        super(Pipeline, self).__init__(**params)
174

175
        # Initialize internal state
176 2
        self._stage = None
177 2
        self._stages = OrderedDict()
178 2
        self._states = {}
179 2
        self._state = None
180 2
        self._linear = True
181 2
        self._block = False
182 2
        self._error = None
183 2
        self._graph = {}
184 2
        self._route = []
185

186
        # Declare UI components
187 2
        self._progress_sel = hv.streams.Selection1D()
188 2
        self._progress_sel.add_subscriber(self._set_stage)
189 2
        self.prev_button = Param(self.param.previous).layout[0]
190 2
        self.prev_button.width = 125
191 2
        self.prev_selector = Select(width=125)
192 2
        self.next_button = Param(self.param.next).layout[0]
193 2
        self.next_button.width = 125
194 2
        self.next_selector = Select(width=125)
195 2
        self.prev_button.disabled = True
196 2
        self.next_selector.param.watch(self._update_progress, 'value')
197 2
        self.network = HoloViews(backend='bokeh')
198 2
        self.title = Markdown('# Header', margin=(0, 0, 0, 5))
199 2
        self.error = Row(width=100)
200 2
        self.buttons = Row(self.prev_button, self.next_button)
201 2
        self.header = Row(
202
            Column(self.title, self.error),
203
            self.network,
204
            self.buttons,
205
            sizing_mode='stretch_width'
206
        )
207 2
        self.network.object = self._make_progress()
208 2
        spinner = Pane(os.path.join(os.path.dirname(__file__), 'assets', 'spinner.gif'))
209 2
        self._spinner_layout = Row(
210
            HSpacer(),
211
            Column(VSpacer(), spinner, VSpacer()),
212
            HSpacer()
213
        )
214 2
        self.stage = Row()
215 2
        self.layout = Column(self.header, self.stage, sizing_mode='stretch_width')
216

217
        # Initialize stages and the graph
218 2
        for stage in stages:
219 2
            kwargs = {}
220 2
            if len(stage) == 2:
221 2
                name, stage = stage
222 0
            elif len(stage) == 3:
223 0
                name, stage, kwargs = stage
224 2
            self.add_stage(name, stage, **kwargs)
225 2
        self.define_graph(graph)
226

227 2
    def _validate(self, stage):
228 2
        if any(stage is s for n, (s, kw) in self._stages.items()):
229 2
            raise ValueError('Stage %s is already in pipeline' % stage)
230 2
        elif not ((isinstance(stage, type) and issubclass(stage, param.Parameterized))
231
                  or isinstance(stage, param.Parameterized)):
232 2
            raise ValueError('Pipeline stages must be Parameterized classes or instances.')
233

234 2
    def __repr__(self):
235 2
        repr_str = 'Pipeline:'
236 2
        for i, (name, (stage, _)) in enumerate(self._stages.items()):
237 2
            if isinstance(stage, param.Parameterized):
238 0
                cls_name = type(stage).__name__
239
            else:
240 2
                cls_name = stage.__name__
241 2
            params = ', '.join(param_reprs(stage))
242 2
            repr_str += '\n    [%d] %s: %s(%s)' % (i, name, cls_name, params)
243 2
        return repr_str
244

245 2
    def __str__(self):
246 0
        return self.__repr__()
247

248 2
    def __getitem__(self, index):
249 2
        return self._stages[index][0]
250

251 2
    def _unblock(self, event):
252 2
        if self._state is not event.obj or self._block:
253 0
            self._block = False
254 0
            return
255

256 2
        button = self.next_button
257 2
        if button.disabled and event.new:
258 2
            button.disabled = False
259 0
        elif not button.disabled and not event.new:
260 0
            button.disabled = True
261

262 2
        stage_kwargs = self._stages[self._stage][-1]
263 2
        if event.new and stage_kwargs.get('auto_advance', self.auto_advance):
264 2
            self._next()
265

266 2
    def _select_next(self, event):
267 2
        if self._state is not event.obj:
268 0
            return
269 2
        self.next_selector.value = event.new
270 2
        self._update_progress()
271

272 2
    def _init_stage(self):
273 2
        stage, stage_kwargs = self._stages[self._stage]
274

275 2
        previous = []
276 2
        for src, tgts in self._graph.items():
277 2
            if self._stage in tgts:
278 2
                previous.append(src)
279 2
        prev_states = [self._states[prev] for prev in previous if prev in self._states]
280

281 2
        outputs = []
282 2
        kwargs, results = {}, {}
283 2
        for state in prev_states:
284 2
            for name, (_, method, index) in state.param.outputs().items():
285 2
                if name not in stage.param:
286 0
                    continue
287 2
                if method not in results:
288 2
                    results[method] = method()
289 2
                result = results[method]
290 2
                if index is not None:
291 0
                    result = result[index]
292 2
                kwargs[name] = result
293 2
                outputs.append(name)
294 2
            if stage_kwargs.get('inherit_params', self.inherit_params):
295 2
                ignored = [stage_kwargs.get(p) or getattr(self, p, None)
296
                           for p in ('ready_parameter', 'next_parameter')]
297 2
                params = [k for k, v in state.param.objects('existing').items()
298
                          if k not in ignored]
299 2
                kwargs.update({k: v for k, v in state.param.get_param_values()
300
                               if k in stage.param and k != 'name' and k in params})
301

302 2
        if isinstance(stage, param.Parameterized):
303 2
            stage.param.set_param(**kwargs)
304 2
            self._state = stage
305
        else:
306 2
            self._state = stage(**kwargs)
307

308
        # Hide widgets for parameters that are supplied by the previous stage
309 2
        for output in outputs:
310 2
            self._state.param[output].precedence = -1
311

312 2
        ready_param = stage_kwargs.get('ready_parameter', self.ready_parameter)
313 2
        if ready_param and ready_param in stage.param:
314 2
            self._state.param.watch(self._unblock, ready_param, onlychanged=False)
315

316 2
        next_param = stage_kwargs.get('next_parameter', self.next_parameter)
317 2
        if next_param and next_param in stage.param:
318 2
            self._state.param.watch(self._select_next, next_param, onlychanged=False)
319

320 2
        self._states[self._stage] = self._state
321 2
        return self._state.panel()
322

323 2
    def _set_stage(self, index):
324 2
        if not index:
325 0
            return
326 2
        stage = self._progress_sel.source.iloc[index[0], 2]
327 2
        if stage in self.next_selector.options:
328 2
            self.next_selector.value = stage
329 2
            self.param.trigger('next')
330 2
        elif stage in self.prev_selector.options:
331 2
            self.prev_selector.value = stage
332 2
            self.param.trigger('previous')
333 2
        elif stage in self._route:
334 0
            while len(self._route) > 1:
335 0
                self.param.trigger('previous')
336
        else:
337
            # Try currently selected route
338 2
            route = find_route(self._graph, self._next_stage, stage)
339 2
            if route is None:
340
                # Try alternate route
341 2
                route = find_route(self._graph, self._stage, stage)
342 2
                if route is None:
343 0
                    raise ValueError('Could not find route to target node.')
344
            else:
345 0
                route = [self._next_stage] + route
346 2
            for r in route:
347 2
                if r not in self.next_selector.options:
348 0
                    break
349 2
                self.next_selector.value = r
350 2
                self.param.trigger('next')
351

352

353 2
    @property
354
    def _next_stage(self):
355 2
        return self.next_selector.value
356

357 2
    @property
358
    def _prev_stage(self):
359 2
        return self.prev_selector.value
360

361 2
    def _update_button(self):
362 2
        stage, kwargs = self._stages[self._stage]
363 2
        options = list(self._graph.get(self._stage, []))
364 2
        next_param = kwargs.get('next_parameter', self.next_parameter)
365 2
        option = getattr(self._state, next_param) if next_param and next_param in stage.param else None
366 2
        if option is None:
367 2
            option = options[0] if options else None
368 2
        self.next_selector.options = options
369 2
        self.next_selector.value = option
370 2
        self.next_selector.disabled = not bool(options)
371 2
        previous = []
372 2
        for src, tgts in self._graph.items():
373 2
            if self._stage in tgts:
374 2
                previous.append(src)
375 2
        self.prev_selector.options = previous
376 2
        self.prev_selector.value = self._route[-1] if previous else None
377 2
        self.prev_selector.disabled = not bool(previous)
378

379
        # Disable previous button
380 2
        if self._prev_stage is None:
381 2
            self.prev_button.disabled = True
382
        else:
383 2
            self.prev_button.disabled = False
384

385
        # Disable next button
386 2
        if self._next_stage is None:
387 2
            self.next_button.disabled = True
388
        else:
389 2
            ready = kwargs.get('ready_parameter', self.ready_parameter)
390 2
            disabled = (not getattr(stage, ready)) if ready in stage.param else False
391 2
            self.next_button.disabled = disabled
392

393 2
    def _get_error_button(self, e):
394 2
        msg = str(e) if isinstance(e, PipelineError) else ""
395 2
        if self.debug:
396 0
            type, value, trb = sys.exc_info()
397 0
            tb_list = tb.format_tb(trb, None) + tb.format_exception_only(type, value)
398 0
            traceback = (("%s\n\nTraceback (innermost last):\n" + "%-20s %s") %
399
                         (msg, ''.join(tb_list[-5:-1]), tb_list[-1]))
400
        else:
401 2
            traceback = msg or "Undefined error, enable debug mode."
402 2
        button = Button(name='Error', button_type='danger', width=100,
403
                        align='center', margin=(0, 0, 0, 5))
404 2
        button.js_on_click(code="alert(`{tb}`)".format(tb=traceback))
405 2
        return button
406

407 2
    @param.depends('next', watch=True)
408
    def _next(self):
409 2
        prev_state, prev_stage = self._state, self._stage
410 2
        self._stage = self._next_stage
411 2
        self.stage[0] = self._spinner_layout
412 2
        try:
413 2
            self.stage[0] = self._init_stage()
414 2
        except Exception as e:
415 2
            self._error = self._stage
416 2
            self._stage = prev_stage
417 2
            self._state = prev_state
418 2
            self.stage[0] = prev_state.panel()
419 2
            self.error[:] = [self._get_error_button(e)]
420 2
            if self.debug:
421 0
                raise e
422 2
            return e
423
        else:
424 2
            self.error[:] = []
425 2
            self._error = None
426 2
            self._update_button()
427 2
            self._route.append(self._stage)
428 2
            stage_kwargs = self._stages[self._stage][-1]
429 2
            ready_param = stage_kwargs.get('ready_parameter', self.ready_parameter)
430 2
            if (ready_param and getattr(self._state, ready_param, False) and
431
                stage_kwargs.get('auto_advance', self.auto_advance)):
432 0
                self._next()
433
        finally:
434 2
            self._update_progress()
435

436 2
    @param.depends('previous', watch=True)
437
    def _previous(self):
438 2
        prev_state, prev_stage = self._state, self._stage
439 2
        self._stage = self._prev_stage
440 2
        try:
441 2
            if self._stage in self._states:
442 2
                self._state = self._states[self._stage]
443 2
                self.stage[0] = self._state.panel()
444
            else:
445 0
                self.stage[0] = self._init_stage()
446 2
            self._block = True
447 0
        except Exception as e:
448 0
            self.error[:] = [self._get_error_button(e)]
449 0
            self._error = self._stage
450 0
            self._stage = prev_stage
451 0
            self._state = prev_state
452 0
            if self.debug:
453 0
                raise e
454
        else:
455 2
            self.error[:] = []
456 2
            self._error = None
457 2
            self._update_button()
458 2
            self._route.pop()
459
        finally:
460 2
            self._update_progress()
461

462 2
    def _update_progress(self, *args):
463 2
        self.title.object = '## Stage: ' + self._stage
464 2
        self.network.object = self._make_progress()
465

466 2
    def _make_progress(self):
467 2
        import holoviews as hv
468 2
        import holoviews.plotting.bokeh # noqa
469

470 2
        if self._graph:
471 2
            root = get_root(self._graph)
472 2
            depth = get_depth(root, self._graph)
473 2
            breadths = get_breadths(root, self._graph)
474 2
            max_breadth = max(len(v) for v in breadths.values())
475
        else:
476 2
            root = None
477 2
            max_breadth, depth = 0, 0
478 2
            breadths = {}
479

480 2
        height = 80 + (max_breadth-1) * 20
481

482 2
        edges = []
483 2
        for src, tgts in self._graph.items():
484 2
            for t in tgts:
485 2
                edges.append((src, t))
486

487 2
        nodes = []
488 2
        for depth, subnodes in breadths.items():
489 2
            breadth = len(subnodes)
490 2
            step = 1./breadth
491 2
            for i, n in enumerate(subnodes[::-1]):
492 2
                if n == self._stage:
493 2
                    state = 'active'
494 2
                elif n == self._error:
495 2
                    state = 'error'
496 2
                elif n == self._next_stage:
497 2
                    state = 'next'
498
                else:
499 2
                    state = 'inactive'
500 2
                nodes.append((depth, step/2.+i*step, n, state))
501

502 2
        cmap = {'inactive': 'white', 'active': '#5cb85c', 'error': 'red',
503
                'next': 'yellow'}
504

505 2
        def tap_renderer(plot, element):
506 0
            from bokeh.models import TapTool
507 0
            gr = plot.handles['glyph_renderer']
508 0
            tap = plot.state.select_one(TapTool)
509 0
            tap.renderers = [gr]
510

511 2
        nodes = hv.Nodes(nodes, ['x', 'y', 'Stage'], 'State').opts(
512
            alpha=0, default_tools=['tap'], hooks=[tap_renderer],
513
            hover_alpha=0, selection_alpha=0, nonselection_alpha=0,
514
            axiswise=True, size=10, backend='bokeh'
515
        )
516 2
        self._progress_sel.source = nodes
517 2
        graph = hv.Graph((edges, nodes)).opts(
518
            edge_hover_line_color='black', node_color='State', cmap=cmap,
519
            tools=[], default_tools=['hover'], selection_policy=None,
520
            node_hover_fill_color='gray', axiswise=True, backend='bokeh')
521 2
        labels = hv.Labels(nodes, ['x', 'y'], 'Stage').opts(
522
            yoffset=-.30, default_tools=[], axiswise=True, backend='bokeh'
523
        )
524 2
        plot = (graph * labels * nodes) if self._linear else (graph * nodes)
525 2
        plot.opts(
526
            xaxis=None, yaxis=None, min_width=400, responsive=True,
527
            show_frame=False, height=height, xlim=(-0.25, depth+0.25),
528
            ylim=(0, 1), default_tools=['hover'], toolbar=None, backend='bokeh'
529
        )
530 2
        return plot
531

532 2
    def _repr_mimebundle_(self, include=None, exclude=None):
533 0
        return self.layout._repr_mimebundle_(include, exclude)
534

535
    #----------------------------------------------------------------
536
    # Public API
537
    #----------------------------------------------------------------
538

539 2
    def add_stage(self, name, stage, **kwargs):
540
        """
541
        Adds a new, named stage to the Pipeline.
542

543
        Arguments
544
        ---------
545
        name: str
546
          A string name for the Pipeline stage
547
        stage: param.Parameterized
548
          A Parameterized object which represents the Pipeline stage.
549
        **kwargs: dict
550
          Additional arguments declaring the behavior of the stage.
551
        """
552 2
        self._validate(stage)
553 2
        for k in kwargs:
554 2
            if k not in self.param:
555 0
                raise ValueError("Keyword argument %s is not a valid parameter. " % k)
556

557 2
        if not self._linear and self._graph:
558 0
            raise RuntimeError("Cannot add stage after graph has been defined.")
559

560 2
        self._stages[name] = (stage, kwargs)
561 2
        if len(self._stages) == 1:
562 2
            self._stage = name
563 2
            self._route = [name]
564 2
            self._graph = {}
565 2
            self.stage[:] = [self._init_stage()]
566
        else:
567 2
            previous = [s for s in self._stages if s not in self._graph][0]
568 2
            self._graph[previous] = (name,)
569 2
        self._update_progress()
570 2
        self._update_button()
571

572 2
    def define_graph(self, graph, force=False):
573
        """
574
        Declares a custom graph structure for the Pipeline overriding
575
        the default linear flow. The graph should be defined as an
576
        adjacency mapping.
577

578
        Arguments
579
        ---------
580
        graph: dict
581
          Dictionary declaring the relationship between different
582
          pipeline stages. Should map from a single stage name to
583
          one or more stage names.
584
        """
585 2
        stages = list(self._stages)
586 2
        if not stages:
587 2
            self._graph = {}
588 2
            return
589

590 2
        graph = {k: v if isinstance(v, tuple) else (v,) for k, v in graph.items()}
591

592 2
        not_found = []
593 2
        for source, targets in graph.items():
594 2
            if source not in stages:
595 0
                not_found.append(source)
596 2
            not_found += [t for t in targets if t not in stages]
597 2
        if not_found:
598 2
            raise ValueError(
599
                'Pipeline stage(s) %s not found, ensure all stages '
600
                'referenced in the graph have been added.' %
601
                (not_found[0] if len(not_found) == 1 else not_found)
602
            )
603

604 2
        if graph:
605 2
            if not (self._linear or force):
606 0
                raise ValueError("Graph has already been defined, "
607
                                 "cannot override existing graph.")
608 2
            self._linear = False
609
        else:
610 2
            graph = {s: (t,) for s, t in zip(stages[:-1], stages[1:])}
611

612 2
        root = get_root(graph)
613 2
        if not is_traversable(root, graph, stages):
614 0
            raise ValueError('Graph is not fully traversable from stage: %s.'
615
                             % root)
616

617 2
        self._stage = root
618 2
        self._graph = graph
619 2
        self._route = [root]
620 2
        if not self._linear:
621 2
            self.buttons[:] = [
622
                Column(self.prev_selector, self.prev_button),
623
                Column(self.next_selector, self.next_button)
624
            ]
625 2
        self.stage[:] = [self._init_stage()]
626 2
        self._update_progress()
627 2
        self._update_button()

Read our documentation on viewing source code .

Loading