Marinka Zitnik avatar Marinka Zitnik committed 63aa2e0

update to easily include multiple wild-type ORFs

Comments (0)

Files changed (3)

 import utilities as utils
 import plotting
 import methods
+
+#For new wild-type ORFs, just add them to this list
+WT_ORFs = ["YOR202W"]
     
 def strains_1p_WT(meta, plate, res_path, plot_attr_hist = True):
     """
     """
     file_name, attrs = meta
     out_name = res_path + file_name[:-4]
-    dataf = utils.filter_attribute(attrs, plate, attr_name = "ORF", attr_values = ["YOR202W"])
+    dataf = utils.filter_attribute(attrs, plate, attr_name = "ORF", attr_values = WT_ORFs)
     dnpf = utils.data2np(dataf, skip_first = 5, skip_last = 1)
     if plot_attr_hist:
         for i in xrange(dnpf.shape[1]):
     dist_matrix = methods.compute_distance(r_dnp, title = file_name, out_name = out_name, labels = tuple(r_dataf[i][0] for i in range(r_dnp.shape[0])))
     
     
-    wt, mt = utils.split_WT_MT(meta, plate)
+    wt, mt = utils.split_WT_MT(meta, plate, wt_mt_name = "ORF", wt_name = WT_ORFs)
     #WT
     dnp_wt = utils.data2np(wt, skip_first = 5, skip_last = 1)
     dnp_wt = methods.standardize(dnp_wt)
     
     pred_kmeans, score_kmeans, trans_kmeans = methods.k_means(dnp, dataf, k_range = range(2,6), out_name_silhouette = out_name, 
                                                               out_dir_predictions = res_path, out_name_predictions = file_name[:-4],
-                                                              save_silhouette = True, save_predictions = True)
+                                                              save_silhouette = True, save_predictions = True, wt_name = WT_ORFs)
     
     methods.outlier2cluster(en_out[0], pred_kmeans)
     pca_trans = methods.decompose_PCA(r_dnp, r_dataf, n_components = 3, title = file_name, out_name = out_name)
     file_name, attrs = metaj
     out_name = res_path + file_name
     
-    dataf = utils.filter_attribute(attrs, platej, attr_name = "ORF", attr_values = ["YOR202W"])
+    dataf = utils.filter_attribute(attrs, platej, attr_name = "ORF", attr_values = WT_ORFs)
     dnpf = utils.data2np(dataf, skip_first = 5, skip_last = 1)
     dnp = methods.standardize(dnpf)
     
     dist_matrix = methods.compute_distance(r_dnp, title = file_name + " (joined std., o.r.)", out_name = out_name + "__joined_std_out", labels = tuple(r_dataf[i][0] for i in range(len(r_dataf))))
     pred_kmeans, score_kmeans, trans_kmeans = methods.k_means(dnp, dataf, k_range = range(2,6), out_name_silhouette = out_name,
                                                               out_dir_predictions = res_path, out_name_predictions = "combined",
-                                                              save_silhouette = True, save_predictions = True)
+                                                              save_silhouette = True, save_predictions = True, wt_name = WT_ORFs)
     methods.outlier2cluster(en_out[0], pred_kmeans)
     pca_trans = methods.decompose_PCA(r_dnp, r_dataf, n_components = 3, title = file_name, out_name = out_name)
     
     print
     dnp_all, plate_all = [], []
     for smeta, splate in zip(meta, plates):
-        wt, mt = utils.split_WT_MT(smeta, splate)
+        wt, mt = utils.split_WT_MT(smeta, splate, wt_mt_name = "ORF", wt_name = WT_ORFs)
         #WT
         dnp_wt = utils.data2np(wt, skip_first = 5, skip_last = 1)
         dnp_wt = methods.standardize(dnp_wt)
     plates_mt, dnp_mtc = [], []
     for coll in [data_del, data_ts, data_sg]:
         for meta, plate in zip(coll[0], coll[1]):
