diff --git a/Sources/Fluent/QueryBuilder/QueryBuilder+Aggregate.swift b/Sources/Fluent/QueryBuilder/QueryBuilder+Aggregate.swift index 72a931b0..93fdbf26 100644 --- a/Sources/Fluent/QueryBuilder/QueryBuilder+Aggregate.swift +++ b/Sources/Fluent/QueryBuilder/QueryBuilder+Aggregate.swift @@ -1,21 +1,33 @@ extension QueryBuilder { // MARK: Aggregate - /// Returns the sum of all entries for the supplied field, falling back to the default value if the Future containing the sum resolves to an Error. + /// Returns the sum of all entries for the supplied field. /// /// let totalLikes = try Post.query(on: conn).sum(\.likes) - /// let totalViralPostLikes = try Post.query(on: conn).filter(\.likes >= 10_000_000).sum(\.likes, default: 0) + /// + /// If a default value is supplied, it will be used when the sum's result + /// set is empty and no sum can be determined. + /// + /// let totalViralPostLikes = try Post.query(on: conn) + /// .filter(\.likes >= 10_000_000) + /// .sum(\.likes, default: 0) /// /// - parameters: /// - field: Field to sum. /// - default: Optional default to use. /// - returns: A `Future` containing the sum. public func sum(_ field: KeyPath, default: T? = nil) -> Future where T: Decodable { - return aggregate(Database.queryAggregateSum, field: field).catchMap { error in - guard let d = `default` else { - throw error + return self.count().flatMap { count in + switch count { + case 0: + if let d = `default` { + return self.connection.map { _ in d } + } else { + throw FluentError(identifier: "noSumResults", reason: "Sum query returned 0 results and no default was supplied.") + } + default: + return self.aggregate(Database.queryAggregateSum, field: field) } - return d } }