Commits

Andrew Dalke committed 26bf3fe

Implemented threshold support. Find the largest common substructure which
is in at least --threshold (between 0.0 and 1.0, inclusive) of the inputs.
No matter what, the found structure will be common to at least two structures.

Comments (0)

Files changed (1)

 
 # Figure out which canonical bonds SMARTS occur in every molecule
 def get_canonical_bondtype_counts(typed_mols):
-    # Get all of the canonical bond counts in the first molecule
-    bondtype_counts = get_counts(typed_mols[0].canonical_bondtypes)
-
-    # Iteratively intersect it with the other typed molecules
-    for typed_mol in typed_mols[1:]:
-        new_counts = get_counts(typed_mol.canonical_bondtypes)
-        bondtype_counts = intersect_counts(bondtype_counts, new_counts)
-
-    return bondtype_counts
-
+    overall_counts = defaultdict(list)
+    for typed_mol in typed_mols:
+        bondtype_counts = get_counts(typed_mols[0].canonical_bondtypes)
+
+        for k,v in bondtype_counts.items():
+            overall_counts[k].append(v)
+    return overall_counts
 
 # If I know which bondtypes exist in all of the structures, I can
 # remove all bonds which aren't in all structures. RDKit's Molecule
 # Caches all previous results.
 
 class CachingTargetsMatcher(dict):
-    def __init__(self, targets):
+    def __init__(self, targets, required_match_count=None):
         self.targets = targets
+        if required_match_count is None:
+            required_match_count = len(targets)
+        self.required_match_count = required_match_count
+        self._num_allowed_errors = len(targets) - required_match_count
         super(dict, self).__init__()
+
+    def shift_targets(self):
+        assert self._num_allowed_errors >= 0, (self.required_match_count, self._num_allowed_errors)
+        if self._num_allowed_errors > 1:
+            self.targets = self.targets[1:]
+            self._num_allowed_errors = len(self.targets) - self.required_match_count
         
     def __missing__(self, smarts):
+        num_allowed_errors = self._num_allowed_errors
+        if num_allowed_errors < 0:
+            raise AssertionError("I should never be called")
+            self[smarts] = False
+            return False
+        
         pat = Chem.MolFromSmarts(smarts)
         if pat is None:
             raise AssertionError("Bad SMARTS: %r" % (smarts,))
+
+        num_allowed_errors = self._num_allowed_errors
         for target in self.targets:
             if not MATCH(target, pat):
-                # Does not match. No need to continue processing
-                self[smarts] = False
-                return False
-                # TODO: should I move the mismatch structure forward
-                # so that it's tested earlier next time?
-        # Matches everything
+                if num_allowed_errors == 0:
+                    # Does not match. No need to continue processing
+                    self[smarts] = False
+                    return False
+                num_allowed_errors -= 1
+        # Matches enough structures, which means it will always
+        # match enough structures. (Even after shifting.)
         self[smarts] = True
         return True
 
             if matches_all_targets[smarts]:
                 best_sizes = hits.add_new_match(subgraph, mol, smarts)
             else:
-                raise AssertionError("This should never happen: %r" % (smarts,))
+                # This can happen if there's a threshold
+                #raise AssertionError("This should never happen: %r" % (smarts,))
                 continue
 
             a1, a2 = bond.atom_indices
         print >>sys.stderr, "  %d subgraphs enumerated, %d processed" % (
             self.num_seeds_added, self.num_seeds_processed)
 