-            wt, mt = utils.split_WT_MT(meta, plate)
+            wt, mt = utils.split_WT_MT(meta, plate, wt_mt_name = "ORF", wt_name = WT_ORFs)
             #WT
             dnp_wt = utils.data2np(wt, skip_first = 5, skip_last = 1)
             if standardize:
         dnp = []
         
         for meta, plate in zip(coll[0], coll[1]):
-            wt, mt = utils.split_WT_MT(meta, plate)
+            wt, mt = utils.split_WT_MT(meta, plate, wt_mt_name = "ORF", wt_name = WT_ORFs)
             #WT
             dnp_wt = utils.data2np(wt, skip_first = 5, skip_last = 1)
             dnp_wt = methods.standardize(dnp_wt)
     .. seealso:: See also functions :func:`sa.methods.fss_wrapper`, :func:`sa.methods.decompose_MDS` and
                  :func:`sa.utilities.std_prep`. 
     """
-    dnp, plates = utils.std_prep(data_del, data_ts, data_sg, res_path)
+    dnp, plates = utils.std_prep(data_del, data_ts, data_sg, res_path, wt_attr_name = "ORF", wt_name = WT_ORFs)
     dnp = dnp[:, 1:]
     attr_names = data_del[0][0][1][6:-1]
     assert dnp.shape[1] == len(attr_names), "The shapes of attribute space and feature names do not match."
     
     .. seealso:: See also functions :func:`sa.methods.decompose_MDS` and :func:`sa.utilities.std_prep`. 
     """
-    dnp, plates = utils.std_prep(data_del, data_ts, data_sg, res_path)
+    dnp, plates = utils.std_prep(data_del, data_ts, data_sg, res_path, wt_attr_name = "ORF", wt_name = WT_ORFs)
     dnp = dnp[:, 1:]
     attr_names = data_del[0][0][1][6:-1]
     assert dnp.shape[1] == len(attr_names), "The shapes of attribute space and feature names do not match."
     plates_mt, dnp_mtc = [], []
     for coll in [data_del, data_ts, data_sg]:
         for meta, plate in zip(coll[0], coll[1]):
-            wt, mt = utils.split_WT_MT(meta, plate)
+            wt, mt = utils.split_WT_MT(meta, plate, wt_mt_name = "ORF", wt_name = WT_ORFs)
             #WT
             dnp_wt = utils.data2np(wt, skip_first = 5, skip_last = 1)
             dnp_wt = methods.standardize(dnp_wt)
     print "No. observations of wild-type strains: %d" % len(plates_wt)
     
     methods.detect_novelties_SVM(dnp_wtc, dnp_mtc, plates_wt, plates_mt, res_path, save_visualization = True)
-    methods.detect_novelties_GMM(dnp_wtc, dnp_mtc, plates_wt, plates_mt, res_path)
+    methods.detect_novelties_GMM(dnp_wtc, dnp_mtc, plates_wt, plates_mt, res_path, save_visualization = True, wt_name = WT_ORFs)
     
     
     
 #clustering
 
 def k_means(dnp, plates, k_range, out_name_silhouette = None, out_dir_predictions = None, 
-            out_name_predictions = None, save_silhouette = True, save_predictions = True):
+            out_name_predictions = None, save_silhouette = True, save_predictions = True, wt_name = ["YOR202W"]):
     """
     Apply k-Means clustering to the data.
     
     :param save_predictions: Indicator whether to save predictions (cluster membership for each observation) for each 
                             number of clusters in :param:`k_range`.
     :type save_predictions: `bool`
+    :param wt_name: Names of the wild-type ORFs.
+    :type wt_name: `list`
     
     For each number of clusters compute the mean silhouette coefficient of all observations. Return clustering 
     results for clustering with highest mean silhouette coefficient. 
             assert len(pred) == len(plates), "Error. Number of predictions does not match with observations in plates."
             NK = np.max(pred)
             for j in xrange(NK + 1):
-                wtn = np.sum([1 for i, el in enumerate(plates) if el[-1] == "YOR202W" and pred[i] == j])
+                wtn = np.sum([1 for i, el in enumerate(plates) if el[-1] in wt_name and pred[i] == j])
                 f.write("Cluster %d: size = %d (WT = %d)\n" % (j, np.sum(pred == j), wtn))
             f.write("%s\n" % (",".join(["ORF", "plate", "date", "row", "col", "cluster"])))
             for i, strain in enumerate(plates):
         print "Saving one-class SVM novelties detection plot to file: %s" % fname
         plt.savefig(fname)
 
-def detect_novelties_GMM(dnp_wt, dnp_mt, plates_wt, plates_mt, out_dir, save_visualization = True):
+def detect_novelties_GMM(dnp_wt, dnp_mt, plates_wt, plates_mt, out_dir, save_visualization = True, wt_name = ["YOR202W"]):
     """
     Novelty detection using variational inference for the GMM (Gaussian Mixture Model). 
     
     :param save_visualization: Indicator whether to visualize training and test observations with 
                       labels in a low dimensionl representation. By default, it is set to True. 
     :type save_visualization: `bool`
