Commits

Gael Pasgrimaud  committed 0e6d225

allow to use extra args in attr() / css()

  • Participants
  • Parent commits cb35449

Comments (0)

Files changed (2)

File pyquery/README.txt

     >>> p.attr.id = "plop"
     >>> p.attr.id
     'plop'
-    >>> p.attr["id"] = "hello"
+    >>> p.attr["id"] = "ola"
     >>> p.attr["id"]
-    'hello'
+    'ola'
+    >>> p.attr(id='hello', class_='hello2')
+    [<p#hello.hello2>]
+    >>> p.attr.class_
+    'hello2'
+    >>> p.attr.class_ = 'hello'
 
 You can also play with css classes::
 
     >>> p.css['font-size'] = "15px"
     >>> p.attr.style
     'font-size: 15px'
+    >>> p.css(font_size="16px")
+    [<p#hello.hello>]
+    >>> p.attr.style
+    'font-size: 16px'
     >>> p.css = {"font-size": "17px"}
     >>> p.attr.style
     'font-size: 17px'

File pyquery/pyquery.py

         class _element(object):
             """real element to support set/get/del attr and item and js call
             style"""
-            def __call__(prop, name, value=NoDefault):
-                if isinstance(name, basestring):
-                    # this is to set css attr
-                    name = name.replace('_', '-')
-                return self.pget(instance, name, value)
+            def __call__(prop, *args, **kwargs):
+                return self.pget(instance, *args, **kwargs)
             __getattr__ = __getitem__ = __setattr__ = __setitem__ = __call__
             def __delitem__(prop, name):
                 if self.pdel is not NoDefault:
     ##############
     # Attributes #
     ##############
-    def attr(self, name, value=NoDefault):
+    def attr(self, *args, **kwargs):
+
+        mapping = {'class_': 'class', 'for_': 'for'}
+
+        attr = value = NoDefault
+        length = len(args)
+        if length == 1:
+            attr = args[0]
+            attr = mapping.get(attr, attr)
+        elif length == 2:
+            attr, value = args
+            attr = mapping.get(attr, attr)
+        elif kwargs:
+            attr = {}
+            for k, v in kwargs.items():
+                attr[mapping.get(k, k)] = v
+        else:
+            raise ValueError('Invalid arguments %s %s' % (args, kwargs))
+
         if not self:
             return None
-        if value is NoDefault:
-            return self[0].get(name)
+        elif isinstance(attr, dict):
+            for tag in self:
+                for key, value in attr.items():
+                    tag.set(key, value)
+        elif value is NoDefault:
+            return self[0].get(attr)
         elif value is None or value == '':
-            return self.removeAttr(name)
-        elif type(name) == dict:
-            for tag in self:
-                for key, value in name.items():
-                    tag.set(key, value)
+            return self.removeAttr(attr)
         else:
             for tag in self:
-                tag.set(name, value)
+                tag.set(attr, value)
         return self
 
     def removeAttr(self, name):
             tag.set('class', ' '.join(classes))
         return self
 
-    def css(self, attr, value=NoDefault):
+    def css(self, *args, **kwargs):
+
+        attr = value = NoDefault
+        length = len(args)
+        if length == 1:
+            attr = args[0]
+        elif length == 2:
+            attr, value = args
+        elif kwargs:
+            attr = kwargs
+        else:
+            raise ValueError('Invalid arguments %s %s' % (args, kwargs))
+
         if isinstance(attr, dict):
             for tag in self:
-                stripped_keys = [key.strip() for key in attr.keys()]
+                stripped_keys = [key.strip().replace('_', '-')
+                                 for key in attr.keys()]
                 current = [el.strip()
                            for el in (tag.get('style') or '').split(';')
                            if el.strip()
                            and not el.split(':')[0].strip() in stripped_keys]
                 for key, value in attr.items():
+                    key = key.replace('_', '-')
                     current.append('%s: %s' % (key, value))
                 tag.set('style', '; '.join(current))
         elif isinstance(value, basestring):
+            attr = attr.replace('_', '-')
             for tag in self:
                 current = [el.strip()
                            for el in (tag.get('style') or '').split(';')