Commits

Graham Helliwell committed d272273

Add forcepartial option

Comments (0)

Files changed (2)

guestrepo/__init__.py

 updateopt = [('u', 'update', False,
            "perform a grupdate after the grpull")]
 
+forcepartialopt = [('f', 'forcepartial', False,
+           "force a partial update by silently ignoring disallowed paths")]
+
 threadopt = [('j', 'threads', 1,
           'number of threads to use. '
           'WARNING: ssh prompts are broken, and username and password prompts '
 # format: {"command-name": (function, options-list, help-string)}
 cmdtable = {
    "grpull": (grpull,
-              localopt + threadopt + updateopt,
+              localopt + threadopt + updateopt + forcepartialopt,
               "hg grpull [REPO [+]] [options]"),
    "grpush": (push,
               localopt + threadopt,
               "hg grpush [REPO [+]] [options]"),
    "grupdate": (grupdate,
-              localopt + threadopt + pullopt,
+              localopt + threadopt + pullopt + forcepartialopt,
               "hg grupdate [REPO [+]] [options]"),
    "grfreeze": (freeze,
               [],

guestrepo/guestrepo.py

 import itertools
 import os
 
+import lockedui
 import mercurial.ui
 from mercurial import util, config, hg, commands, \
                       error, node, scmutil, bookmarks, localrepo
 
 GR_CONFIG = '.hgguestrepo'
 GR_MAPPING = '.hggrmapping'
+GR_AUTHFILENAME = '.grauth'
 
 #####################
 # Commands
    else:
       return 0
 
+#####################
 
-def pull(ui, repo, local, args, opts):
-    guests = getguests(repo)
+def pullupdate(ui, repo, shouldpull, shouldupdate, args, opts):
+    local = opts.get('local')
+    abortIfInvalid= not opts.get('forcepartial')
+    succeeded = True
+
+    guests = getguests(repo, None, abortIfInvalid)
     if args:
-        guests = matchguests(repo.root,
-                             os.getcwd(),
-                             args,
-                             guests)
+       guests = matchguests(repo.root,
+                         os.getcwd(),
+                         args,
+                         guests)
 
+    if (shouldpull):
+       succeeded = pull(ui, repo, local, guests, opts)
+    if succeeded and shouldupdate:
+       succeeded = update(ui, repo, local, guests, opts)
+    return 0 if succeeded else 1
+
+def pull(ui, repo, local, guests, opts):
     def pullaction(ui, repo, guest):
         ui.status('pulling %s\n' % guest.canonpath)
         commands.pull(ui, repo, guest.uri)
         workers.join()
     if len(workers.errors) > 0:
         showerrors(ui, workers)
-        return 1
+        return False
     else:
-        return 0
+        return True
 
-def update(ui, repo, local, args, opts):
-    guests = getguests(repo)
-    if args:
-        guests = matchguests(repo.root,
-                             os.getcwd(),
-                             args,
-                             guests)
-
+def update(ui, repo, local, guests, opts):
     path = dirtyrecursive(ui, repo, local, guests)
     if path:
         raise util.Abort("repository %s contains uncommitted changes" % path)
         workers.join()
     if len(workers.errors) > 0:
         showerrors(ui, workers)
-        return 1
+        return False
     else:
-        return 0
+        return True
 
 #####################
 
-_AUTHFILENAME = '.grauth'
-
 class guestrepo(object):
     '''An aggregate representing a guest repository'''
     def __init__(self, name, configpath, canonpath, uri, csid, root):
     return mappingconfig['']
 
 
-def getguests(repo, ctx=None):
+def getguests(repo, ctx=None, abortIfInvalid=True):
     ''' Get the guest repos by parsing the .hgguestrepo.
 
         The uri field of the guests is not set until they are matched using the
                                       configpath,
                                       auditor = pathauditor)
         except util.Abort:
-               canonpath = authorizedoutisderoot(os.path.join(repo.root, configpath))
+               canonpath = authorizedoutisderoot(os.path.join(repo.root, configpath), abortIfInvalid)
+
+        if canonpath == None:
+           continue
 
         if canonpath == '':
             raise util.Abort("guest path '%s' refers to parent repository!" %
 
     return guests
 
-def authorizedoutisderoot(path):
-    '''A path is authorized if anywhere above it in the file tree there's a file called _AUTHFILENAME
+def authorizedoutisderoot(path, abortIfInvalid=True):
+    '''A path is authorized if anywhere above it in the file tree there's a file called .grauth
        If the given path is authorized the normalized path is returned.
-       Otherwise, an Abort exception is thrown.
+       If the given path is not authorized, an Abort exception is thrown if abortIfInvalid==True, or None is returned otherwise.
     '''
     #Don't check the path itself
     path = util.pconvert(os.path.normpath(path))
           if currentpath == previouspath:
              break
 
-          authfile = os.path.join(currentpath, _AUTHFILENAME)
+          authfile = os.path.join(currentpath, GR_AUTHFILENAME)
           if os.path.exists(authfile):
              return path
           previouspath = currentpath
 
     bestgrauthlocation, p = os.path.split(path)
-    bestgrauthlocation = os.path.join(bestgrauthlocation, _AUTHFILENAME)
-    raise util.Abort("'{0}' not under root. To override this check, add an empty file called '{1}'".format(path, bestgrauthlocation))
+    bestgrauthlocation = os.path.join(bestgrauthlocation, GR_AUTHFILENAME)
+    if abortIfInvalid:
+       raise util.Abort("'{0}' not under root. To override this check, add an empty file called '{1}'".format(path, bestgrauthlocation))
+    return None
 
 def rejectnestedguests(guests):
     ''' Given a collection of guests, throw an exception if the path of any
     if mapping:
         resolvemapping(guests, getmapping(repo, mapping == 'local'))
     for guest in guests:
-        newui = ui.copy()
-        newui.associate_guest(guest)
+        newui = lockedui.lockedui(ui)
         if guest.isrepo():
             def work(guest=guest,newui=newui):
                 checkcycles(guest.uri, bannedlocs)