Source code for nqs_sdk.bindings.env_builder

import json
from typing import Any, List, Optional, Tuple

from nqs_sdk import MutBuilderSharedState, ProtocolFactoryAdapter, SimulationClock, Simulator, SimulatorBuilder
from nqs_sdk.bindings.protocols.protocol_infos import ProtocolInfos
from nqs_sdk.bindings.spots.spot_generator import SpotGenerator
from nqs_sdk.interfaces.protocol import Protocol
from nqs_sdk.interfaces.protocol_factory import ProtocolFactory
from nqs_sdk.interfaces.protocol_metafactory import ProtocolMetaFactory
from nqs_sdk.interfaces.tx_generator import TxGenerator
from nqs_sdk.utils.json_decimal_encoder import DecimalEncoder


[docs] class Agent:
[docs] def __init__(self, name: str, wallet: dict[str, float]) -> None: self.name = name self.wallet = wallet
[docs] class SimTxGenerator(ProtocolFactory):
[docs] def __init__(self) -> None: self.tx_generators: list[TxGenerator] = []
[docs] def id(self) -> str: return "sim_tx_generator"
[docs] def register_tx_generator(self, tx_generator: TxGenerator) -> None: self.tx_generators.append(tx_generator)
[docs] def build( self, clock: SimulationClock, builder_state: MutBuilderSharedState, common_config: Any, backtest: bool, config: Any, ) -> Tuple[List[Protocol], List[TxGenerator]]: return [], self.tx_generators
[docs] class SimulatorEnvBuilder:
[docs] def __init__( self, common_args: Optional[dict[str, Any]] = None, save_config: Optional[str] = None, do_backtest: bool = False ) -> None: self.factories: dict[str, ProtocolMetaFactory] = {} # collect_all_observables is True by default if common_args is None: self.common_args: dict[str, Any] = {"collect_all_observables": False} else: self.common_args = common_args if "collect_all_observables" not in self.common_args: self.common_args["collect_all_observables"] = False self.tokens_info: dict[str, int] = {} self.sim_tx_generator = SimTxGenerator() self.spot_generators: list[SpotGenerator] = [] self.agents: list[Agent] = [] self.save_config = save_config self.do_backtest = do_backtest
[docs] def set_simulator_time(self, start_block: int, end_block: int, block_step_metrics: int) -> None: self.common_args.update( {"block_number_start": start_block, "block_number_end": end_block, "block_step_metrics": block_step_metrics} )
[docs] def set_numeraire(self, numeraire: str) -> None: self.common_args.update({"numeraire": numeraire})
[docs] def set_gas_fee(self, gas_fee: float, gas_fee_ccy: Optional[str] = None) -> None: self.common_args.update({"gas_fee": gas_fee}) if gas_fee_ccy is not None: self.common_args.update({"gas_fee_ccy": gas_fee_ccy})
[docs] def register_factory(self, factory: ProtocolMetaFactory) -> None: self.factories[factory.id()] = factory
[docs] def register_agent(self, agent_name: str, agent_wallet: dict[str, float]) -> None: self.agents.append(Agent(agent_name, agent_wallet))
[docs] def register_spot_generator(self, spot_generator: SpotGenerator) -> None: self.spot_generators.append(spot_generator)
[docs] def get_all_agents(self) -> list[Agent]: return self.agents
[docs] def register_protocol(self, protocol: ProtocolInfos) -> None: assert protocol.factory_id in self.factories, f"Factory {protocol.factory_id} not registered" self.factories[protocol.factory_id].register_protocol(protocol) self.tokens_info.update(protocol.get_token_infos())
[docs] def register_tx_generator(self, tx_generator: TxGenerator) -> None: self.sim_tx_generator.register_tx_generator(tx_generator)
[docs] def get_common_args(self) -> dict: return self.common_args
[docs] def build(self) -> Simulator: assert len(self.factories) > 0, "No factories registered" assert len(self.spot_generators) > 0, "No spot generators registered" assert "numeraire" in self.common_args, "Numeraire not registered" assert self.common_args["gas_fee_ccy"] in self.tokens_info, "Gas fee currency not registered" assert self.common_args["gas_fee"] >= 0, "Gas fee must be positive" assert self.common_args["block_number_start"] >= 0, "Start block must be positive" assert self.common_args["block_number_end"] >= self.common_args["block_number_start"], ( "End block must be greater than start block" ) assert self.common_args["block_step_metrics"] > 0, "Block step metrics must be positive" tokens_info = {token: {"decimals": decimals} for token, decimals in self.tokens_info.items()} spot_list = [] environment: dict if self.do_backtest: for spot_generator in self.spot_generators: for name in spot_generator.names: spot_list.append({"name": name, "historical": {}}) factory_configs = {} for factory in self.factories.values(): pools = [] for k, v in factory.get_config().items(): # FIXME HACK OF INTERNAL STRUCTURES pools.extend(v["initial_state"]["historical_state"]["pools"]) factory_configs.update({k: {"pools": pools}}) environment = {"backtest_environment": {"protocols_to_replay": factory_configs}} else: for spot_generator in self.spot_generators: spots = spot_generator.generate_spot( self.common_args["block_number_start"], self.common_args["block_number_end"] ) for name, spot in zip(spot_generator.names, spots): spot_list.append({"name": name, "custom": {"timestamps": spot[0], "path": spot[1]}}) factory_configs = {} for factory in self.factories.values(): factory_configs.update(factory.get_config()) environment = { "simulation_environment": {"tokens_info": tokens_info, "protocols_to_simulate": factory_configs} } # add the tx generator factory factory_configs.update({"sim_tx_generator": {}}) config = { "version": "1.0.0", "common": self.common_args, "spot": {"spot_list": spot_list}, **environment, "agents": [{"name": agent.name, "wallet": agent.wallet} for agent in self.agents], } # Save the configuration to a JSON file config_json = json.dumps(config, indent=2, cls=DecimalEncoder) if self.save_config: with open(self.save_config, "w") as f: f.write(config_json) builder = SimulatorBuilder.from_json(config_json) for factory in self.factories.values(): subfactories = factory.get_factories() for subfactory in subfactories: if not isinstance(subfactory, ProtocolFactoryAdapter): subfactory = ProtocolFactoryAdapter(subfactory) builder.add_factory(subfactory) # add the tx generator factory builder.add_factory(ProtocolFactoryAdapter(self.sim_tx_generator)) sim = builder.build() return sim