+    :param wt_name: Names of the wild-type ORFs.
+    :type wt_name: `list`
     
     To  directory :param:`out_dir` is saved file named `novelty_detection_GMM.csv` which contains
     strain identifiers and predictions (predicted class for each observation and predicted posterior 
     print "No. of available components: %d" % clf.n_components
     print "Components sizes:"
     for i in xrange(clf.n_components):
-        wtn = np.sum([1 for k, el in enumerate(plates_wt) if el[-1] == "YOR202W" and pred_mt[k] == i])
+        wtn = np.sum([1 for k, el in enumerate(plates_wt) if el[-1] in wt_name and pred_mt[k] == i])
         cls.append((np.sum(pred_mt == i), i))
         print "\t comp. no. %d: size = %d (WT = %d)" % (i, cls[-1][0], wtn)
     cls.sort()
     f.close()
     return names, data_k
 
-def std_prep(data_del, data_ts, data_sg, out_dir):
+def std_prep(data_del, data_ts, data_sg, out_dir, wt_attr_name = "ORF", wt_name = ["YOR202W"]):
     """
     Standard preprocessing; (1) standardize WT strains in each plate and remove outiers,
     (2) standardize mutant strains, (3) combine computational and
     :type data_sg: `tuple` (meta_data, plates_data)
     :param out_dir: Full path to directory where data in Orange format is saved.
     :type out_dir: `str`
+    :param wt_attr_name: Identifier of attribute that contains ORFs.
+    :type wt_attr_name: `str`
+    :param wt_name: Names of the wild-type ORFs.
+    :type wt_name: `list`
     
     Return preprocessed computational profiles and plates data.
     
     plates_mt, dnp_mtc = [], []
     for coll in [data_del, data_ts, data_sg]:
         for meta, plate in zip(coll[0], coll[1]):
-            wt, mt = split_WT_MT(meta, plate)
+            wt, mt = split_WT_MT(meta, plate, wt_attr_name, wt_name)
             #WT
             dnp_wt = data2np(wt, skip_first = 5, skip_last = 1)
             dnp_wt = methods.standardize(dnp_wt)
     print "No. filtered observations: %d" % len(filtered)
     return filtered
 
-def split_WT_MT(meta, data, wt_mt_name = "ORF", wt_name = "YOR202W"):
+def split_WT_MT(meta, data, wt_mt_name = "ORF", wt_name = ["YOR202W"]):
     """Split plate data to two groups: (i) wild-type, (ii) mutants. Wild-type strains are in entire border."""
     fname, attrs = meta
     wt = []
     print "File: %s" % fname
     nidx = attrs.index(wt_mt_name)
     for row in data:
-        if row[nidx] == wt_name:
+        if row[nidx] in wt_name:
             wt.append(row)
         else:
             mt.append(row)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.