Fixed several functions (see details.)
authorFrederic Jolliton <frederic@jolliton.com>
Mon, 12 Sep 2005 08:40:38 +0000 (08:40 +0000)
committerFrederic Jolliton <frederic@jolliton.com>
Mon, 12 Sep 2005 08:40:38 +0000 (08:40 +0000)
 * Updated registerFast to check return value of functions in debug mode
   (to test if they are correctly returning a Sequence.)

 * Fixed register which was creating a wrapper that was not forwarding
   parameters correctly.

 * Updated op:union (small optimization.)

 * Fixed op:intersect and op:except. They were returning wrong result
   in some cases.

 * Updated various functions to correctly return a Sequence instead of an
   item.

 * Removed some call to zeroOrMoreItem, since this was a no-op.
git-archimport-id: frederic@jolliton.com--2005-main/tx--main--0.1--patch-31

xpathfn.py

index 0d739c0..a88a96d 100644 (file)
@@ -41,6 +41,8 @@ from iterators import iterDescendantOrSelfFull
 
 functions = {}
 
+_debug = True
+
 _NAN    = float( 'NaN' )
 _POSINF = float( 'inf' )
 _NEGINF = float( '-inf' )
@@ -74,9 +76,22 @@ def registerFast( qname , hold = False ) :
 
        def fun( f ) :
                f.hold = hold
+               if _debug :
+                       def wrapper( *args , **kwargs ) :
+                               result = f( *args , **kwargs )
+                               if not isinstance( result , Sequence ) :
+                                       print 'WARNING: Function %r doesn\'t return a Sequence !' % ( qname , )
+                                       result = Sequence( result )
+                               return result
+                       wrapper.hold = False
+                       wrapper.__name__ = 'wrapper_%s' % ( f.__name__ or '<??>' )
+                       wrapper.__dict__ = f.__dict__
+                       wrapper.__doc__  = f.__doc__
+               else :
+                       wrapper = f
                for name in names :
-                       functions[ name ] = f
-               return f
+                       functions[ name ] = wrapper
+               return wrapper
        return fun
 
 def register( qname ) :
@@ -88,7 +103,7 @@ def register( qname ) :
                # Wrapper ensure than return value is a Sequence
                #
                def wrapper( *args , **kwargs ) :
-                       result = f( args , kwargs )
+                       result = f( *args , **kwargs )
                        if not isinstance( result , Sequence ) :
                                result = Sequence( result )
                        return result
@@ -290,10 +305,8 @@ def sequenceToBoolean( sequence ) :
 @registerFast( 'ext:descendant-attribute' )
 def extDescendantAttribute( context , arg ) :
 
-       '''
-       Find all attributes of context node and its descendant with name
-       'arg'.
-       '''
+       '''Find all attributes with name 'arg' of context node and its
+       descendant.'''
 
        name = arg[ 0 ]
        item = context.item
@@ -388,7 +401,7 @@ def fnCeiling( context , arg ) :
                elif item == _NEGINF :
                        return _NegativeInfinity
                else :
-                       return math.ceil( item )
+                       return Sequence( math.ceil( item ) )
 
 # XO/6.4.3
 @registerFast( 'fn:floor' )
@@ -403,7 +416,7 @@ def fnFloor( context , arg ) :
                elif item == _NEGINF :
                        return _NegativeInfinity
                else :
-                       return math.floor( item )
+                       return Sequence( math.floor( item ) )
 
 # XO/6.4.4
 @registerFast( 'fn:round' )
@@ -418,7 +431,7 @@ def fnRound( context , arg ) :
                elif item == _NEGINF :
                        return _NegativeInfinity
                else :
-                       return math.floor( item + 0.5 )
+                       return Sequence( math.floor( item + 0.5 ) )
 
 # XO/7.2.1
 @registerFast( 'fn:codepoints-to-string' )
@@ -454,7 +467,7 @@ def fnStringJoin( context , arg1 , arg2 ) :
        if not arg1 :
                return _EmptyString
        else :
-               return asString( arg2 ).join( map( asString , arg1 ) )
+               return Sequence( asString( arg2 ).join( map( asString , arg1 ) ) )
 
 # XO/7.4.3
 @registerFast( 'fn:substring' )
@@ -607,14 +620,14 @@ def fnContains( context , arg1 , arg2 , collation = None ) :
        arg1 = zeroOrOneItem( arg1 )
        arg2 = zeroOrOneItem( arg2 )
        if arg1 is arg2 :
-               return True
+               return _True
        else :
                arg2 = asString( arg2 )
                if arg2 == '' :
-                       return True
+                       return _True
                arg1 = asString( arg1 )
                if arg1 == '' :
-                       return False
+                       return _False
                return _Boolean[ arg2 in arg1 ]
 
 # XO/7.5.2
@@ -842,10 +855,9 @@ def fnIndexOf( context , seqParam , srchParam , collation = None ) :
        if collation is not None :
                raise XPathError( 'FOCH0004' , 'Collation not supported' )
 
-       seqParam = zeroOrMoreItem( seqParam )
        srchParam = oneItem( srchParam )
 
-       return tuple( i + 1 for i , item in enumerate( seqParam ) if compareValue( item , srchParam ) == 0 )
+       return Sequence( i + 1 for i , item in enumerate( seqParam ) if compareValue( item , srchParam ) == 0 )
 
 # XO/15.1.4
 @registerFast( 'fn:empty' )
@@ -885,9 +897,7 @@ def fnDistinctValues( context , arg ) :
 @registerFast( 'fn:insert-before' )
 def fnInsertBefore( context , target , position , inserts ) :
 
