Source

django-tsearch2 / tsearch2 / managers.py

Full commit
from django.db import models
from django.db import connection


class SearchManager(models.Manager):
    def __init__(self, fields=None, config=None):
        self.fields = fields
        self.config = config or 'pg_catalog.english'
        self._vector_field_cache = None
        self._from_clause = []
        self._where_clause = []
        super(SearchManager, self).__init__()
    
    
    vector_field = property(lambda self: getattr(self.model, self.model._meta.tsvector_field_name))
    
    def _vector_sql(self, name, weight):
        """
        Returns the SQL used to build a tsvector from the given (django) field name.
        """
        if not '__' in name:
            field = self.model._meta.get_field(name)
            table = self.model._meta.db_table
            column = field.column
        else:
            field = self.model._meta.get_field(name.split("__")[0])
            table = field.rel.to._meta.db_table
            column = name.split("__")[1]
            if table not in self._from_clause:
                self._from_clause.append(table)
            self._where_clause.append("%s.%s=%s.%s" % (self.model._meta.db_table, field.attname, table, field.rel.field_name))
        return "setweight( to_tsvector( '%s', coalesce(%s.%s,'') ), '%s' )" % (self.config, table, column, weight)
    
    def get_where_clause(self):
        if self._where_clause:
            return " WHERE %s" % " AND ".join(self._where_clause)
        return ''

    def get_from_clause(self):
        if self._from_clause:
            return " FROM %s" % ", ".join(self._from_clause)
        return ''

    def update_index(self, pk=None):
        """
        Updates the full-text index for one, many, or all instances of this manager's model.
        """
        # Build SELECT clause for all specified field.
        select = ' || '.join(self._vector_sql(field, weight) for field, weight in self.fields.items())
        if pk:
            if isinstance(pk, (list, tuple)):
                ids = ','.join([str(v) for v in pk])
                self._where_clause.append("%s.%s IN (%s)" % (self.model._meta.db_table, self.model._meta.pk.column, ids))
            else:
                self._where_clause.append("%s.%s=%s" % (self.model._meta.db_table, self.model._meta.pk.column, pk))
        sql = "UPDATE %s SET %s = %s %s %s;" % (self.model._meta.db_table, self.vector_field.column, select, self.get_from_clause(), self.get_where_clause())
        cursor = connection.cursor()
        cursor.execute(sql)
        cursor.execute("COMMIT;")
        cursor.close()

    
    def search(self, query, rank_field=None, rank_normalization=32):
        """
        Returns a queryset after having applied the full-text search query. If rank_field
        is specified, it is the name of the field that will be put on each returned instance.
        When specifying a rank_field, the results will automatically be ordered by -rank_field.
        
        For possible rank_normalization values, refer to:
        http://www.postgresql.org/docs/8.4/static/textsearch-controls.html#TEXTSEARCH-RANKING
        """
        ts_query = u"plainto_tsquery('%s','%s')" % (self.config, unicode(query).replace("'","''"))
        where = u"\"%s\" @@ %s" % (self.vector_field.column, ts_query)
        select = {}
        order = []
        if rank_field is not None:
            select[rank_field] = u'ts_rank( "%s", %s, %d )' % (self.vector_field.column, ts_query, rank_normalization)
            order = ['-%s' % rank_field]
        return self.all().extra(select=select, where=[where], order_by=order)