1
import re
2
from lxml import etree
3

4
from six import string_types
5

6
from w3lib.html import HTML5_WHITESPACE
7

8
regex = '[{}]+'.format(HTML5_WHITESPACE)
9
replace_html5_whitespaces = re.compile(regex).sub
10

11

12
def set_xpathfunc(fname, func):
13
    """Register a custom extension function to use in XPath expressions.
14

15
    The function ``func`` registered under ``fname`` identifier will be called
16
    for every matching node, being passed a ``context`` parameter as well as
17
    any parameters passed from the corresponding XPath expression.
18

19
    If ``func`` is ``None``, the extension function will be removed.
20

21
    See more `in lxml documentation`_.
22

23
    .. _`in lxml documentation`: https://lxml.de/extensions.html#xpath-extension-functions
24

25
    """
26
    ns_fns = etree.FunctionNamespace(None)
27
    if func is not None:
28
        ns_fns[fname] = func
29
    else:
30
        del ns_fns[fname]
31

32

33
def setup():
34
    set_xpathfunc('has-class', has_class)
35

36

37
def has_class(context, *classes):
38
    """has-class function.
39

40
    Return True if all ``classes`` are present in element's class attr.
41

42
    """
43
    if not context.eval_context.get('args_checked'):
44
        if not classes:
45
            raise ValueError(
46
                'XPath error: has-class must have at least 1 argument')
47
        for c in classes:
48
            if not isinstance(c, string_types):
49
                raise ValueError(
50
                    'XPath error: has-class arguments must be strings')
51
        context.eval_context['args_checked'] = True
52

53
    node_cls = context.context_node.get('class')
54
    if node_cls is None:
55
        return False
56
    node_cls = ' ' + node_cls + ' '
57
    node_cls = replace_html5_whitespaces(' ', node_cls)
58
    for cls in classes:
59
        if ' ' + cls + ' ' not in node_cls:
60
            return False
61
    return True

Read our documentation on viewing source code .

Loading