Skip to content

Commit

Permalink
Merge pull request rails#52428 from Shopify/refactor-adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
byroot authored Jul 29, 2024
2 parents 0b3320b + 8078ebc commit 4294d71
Show file tree
Hide file tree
Showing 20 changed files with 328 additions and 465 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, retu
# +binds+ as the bind substitutes. +name+ is logged along with
# the executed +sql+ statement.
def exec_delete(sql, name = nil, binds = [])
internal_exec_query(sql, name, binds)
affected_rows(internal_execute(sql, name, binds))
end

# Executes update +sql+ statement in the context of this connection using
# +binds+ as the bind substitutes. +name+ is logged along with
# the executed +sql+ statement.
def exec_update(sql, name = nil, binds = [])
internal_exec_query(sql, name, binds)
affected_rows(internal_execute(sql, name, binds))
end

def exec_insert_all(sql, name) # :nodoc:
Expand Down Expand Up @@ -532,30 +532,79 @@ def high_precision_current_timestamp
HIGH_PRECISION_CURRENT_TIMESTAMP
end

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
raise NotImplementedError
# Same as raw_execute but returns an ActiveRecord::Result object.
def raw_exec_query(...) # :nodoc:
cast_result(raw_execute(...))
end

# Execute a query and returns an ActiveRecord::Result
def internal_exec_query(...) # :nodoc:
cast_result(internal_execute(...))
end

private
def internal_execute(sql, name = "SCHEMA", allow_retry: false, materialize_transactions: true)
sql = transform_query(sql)
check_if_write_query(sql)
# Lowest level way to execute a query. Doesn't check for illegal writes, doesn't annotate queries, yields a native result object.
def raw_execute(sql, name = nil, binds = [], prepare: false, async: false, allow_retry: false, materialize_transactions: true)
type_casted_binds = type_casted_binds(binds)
notification_payload = {
sql: sql,
name: name,
binds: binds,
type_casted_binds: type_casted_binds,
async: async,
connection: self,
transaction: current_transaction.user_transaction.presence,
statement_name: nil,
row_count: 0,
}
@instrumenter.instrument("sql.active_record", notification_payload) do
with_raw_connection(allow_retry: allow_retry, materialize_transactions: materialize_transactions) do |conn|
perform_query(conn, sql, binds, type_casted_binds, prepare: prepare, notification_payload: notification_payload)
end
rescue ActiveRecord::StatementInvalid => ex
raise ex.set_query(sql, binds)
end
end

def perform_query(raw_connection, sql, binds, type_casted_binds, prepare:, notification_payload:)
raise NotImplementedError
end

# Receive a native adapter result object and returns an ActiveRecord::Result object.
def cast_result(raw_result)
raise NotImplementedError
end

def affected_rows(raw_result)
raise NotImplementedError
end

def preprocess_query(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)

raw_execute(sql, name, allow_retry: allow_retry, materialize_transactions: materialize_transactions)
# We call tranformers after the write checks so we don't add extra parsing work.
# This means we assume no transformer whille change a read for a write
# but it would be insane to do such a thing.
ActiveRecord.query_transformers.each do |transformer|
sql = transformer.call(sql, self)
end

sql
end

# Same as #internal_exec_query, but yields a native adapter result
def internal_execute(sql, name = "SQL", binds = [], prepare: false, async: false, allow_retry: false, materialize_transactions: true, &block)
sql = preprocess_query(sql)
raw_execute(sql, name, binds, prepare: prepare, async: async, allow_retry: allow_retry, materialize_transactions: materialize_transactions, &block)
end

def execute_batch(statements, name = nil)
statements.each do |statement|
internal_execute(statement, name)
raw_execute(statement, name)
end
end

def raw_execute(sql, name, async: false, allow_retry: false, materialize_transactions: true)
raise NotImplementedError
end

DEFAULT_INSERT_VALUE = Arel.sql("DEFAULT").freeze
private_constant :DEFAULT_INSERT_VALUE

