# attention-visual **Repository Path**: lddsdu/attention-visual ## Basic Information - **Project Name**: attention-visual - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2019-07-12 - **Last Updated**: 2020-12-19 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README #### 实验结果展示 **Attention** ```python import numpy as np import matplotlib.pyplot as plt def plot_attention_weight(x_labels, y_labels, color, font_size, values): """ :param x_labels: ["Q", "bird", "head", "legs", "slides", "wood"] :param y_labels: ["wood /r/DistinctFrom carpet", "word /r/AtLocation a fire", "split /r/DistinctFrom like"] :param color: :param font_size: :return: """ x_len = len(x_labels) + 2 y_len = len(y_labels) plt.figure(figsize=(x_len + 1, y_len + 1)) # plt.plot(np.linspace(0, x_len, 100), np.linspace(0, y_len, 100)) x_unit = 1.0 / (x_len + 1) y_unit = 1.0 / (y_len + 1) x_half_unit = x_unit / 2 y_half_unit = y_unit / 2 # assert values.shape[0] == x_len assert values.shape[1] == y_len ax = plt.gca() # p = plt.Rectangle((0, 0), 1, 1, fill=False) # p.set_transform(ax.transAxes) # p.set_clip_on(False) # ax.add_patch(p) plt.axis("off") for x_index, x_label in enumerate(x_labels): ax.text((x_index + 0.5) * x_unit + x_half_unit, 0.0, x_label, horizontalalignment='center', verticalalignment='top', rotation=90, transform=ax.transAxes, fontsize=font_size) for y_index, y_label in enumerate(y_labels): ax.text(0.0, (y_index + 0.5) * y_unit + y_half_unit, y_label, horizontalalignment='right', verticalalignment='center', transform=ax.transAxes, fontsize=font_size) # x_index = 1 # y_index = 1 # x = np.linspace(x_index - 0.5, x_index + 1 - 0.5, 100) # y1 = np.linspace(y_index + 1 - 0.5, y_index + 1 - 0.5, 100) # y2 = np.linspace(y_index - 0.5, y_index - 0.5, 100) # plt.fill_between(x, y1, y2, where=y1 > y2, facecolor="red") for x_index, v_value in enumerate(values, start=0): for y_index, v in enumerate(v_value, start=0): x = np.linspace(x_index - 0.5, x_index + 1 - 0.5, 100) y1 = np.linspace(y_index + 1 - 0.5, y_index + 1 - 0.5, 100) y2 = np.linspace(y_index - 0.5, y_index - 0.5, 100) plt.fill_between(x, y1, y2, where=y1 > y2, facecolor=color, alpha=v) x = np.linspace(x_index + 1 - 0.5, x_index + 2 - 0.5, 1000) y1 = np.linspace(y_len - 0.5, y_len - 0.5, 1000) y2 = np.linspace(- 0.5, - 0.5, 1000) plt.fill_between(x, y1, y2, where=y1>y2, facecolor="white") grain_num = 1000 grain = 1.0 / grain_num for i in np.linspace(0, 1.0 - grain, grain_num - 1): x = np.linspace(x_index + 2 - 0.5, x_index + 3 - 0.5, 1000) y1 = np.linspace(i * y_len - 0.5 + grain * y_len + grain, i * y_len - 0.5 + grain * y_len + grain, 1000) y2 = np.linspace(i * y_len - 0.5, i * y_len - 0.5, 1000) plt.fill_between(x, y1, y2, where=y1 > y2, facecolor=color, alpha=i) plt.text(x_unit * (x_len + 1), y_half_unit, "0.0", horizontalalignment='center', verticalalignment='top', transform=ax.transAxes, fontsize=font_size) plt.text(x_unit * (x_len + 1), 1. - y_half_unit / 2.0, "1.0", horizontalalignment='center', verticalalignment='top', transform=ax.transAxes, fontsize=font_size) plt.show() if __name__ == '__main__': values = np.random.random((12, 6)) values_sum = np.tile(np.sum(values, axis=1).reshape(-1, 1), (1, 6)) values = values / values_sum plot_attention_weight( ["Q", "bird", "head", "legs", "slides", "wood", "Q", "bird", "head", "legs", "slides", "wood"], ["wood /r/DistinctFrom carpet", "word /r/AtLocation a fire", "split /r/DistinctFrom like"] * 2, color="green", font_size=20, values=values) values = np.asarray([[0.1, 0.3, 0.5], [0.9, 0.7, 0.5]], dtype=np.float32) values = np.transpose(values) print(values) print(values.shape) x = ["a", "b", "c"] y = ["d", "e"] plot_attention_weight(x, y, color="red", font_size=18, values=values) ``` **效果分别如下所示** ![attention_weight](assets/attention_weight.png) ![attention_weight2](assets/attention_weight2.png) ![attention_weight2](assets/attention_weight3.png)