Source

django-tsearch2 / tsearch2 / managers.py

Full commit
from django.db import models
from django.db import connection


class SearchManager(models.Manager):
    def __init__(self, fields=None, config=None):
        self.fields = fields
        self.default_weight = 'D'
        self.config = config or 'pg_catalog.english'
        self._vector_field_cache = None
        super(SearchManager, self).__init__()
    
    #def contribute_to_class( self, cls, name ):
    #    # Instances need to get to us to update their indexes.
    #    setattr( cls, '_search_manager', self )
    #    super( SearchManager, self ).contribute_to_class( cls, name )
    
    def _find_text_fields(self):
        """
        Return the names of all CharField and TextField fields defined for this manager's model.
        """
        return [f.name for f in self.model._meta.fields if isinstance(f, (models.CharField, models.TextField))]
    
    def _vector_field(self):
        """
        Returns the VectorField defined for this manager's model. 
        """
        if self._vector_field_cache is not None:
            return self._vector_field_cache
        self._vector_field_cache = getattr(self.model, self.model._meta.tsvector_field_name)
        return self._vector_field_cache

    vector_field = property(_vector_field)
    
    def _vector_sql(self, field, weight=None):
        """
        Returns the SQL used to build a tsvector from the given (django) field name.
        """
        if weight is None:
            weight = self.default_weight
        f = self.model._meta.get_field(field)
        return "setweight( to_tsvector( '%s', coalesce(\"%s\",'') ), '%s' )" % (self.config, f.column, weight)
    
    def update_index(self, pk=None):
        """
        Updates the full-text index for one, many, or all instances of this manager's model.
        """
        # Build a list of SQL clauses that generate tsvectors for each specified field.
        clauses = []
        if self.fields is None:
            self.fields = self._find_text_fields()
        if isinstance(self.fields, (list, tuple)):
            for field in self.fields:
                clauses.append(self._vector_sql(field))
        else:
            for field, weight in self.fields.items():
                clauses.append(self._vector_sql(field, weight))
        vector_sql = ' || '.join(clauses)
        where = ''
        # If one or more pks are specified, tack a WHERE clause onto the SQL.
        if pk is not None:
            if isinstance(pk, (list, tuple)):
                ids = ','.join([str(v) for v in pk])
                where = " WHERE \"%s\" IN (%s)" % (self.model._meta.pk.column, ids)
            else:
                where = " WHERE \"%s\" = %s" % (self.model._meta.pk.column, pk)
        sql = "UPDATE \"%s\" SET \"%s\" = %s%s;" % (self.model._meta.db_table, self.vector_field.column, vector_sql, where)
        cursor = connection.cursor()
        cursor.execute(sql)
        cursor.execute("COMMIT;")
        cursor.close()
    
    def search(self, query, rank_field=None, rank_normalization=32):
        """
        Returns a queryset after having applied the full-text search query. If rank_field
        is specified, it is the name of the field that will be put on each returned instance.
        When specifying a rank_field, the results will automatically be ordered by -rank_field.
        
        For possible rank_normalization values, refer to:
        http://www.postgresql.org/docs/8.4/static/textsearch-controls.html#TEXTSEARCH-RANKING
        """
        ts_query = u"plainto_tsquery('%s','%s')" % (self.config, unicode(query).replace("'","''"))
        where = u"\"%s\" @@ %s" % (self.vector_field.column, ts_query)
        select = {}
        order = []
        if rank_field is not None:
            select[rank_field] = u'ts_rank( "%s", %s, %d )' % (self.vector_field.column, ts_query, rank_normalization)
            order = ['-%s' % rank_field]
        return self.all().extra(select=select, where=[where], order_by=order)