Commits

Mike Bayer committed 6376006

- [bug] Fixed cextension bug whereby the
"ambiguous column error" would fail to
function properly if the given index were
a Column object and not a string.
Note there are still some column-targeting
issues here which are fixed in 0.8.
[ticket:2553]
- find more cases where column targeting is being inaccurate, add
more information to result_map to better differentiate "ambiguous"
results from "present" or "not present". In particular, result_map
is sensitive to dupes, even though no error is raised; the conflicting
columns are added to the "obj" member of the tuple so that the two
are both directly accessible in the result proxy
- handwringing over the damn "name fallback" thing in results. can't
really make it perfect yet
- fix up oracle returning clause. not sure why its guarding against
labels, remove that for now and see what the bot says.

  • Participants
  • Parent commits 3548891

Comments (0)

Files changed (9)

     the absense of which was preventing the new
     GAE dialect from being loaded.  [ticket:2529]
 
+  - [bug] Fixed cextension bug whereby the
+    "ambiguous column error" would fail to
+    function properly if the given index were
+    a Column object and not a string.
+    Note there are still some column-targeting
+    issues here which are fixed in 0.8.
+    [ticket:2553]
+
   - [bug] Fixed the repr() of Enum to include
     the "name" and "native_enum" flags.  Helps
     Alembic autogenerate.

lib/sqlalchemy/cextension/resultproxy.c

     PyObject *processors, *values;
     PyObject *processor, *value, *processed_value;
     PyObject *row, *record, *result, *indexobject;
-    PyObject *exc_module, *exception;
+    PyObject *exc_module, *exception, *cstr_obj;
     char *cstr_key;
     long index;
     int key_fallback = 0;
             if (exception == NULL)
                 return NULL;
 
-            cstr_key = PyString_AsString(key);
-            if (cstr_key == NULL)
+            // wow.  this seems quite excessive.
+            cstr_obj = PyObject_Str(key);
+            if (cstr_obj == NULL)
                 return NULL;
+            cstr_key = PyString_AsString(cstr_obj);
+            if (cstr_key == NULL) {
+                Py_DECREF(cstr_obj);
+                return NULL;
+            }
+            Py_DECREF(cstr_obj);
 
             PyErr_Format(exception,
                     "Ambiguous column name '%.200s' in result set! "

lib/sqlalchemy/dialects/mssql/base.py

                                         t, column)
 
                 if add_to_result_map is not None:
-                    self.result_map[column.name
-                                if self.dialect.case_sensitive
-                                else column.name.lower()] = \
-                                    (column.name, (column, ) + add_to_result_map,
-                                                    column.type)
+                    add_to_result_map(
+                            column.name,
+                            column.name,
+                            (column, ),
+                            column.type
+                    )
 
                 return super(MSSQLCompiler, self).\
-                                visit_column(converted,
-                                            result_map=None, **kwargs)
+                                visit_column(converted, **kwargs)
 
         return super(MSSQLCompiler, self).visit_column(
                         column, add_to_result_map=add_to_result_map, **kwargs)

lib/sqlalchemy/dialects/oracle/base.py

 
         columnlist = list(expression._select_iterables(returning_cols))
 
-        # within_columns_clause =False so that labels (foo AS bar) don't render
-        columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) for c in columnlist]
+        columns = [
+                self._label_select_column(None, c, True, False, {})
+                for c in columnlist
+            ]
 
         binds = [create_out_param(c, i) for i, c in enumerate(columnlist)]
 
