Fix optimization when looking for attributes with any name ('*').
[tx] / xpathfn.py
1 # -*- coding:utf-8 -*-
2
3 # XPath/XQuery functions
4 # Copyright (C) 2005  Frédéric Jolliton <frederic@jolliton.com>
5
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19
20 from __future__ import division
21
22 # __all__ = [
23 #       'functions'
24 #       ]
25
26 import operator
27
28 from error import Error, XPathError
29 from sequence import *
30 from sequence import _Empty, _EmptyString, _False, _True, _Boolean
31 from nodes import *
32 from misc import typeOf, iterSingle, identity
33 from iterators import iterDescendantOrSelfFull
34
35 functions = {}
36
37 def functionNames( qname ) :
38
39         '''
40         'foo'     -> ['foo']
41         'bar:foo' -> ['bar:foo']
42         'fn:foo'  -> ['fn:foo', 'foo']
43         '''
44
45         names = [ qname ]
46         parts = qname.split( ':' )
47         if len( parts ) == 1 :
48                 pass
49         elif len( parts ) == 2 :
50                 if parts[ 0 ] == 'fn' :
51                     names.append( parts[ 1 ] )
52         else :
53                 raise Error( 'invalid function name %r' % qname )
54         return names
55
56 def registerFast( qname , hold = False ) :
57
58         names = functionNames( qname )
59
60         def fun( f ) :
61                 f.hold = hold
62                 for name in names :
63                         functions[ name ] = f
64                 return f
65         return fun
66
67 def register( qname ) :
68
69         names = functionNames( qname )
70
71         def fun( f ) :
72                 #
73                 # Wrapper ensure than return value is a Sequence
74                 #
75                 def wrapper( *args , **kwargs ) :
76                         result = f( args , kwargs )
77                         if not isinstance( result , Sequence ) :
78                                 result = Sequence( result )
79                         return result
80                 wrapper.hold = False
81                 wrapper.__name__ = 'wrapper_%s' % ( f.__name__ or '<??>' )
82                 wrapper.__dict__ = f.__dict__
83                 wrapper.__doc__  = f.__doc__
84                 for name in names :
85                         functions[ name ] = wrapper
86                 return wrapper
87         return fun
88
89 #----------------------------------------------------------------------------
90
91 def isBoolean( item ) :
92
93         return isinstance( item , bool )
94
95 def isNumber( item ) :
96
97         return isinstance( item , ( int , float , long ) ) \
98                 and not isinstance( item , bool )
99
100 def isString( item ) :
101
102         return isinstance( item , basestring )
103
104 def isAtomic( item ) :
105
106         return not isinstance( item , Node )
107
108 def isAttribute( item ) :
109
110         return isinstance( item , Attribute )
111
112 def isDocument( item ) :
113
114         return isinstance( item , Document )
115
116 def isNode( item ) :
117
118         return isinstance( item , Node )
119
120 def isItem( item ) :
121
122         return True
123
124 #----------------------------------------------------------------------------
125
126 def ZeroOrOne( t ) :
127
128         def fun( sequence ) :
129                 if len( sequence ) == 0 :
130                         return None
131                 elif len( sequence ) == 1 :
132                         item = sequence[ 0 ]
133                         if t( item ) :
134                                 return item
135                 raise XPathError( 'XPTY0004' , 'Expected a sequence of 0 or 1 item (%r)' % t )
136         return fun
137
138 def One( t ) :
139
140         def fun( sequence ) :
141                 if len( sequence ) == 1 :
142                         item = sequence[ 0 ]
143                         if t( item ) :
144                                 return item
145                 raise XPathError( 'XPTY0004' , 'Expected a sequence of 1 item (%r)' % t )
146         return fun
147
148 def zeroOrMoreNodes( sequence ) :
149
150         if sequence.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) :
151                 raise XPathError( 'XPTY0004' , 'Expected a sequence of nodes' ) # really XPTY0004 here ?
152         return sequence
153
154 zeroOrOneNode  = ZeroOrOne( isNode )
155 zeroOrOneItem  = ZeroOrOne( isItem )
156 zeroOrMoreItem = identity
157 oneNumber      = One( isNumber )
158 oneAtomic      = One( isAtomic )
159 oneItem        = One( isItem )
160 oneNode        = One( isNode )
161
162 #----------------------------------------------------------------------------
163
164 def atomize( item ) :
165
166         '''Convert node to string, and keep inchanged atomic value.'''
167
168         if isNode( item ) :
169                 return item.dmStringValue()
170         else :
171                 return item
172
173 def sequenceToData( sequence ) :
174
175         return tuple( atomize( item ) for item in sequence )
176
177 def asString( item ) :
178
179         '''Convert item (node or atomic value) to a string.'''
180
181         result = atomize( item )
182         if not isinstance( result , basestring ) :
183                 result = str( result )
184         return result
185
186 def asNumber( item ) :
187
188         '''
189         Convert item (node or atomic value) as a number.
190         Return None if no conversion can be done.
191         '''
192
193         if isNumber( item ) :
194                 return item
195         else :
196                 item = atomize( item )
197                 try :
198                         return int( item )
199                 except :
200                         try :
201                                 return float( item )
202                         except :
203                                 pass
204
205 def oneAtomizedItem( sequence ) :
206
207         return atomize( oneItem( sequence ) )
208
209 #----------------------------------------------------------------------------
210
211 def sequenceToBoolean( sequence ) :
212
213         if not sequence :
214                 return False
215         else :
216                 item = sequence[ 0 ]
217                 if isNode( item ) :
218                         return True
219                 elif len( sequence ) == 1 :
220                         if isBoolean( item ) :
221                                 return item
222                         elif isString( item ) :
223                                 return bool( item )
224                         elif isNumber( item ) :
225                                 return item != 0
226         raise XError( 'FORG0006' )
227
228 #----------------------------------------------------------------------------
229
230 def _compareString( s1 , s2 ) :
231
232         '''
233         Compare strings 's1' and 's2'. This function also accept iterators
234         returning string. This is usefull to compare two nodes without
235         first converting them to strings for performance reason.
236         '''
237
238         if isinstance( s1 , basestring ) :
239                 if isinstance( s2 , basestring ) :
240                         return cmp( s1 , s2 )
241                 else :
242                         s1 = iterSingle( s1 )
243         elif isinstance( s2 , basestring ) :
244                 s2 = iterSingle( s2 )
245
246         result = 0
247         d1 = ''
248         d2 = ''
249         while 1 :
250                 if not d1 :
251                         try :
252                                 d1 = s1.next()
253                         except StopIteration :
254                                 d1 = ''
255                 if not d2 :
256                         try :
257                                 d2 = s2.next()
258                         except StopIteration :
259                                 d2 = ''
260                 if not d1 :
261                         if not d2 :
262                                 r = 0
263                         else :
264                                 r = -1
265                         break
266                 elif not d2 :
267                         r = 1
268                         break
269                 m = min( len( d1 ) , len( d2 ) )
270                 r = cmp( d1[ : m ] , d2[ : m ] )
271                 if r != 0 :
272                         break
273                 d1 = d1[ m : ]
274                 d2 = d2[ m : ]
275         return r
276
277 #--[ Extensions ]------------------------------------------------------------
278
279 @registerFast( 'ext:descendant-attribute' )
280 def extDescendantAttribute( context , arg ) :
281
282         '''
283         Find all attributes of context node and its descendant with name
284         'arg'.
285         '''
286
287         name = arg[ 0 ]
288         item = context.item
289         if isDocument( item ) :
290                 if name == '*' :
291                         attrs = sum( item.attributesByName.values() , [] )
292                         attrs.sort( lambda a , b : cmp( a.position , b.position ) )
293                         return Sequence( attrs )
294                 else :
295                         return Sequence( item.attributesByName.get( name , () ) )
296         else :
297                 if name == '*' :
298                         return Sequence( attribute
299                                                          for attribute in iterDescendantOrSelfFull( item )
300                                                          if isAttribute( attribute ) )
301                 else :
302                         return Sequence( attribute
303                                                          for attribute in iterDescendantOrSelfFull( item )
304                                                          if isAttribute( attribute ) and attribute.name == name )
305
306 #--[ XPath/XQuery functions ]------------------------------------------------
307
308 @registerFast( 'fn:id' )
309 def fnId( context , arg , node = None ) :
310
311         if node is None :
312                 item = context.item
313                 if item is None :
314                         raise XPathError( 'FONC0001' )
315         else :
316                 item = oneNode( node )
317         doc = item.root
318         if doc is None or not isDocument( doc ) :
319                 raise XPathError( 'FODC0001' ) # unsure
320         ids = map( asString , zeroOrMoreItem( arg ) )
321         result = []
322         for id in ids :
323                 n = doc.ids.get( id )
324                 if n is not None :
325                         result.append( n )
326         result.sort( lambda a , b : cmp( a.position , b.position ) )
327         return Sequence( result )
328
329 @registerFast( 'fn:normalize-space' )
330 def fnNormalizeSpace( context , arg = None ) :
331
332         if arg is None :
333                 item = context.item
334                 if item is None :
335                         raise XPathError( 'FONC0001' )
336                 item = asString( item )
337         else :
338                 item = zeroOrOneItem( arg )
339                 if item is None :
340                         return _EmptyString
341                 item = asString( item )
342         if not item :
343                 return _EmptyString
344         else :
345                 s = ' '.join( item.split() )
346                 if s[ 0 ] != item[ 0 ] :
347                         s = ' ' + s
348                 if s[ -1 ] != item[ -1 ] :
349                         s = s + ' '
350                 return Sequence( s )
351
352 @registerFast( 'fn:lower-case' )
353 def fnLowerCase( context , arg ) :
354
355         arg = zeroOrOneItem( arg )
356         if arg is None :
357                 return _EmptyString
358         else :
359                 return Sequence( asString( arg ).lower() )
360
361 @registerFast( 'fn:upper-case' )
362 def fnUpperCase( context , arg ) :
363
364         arg = zeroOrOneItem( arg )
365         if arg is None :
366                 return _EmptyString
367         else :
368                 return Sequence( asString( arg ).upper() )
369
370 @registerFast( 'fn:false' )
371 def fnFalse( context ) :
372
373         return _False
374
375 @registerFast( 'fn:true' )
376 def fnTrue( context ) :
377
378         return _True
379
380 @registerFast( 'fn:empty' )
381 def fnEmpty( context , sequence ) :
382
383         return _Boolean[ len( sequence ) == 0 ]
384
385 @registerFast( 'fn:exists' )
386 def fnExists( context , sequence ) :
387
388         return _Boolean[ len( sequence ) > 0 ]
389
390 @registerFast( 'fn:string' )
391 def fnString( context , arg = None ) :
392
393         if arg is None :
394                 item = context.item
395                 if item is None :
396                         raise XError( 'FONC0001' )
397         else :
398                 item = zeroOrOneItem( arg )
399                 if item is None :
400                         return _EmptyString
401         return Sequence( asString( item ) )
402
403 @registerFast( 'fn:distinct-values' )
404 def fnDistinctValues( context , arg ) :
405
406         if not arg :
407                 return arg
408         elif arg.type == SEQUENCE_ATOMICS :
409                 result = {}
410                 for item in arg :
411                         result[ item ] = True
412                 return Sequence( result.keys() )
413         else :
414                 # FIXME: Naive implementation O(n^2)
415                 result = []
416                 for item in arg :
417                         for existingItem in result :
418                                 if compareValue( existingItem , item ) == 0 :
419                                         break
420                         else :
421                                 result.append( item )
422                 return Sequence( result )
423
424 @registerFast( 'fn:not' )
425 def fnNot( context , arg ) :
426
427         return _Boolean[ not sequenceToBoolean( arg ) ]
428
429 @registerFast( 'fn:count' )
430 def fnCount( context , arg ) :
431
432         return Sequence( len( arg ) )
433
434 @registerFast( 'fn:root' )
435 def fnRoot( context , sequence = None ) :
436
437         if sequence is None :
438                 item = context.item
439                 if item is None :
440                         raise XError( 'FONC0001' )
441                 elif not isNode( item ) :
442                         raise XError( 'XPTY0006' )
443                 elif isDocument( item ) :
444                         return Sequence( item )
445                 else :
446                         return Sequence( item.root )
447         else :
448                 item = zeroOrOneNode( sequence )
449                 if item is None :
450                         return _Empty
451                 elif isDocument( item ) :
452                         return sequence
453                 else :
454                         return Sequence( item.root )
455
456 @registerFast( 'fn:position' )
457 def fnPosition( context ) :
458
459         return Sequence( context.position )
460
461 @registerFast( 'fn:last' )
462 def fnLast( context ) :
463
464         return Sequence( context.last )
465
466 @registerFast( 'fn:boolean' )
467 def fnBoolean( context , arg ) :
468
469         return Sequence( sequenceToBoolean( arg ) )
470
471 @registerFast( 'fn:trace' )
472 def fnTrace( context , value , label ) :
473
474         label = atomize( oneItem( label ) )
475         print '[TRACE] %s = %s' % ( label , value )
476         return value
477
478 @registerFast( 'fn:node-name' )
479 def fnNodeName( context , arg ) :
480
481         node = zeroOrOneNode( arg )
482         if node is None :
483                 return _Empty
484         else :
485                 return Sequence( node.dmNodeName() )
486
487 @registerFast( 'fn:name' )
488 def fnName( context , arg = None ) :
489
490         if arg is None :
491                 item = context.item
492                 if item is None :
493                         raise XError( 'FONC0001' )
494                 if not isNode( item ) :
495                         raise XError( 'XPTY0006' )
496         else :
497                 item = zeroOrOneNode( arg )
498                 if item is None :
499                         return _EmptyString
500         return Sequence( item.dmNodeName() )
501
502 #----------------------------------------------------------------------------
503
504 @registerFast( 'op:is-same-node' )
505 def opIsSameNode( context , arg1 , arg2 ) :
506
507         arg1 = oneNode( arg1 )
508         arg2 = oneNode( arg2 )
509         return _Boolean[ arg1 is arg2 ]
510
511 @registerFast( 'op:node-before' )
512 def opNodeBefore( context , arg1 , arg2 ) :
513
514         arg1 = oneNode( arg1 )
515         arg2 = oneNode( arg2 )
516         return _Boolean[ arg1.position < arg2.position ]
517
518 @registerFast( 'op:node-after' )
519 def opNodeAfter( context , arg1 , arg2 ) :
520
521         arg1 = oneNode( arg1 )
522         arg2 = oneNode( arg2 )
523         return _Boolean[ arg1.position > arg2.position ]
524
525 @registerFast( 'op:or' , hold = True )
526 def opOr( context , arg1 , arg2 ) :
527
528         return _Boolean[ sequenceToBoolean( arg1( context ) ) \
529                                          or sequenceToBoolean( arg2( context ) ) ]
530
531 @registerFast( 'op:and' , hold = True )
532 def opAnd( context , arg1 , arg2 ) :
533
534         return _Boolean[ sequenceToBoolean( arg1( context ) ) \
535                                          and sequenceToBoolean( arg2( context ) ) ]
536
537 @registerFast( 'op:to' )
538 def opTo( context , arg1 , arg2 ) :
539
540         arg1 = oneNumber( arg1 )
541         arg2 = oneNumber( arg2 )
542         return Sequence( range( int( arg1 ) , int( arg2 ) + 1 ) )
543
544 @registerFast( 'op:union' )
545 def opUnion( context , arg1 , arg2 ) :
546
547         if not arg1 :
548                 return arg2
549         if not arg2 :
550                 return arg1
551         if arg1.type != SEQUENCE_NODES \
552                 or arg2.type != SEQUENCE_NODES :
553                 raise XPathError( 'XPTY0004' , 'union operator expect sequence of nodes only' )
554         arg1 = zeroOrMoreNodes( arg1 )
555         arg2 = zeroOrMoreNodes( arg2 )
556         result = list( set( arg1 ) | set( arg2 ) )
557         result.sort( lambda a , b : cmp( a.position , b.position ) )
558         return Sequence( result )
559
560 @registerFast( 'op:except' )
561 def opExcept( context , arg1 , arg2 ) :
562
563         if not arg1 :
564                 return arg2
565         if not arg2 :
566                 return arg1
567         if arg1.type != SEQUENCE_NODES \
568                 or arg2.type != SEQUENCE_NODES :
569                 raise XPathError( 'XPTY0004' , 'except operator expect sequence of nodes only' )
570         arg1 = zeroOrMoreNodes( arg1 )
571         arg2 = zeroOrMoreNodes( arg2 )
572         result = list( set( arg1 ) - set( arg2 ) )
573         result.sort( lambda a , b : cmp( a.position , b.position ) )
574         return Sequence( result )
575
576 @registerFast( 'op:intersection' )
577 def opIntersection( context , arg1 , arg2 ) :
578
579         if not arg1 :
580                 return arg2
581         if not arg2 :
582                 return arg1
583         if arg1.type != SEQUENCE_NODES \
584                 or arg2.type != SEQUENCE_NODES :
585                 raise XPathError( 'XPTY0004' , 'intersection operator expect sequence of nodes only' )
586         arg1 = zeroOrMoreNodes( arg1 )
587         arg2 = zeroOrMoreNodes( arg2 )
588         result = list( set( arg1 ) & set( arg2 ) )
589         result.sort( lambda a , b : cmp( a.position , b.position ) )
590         return Sequence( result )
591
592 @registerFast( 'op:neg' )
593 def opNeg( context , arg ) :
594
595         arg = oneAtomic( arg )
596         if isNumber( arg ) :
597                 return Sequence( -arg )
598         else :
599                 arg = asNumber( arg )
600                 if arg is not None :
601                         return Sequence( -arg )
602                 else :
603                         return _Empty
604
605 @registerFast( 'op:add' )
606 def opAdd( context , arg1 , arg2 ) :
607
608         arg1 = asNumber( oneItem( arg1 ) )
609         arg2 = asNumber( oneItem( arg2 ) )
610         if arg1 is None or arg2 is None :
611                 return _Empty
612         else :
613                 return Sequence( arg1 + arg2 )
614
615 @registerFast( 'op:substract' )
616 def opSubstract( context , arg1 , arg2 ) :
617
618         arg1 = asNumber( oneItem( arg1 ) )
619         arg2 = asNumber( oneItem( arg2 ) )
620         if arg1 is None or arg2 is None :
621                 return _Empty
622         else :
623                 return Sequence( arg1 - arg2 )
624
625 @registerFast( 'op:multiply' )
626 def opMultiply( context , arg1 , arg2 ) :
627
628         arg1 = asNumber( oneItem( arg1 ) )
629         arg2 = asNumber( oneItem( arg2 ) )
630         if arg1 is None or arg2 is None :
631                 return _Empty
632         else :
633                 return Sequence( arg1 * arg2 )
634
635 @registerFast( 'op:divide' )
636 def opDivide( context , arg1 , arg2 ) :
637
638         arg1 = asNumber( oneItem( arg1 ) )
639         arg2 = asNumber( oneItem( arg2 ) )
640         if arg1 is None or arg2 is None :
641                 return _Empty
642         else :
643                 return Sequence( arg1 / float( arg2 ) )
644
645 @registerFast( 'op:integer-divide' )
646 def opIntegerDivide( context , arg1 , arg2 ) :
647
648         arg1 = asNumber( oneItem( arg1 ) )
649         arg2 = asNumber( oneItem( arg2 ) )
650         if arg1 is None or arg2 is None :
651                 return _Empty
652         else :
653                 return Sequence( arg1 // arg2 )
654
655 # Return -1, 0, 1 or None
656 def compareValue( a , b ) :
657
658         if isNode( a ) :
659                 if isNode( b ) :
660                         return _compareString( a.iterStringValue() , b.iterStringValue() )
661                 elif isString( b ) :
662                         return _compareString( a.iterStringValue() , b )
663         elif isString( a ) :
664                 if isNode( b ) :
665                         return _compareString( a , b.iterStringValue() )
666                 elif isString( b ) :
667                         return cmp( a , b )
668         elif isNumber( a ) :
669                 if isNumber( b ) :
670                         return cmp( a , b )
671
672 # eq, ne, lt, le, gt, ge
673 def makeValueComparator( comparator ) :
674
675         def fun( context , arg1 , arg2 ) :
676                 if not arg1 or not arg2 :
677                         return _Empty
678                 elif len( arg1 ) == 1 and len( arg2 ) == 1 :
679                         item1 = arg1[ 0 ]
680                         item2 = arg2[ 0 ]
681                         r = compareValue( item1 , item2 )
682                         if r is not None :
683                                 return _Boolean[ comparator( r , 0 ) ]
684                         else :
685                                 raise XPathError( 'XPTY0004' , 'cannot compare %s and %s' % ( typeOf( item1 ) , typeOf( item2 ) ) )
686                 else :
687                         raise XPathError( 'XPTY0004' , 'Value comparison operators expect arguments of 0 or 1 items' )
688         return fun
689
690 # !=, =, <, <=, >, >=
691 def makeGeneralComparator( comparator ) :
692
693         def fun( context , arg1 , arg2 ) :
694                 if arg1 and arg2 :
695                         for item1 in arg1 :
696                                 for item2 in arg2 :
697                                         r = False
698                                         if isNumber( item1 ) :
699                                                 if isNumber( item2 ) :
700                                                         r = comparator( item1 , item2 )
701                                                 else :
702                                                         item2 = asNumber( item2 )
703                                                         if item2 is None : # NaN
704                                                                 r = False
705                                                         else :
706                                                                 r = comparator( item1 , item2 )
707                                         elif isNode( item1 ) :
708                                                 if isNumber( item2 ) :
709                                                         item1 = asNumber( item1 )
710                                                         if item1 is None :
711                                                                 r = False
712                                                         else :
713                                                                 r = comparator( item1 , item2 )
714                                                 elif isNode( item2 ) :
715                                                         r = comparator( _compareString( item1.iterStringValue() , item2.iterStringValue() ) , 0 )
716                                                 elif isString( item2 ) :
717                                                         r = comparator( _compareString( item1.iterStringValue() , item2 ) , 0 )
718                                                 else :
719                                                         raise XPathError( 'XPTY0004' , 'cannot compare %s and %s' % ( typeOf( item1 ) , typeOf( item2 ) ) )
720                                         elif isString( item1 ) :
721                                                 if isString( item2 ) :
722                                                         r = comparator( item1 , item2 )
723                                                 elif isNode( item2 ) :
724                                                         r = comparator( _compareString( item1 , item2.iterStringValue() ) , 0 )
725                                                 else :
726                                                         raise XPathError( 'XPTY0004' , 'cannot compare %s and %s' % ( typeOf( item1 ) , typeOf( item2 ) ) )
727                                         else :
728                                                 raise XPathError( 'XPTY0004' , 'cannot compare %s and %s' % ( typeOf( item1 ) , typeOf( item2 ) ) )
729                                         if r :
730                                                 return _True
731                 return _False
732         return fun
733
734 #
735 # Not really XPath operators, but rather dispatcher to the right
736 # operator.
737 #
738 opValueEqual       = makeValueComparator( operator.eq )
739 opValueNotEqual    = makeValueComparator( operator.ne )
740 opValueLessThan    = makeValueComparator( operator.lt )
741 opValueGreaterThan = makeValueComparator( operator.gt )
742
743 opGeneralEqual       = makeGeneralComparator( operator.eq )
744 opGeneralNotEqual    = makeGeneralComparator( operator.ne )
745 opGeneralLessThan    = makeGeneralComparator( operator.lt )
746 opGeneralGreaterThan = makeGeneralComparator( operator.gt )
747
748 # Local Variables:
749 # tab-width: 4
750 # python-indent: 4
751 # End: