diff --git a/README.md b/README.md index cf4ab46..fb3d1c3 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,8 @@ Call `preload` when defining your field: # Post.includes(:comments, :authors) preload [:comments, :authors] - # Post.includes(:comments, authors: [:posts, :followers]) - preload [:comments, { authors: [:posts, :followers] }] + # Post.includes(:comments, authors: [:followers, :posts]) + preload [:comments, { authors: [:followers, :posts] }] resolve ->(obj, args, ctx) { obj.comments } end diff --git a/lib/graphql/preload/instrument.rb b/lib/graphql/preload/instrument.rb index c231698..9153404 100644 --- a/lib/graphql/preload/instrument.rb +++ b/lib/graphql/preload/instrument.rb @@ -33,17 +33,17 @@ def instrument(_type, field) promises << preload(record, sub_association) end when Hash - association.each do |sub_association, property| + association.each do |sub_association, nested_association| promises << preload_single_association(record, sub_association).then do associated_records = record.public_send(sub_association) case associated_records when ActiveRecord::Base - preload(associated_records, property) + preload(associated_records, nested_association) else Promise.all( Array.wrap(associated_records).map do |associated_record| - preload(associated_record, property) + preload(associated_record, nested_association) end ) end @@ -56,7 +56,6 @@ def instrument(_type, field) end private def preload_single_association(record, association) - return Promise.resolve(record) if record.association(association).loaded? GraphQL::Preload::Loader.for(record.class, association).load(record) end end diff --git a/lib/graphql/preload/loader.rb b/lib/graphql/preload/loader.rb index 424c8f8..e902fe2 100644 --- a/lib/graphql/preload/loader.rb +++ b/lib/graphql/preload/loader.rb @@ -4,6 +4,10 @@ module Preload class Loader < GraphQL::Batch::Loader attr_reader :association, :model + def cache_key(record) + record.object_id + end + def initialize(model, association) @association = association @model = model @@ -13,38 +17,41 @@ def initialize(model, association) def load(record) unless record.is_a?(model) - raise TypeError, "loader for #{model.name} can't load associations for #{record.class.name} objects" + raise TypeError, "Loader for #{model} can't load associations for #{record.class} objects" end - if record.association(association).loaded? - Promise.resolve(record) - else - super - end + return Promise.resolve(record) if association_loaded?(record) + super end def perform(records) + preload_association(records) + records.each { |record| fulfill(record, record) } + end + + private def association_loaded?(record) + record.association(association).loaded? + end + + private def preload_association(records) if ActiveRecord::VERSION::MAJOR > 3 ActiveRecord::Associations::Preloader.new.preload(records, association) else ActiveRecord::Associations::Preloader.new(records, association).run end - - records.each { |record| fulfill(record, record) } end private def validate_association unless association.is_a?(Symbol) - raise ArgumentError, 'association must be a Symbol object' + raise ArgumentError, 'Association must be a Symbol object' end unless model < ActiveRecord::Base - raise ArgumentError, 'model must be an ActiveRecord::Base descendant' + raise ArgumentError, 'Model must be an ActiveRecord::Base descendant' end return if model.reflect_on_association(association) - - raise TypeError, "association :#{association} does not exist on #{model.name}" + raise TypeError, "Association :#{association} does not exist on #{model}" end end end diff --git a/lib/graphql/preload/version.rb b/lib/graphql/preload/version.rb index 6fd26b8..2b5c303 100644 --- a/lib/graphql/preload/version.rb +++ b/lib/graphql/preload/version.rb @@ -1,5 +1,5 @@ module GraphQL module Preload - VERSION = '1.0.2'.freeze + VERSION = '1.0.3'.freeze end end