-        return 'RETURNING ' + ', '.join(columns) +  " INTO " + ", ".join(binds)
+        return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
 
     def _TODO_visit_compound_select(self, select):
         """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""

lib/sqlalchemy/engine/result.py

                 # unambiguous.
                 primary_keymap[name
                                 if self.case_sensitive
-                                else name.lower()] = (processor, obj, None)
+                                else name.lower()] = rec = (processor, obj, None)
 
             self.keys.append(colname)
             if obj:

lib/sqlalchemy/sql/compiler.py

         return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
 
     def visit_label(self, label,
-                            add_to_result_map = None,
+                            add_to_result_map=None,
                             within_label_clause=False,
                             within_columns_clause=False, **kw):
         # only render labels within the columns clause
                 labelname = label.name
 
             if add_to_result_map is not None:
-                self.result_map[
-                        labelname
-                            if self.dialect.case_sensitive
-                            else labelname.lower()
-                        ] = (
-                            label.name,
-                            (label, label.element, labelname, ) +
-                                label._alt_names +
-                                add_to_result_map,
-                            label.type,
-                        )
+                add_to_result_map(
+                        labelname,
+                        label.name,
+                        (label, label.element, labelname, ) + label._alt_names,
+                        label.type
+                )
 
             return label.element._compiler_dispatch(self,
                                     within_columns_clause=True,
             name = self._truncated_identifier("colident", name)
 
         if add_to_result_map is not None:
-            self.result_map[
-                        name
-                        if self.dialect.case_sensitive
-                        else name.lower()
-                    ] = (
-                        orig_name,
-                        (column, name, column.key) + add_to_result_map,
-                        column.type
-                    )
+            add_to_result_map(
+                name,
+                orig_name,
+                (column, name, column.key),
+                column.type
+            )
 
         if is_literal:
             name = self.escape_literal_column(name)
 
     def visit_function(self, func, add_to_result_map=None, **kwargs):
         if add_to_result_map is not None:
-            self.result_map[
-                        func.name
-                        if self.dialect.case_sensitive
-                        else func.name.lower()
-                    ] = (func.name, add_to_result_map, func.type)
+            add_to_result_map(
+                func.name, func.name, (), func.type
+            )
 
         disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
         if disp:
         else:
             return alias.original._compiler_dispatch(self, **kwargs)
 
+    def _add_to_result_map(self, keyname, name, objects, type_):
+        if not self.dialect.case_sensitive:
+            keyname = keyname.lower()
+
+        if keyname in self.result_map:
+            # conflicting keyname, just double up the list
+            # of objects.  this will cause an "ambiguous name"
+            # error if an attempt is made by the result set to
+            # access.
+            e_name, e_obj, e_type = self.result_map[keyname]
+            self.result_map[keyname] = e_name, e_obj + objects, e_type
+        else:
+            self.result_map[keyname] = name, objects, type_
+
     def _label_select_column(self, select, column, populate_result_map,
                                     asfrom, column_clause_args):
         """produce labeled columns present in a select()."""
         if column.type._has_column_expression:
             col_expr = column.type.column_expression(column)
             if populate_result_map:
-                add_to_result_map = (column, )
+                add_to_result_map = lambda keyname, name, objects, type_: \
+                                    self._add_to_result_map(
+                                            keyname, name,
+                                            objects + (column,), type_)
             else:
                 add_to_result_map = None
         else:
             col_expr = column
             if populate_result_map:
-                add_to_result_map = ()
+                add_to_result_map = self._add_to_result_map
             else:
                 add_to_result_map = None
 
                         **column_clause_args
                     )
 
-
     def format_from_hint_text(self, sqltext, table, hint, iscrud):
         hinttext = self.get_from_hint_text(table, hint)
         if hinttext:

lib/sqlalchemy/sql/expression.py

         self.is_literal = is_literal
 
     def _compare_name_for_result(self, other):
+        # TODO: this still isn't 100% correct
         if self.table is not None and hasattr(other, 'proxy_set'):
-            return other.proxy_set.intersection(self.proxy_set)
+            return self.proxy_set.intersection(other.proxy_set)
         else:
             return super(ColumnClause, self).\
                     _compare_name_for_result(other)

test/dialect/test_oracle.py

                             'addresses.user_id = :user_id_1 ORDER BY '
                             'addresses.id, address_types.id')
 
+    def test_returning_insert(self):
+        t1 = table('t1', column('c1'), column('c2'), column('c3'))
+        self.assert_compile(
+            t1.insert().values(c1=1).returning(t1.c.c2, t1.c.c3),
+            "INSERT INTO t1 (c1) VALUES (:c1) RETURNING "
+                "t1.c2, t1.c3 INTO :ret_0, :ret_1"
+        )
+
     def test_compound(self):
         t1 = table('t1', column('c1'), column('c2'), column('c3'))
         t2 = table('t2', column('c1'), column('c2'), column('c3'))

test/sql/test_query.py

 
     def test_ambiguous_column(self):
         users.insert().execute(user_id=1, user_name='john')
-        r = users.outerjoin(addresses).select().execute().first()
+        result = users.outerjoin(addresses).select().execute()
+        r = result.first()
+
         assert_raises_message(
             exc.InvalidRequestError,
             "Ambiguous column name",
             lambda: r['user_id']
         )
 
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name",
+            lambda: r[users.c.user_id]
+        )
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name",
+            lambda: r[addresses.c.user_id]
+        )
+
+        # try to trick it - fake_table isn't in the result!
+        # we get the correct error
+        fake_table = Table('fake', MetaData(), Column('user_id', Integer))
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Could not locate column in row for column 'fake.user_id'",
+            lambda: r[fake_table.c.user_id]
+        )
+
         r = util.pickle.loads(util.pickle.dumps(r))
         assert_raises_message(
             exc.InvalidRequestError,
             lambda: r['user_id']
         )
 
+    def test_ambiguous_column_by_col(self):
+        users.insert().execute(user_id=1, user_name='john')
+        ua = users.alias()
+        u2 = users.alias()
+        result = select([users.c.user_id, ua.c.user_id]).execute()
+        row = result.first()
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name",
+            lambda: row[users.c.user_id]
+        )
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name",
+            lambda: row[ua.c.user_id]
+        )
+
+        # Unfortunately, this fails -
+        # we'd like
+        # "Could not locate column in row"
+        # to be raised here, but the check for
+        # "common column" in _compare_name_for_result()
+        # has other requirements to be more liberal.
+        # Ultimately the
+        # expression system would need a way to determine
+        # if given two columns in a "proxy" relationship, if they
+        # refer to a different parent table
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name",
+            lambda: row[u2.c.user_id]
+        )
+
     @testing.requires.subqueries
     def test_column_label_targeting(self):
         users.insert().execute(user_id=7, user_name='ed')
         keyed3 = self.tables.keyed3
 
         row = testing.db.execute(select([keyed1, keyed3])).first()
-        assert 'b' not in row
         eq_(row.q, "c1")
         assert_raises_message(
             exc.InvalidRequestError,
+            "Ambiguous column name 'b'",
+            getattr, row, "b"
+        )
+        assert_raises_message(
+            exc.InvalidRequestError,
             "Ambiguous column name 'a'",
             getattr, row, "a"
         )