Skip to content

Commit

Permalink
Add api get_change_ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Sep 28, 2022
1 parent 7385cf8 commit 3650cfd
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 28 deletions.
4 changes: 2 additions & 2 deletions examples/research/top_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
# 每月涨幅前30,市值90%分布在100亿以下
# 重复上榜的有1/4左右
# 连续两个月上榜的1/10左右
def top_tags(data_provider="em", start_timestamp="2010-01-01"):
def top_tags(data_provider="em", start_timestamp="2020-01-01", end_timestamp="2021-01-01"):
records = []
for _, timestamp, df in get_top_performance_by_month(
start_timestamp=start_timestamp, list_days=250, data_provider=data_provider
start_timestamp=start_timestamp, end_timestamp=end_timestamp, list_days=250, data_provider=data_provider
):
for entity_id in df.index[:30]:
query_timestamp = timestamp
Expand Down
35 changes: 22 additions & 13 deletions src/zvt/api/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,20 +449,29 @@ def show_industry_composition(entity_ids, timestamp):
drawer.draw_pie(show=True)


def get_change_ratio(
entity_type="stock",
start_timestamp=None,
end_timestamp=None,
adjust_type: Union[AdjustType, str] = None,
provider="em",
):
def cal_ratio(df):
positive = df[df["direction"]]
other = df[~df["direction"]]
return len(positive) / len(other)

if not adjust_type:
adjust_type = default_adjust_type(entity_type=entity_type)
data_schema = get_kdata_schema(entity_type=entity_type, adjust_type=adjust_type)
kdata_df = data_schema.query_data(provider=provider, start_timestamp=start_timestamp, end_timestamp=end_timestamp)
kdata_df["direction"] = kdata_df["change_pct"] > 0
ratio_df = kdata_df.groupby("timestamp").apply(lambda df: cal_ratio(df))
return ratio_df


if __name__ == "__main__":
df = get_performance_stats_by_month()
print(df)
dfs = []
for timestamp, _, df in get_top_performance_by_month(start_timestamp="2012-01-01", list_days=250):
if pd_is_not_null(df):
entity_ids = df.index.tolist()
the_date = pre_month_end_date(timestamp)
show_industry_composition(entity_ids=entity_ids, timestamp=timestamp)
for entity_id in df.index:
from zvt.utils.time_utils import month_end_date, pre_month_start_date

end_date = month_end_date(pre_month_start_date(timestamp))
TechnicalFactor(entity_ids=[entity_id], end_timestamp=end_date).draw(show=True)
get_change_ratio(start_timestamp="2022-06-01")

# the __all__ is generated
__all__ = [
Expand Down
4 changes: 2 additions & 2 deletions src/zvt/contract/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ class AdjustType(Enum):
#: 不复权
bfq = "bfq"
#: pre adjusted
#: 不复权
#: 前复权
qfq = "qfq"
#: post adjusted
#: 不复权
#: 后复权
hfq = "hfq"


Expand Down
4 changes: 2 additions & 2 deletions src/zvt/contract/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def del_data(data_schema: Type[Mixin], filters: List = None, provider=None):
session.commit()


def get_one(data_schema, id: str, provider: str = None, session: Session = None):
def get_by_id(data_schema, id: str, provider: str = None, session: Session = None):
"""
get one record by id from data schema
Expand Down Expand Up @@ -661,7 +661,7 @@ def get_entity_ids(
"get_schema_columns",
"common_filter",
"del_data",
"get_one",
"get_by_id",
"get_data",
"data_exist",
"get_data_count",
Expand Down
4 changes: 2 additions & 2 deletions src/zvt/contract/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class OneStateService(StatefulService):

def __init__(self) -> None:
super().__init__()
self.state_domain = self.state_schema.get_one(id=self.name)
self.state_domain = self.state_schema.get_by_id(id=self.name)
if self.state_domain:
self.state: dict = self.decode_state(self.state_domain.state)
else:
Expand Down Expand Up @@ -106,7 +106,7 @@ def persist_state(self, entity_id):
state = self.states.get(entity_id)
if state:
domain_id = f"{self.name}_{entity_id}"
state_domain = self.state_schema.get_one(domain_id)
state_domain = self.state_schema.get_by_id(domain_id)
state_str = self.encode_state(state)
if not state_domain:
state_domain = self.state_schema(id=domain_id, entity_id=entity_id, state_name=self.name)
Expand Down
6 changes: 3 additions & 3 deletions src/zvt/contract/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def test_data_correctness(cls, provider, data_samples):
assert item[0][k] == data[k]

@classmethod
def get_one(cls, id, provider_index: int = 0, provider: str = None):
from .api import get_one
def get_by_id(cls, id, provider_index: int = 0, provider: str = None):
from .api import get_by_id

if not provider:
provider = cls.providers[provider_index]
return get_one(data_schema=cls, id=id, provider=provider)
return get_by_id(data_schema=cls, id=id, provider=provider)

@classmethod
def query_data(
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/tag/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_tag_timestamps(self):
def get_tag_domain(self, entity_id, timestamp, **fill_kv):
the_date = to_time_str(timestamp, fmt=TIME_FORMAT_DAY)
the_id = f"{entity_id}_{the_date}"
the_domain = self.data_schema.get_one(id=the_id)
the_domain = self.data_schema.get_by_id(id=the_id)

if the_domain:
for k, v in fill_kv.items():
Expand Down
3 changes: 0 additions & 3 deletions src/zvt/trader/trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,6 @@ def on_finish(self, timestmap):
self.on_trading_finish(timestmap)
# show the result
if self.draw_result:
import plotly.io as pio

pio.renderers.default = "browser"
reader = AccountStatsReader(trader_names=[self.trader_name])
df = reader.data_df
drawer = Drawer(
Expand Down

0 comments on commit 3650cfd

Please sign in to comment.