orange-modelmaps / examples / projections / radviz.py

import os.path
import numpy as np
import matplotlib.pyplot as plt
import _modelmaps as mm

from itertools import groupby
from operator import itemgetter
from Orange import clustering, data, distance, utils
from Orange.orng import orngVizRank as vr

ROOT = "/home/miha/work/res/modelmaps"
#ROOT = "/Network/Servers/xgridcontroller.private/lab/mihas/modelmaps"
ROOT = "C:\\Users\\Miha\\work\\res\\modelmaps"

def radviz_in_vr_mm(DATASET, centroids):
    build_map = mm.BuildModelMap(os.path.join(utils.environ.dataset_install_dir, "%s%s" % (DATASET, ".tab")))
    nfeatures = len(build_map.data_d.domain.features)

    features = mm.get_feature_subsets(build_map.data().domain, min_features=3, max_features=3)


    models = []
    models.extend(build_map.build_projection_model(attrs, vr.RADVIZ) for attrs in features)
    table = build_map.build_model_data(models)

    # VIZRANK

    def save_figure(model_instances, method):
        fig = plt.figure(figsize=(6, 9), dpi=300)
        fig.subplots_adjust(wspace=0.3, hspace=0.6, top=0.9, bottom=0.05, left=0.1, right=0.95)

        for i, (score, attr) in enumerate(scored_attributes):
            add_subplot(fig, score, attr, i=(i + 1))

        plt.figtext(0.5, 0.965,  r"%s: %s" % (method, DATASET), ha='center', color='black', weight='bold', size='large')
        plt.savefig(os.path.join(ROOT, "_projections_", "radviz_%s_%s.pdf" % (DATASET, method.lower().replace(" ", ""))))

    def add_subplot(fig, score, attrs, i=1):
        graph = data.preprocess.scaling.ScaleScatterPlotData()
        graph.setData(build_map.data(), graph.rawSubsetData)
        attr_indices = [graph.attribute_name_index[attr] for attr in attrs]
        selected_data = np.take(graph.scaled_data, attr_indices, axis=0)
        class_list = graph.original_data[graph.data_class_index]

        ax = fig.add_subplot(3, 2, i)

        x_dom = set(selected_data[0])
        if len(x_dom) < 10:
            ax.set_xticklabels(list(set(selected_data[0])), size='x-small')
        else:
            for label in ax.get_xticklabels():
                label.set_fontsize('x-small')

        y_dom = set(selected_data[1])
        if len(y_dom) < 10:
            ax.set_yticklabels(list(set(selected_data[1])), size='x-small')
        else:
            for label in ax.get_yticklabels():
                label.set_fontsize('x-small')

        ax.scatter(selected_data[0], selected_data[1], c=class_list, s=50., alpha=0.75)

        ax.set_xlabel(attrs[0], size='small')
        ax.set_ylabel(attrs[1], size='small')

        ax.set_title(r"$\overline{P}=%.2f$" % (score*100), weight='bold', size='medium', position=(0.5, 1.1),
                        horizontalalignment='center', verticalalignment='center')

    scored = sorted((ex for ex in table), key=lambda x: x["P"].value, reverse=True)
    save_figure(scored[:6], "VizRank")


    # MODEL MAP

#    class ModelDistanceConstructor(distance.DistanceConstructor):
#
#        def __new__(cls, data=None):
#            self = distance.DistanceConstructor.__new__(cls)
#            return self.__call__(data) if data else self
#
#        def __call__(self, table):
#            return ModelDistance()
#
#    class ModelDistance(distance.Distance):
#        def __call__(self, e1, e2):
#            return mm.distance_manhattan(e1["model"].value, e2["model"].value)
#
#    def data_center(table):
#        onemodel = table[0]["model"].value
#        model = mm.Model("SCATTERPLOT", None,
#            np.mean(np.array([ex["model"].value.probabilities for ex in table]), axis=0),
#            onemodel.class_values, [],
#            [onemodel.class_values.keys()[0]] * len(onemodel.instance_predictions),
#            [onemodel.class_values.keys()[0]] * len(onemodel.instance_classes))
#
#        return model.get_instance(table.domain)
#
#    clustering.kmeans.data_center = data_center
#    kmeans = clustering.kmeans.Clustering(table, centroids=centroids, distance=ModelDistanceConstructor, initialization=clustering.kmeans.init_diversity)
#
#    clusters = sorted(zip(kmeans.clusters, range(len(kmeans.clusters))), key=itemgetter(0))
#
#    best_projs = []
#    for k, g in groupby(clusters, key=itemgetter(0)):
#        best_projs.append(max(((table[i]["P"].value, i) for c, i in g), key=itemgetter(0)))
#
#    best_projs.sort(key=itemgetter(0), reverse=True)
#    scored = [(score, table[key]["attributes"].value.split(", ")) for score, key in best_projs]
#
#    save_figure(scored[:6], "Model Map")


    # SAVE MODEL MAP

    smx = build_map.build_model_matrix(models)
    mm.save(os.path.join(ROOT, "_projections_", "radviz_%s" % DATASET), smx, table, build_map.data())

radviz_in_vr_mm("zoo", 10)
#radviz_in_vr_mm("vehicle", 6)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.