diff --git a/src/ITR/data/base_providers.py b/src/ITR/data/base_providers.py index e306daa9..2c05da75 100644 --- a/src/ITR/data/base_providers.py +++ b/src/ITR/data/base_providers.py @@ -2,7 +2,7 @@ import warnings # needed until quantile behaves better with Pint quantities in arrays from functools import partial, reduce from operator import add -from typing import Any, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, cast import numpy as np import pandas as pd @@ -889,45 +889,32 @@ def get_company_projected_trajectories(self, company_ids: List[str], year=None) :param year: values for a specific year, or all years if None :return: A pandas DataFrame with projected intensity trajectories per company, indexed by company_id and scope """ - company_ids, scopes, projections = list( - map( - list, - zip( - *[ - ( - c.company_id, - EScope[scope_name], - c.projected_intensities[scope_name].projections, - ) - # FIXME: we should make _companies a dict so we can look things up rather than searching every time! - for c in self._companies - for scope_name in EScope.get_scopes() - if c.company_id in company_ids - if c.projected_intensities[scope_name] - ] - ), - ) - ) - if projections: - index = pd.MultiIndex.from_tuples(zip(company_ids, scopes), names=["company_id", "scope"]) - if year is not None: - if isinstance(projections[0], ICompanyEIProjectionsScopes): - values = [yvp.value for pt in projections for yvp in pt if yvp.year == year] - else: - values = list(map(lambda x: x[year].squeeze(), projections)) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # pint units don't like columns of heterogeneous data...tough! - return pd.Series(data=values, index=index, name=year) - else: - if isinstance(projections[0], ICompanyEIProjectionsScopes): - values = [{yvp.year: yvp.value for yvp in pt} for pt in projections] - else: - values = projections - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return pd.DataFrame(data=values, index=index) - return pd.DataFrame() + c_ids: List[str] = [] + scopes: List[EScope] = [] + projections: List[DF_ICompanyEIProjections] = [] + + for c in self._companies: + if c.company_id in company_ids: + for scope_name in EScope.get_scopes(): + if c.projected_intensities[scope_name]: + c_ids.append(c.company_id) + scopes.append(EScope[scope_name]) + projections.append(c.projected_intensities[scope_name].projections) + + if len(projections) == 0: + return pd.DataFrame() + index = pd.MultiIndex.from_tuples(zip(c_ids, scopes), names=["company_id", "scope"]) + if year is not None: + values = list(map(cast(Callable[[pd.Series], Any], lambda x: x[year].squeeze()), projections)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # pint units don't like columns of heterogeneous data...tough! + return pd.Series(data=values, index=index, name=year) + else: + values = projections + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return pd.DataFrame(data=values, index=index) def get_company_projected_targets(self, company_ids: List[str], year=None) -> pd.DataFrame: """