Commits

Davide Cittaro committed a0b8874

better pca plots

Comments (0)

Files changed (1)

+#!/usr/bin/env python 
+
+import vcf
+import sys
+import argparse
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import matplotlib.mlab as mlab
+import os
+
+def count_lines(filename):
+  c = 0
+  for line in open(filename):
+    if not line.startswith('#'):
+      c += 1
+  return c
+
+def plot_pca(results, ncomp=2, samples=[], classes='', prefix="PCA"):
+  ext_ind = prefix.rfind('.vcf')
+  if ext_ind > 1:
+    prefix = prefix[:ext_ind]
+  if not classes:
+    color_list = 'k' * len(samples)
+  else:
+    # at the moment we support few colors b g r c m y k
+    supported = 'bgrcmyk'
+    color_list = ''
+    class_names = set([x for x in classes])
+    color_map = {}
+    color_legend = []
+    string_legend = []
+    for x, c in enumerate(class_names):
+      color_map[c] = supported[x]
+      color_legend.append(plt.Circle((1,1),fc=supported[x]))
+      string_legend.append(c)
+    for c in classes:
+      color_list += color_map[c]
+
+  if not color_list:
+    color_list = 'k'
+
+  for x in range(ncomp - 1):
+    for y in range(x + 1, ncomp):
+      x_ext = np.max(np.abs(results.Wt[x]))
+      y_ext = np.max(np.abs(results.Wt[y]))
+      x_ext = x_ext + x_ext * 0.1
+      y_ext = y_ext + y_ext * 0.1
+      for z in range(len(results.fracs)):
+        plt.plot(results.Wt[x][z], results.Wt[y][z], 'o', color=color_list[z])
+        if samples:
+          plt.text(results.Wt[x][z], results.Wt[y][z], samples[z])
+      plt.xlabel("PC%d (%.2f%%)" % ((1 + x), (results.fracs[x] * 100)))
+      plt.ylabel("PC%d (%.2f%%)" % ((1 + y), (results.fracs[y] * 100)))
+      plt.title("Variance explained = %.2f%%" % ((100 * (results.fracs[y] + results.fracs[x]))))
+      if classes:
+        plt.legend(color_legend, string_legend, loc=0)
+      plt.xlim(-x_ext, x_ext)
+      plt.ylim(-y_ext, y_ext)
+      plt.hlines(0, -x_ext, x_ext, linestyle='dashed')  
+      plt.vlines(0, -y_ext, y_ext, linestyle='dashed')  
+      plt.savefig("%s_PC%d_%d.png" % (prefix, (x + 1), (y + 1)))
+      plt.close()
+      
+
+      
+  
+
+
+def analyze():
+  option_parser = argparse.ArgumentParser(
+  description = "PCA analysis of a VCF file", 
+  prog="pcavcf.py")
+  option_parser.add_argument("--version", action="version", version="%(prog)s 0.1")
+  option_parser.add_argument("-v", "--vcf", help="input VCF file", action="store", required=True)
+  option_parser.add_argument("-n", "--ncomp", help="Number of PC to retain", action="store", default=2, type=int)
+  option_parser.add_argument("-o", "--vcfout", help="If given, annotate VCF with PCA data", action="store")
+  option_parser.add_argument("-c", "--classes", help="A string representing the classes of samples (single letters)", action="store", default='')
+  option_parser.add_argument("-d", "--dump", help="Dump PCA data into a npz file", action="store_true", default=False)
+  option_parser.add_argument("-l", "--label", help="Label samples in PCA plots", action="store_true", default=False)
+  
+  
+  # parse arguments
+  cli_options = option_parser.parse_args()
+  
+  # get the number of variants
+  n_var = count_lines(cli_options.vcf)
+  
+  if len(set([x for x in cli_options.classes])) > 7:
+    sys.stderr.write("At the moment max 7 classes are supported")
+    sys.exit(1)
+    
+  parser = vcf.Reader(open(cli_options.vcf))
+  
+  n_samples = len(parser.samples)
+  
+  if cli_options.classes and n_samples != len(cli_options.classes):
+    sys.stderr.write("Number of classes and number of samples do not match\n")
+    sys.exit(1)
+  
+  var_data = np.zeros(n_var * n_samples).reshape((n_var, n_samples))
+
+  # assign -1 to REF/REF, 0 to REF/ALT and 1 to ALT/ALT. 
+  # not called data are np.nan
+  sys.stderr.write("Reading VCF file...\n")
+  for x, record in enumerate(parser):
+    for y, s in enumerate(record.samples):
+      if s.called:
+        if s.is_variant:   
+          if s.is_het:
+            v = 0
+          else:    
+            v = 1
+        else:       
+          v = -1
+      else:
+        v = np.nan    
+      var_data[x][y] = v
+
+  nan_mask = np.sum(np.isnan(var_data), axis=1) == 0  # discard missing data
+  t_data = var_data[nan_mask,]
+
+  # perform the pca
+  results = mlab.PCA(t_data)
+  
+  sys.stderr.write("Plotting PCA charts...\n")
+  if cli_options.label:
+    labels = parser.samples
+  else:
+    labels = []  
+  plot_pca(results, ncomp=cli_options.ncomp, samples=labels, classes=cli_options.classes, prefix=cli_options.vcf)
+  
+  
+  
+
+
+if __name__ == '__main__':
+  analyze()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.