Expand Down Expand Up @@ -637,6 +686,8 @@ def select(sql, name = nil, binds = [], prepare: false, async: false, allow_retr
raise AsynchronousQueryInsideTransactionError, "Asynchronous queries are not allowed inside transactions"
end

# We make sure to run query transformers on the orignal thread
sql = preprocess_query(sql)
future_result = async.new(
pool,
sql,
Expand All @@ -649,14 +700,14 @@ def select(sql, name = nil, binds = [], prepare: false, async: false, allow_retr
else
future_result.execute!(self)
end
return future_result
end

result = internal_exec_query(sql, name, binds, prepare: prepare, allow_retry: allow_retry)
if async
FutureResult.wrap(result)
future_result
else
result
result = internal_exec_query(sql, name, binds, prepare: prepare, allow_retry: allow_retry)
if async
FutureResult.wrap(result)
else
result
end
end
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def sanitize_as_sql_comment(value) # :nodoc:

private
def type_casted_binds(binds)
binds.map do |value|
binds&.map do |value|
if ActiveModel::Attribute === value
type_cast(value.value_for_database)
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,24 +1106,25 @@ def type_map
end
end

def translate_exception_class(e, sql, binds)
message = "#{e.class.name}: #{e.message}"
def translate_exception_class(native_error, sql, binds)
return native_error if native_error.is_a?(ActiveRecordError)

exception = translate_exception(
e, message: message, sql: sql, binds: binds
message = "#{native_error.class.name}: #{native_error.message}"

active_record_error = translate_exception(
native_error, message: message, sql: sql, binds: binds
)
exception.set_backtrace e.backtrace
exception
active_record_error.set_backtrace(native_error.backtrace)
active_record_error
end

def log(sql, name = "SQL", binds = [], type_casted_binds = [], statement_name = nil, async: false, &block) # :doc:
def log(sql, name = "SQL", binds = [], type_casted_binds = [], async: false, &block) # :doc:
@instrumenter.instrument(
"sql.active_record",
sql: sql,
name: name,
binds: binds,
type_casted_binds: type_casted_binds,
statement_name: statement_name,
async: async,
connection: self,
transaction: current_transaction.user_transaction.presence,
Expand All @@ -1134,13 +1135,6 @@ def log(sql, name = "SQL", binds = [], type_casted_binds = [], statement_name =
raise ex.set_query(sql, binds)
end

def transform_query(sql)
ActiveRecord.query_transformers.each do |transformer|
sql = transformer.call(sql, self)
end
sql
end

def translate_exception(exception, message:, sql:, binds:)
# override in derived class
case exception
Expand All @@ -1152,7 +1146,7 @@ def translate_exception(exception, message:, sql:, binds:)
end

def without_prepared_statement?(binds)
!prepared_statements || binds.empty?
!prepared_statements || binds.nil? || binds.empty?
end

def column_for(table_name, column_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,6 @@ def index_algorithms

# HELPER METHODS ===========================================

# The two drivers have slightly different ways of yielding hashes of results, so
# this method must be implemented to provide a uniform interface.
def each_hash(result) # :nodoc:
raise NotImplementedError
end

# Must return the MySQL error number from the exception, if the exception has an
# error number.
def error_number(exception) # :nodoc:
Expand All @@ -226,17 +220,6 @@ def disable_referential_integrity # :nodoc:
# DATABASE STATEMENTS ======================================
#++

# Mysql2Adapter doesn't have to free a result after using it, but we use this method
# to write stuff in an abstract way without concerning ourselves about whether it
# needs to be explicitly freed or not.
def execute_and_free(sql, name = nil, async: false, allow_retry: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)

mark_transaction_written_if_write(sql)
yield raw_execute(sql, name, async: async, allow_retry: allow_retry)
end

def begin_db_transaction # :nodoc:
internal_execute("BEGIN", "TRANSACTION", allow_retry: true, materialize_transactions: false)
end
Expand Down Expand Up @@ -787,11 +770,6 @@ def warning_ignored?(warning)
warning.level == "Note" || super
end

# Make sure we carry over any changes to ActiveRecord.default_timezone that have been
# made since we established the connection
def sync_timezone_changes(raw_connection)
end

# See https://dev.mysql.com/doc/mysql-errors/en/server-error-reference.html
ER_DB_CREATE_EXISTS = 1007
ER_FILSORT_ABORT = 1028
Expand Down Expand Up @@ -961,13 +939,11 @@ def configure_connection
end.join(", ")

# ...and send them all in one query
internal_execute("SET #{encoding} #{sql_mode_assignment} #{variable_assignments}")
raw_execute("SET #{encoding} #{sql_mode_assignment} #{variable_assignments}", "SCHEMA")
end

def column_definitions(table_name) # :nodoc:
execute_and_free("SHOW FULL FIELDS FROM #{quote_table_name(table_name)}", "SCHEMA") do |result|
each_hash(result)
end
internal_exec_query("SHOW FULL FIELDS FROM #{quote_table_name(table_name)}", "SCHEMA")
end

def create_table_info(table_name) # :nodoc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,43 @@ module SchemaStatements # :nodoc:
def indexes(table_name)
indexes = []
current_index = nil
execute_and_free("SHOW KEYS FROM #{quote_table_name(table_name)}", "SCHEMA") do |result|
each_hash(result) do |row|
if current_index != row[:Key_name]
next if row[:Key_name] == "PRIMARY" # skip the primary key
current_index = row[:Key_name]

mysql_index_type = row[:Index_type].downcase.to_sym
case mysql_index_type
when :fulltext, :spatial
index_type = mysql_index_type
when :btree, :hash
index_using = mysql_index_type
end

indexes << [
row[:Table],
row[:Key_name],
row[:Non_unique].to_i == 0,
[],
lengths: {},
orders: {},
type: index_type,
using: index_using,
comment: row[:Index_comment].presence
]
internal_exec_query("SHOW KEYS FROM #{quote_table_name(table_name)}", "SCHEMA").each do |row|
if current_index != row["Key_name"]
next if row["Key_name"] == "PRIMARY" # skip the primary key
current_index = row["Key_name"]

mysql_index_type = row["Index_type"].downcase.to_sym
case mysql_index_type
when :fulltext, :spatial
index_type = mysql_index_type
when :btree, :hash
index_using = mysql_index_type
end

if row[:Expression]
expression = row[:Expression].gsub("\\'", "'")
expression = +"(#{expression})" unless expression.start_with?("(")
indexes.last[-2] << expression
indexes.last[-1][:expressions] ||= {}
indexes.last[-1][:expressions][expression] = expression
indexes.last[-1][:orders][expression] = :desc if row[:Collation] == "D"
else
indexes.last[-2] << row[:Column_name]
indexes.last[-1][:lengths][row[:Column_name]] = row[:Sub_part].to_i if row[:Sub_part]
indexes.last[-1][:orders][row[:Column_name]] = :desc if row[:Collation] == "D"
end
indexes << [
row["Table"],
row["Key_name"],
row["Non_unique"].to_i == 0,
[],
lengths: {},
orders: {},
type: index_type,
using: index_using,
comment: row["Index_comment"].presence
]
end

if expression = row["Expression"]
expression = expression.gsub("\\'", "'")
expression = +"(#{expression})" unless expression.start_with?("(")
indexes.last[-2] << expression
indexes.last[-1][:expressions] ||= {}
indexes.last[-1][:expressions][expression] = expression
indexes.last[-1][:orders][expression] = :desc if row["Collation"] == "D"
else
indexes.last[-2] << row["Column_name"]
indexes.last[-1][:lengths][row["Column_name"]] = row["Sub_part"].to_i if row["Sub_part"]
indexes.last[-1][:orders][row["Column_name"]] = :desc if row["Collation"] == "D"
end
end

Expand Down Expand Up @@ -182,12 +180,12 @@ def default_type(table_name, field_name)
end

def new_column_from_field(table_name, field, _definitions)
field_name = field.fetch(:Field)
type_metadata = fetch_type_metadata(field[:Type], field[:Extra])
default, default_function = field[:Default], nil
field_name = field.fetch("Field")
type_metadata = fetch_type_metadata(field["Type"], field["Extra"])
default, default_function = field["Default"], nil

if type_metadata.type == :datetime && /\ACURRENT_TIMESTAMP(?:\([0-6]?\))?\z/i.match?(default)
default = "#{default} ON UPDATE #{default}" if /on update CURRENT_TIMESTAMP/i.match?(field[:Extra])
default = "#{default} ON UPDATE #{default}" if /on update CURRENT_TIMESTAMP/i.match?(field["Extra"])
default, default_function = nil, default
elsif type_metadata.extra == "DEFAULT_GENERATED"
default = +"(#{default})" unless default.start_with?("(")
Expand All @@ -203,13 +201,13 @@ def new_column_from_field(table_name, field, _definitions)
end

MySQL::Column.new(
field[:Field],
field["Field"],
default,
type_metadata,
field[:Null] == "YES",
field["Null"] == "YES",
default_function,
collation: field[:Collation],
comment: field[:Comment].presence
collation: field["Collation"],
comment: field["Comment"].presence
)
end

Expand Down
Loading

0 comments on commit 4294d71

Please sign in to comment.