-       target = zeroOrMoreItem( target )
        position = asNumber( oneItem( position ) )
-       inserts = zeroOrMoreItem( inserts )
 
        position = max( 0 , int( position - 1 ) )
 
@@ -897,10 +907,9 @@ def fnInsertBefore( context , target , position , inserts ) :
 @registerFast( 'fn:remove' )
 def fnRemove( context , target , position ) :
 
-       target = zeroOrMoreItem( target )
        position = asNumber( oneItem( position ) )
 
-       position = max( -1 , int( position - 1 ) )
+       position = int( position - 1 )
        if position < 0 :
                return target
        else :
@@ -954,7 +963,7 @@ def fnUnordered( context , sourceSeq ) :
 @registerFast( 'fn:zero-or-one' )
 def fnZeroOrOne( context , arg ) :
 
-       zeroOrMoreItem( arg )
+       zeroOrOneItem( arg )
        return arg
 
 # XO/15.2.2
@@ -968,7 +977,7 @@ def fnOneOrMore( context , arg ) :
 @registerFast( 'fn:exactly-one' )
 def fnExactlyONe( context , arg ) :
 
-       zeroOrMoreItem( arg )
+       oneItem( arg )
        return arg
 
 # FIXME: Rewrite it as iterative (following one tree iteratively, and
@@ -1039,27 +1048,24 @@ def fnDeepEqual( context , parameter1 , parameter2 , collation = None ) :
        if collation is not None :
                raise XPathError( 'FOCH0004' , 'Collation not supported' )
 
-       parameter1 = zeroOrMoreItem( parameter1 )
-       parameter2 = zeroOrMoreItem( parameter2 )
-
        if not parameter1 and not parameter2 :
-               return True
+               return _True
        if len( parameter1 ) != len( parameter2 ) :
-               return False
+               return _False
        for i1 , i2 in zip( parameter1 , parameter2 ) :
                if isNode( i1 ) :
                        if isNode( i2 ) :
                                if not treeDeepEqual( i1 , i2 ) :
-                                       return False
+                                       return _False
                        else :
-                               return False
+                               return _False
                else :
                        if isNode( i2 ) :
-                               return False
+                               return _False
                        else :
                                if compareValue( i1 , i2 ) != 0 :
-                                       return False
-       return True
+                                       return _False
+       return _True
 
 # XO/15.4.1
 @registerFast( 'fn:count' )
@@ -1142,7 +1148,7 @@ def fnId( context , arg , node = None ) :
        doc = item.root
        if doc is None or not isDocument( doc ) :
                raise XPathError( 'FODC0001' , '..' ) # unsure
-       ids = map( asString , zeroOrMoreItem( arg ) )
+       ids = map( asString , arg )
        result = []
        for id in ids :
                n = doc.ids.get( id )
@@ -1220,15 +1226,13 @@ def opTo( context , arg1 , arg2 ) :
 @registerFast( 'op:union' )
 def opUnion( context , arg1 , arg2 ) :
 
+       if arg1.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) \
+               or arg2.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) :
+               raise XPathError( 'XPTY0004' , '\'union\' operator expect sequence of nodes only' )
        if not arg1 :
                return arg2
        if not arg2 :
                return arg1
-       if arg1.type != SEQUENCE_NODES \
-               or arg2.type != SEQUENCE_NODES :
-               raise XPathError( 'XPTY0004' , 'union operator expect sequence of nodes only' )
-       arg1 = zeroOrMoreNodes( arg1 )
-       arg2 = zeroOrMoreNodes( arg2 )
        result = list( set( arg1 ) | set( arg2 ) )
        result.sort( lambda a , b : cmp( a.position , b.position ) )
        return Sequence( result )
@@ -1236,15 +1240,13 @@ def opUnion( context , arg1 , arg2 ) :
 @registerFast( 'op:except' )
 def opExcept( context , arg1 , arg2 ) :
 
+       if arg1.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) \
+               or arg2.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) :
+               raise XPathError( 'XPTY0004' , '\'except\' operator expect sequence of nodes only' )
        if not arg1 :
-               return arg2
+               return _Empty
        if not arg2 :
                return arg1
-       if arg1.type != SEQUENCE_NODES \
-               or arg2.type != SEQUENCE_NODES :
-               raise XPathError( 'XPTY0004' , 'except operator expect sequence of nodes only' )
-       arg1 = zeroOrMoreNodes( arg1 )
-       arg2 = zeroOrMoreNodes( arg2 )
        result = list( set( arg1 ) - set( arg2 ) )
        result.sort( lambda a , b : cmp( a.position , b.position ) )
        return Sequence( result )
@@ -1252,15 +1254,11 @@ def opExcept( context , arg1 , arg2 ) :
 @registerFast( 'op:intersection' )
 def opIntersection( context , arg1 , arg2 ) :
 
-       if not arg1 :
-               return arg2
-       if not arg2 :
-               return arg1
-       if arg1.type != SEQUENCE_NODES \
-               or arg2.type != SEQUENCE_NODES :
-               raise XPathError( 'XPTY0004' , 'intersection operator expect sequence of nodes only' )
-       arg1 = zeroOrMoreNodes( arg1 )
-       arg2 = zeroOrMoreNodes( arg2 )
+       if arg1.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) \
+               or arg2.type not in ( SEQUENCE_EMPTY , SEQUENCE_NODES ) :
+               raise XPathError( 'XPTY0004' , '\'intersect\' operator expect sequence of nodes only' )
+       if not arg1 or not arg2 :
+               return _Empty
        result = list( set( arg1 ) & set( arg2 ) )
        result.sort( lambda a , b : cmp( a.position , b.position ) )
        return Sequence( result )