-def compute_mcs(enumeration_mols, targets, maximize = Default.maximize,
+def compute_mcs(fragmented_mols, typed_mols, min_num_atoms, threshold_count=None, maximize = Default.maximize,
                 complete_rings_only = Default.complete_rings_only,
                 timeout = Default.timeout,
                 timer = None, verbose=False, verbose_delay=1.0):
     assert timer is not None
+    assert 0 < threshold_count <= len(fragmented_mols), threshold_count
+    assert len(fragmented_mols) == len(typed_mols)
+    assert len(fragmented_mols) >= 2
+    if threshold_count is None:
+        threshold_count = len(fragmented_mols)
+    else:
+        assert threshold_count >= 2, threshold_count
+    
     atom_assignment = Uniquer()
     if verbose:
         if verbose_delay < 0.0:
             raise ValueError("verbose_delay may not be negative")
-        matches_all_targets = VerboseCachingTargetsMatcher(list(targets))
+        matches_all_targets = VerboseCachingTargetsMatcher(typed_mols[1:], threshold_count-1)
         heapops = VerboseHeapOps(matches_all_targets.report, verbose_delay)
         push = heapops.heappush
         pop = heapops.heappop
         end_verbose = heapops.trigger_report
     else:
-        matches_all_targets = CachingTargetsMatcher(list(targets))
+        matches_all_targets = CachingTargetsMatcher(typed_mols[1:], threshold_count-1)
         push = heappush
         pop = heappop
         end_verbose = lambda: 1
-    
+
     try:
         prune, hits_class = _maximize_options[(maximize, bool(complete_rings_only))]
     except KeyError:
 
     hits = hits_class(timer, verbose)
 
-    success = enumerate_subgraphs(enumeration_mols, prune, atom_assignment, matches_all_targets, hits,
-                                  timeout, push, pop)
+    remaining_time = None
+    if timeout is not None:
+        stop_time = time.time() + timeout
+    
+    for query_index, fragmented_query_mol in enumerate(fragmented_mols):
+        enumerated_query_fragments = fragmented_mol_to_enumeration_mols(fragmented_query_mol, min_num_atoms)
+        
+        targets = typed_mols
+        if timeout is not None:
+            remaining_time = stop_time - time.time()
+        success = enumerate_subgraphs(enumerated_query_fragments, prune, atom_assignment, matches_all_targets, hits,
+                                      remaining_time, push, pop)
+        if query_index + threshold_count > len(fragmented_mols):
+            break
+        if not success:
+            break
+        matches_all_targets.shift_targets()
+        
     end_verbose()
     
     return hits.get_result(success)
             diff = None
         times[dest] = diff
 
+def _get_threshold_count(num_mols, threshold):
+    if threshold is None:
+        return num_mols
+
+    x = num_mols * threshold
+    threshold_count = int(x)
+    if threshold_count < x:
+        threshold_count += 1
+    
+    if threshold_count < 2:
+        # You can specify 0.00001 or -2.3 but you'll still get
+        # at least one *common* substructure.
+        threshold_count = 2
+
+    return threshold_count
 
 
 def fmcs(mols, min_num_atoms=2,
          maximize = Default.maximize,
          atom_compare = Default.atom_compare,
          bond_compare = Default.bond_compare,
+         threshold = 1.0,
          match_valences = Default.match_valences,
          ring_matches_ring_only = False,
          complete_rings_only = False,
         if timeout <= 0.0:
             raise ValueError("timeout must be None or a positive value")
 
+    threshold_count = _get_threshold_count(len(mols), threshold)
+    if threshold_count > len(mols):
+        # Threshold is too high. No possible matches.
+        return MCSResult(-1, -1, None, 1)
+        
     if complete_rings_only:
         ring_matches_ring_only = True
 
                                                   match_valences = match_valences,
                                                   ring_matches_ring_only = ring_matches_ring_only)
     bondtype_counts = get_canonical_bondtype_counts(typed_mols)
+    supported_bondtypes = set()
+    for bondtype, count_list in bondtype_counts.items():
+        if len(count_list) >= threshold_count:
+            supported_bondtypes.add(bondtype)
+            # For better filtering, find the largest count which is in threshold
+            # This can likely be done with:
+            #  count_list.sort(reversed=True)
+            #  max_count = count_list[threshold_count-1]
+
+    
     fragmented_mols = [remove_unknown_bondtypes(typed_mol, bondtype_counts) for typed_mol in typed_mols]
     timer.mark("end fragment")
 
 
     timer.mark("end select")
 
-    # Use the first as the query, the rest as the targets
-    query_fragments = fragmented_mol_to_enumeration_mols(sizes[0][4], min_num_atoms)
-
-    targets = [size[3].rdmol for size in sizes[1:]]
+    # Extract the (typed mol, fragmented mol) pairs.
+    fragmented_mols = [size_info[4] for size_info in sizes]  # used as queries
+    typed_mols = [size_info[3].rdmol for size_info in sizes]    # used as targets
 
     timer.mark("start enumeration")
-    mcs_result = compute_mcs(query_fragments, targets, maximize=maximize,
+    mcs_result = compute_mcs(fragmented_mols, typed_mols, min_num_atoms,
+                             threshold_count=threshold_count, maximize=maximize,
                              complete_rings_only=complete_rings_only, timeout=timeout,
                              timer=timer, verbose=verbose, verbose_delay=verbose_delay)
     timer.mark("end fmcs")
         raise argparse.ArgumentTypeError("must be at least 2, not %s" % s)
     return num_atoms
 
+def parse_threshold(s):
+    try:
+        import fractions
+    except ImportError:
+        threshold = float(s)
+        one = 1.0
+    else:
+        threshold = fractions.Fraction(s)
+        one = fractions.Fraction(1)
+    if not (0 <= threshold <= one):
+        raise argparse.ArgumentTypeError("must be a value between 0.0 and 1.0, not %s" % s)
+    return threshold
+
 def parse_timeout(s):
     if s == "none":
         return None
     if timeout < 0.0:
         raise argparse.ArgumentTypeError("Must be a non-negative value, not %r" % (s,))
     return timeout
-        
 
 class starting_from(object):
     def __init__(self, left):
                         "other bond. With 'bondtypes', bonds are the same only if their bond types "
                         "are the same. (Default: bondtypes)"))
 
+parser.add_argument("--threshold", default="1.0", type=parse_threshold, help=
+                    "Minimum structure match threshold. A value of 1.0 means that the common "
+                    "substructure must be in all of the input structures. A value of 0.8 finds "
+                    "the largest substructure which is common to at least 80% of the input "
+                    "structures. (Default: 1.0)")
+
 parser.add_argument("--atom-class-tag", metavar="TAG", help=
                     "Use atom class assignments from the field 'TAG'. The tag data must contain a space "
                     "separated list of integers in the range 1-10000, one for each atom. Atoms are "
                maximize = args.maximize,
                atom_compare = args.atom_compare,
                bond_compare = args.bond_compare,
+               threshold = args.threshold,
                #match_valences = args.match_valences,
                match_valences = False, # Do I really want to support this?
                ring_matches_ring_only = args.ring_matches_ring_only,
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.