Fixed / overloading when enabling true division.
[tx] / xpath.py
1 # -*- coding:utf-8 -*-
2
3 # XPath main module
4 # Copyright (C) 2005  Frédéric Jolliton <frederic@jolliton.com>
5
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19
20 __all__ = [
21         'compile' ,
22         'XPath'
23         ]
24
25 #
26 # TODO:
27 #
28 # [ ] Use iterator instead of sequence in some place to optimize speed
29 #     and memory usage ?
30 # [ ] Identify more place for optimization
31 #
32
33 g_debug = False
34 g_dontOptimize = False
35
36 from error import Error, XPathError
37 from sequence import _Empty, _False, _True
38 from context import Context, nullContext
39 from nodes import Node, Document, Element, Attribute, Comment, Text
40 from iterators import *
41 from misc import NotReached, plural
42 import xpathparser
43 from xpath_misc import lispy
44
45 from xpathfn import *
46
47 # [TR/xquery-operators]
48
49 from xpathfn import functions as xpathFunctions
50
51 # [xpath20] 3.2.1.1 Axes
52
53 axes = {
54           'child'              : iterChild
55         , 'child-or-top'       : iterChildOrTop
56         , 'descendant'         : iterDescendant
57         , 'attribute'          : iterAttribute
58         , 'attribute-or-top'   : iterAttributeOrTop
59         , 'self'               : iterSelf
60         , 'descendant-or-self' : iterDescendantOrSelf
61         , 'following-sibling'  : iterFollowingSibling
62         , 'following'          : iterFollowing
63         , 'namespace'          : None
64         , 'parent'             : iterParent
65         , 'ancestor'           : iterAncestor
66         , 'preceding-sibling'  : iterPrecedingSibling
67         , 'preceding'          : iterPreceding
68         , 'ancestor-or-self'   : iterAncestorOrSelf
69 }
70
71 unaryOperators = {
72           'minus' : opMinus ,
73           'plus'  : opPlus
74 }
75
76 binaryOperators = {
77           'eq'        : opValueEqual
78         , 'ne'        : opValueNotEqual
79         , 'lt'        : opValueLessThan
80         , 'gt'        : opValueGreaterThan
81         , 'le'        : opValueLessOrEqual
82         , 'ge'        : opValueGreaterOrEqual
83         , '='         : opGeneralEqual
84         , '!='        : opGeneralNotEqual
85         , '<'         : opGeneralLessThan
86         , '>'         : opGeneralGreaterThan
87         , '<='        : opGeneralLessOrEqual
88         , '>='        : opGeneralGreaterOrEqual
89         , 'is'        : opIsSameNode
90         , '<<'        : opNodeBefore
91         , '>>'        : opNodeAfter
92         , 'and'       : opAnd
93         , '|'         : opUnion
94         , 'union'     : opUnion
95         , 'except'    : opExcept
96         , 'intersect' : opIntersection
97         , '+'         : opAdd
98         , '-'         : opSubstract
99         , '*'         : opMultiply
100         , 'div'       : opDivide
101         , 'idiv'      : opIntegerDivide
102         , 'range'     : opTo
103 }
104
105 def functionArity( f ) :
106
107         n = f.func_code.co_argcount
108         min = n - len( f.func_defaults or () )
109         max = n
110         if f.func_code.co_flags & 4 :
111                 max = None
112         return min , max
113
114 def assertFunctionArity( fun , label , n , diff = 0 ) :
115
116         range = functionArity( fun )
117         error = None
118         if range[ 0 ] == range[ 1 ] :
119                 expected = range[ 0 ] - diff
120                 if n != expected :
121                         if expected == 0 :
122                                 error = 'no' , 0
123                         else :
124                                 error = 'exactly %d' % expected , expected
125         else :
126                 if n + diff < range[ 0 ] :
127                         error = 'at least %d' % ( range[ 0 ] - diff ) , range[ 0 ] - diff
128                 elif range[ 1 ] is not None and n + diff > range[ 1 ] :
129                         error = 'at most %d' % ( range[ 1 ] - diff ) , range[ 1 ] - diff
130         if error is not None :
131                 raise XPathError( 'XPST0017' ,
132                                                   '%s() takes %s argument%s (%d given)' \
133                                                   % ( label ,
134                                                           error[ 0 ] ,
135                                                           plural( error[ 1 ] , 's' ) , n ) )
136
137 #--[ Functions used during evaluation ]--------------------------------------
138 # ...
139
140 def contextItem( context ) :
141
142         if context.item is None :
143                 raise XPathError( 'XPDY0002' , 'undefined context item' )
144         return Sequence( context.item )
145
146 def unionOfSequences( sequences ) :
147
148         '''
149         Merge several sequence of nodes
150         '''
151
152         unionType = SEQUENCE_EMPTY
153         for sequence in sequences :
154                 unionType |= sequence.type
155                 if unionType == SEQUENCE_MIXED :
156                         raise XPathError( 'XPTY0018' , 'expected either nodes or atomic items but not both' )
157
158         # FIXME: Don't work if nodes come from several documents
159         if unionType == SEQUENCE_EMPTY :
160                 return Sequence()
161         elif len( sequences ) == 1 :
162                 return sequences[ 0 ]
163         elif unionType == SEQUENCE_NODES :
164                 nodes = {}
165                 for sequence in sequences :
166                         for node in sequence :
167                                 nodes[ node.position ] = node
168                 nodes = nodes.items()
169                 nodes.sort()
170                 return Sequence( node[ 1 ] for node in nodes )
171         else :
172                 return Sequence( sequences )
173
174 #----------------------------------------------------------------------------
175
176 def ensureString( s ) :
177
178         if isinstance( s , str ) :
179                 return s
180         elif isinstance( s , unicode ) :
181                 return s.encode( 'ascii' , 'ignore' )
182         else :
183                 return str( s )
184
185 def ensureCallable( e ) :
186
187         if not callable( e ) :
188                 def fun( context ) :
189                         return e
190                 fun.__name__ = '_static_'
191         else :
192                 fun = e
193         return fun
194
195 #--[ Functions producing functions ]-----------------------------------------
196
197 #
198 # Return (node -> bool)
199 #
200
201 def is_node() :
202
203         def fun( node ) :
204                 return isinstance( node , Node )
205         fun.__name__ = '_is_node_'
206
207         return fun
208
209 def is_document() :
210
211         def fun( node ) :
212                 return isinstance( node , Document )
213         fun.__name__ = '_is_document_'
214
215         return fun
216
217 def is_element( name = None ) :
218
219         if name is None or name == '*' :
220                 def fun( node ) :
221                         return isinstance( node , Element )
222         else :
223                 def fun( node ) :
224                         return isinstance( node , Element ) and node.name == name
225         fun.__name__ = '_is_element(..)_'
226
227         return fun
228
229 def is_attribute( name = None ) :
230
231         if name is None or name == '*' :
232                 def fun( node ) :
233                         return isinstance( node , Attribute )
234         else :
235                 def fun( node ) :
236                         return isinstance( node , Attribute ) and node.name == name
237         fun.__name__ = '_attribute(..)_'
238
239         return fun
240
241 def is_comment() :
242
243         def fun( node ) :
244                 return isinstance( node , Comment )
245         fun.__name__ = '_is_comment_'
246
247         return fun
248
249 def is_text() :
250
251         def fun( node ) :
252                 return isinstance( node , Text )
253         fun.__name__ = '_is_text_'
254
255         return fun
256
257 def makeTest( t ) :
258
259         '''
260         Return (node -> bool)
261         '''
262
263         if t[ 0 ] == 'element' :
264                 test = is_element( *t[ 1 : ] )
265         elif t[ 0 ] == 'attribute' :
266                 test = is_attribute( *t[ 1 : ] )
267         elif t[ 0 ] == 'node' :
268                 test = is_node()
269         elif t[ 0 ] == 'document' :
270                 test = is_document()
271         elif t[ 0 ] == 'text' :
272                 test = is_text()
273         elif t[ 0 ] == 'comment' :
274                 test = is_comment()
275         else :
276                 raise NotReached
277         return test
278
279 #--[ Sequence filter ]-------------------------------------------------------
280
281 def at_position( n ) :
282
283         '''
284         Return (context x sequence -> sequence)
285         '''
286
287         def fun( context , sequence ) :
288                 if 1 <= n <= len( sequence ) :
289                         return Sequence( sequence[ n - 1 ] )
290                 else :
291                         return Sequence()
292         fun.__name__ = '_at_position_'
293
294         return fun      
295
296 #----------------------------------------------------------------------------
297
298 def doStep( producer , predicates = () ) :
299
300         '''
301         A step is for example: descendant-or-self::element(foo)[@bar>10][2][quux and ../baz]
302         where the 'descendant-or-self::element(foo)' part is performed by function 'producer',
303         and     each predicates ([@bar>10],..) are performed by functions listed in 'predicates'.
304
305         producer: (context -> sequence)
306         predicates: list-of (context x sequence -> sequence)
307         Return (context x sequence -> sequence)
308         '''
309
310         def fun( context , sequence ) :
311
312                 if sequence.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) :
313                         raise XPathError( 'XPTY0020' , 'XPath step should start with a sequence of nodes' )
314
315                 results = []
316                 focus = context.getFocus()
317                 for i , item in enumerate( sequence ) :
318                         #
319                         # Update context
320                         #
321                         context.last = len( sequence )
322                         context.item = item
323                         context.position = i + 1
324                         #
325                         # Evaluate the step
326                         #
327                         result = producer( context )
328                         #
329                         # Apply predicates
330                         #
331                         for predicate in predicates :
332                                 result = predicate( context , result )
333                         #
334                         # Check consistency
335                         #
336                         if sequence.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) :
337                                 raise XPathError( 'XPTY0019' , 'XPath step should return a sequence of nodes.' )
338                         results.append( result )
339                 context.restoreFocus( focus )
340
341                 return unionOfSequences( results )
342
343         fun.__name__ = '_do_step_'
344
345         return fun
346
347 def makePredicate( predicate ) :
348
349         '''
350         predicates: sequence | (context -> sequence)
351         Return list-of (context x sequence -> sequence)
352         '''
353
354         if isinstance( predicate , Sequence ) :
355                 if len( predicate ) == 1 :
356                         try :
357                                 index = predicate[ 0 ]
358                         except :
359                                 raise Error( 'Expected an index as static predicate.' )
360                         return at_position( int( predicate[ 0 ] ) )
361                 else :
362                         raise Error( 'Unexpected static sequence' )
363         else :
364                 def fun( context , sequence ) :
365
366                         focus = context.getFocus()
367                         result = []
368                         for i , item in enumerate( sequence ) :
369                                 i += 1
370                                 #
371                                 # Update context
372                                 #
373                                 context.item     = item
374                                 context.position = i
375                                 context.last     = len( sequence )
376                                 #
377                                 # Do the test
378                                 #
379                                 r = predicate( context )
380                                 keep = False
381                                 if len( r ) == 1 and isNumber( r[ 0 ] ) :
382                                         n = int( r[ 0 ] )
383                                         if n == i :
384                                                 keep = True
385                                 elif sequenceToBoolean( r ) :
386                                         keep = True
387                                 if keep :
388                                         result.append( item )
389                         context.restoreFocus( focus )
390                         return Sequence( result )
391
392                 fun.__name__ = '_predicate_'
393
394                 return fun
395
396 def makeVariableReference( name ) :
397
398         '''
399         Return (context -> sequence)
400         '''
401         def fun( context ) :
402                 return Sequence( context.getVariable( name ) )
403         fun.__name__ = '_var_%s_' % ensureString( name )
404
405         return fun
406
407 def makeAxis( it , test ) :
408
409         if it is None :
410                 raise XPathError( 'XPST0010' , 'Unsupported axis' )
411         def fun( context ) :
412                 return Sequence( node for node in it( context.item ) if test( node ) )
413         fun.__name__ = '_iterate_'
414         return fun
415
416 def makeFunctionCall( expr ) :
417
418         '''
419         Return (context -> sequence)
420         '''
421
422         functionName , args = expr[ 0 ] , expr[ 1 : ]
423         fargs = map( makeExpr , args )
424         fargs = map( ensureCallable , fargs )
425         if functionName in xpathFunctions :
426                 #
427                 # Known function.
428                 #
429                 fn = xpathFunctions[ functionName ]
430
431                 assertFunctionArity( fn , functionName , len( fargs ) , 1 )
432
433                 def fun( context ) :
434                         args = map( lambda f : f( context ) , fargs )
435                         return fn( context , *args )
436                 fun.__name__ = '_fn_%s_' % ensureString( functionName )
437         else :
438                 #
439                 # Unknown function. Resolution at run-time
440                 #
441                 def fun( context ) :
442                         fn = context.getFunction( functionName )
443                         if fn is None :
444                                 raise XPathError( 'XPST0017' , 'Undefined function %r' % ( functionName , ) )
445                         assertFunctionArity( fn , functionName , len( fargs ) , 1 )
446                         args = map( lambda f : f( context ) , fargs )
447                         return fn( context , *args )
448                 fun.__name__ = '_fn[rt]_%s_' % ensureString( functionName )
449         return fun
450
451 def makeFilter( exprlist , predicates ) :
452
453         if exprlist[ 0 ] == 'exprlist' :
454                 process = makeExprList
455         else :
456                 process = makeStep
457         exprlist = process( exprlist )
458         predicates = map( makeExprList , predicates )
459         predicates = map( makePredicate , predicates )
460         if not callable( exprlist ) :
461                 r = exprlist
462                 for predicate in predicates :
463                         r = predicate( nullContext , r )
464                 return r
465         else :
466                 def fun( context ) :
467                         sequence = exprlist( context )
468                         for predicate in predicates :
469                                 sequence = predicate( context , sequence )
470                         return sequence
471                 fun.__name__ = '_predicate_'
472                 return fun
473
474 def makeStep( step ) :
475
476         first , rest = step[ 0 ] , step[ 1 : ]
477         if first == '.' :
478                 return contextItem
479         elif first == '/' :
480                 return fnRoot
481         elif first in axes :
482                 return makeAxis( axes[ first ] , makeTest( rest[ 0 ] ) )
483         elif first == 'integer' :
484                 return Sequence( int( rest[ 0 ] ) )
485         elif first == 'decimal' :
486                 return Sequence( float( rest[ 0 ] ) )
487         elif first == 'double' :
488                 return Sequence( float( rest[ 0 ] ) )
489         elif first == 'string' :
490                 return Sequence( rest[ 0 ] )
491         elif first in ( 'filter' , 'predicates' ) :
492                 return makeFilter( rest[ 0 ] , rest[ 1 : ] )
493         elif first == 'void' :
494                 return lambda context : _Empty
495         elif first == 'call' :
496                 return makeFunctionCall( rest )
497         elif first == 'varref' :
498                 return makeVariableReference( rest[ 0 ] )
499         elif first == 'exprlist' :
500                 return makeExprList( step )
501         else :
502                 raise Error( 'Unexpected step %r' % ( step , ) )
503
504
505 # bindings = {}
506 # match( bindings , ('+', 1, ('+', 2, 3) ) , ('+', 'e1' , ('+', 'e2', 'e3')))
507 # bindings => {1: 'e1', 2: 'e2', 3: 'e3'}
508 def match( bindings , pattern , expression ) :
509
510         def match_( pattern , expression ) :
511                 if len( pattern ) == len( expression ) :
512                         for a , b in zip( pattern , expression ) :
513                                 if isinstance( a , ( int , long ) ) :
514                                         if a in bindings :
515                                                 if bindings[ a ] != b :
516                                                         return False
517                                         else :
518                                                 bindings[ a ] = b
519                                 elif isinstance( a , tuple ) :
520                                         if not isinstance( b , tuple ) :
521                                                 return False
522                                         if not match_( a , b ) :
523                                                 return False
524                                 elif a == b :
525                                         pass
526                                 else :
527                                         return False
528                 else :
529                         return False
530                 return True
531         return match_( pattern , expression )
532
533 def replace( bindings , pattern ) :
534
535         if isinstance( pattern , int ) :
536                 return bindings.get( pattern )
537         elif isinstance( pattern , tuple ) :
538                 return tuple( replace( bindings , item ) for item in pattern )
539         else :
540                 return pattern
541
542 def rewrite( expression , sourcePattern , targetPattern , recurse = False ) :
543
544         bindings = {}
545         result = None
546         if match( bindings , sourcePattern , expression ) :
547                 result = replace( bindings , targetPattern )
548         return result
549
550 def Rule( a , b ) : return ( a , b )
551 def From( *a ) : return a
552 def To( *a ) : return a
553
554 rules = [
555         # <STEP>/() -> ()
556         Rule( From( 1 , ( 'void' , ) ) ,
557                   To( 'void'  ) ) ,
558
559         # descendant-or-self::node()/child::<TEST> -> descendant::<TEST>
560         Rule( From( ( 'descendant-or-self' , ( 'node' , ) ) ,
561                                 ( 'child' , 1 ) ) ,
562                   To( 'descendant' , 1 ) ) ,
563
564 #   WRONG ! Don't work for //p[1]
565 #       # descendant-or-self::node()/child::<TEST>[<PRED>] -> descendant::<TEST>[<PRED>]
566 #       Rule( From( ( 'descendant-or-self' , ( 'node' , ) ) ,
567 #                               ( 'predicates' , ( 'child' , 1 ) , 2 ) ) ,
568 #                 To( 'predicates' , ( 'descendant' , 1 ) , 2 ) ) ,
569
570         # child::<TEST>/descendant::<TEST> -> descendant::<TEST>
571         Rule( From( ( 'child' , 1 ) ,
572                                 ( 'descendant' , 1 ) ) ,
573                   To( 'descendant' , 1 ) ) ,
574
575 #   PROBABLY INCORRECT
576 #       # child::<TEST>/descendant::<TEST>[<PRED>] -> descendant::<TEST>[<PRED>]
577 #       Rule( From( ( 'child' , 1 ) ,
578 #                               ( 'predicates' , ( 'descendant' , 1 ) , 2 ) ) ,
579 #                 To( 'predicates' , ( 'descendant' , 1 ) , 2 ) ) ,
580
581         # descendant-or-self::node()/attribute::<NAME> -> ext:descendant-attribute(<NAME>)
582         Rule( From( ( 'descendant-or-self' , ( 'node' , ) ) ,
583                                 ( 'attribute' , ( 'attribute' , 1 ) ) ) ,
584                   To( 'call' , 'ext:descendant-attribute' , ( 'path' , ( 'string' , 1 ) ) ) ) ,
585
586         # descendant-or-self::node()/attribute::<NAME>[<PRED>]
587         # -> ext:descendant-attribute(<NAME>)/../attribute::attribute(<NAME>)[PRED]
588         Rule( From( ( 'descendant-or-self' , ( 'node' , ) ) ,
589                                 ( 'predicates' , ( 'attribute' , ( 'attribute' , 1 ) ) , 2 ) ) ,
590                   To( ( 'call' , 'ext:descendant-attribute' ,
591                                 ( 'path' , ( 'string' , 1 ) ) ) ,
592                           ( 'parent' , ( 'element' , '*' ) ) ,
593                           ( 'predicates' , ( 'attribute' , ( 'attribute' , 1 ) ) ,
594                                 2 ) ) ) ,
595
596         # descendant-or-self::node()/child::<NAME>[attribute::<ATT>]
597         #   -> ext:descendant-attribute(<ATT>)/parent::<NAME>
598         Rule( From( ( 'descendant-or-self' , ( 'node' , ) ) ,
599                                 ( 'predicates' , ( 'child' , 1 ) ,
600                                   ( 'exprlist' , ( 'path' , ( 'attribute' , ( 'attribute' , 2 ) ) ) ) ) ) ,
601                   To( ( 'call' , 'ext:descendant-attribute' , ( 'path' , ( 'string' , 2 ) ) ) ,
602                           ( 'parent' , 1 ) ) )
603 ]
604
605 def optimizeSteps( steps ) :
606
607         if g_dontOptimize :
608                 return steps
609         current = 0
610         if g_debug :
611                 print 'OPTIMIZE:'
612                 print lispy( ( 'path' , ) + steps )
613         while current < len( steps ) - 1 :
614                 a = steps[ current ]
615                 b = steps[ current + 1 ]
616                 replace = None
617                 for rule in rules :
618                         r = rewrite( ( a , b ) , *rule )
619                         if r is not None :
620                                 replace = r
621                                 break
622                 if replace is not None :
623                         if g_debug :
624                                 print 'OPT' , a
625                                 print '  +' , b
626                                 print ' ->' , replace
627                         if isinstance( replace , tuple ) and replace and isinstance( replace[ 0 ] , tuple ) :
628                                 pass
629                         else :
630                                 replace = ( replace , )
631                         steps = steps[ : current ] + replace + steps[ current + 2 : ]
632                         if current > 0 :
633                                 current -= 1 # look again at previous pair
634                 else :
635                         current += 1
636         if g_debug :
637                 print 'RESULT:'
638                 print lispy( ( 'path' , ) + steps )
639         return steps
640
641 def makePath( stepExpressions ) :
642
643         '''
644         Return (context -> sequence)
645         '''
646
647         stepExpressions = optimizeSteps( stepExpressions )
648         steps = map( makeStep , stepExpressions )
649         first , rest = steps[ 0 ] , steps[ 1 : ]
650
651         if not callable( first ) :
652                 if rest :
653                         raise XPathError( 'XPST0003' , 'atomic values cannot be used to start a path' )
654                 return first
655
656         first = ensureCallable( first )
657         rest = map( ensureCallable , rest )
658
659         rest = map( doStep , rest )
660
661         def fun( context ) :
662                 sequence = first( context )
663                 for step_ in rest :
664                         sequence = step_( context , sequence )
665                 return sequence
666         fun.__name__ = '_path_'
667
668         return fun
669
670 def makeIf( predicate , consequent , alternative ) :
671
672         '''
673         Return sequence | (context -> sequence)
674         '''
675
676         predicate = makeExprList( predicate )
677         consequent = makeExpr( consequent )
678         alternative = makeExpr( alternative )
679
680         if not callable( predicate ) :
681                 if sequenceToBoolean( predicate ) :
682                         return consequent
683                 else :
684                         return alternative
685         else :
686                 consequent = ensureCallable( consequent )
687                 alternative = ensureCallable( alternative )
688
689                 def fun( context ) :
690                         if sequenceToBoolean( predicate( context ) ) :
691                                 return consequent( context )
692                         else :
693                                 return alternative( context )
694                 fun.__name__ = '_if_'
695                 return fun
696
697 def makeUnary( opName , arg ) :
698
699         '''
700         Return sequence | (context -> sequence)
701         '''
702
703         op = unaryOperators[ opName ]
704
705         arg = makeExpr( arg )
706         if not callable( arg ) :
707                 return op( nullContext , arg )
708         else :
709                 arg = ensureCallable( arg )
710                 def fun( context ) :
711                         return op( context , arg( context ) )
712                 fun.__name__ = '_op1_%s_' % ensureString( opName )
713                 return fun
714
715 def makeBinary( opName , arg1 , arg2 ) :
716
717         '''
718         Return sequence | (context -> sequence)
719         '''
720
721         op = binaryOperators[ opName ]
722
723         arg1 = makeExpr( arg1 )
724         arg2 = makeExpr( arg2 )
725         if not callable( arg1 ) and not callable( arg2 ) :
726                 return op( nullContext , arg1 , arg2 )
727         else :
728                 arg1 = ensureCallable( arg1 )
729                 arg2 = ensureCallable( arg2 )
730                 if getattr( op , 'hold' , False ) :
731                         def fun( context ) :
732                                 return op( context , arg1 , arg2 )
733                 else :
734                         def fun( context ) :
735                                 return op( context , arg1( context ) , arg2( context ) )
736                 fun.__name__ = '_op2_%s_' % ensureString( opName )
737                 return fun
738
739 def makeFor( clauses , returnExpr ) :
740
741         '''
742         Return (context -> sequence)
743         '''
744
745         vars = []
746         for var , seqExpr in clauses :
747                 seqExpr = makeExpr( seqExpr )
748                 vars.append( ( var , seqExpr ) )
749         returnExpr = makeExpr( returnExpr )
750
751         def makeFun( name , seq , returnExpr ) :
752                 seq = ensureCallable( seq )
753                 def fun( context ) :
754                         result = []
755                         for item in seq( context ) :
756                                 context.variables[ name ] = item
757                                 result.append( returnExpr( context ) )
758                         return Sequence( result )
759                 return fun
760
761         fun = returnExpr
762         for name , seq in reversed( vars ) :
763                 fun = makeFun( name , seq , fun )
764         return fun
765
766 def makeQuantified( quantifier , clauses , test ) :
767
768         '''
769         Return (context -> sequence)
770         '''
771
772         vars = []
773         for var , seqExpr in clauses :
774                 seqExpr = makeExpr( seqExpr )
775                 vars.append( ( var , seqExpr ) )
776         testExpr = makeExpr( test )
777
778         def makeFun( name , seq , testExpr ) :
779                 seq = ensureCallable( seq )
780                 if quantifier == 'some' :
781                         def fun( context ) :
782                                 result = []
783                                 for item in seq( context ) :
784                                         context.variables[ name ] = item
785                                         r = sequenceToBoolean( testExpr( context ) )
786                                         if r is True :
787                                                 return _True
788                                 return _False
789                         return fun
790                 elif quantifier == 'every' :
791                         def fun( context ) :
792                                 result = []
793                                 for item in seq( context ) :
794                                         context.variables[ name ] = item
795                                         r = sequenceToBoolean( testExpr( context ) )
796                                         if r is False :
797                                                 return _False
798                                 return _True
799                         return fun
800                 else :
801                         raise NotReached
802
803         fun = testExpr
804         for name , seq in reversed( vars ) :
805                 fun = makeFun( name , seq , fun )
806         return fun
807
808 def makeExpr( e ) :
809
810         '''
811         Return sequence | (context -> sequence)
812         '''
813
814         if e[ 0 ] == 'path' :
815                 return makePath( e[ 1 : ] )
816         elif e[ 0 ] == 'if' :
817                 return makeIf( e[ 1 ] , e[ 2 ] , e[ 3 ] )
818         elif e[ 0 ] == 'for' :
819                 return makeFor( e[ 1 ] , e[ 2 ] )
820         elif e[ 0 ] in ( 'some' , 'every' ) :
821                 return makeQuantified( e[ 0 ] , e[ 1 ] , e[ 2 ] )
822         elif e[ 0 ] in unaryOperators :
823                 return makeUnary( e[ 0 ] , e[ 1 ] )
824         elif e[ 0 ] in binaryOperators :
825                 return makeBinary( e[ 0 ] , e[ 1 ] , e[ 2 ] )
826         else :
827                 raise Error( 'Unexpected expression %r' % ( e , ) )
828
829 def makeExprList( e ) :
830
831         '''Return (context -> sequence) or a static sequence.
832
833         The returned function evaluate a static XPath expression
834         and return sequence as result.'''
835
836         assert e[ 0 ] == 'exprlist' , 'Unexpected expression %r' % ( e , )
837
838         exprs = []
839         dynamic = False
840         #
841         # Compile each expression
842         #
843         for expr in e[ 1 : ] :
844                 r = makeExpr( expr )
845                 if callable( r ) :
846                         dynamic = True
847                         exprs.append( r )
848                 elif isinstance( r , tuple ) :
849                         exprs += list( r )
850                 else :
851                         exprs.append( r )
852
853         if not dynamic :
854                 #
855                 # If fully static, then just return concatenation of all
856                 # the sequences.
857                 #
858                 return Sequence( exprs )
859         else :
860                 def fun( context ) :
861                         result = []
862                         for item in exprs :
863                                 if callable( item ) :
864                                         # Dynamic
865                                         result.append( item( context ) )
866                                 else :
867                                         # Constant
868                                         result.append( item )
869                         return Sequence( result )
870                 fun.__name__ = '_exprlist_'
871                 return fun
872
873 def compile( s ) :
874
875         '''Compile XPath expression 's' into a Python function.'''
876
877         e = xpathparser.parser( s )
878
879         return makeExprList( e )
880
881 _cache = {}
882
883 def flushCache() :
884
885         _cache.clear()
886
887 class XPath( object ) :
888
889         __slots__ = [ 'path' , 'fun' ]
890
891         def __init__( self , path ) :
892
893                 self.path = path
894                 if path in _cache :
895                         self.fun = _cache[ path ]
896                 else :
897                         try :
898                                 self.fun = compile( path )
899                         except xpathparser.NoMatch :
900                                 raise XPathError( 'XPST0003' , 'Unable to parse %r' % path )
901                         _cache[ path ] = self.fun
902                         if len( _cache ) > 50 :
903                                 _cache.clear() # FIXME: OUCH!
904
905         def eval( self , doc = None , variables = {} , functions = {} ) :
906
907                 f = self.fun
908                 if callable( f ) :
909                         context = Context()
910                         if doc is not None :
911                                 context.item = doc
912                                 context.position = 1
913                                 context.last = 1
914                         context.variables.update( variables )
915                         context.functions.update( functions )
916                         return f( context )
917                 else :
918                         return f
919
920 # Local Variables:
921 # tab-width: 4
922 # python-indent: 4
923 # End: