Spaces:
Runtime error
Runtime error
| import pprint | |
| import numpy as np | |
| from . import viz2d | |
| from .tools import RadioHideTool, ToggleTool, __plot_dict__ | |
| class FormatPrinter(pprint.PrettyPrinter): | |
| def __init__(self, formats): | |
| super(FormatPrinter, self).__init__() | |
| self.formats = formats | |
| def format(self, obj, ctx, maxlvl, lvl): | |
| if type(obj) in self.formats: | |
| return self.formats[type(obj)] % obj, 1, 0 | |
| return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl) | |
| class TwoViewFrame: | |
| default_conf = { | |
| "default": "matches", | |
| "summary_visible": False, | |
| } | |
| plot_dict = __plot_dict__ | |
| childs = [] | |
| event_to_image = [None, "color", "depth", "color+depth"] | |
| def __init__(self, conf, data, preds, title=None, event=1, summaries=None): | |
| self.conf = conf | |
| self.data = data | |
| self.preds = preds | |
| self.names = list(preds.keys()) | |
| self.plot = self.event_to_image[event] | |
| self.summaries = summaries | |
| self.fig, self.axes, self.summary_arts = self.init_frame() | |
| if title is not None: | |
| self.fig.canvas.manager.set_window_title(title) | |
| keys = None | |
| for _, pred in preds.items(): | |
| if keys is None: | |
| keys = set(pred.keys()) | |
| else: | |
| keys = keys.intersection(pred.keys()) | |
| keys = keys.union(data.keys()) | |
| self.options = [ | |
| k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys) | |
| ] | |
| self.handle = None | |
| self.radios = self.fig.canvas.manager.toolmanager.add_tool( | |
| "switch plot", | |
| RadioHideTool, | |
| options=self.options, | |
| callback_fn=self.draw, | |
| active=conf.default, | |
| keymap="R", | |
| ) | |
| self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool( | |
| "toggle summary", | |
| ToggleTool, | |
| toggled=self.conf.summary_visible, | |
| callback_fn=self.set_summary_visible, | |
| keymap="t", | |
| ) | |
| if self.fig.canvas.manager.toolbar is not None: | |
| self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation") | |
| self.draw(conf.default) | |
| def init_frame(self): | |
| """initialize frame""" | |
| view0, view1 = self.data["view0"], self.data["view1"] | |
| if self.plot == "color" or self.plot == "color+depth": | |
| imgs = [ | |
| view0["image"][0].permute(1, 2, 0), | |
| view1["image"][0].permute(1, 2, 0), | |
| ] | |
| elif self.plot == "depth": | |
| imgs = [view0["depth"][0], view1["depth"][0]] | |
| else: | |
| raise ValueError(self.plot) | |
| imgs = [imgs for _ in self.names] # repeat for each model | |
| fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5) | |
| [viz2d.add_text(0, n, axes=axes[i]) for i, n in enumerate(self.names)] | |
| if ( | |
| self.plot == "color+depth" | |
| and "depth" in view0.keys() | |
| and view0["depth"] is not None | |
| ): | |
| hmaps = [[view0["depth"][0], view1["depth"][0]] for _ in self.names] | |
| [ | |
| viz2d.plot_heatmaps(hmaps[i], axes=axes[i], cmap="Spectral") | |
| for i, _ in enumerate(hmaps) | |
| ] | |
| fig.canvas.mpl_connect("pick_event", self.click_artist) | |
| if self.summaries is not None: | |
| formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"}) | |
| toggle_artists = [ | |
| viz2d.add_text( | |
| 0, | |
| formatter.pformat(self.summaries[n]), | |
| axes=axes[i], | |
| pos=(0.01, 0.01), | |
| va="bottom", | |
| backgroundcolor=(0, 0, 0, 0.5), | |
| visible=self.conf.summary_visible, | |
| ) | |
| for i, n in enumerate(self.names) | |
| ] | |
| else: | |
| toggle_artists = [] | |
| return fig, axes, toggle_artists | |
| def draw(self, value): | |
| """redraw content in frame""" | |
| self.clear() | |
| self.conf.default = value | |
| self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds) | |
| return self.handle | |
| def clear(self): | |
| if self.handle is not None: | |
| try: | |
| self.handle.clear() | |
| except AttributeError: | |
| pass | |
| self.handle = None | |
| for row in self.axes: | |
| for ax in row: | |
| [li.remove() for li in ax.lines] | |
| [c.remove() for c in ax.collections] | |
| self.fig.artists.clear() | |
| self.fig.canvas.draw_idle() | |
| self.handle = None | |
| def click_artist(self, event): | |
| art = event.artist | |
| select = art.get_arrowstyle().arrow == "-" | |
| art.set_arrowstyle("<|-|>" if select else "-") | |
| if select: | |
| art.set_zorder(1) | |
| if hasattr(self.handle, "click_artist"): | |
| self.handle.click_artist(event) | |
| self.fig.canvas.draw_idle() | |
| def set_summary_visible(self, visible): | |
| self.conf.summary_visible = visible | |
| [s.set_visible(visible) for s in self.summary_arts] | |
| self.fig.canvas.draw_idle() | |