Commits

Christian Wyglendowski committed dd74995

Added find, filter and end traversal (plus basic tests).

Comments (0)

Files changed (2)

     # Traversing #
     ##############
 
+    def filter(self, selector):
+        """Filter elements in self using selector."""
+        return self.__class__(selector, self, **dict(parent=self))
+
+    def find(self, selector):
+        """Find elements using selector traversing down from self."""
+        xpath = selector_to_xpath(selector)
+        results = [child.xpath(xpath) for tag in self for child in tag.getchildren()]
+        # Flatten the results
+        elements = []
+        for r in results:
+            elements.extend(r)
+        return self.__class__(elements, **dict(parent=self))
+
     def each(self, func):
         """apply func on each nodes
         """
     def size(self):
         return len(self)
 
+    def end(self):
+        return self._parent
+
     ##############
     # Attributes #
     ##############
            </html>
            """
 
+    html3 = """
+           <html>
+            <body>
+              <div id="node1"><span>node1</span></div>
+              <div id="node2"><span>node2</span><span> booyah</span></div>
+            </body>
+           </html>
+           """
+
     def test_selector_from_doc(self):
         doc = etree.fromstring(self.html)
         assert len(self.klass(doc)) == 1
         assert isinstance(n, self.klass)
         assert n._parent is e
 
+    def test_filter(self):
+        assert len(self.klass('div', self.html).filter('.node3')) == 1
+        assert len(self.klass('div', self.html).filter('#node2')) == 1
+
+    def test_find(self):
+        assert len(self.klass('#node1', self.html3).find('span')) == 1
+        assert len(self.klass('#node2', self.html3).find('span')) == 2
+        assert len(self.klass('div', self.html3).find('span')) == 3
+
+    def test_end(self):
+        assert len(self.klass('div', self.html3).find('span').end()) == 2
+        assert len(self.klass('#node2', self.html3).find('span').end()) == 1
+
 def application(environ, start_response):
     req = Request(environ)
     response = Response()