上一主题,“为K线图表添砖加瓦——让CTA策略的运行看得见”,可以说是失败了,原因已经找到,是因为掉到类变量和实例变量的坑里了。具体过程参加这个主题:
https://www.vnpy.com/forum/topic/3860-wei-kxian-tu-biao-tian-zhuan-jia-wa-rang-ctace-lue-de-yun-xing-kan-de-jian
1 解决了不同合约同时运行_kx_strategy策略时,K线图会互相影响的问题;
2 去掉策略管理器中的“K线图表”按钮,保持与原来的界面一致,在_kx_strategy策略中增加一个show_chart参数项目,如果想显示K线图,为它配置为True,否则不会显示K线图;
3 增加策略被移除时,删除该策略的K线图表功能
4 K线图表中的显示内容在_kx_strategy策略中配置,而不是一个固定的主图和附图搭配。参照我的init_kx_chart()方法,您也可以为自己的策略配置自己的K线主图和附图指标;
5 添加最后一根了临时K线的显示
vnpy\app\cta_strategy\base.py
vnpy\app\cta_strategy\engine.py
vnpy\app\cta_strategy\ui\widget.py
vnpy\app\cta_backtester\engine.py
"""
Defines constants and objects used in CtaStrategy App.
"""
from dataclasses import dataclass, field
from enum import Enum
from datetime import timedelta
from vnpy.trader.constant import Direction, Offset, Interval
APP_NAME = "CtaStrategy"
STOPORDER_PREFIX = "STOP"
class StopOrderStatus(Enum):
WAITING = "等待中"
CANCELLED = "已撤销"
TRIGGERED = "已触发"
class EngineType(Enum):
LIVE = "实盘"
BACKTESTING = "回测"
class BacktestingMode(Enum):
BAR = 1
TICK = 2
@dataclass
class StopOrder:
vt_symbol: str
direction: Direction
offset: Offset
price: float
volume: float
stop_orderid: str
strategy_name: str
lock: bool = False
vt_orderids: list = field(default_factory=list)
status: StopOrderStatus = StopOrderStatus.WAITING
EVENT_CTA_LOG = "eCtaLog"
EVENT_CTA_STRATEGY = "eCtaStrategy"
EVENT_CTA_STOPORDER = "eCtaStopOrder"
EVENT_CTA_TICK = "eCtaTick" # hxxjava add
EVENT_CTA_HISTORY_BAR = "eCtaHistoryBar" # hxxjava add
EVENT_CTA_BAR = "eCtaBar" # hxxjava add
EVENT_CTA_ORDER = "eCtaOrder" # hxxjava add
EVENT_CTA_TRADE = "eCtaTrade" # hxxjava add
INTERVAL_DELTA_MAP = {
Interval.MINUTE: timedelta(minutes=1),
Interval.HOUR: timedelta(hours=1),
Interval.DAILY: timedelta(days=1),
}
""""""
import importlib
import os
import traceback
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from tzlocal import get_localzone
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.object import (
OrderRequest,
SubscribeRequest,
HistoryRequest,
LogData,
TickData,
BarData,
ContractData
)
from vnpy.trader.event import (
EVENT_TICK,
EVENT_ORDER,
EVENT_TRADE,
EVENT_POSITION
)
from vnpy.trader.constant import (
Direction,
OrderType,
Interval,
Exchange,
Offset,
Status
)
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to
from vnpy.trader.database import database_manager
from vnpy.trader.rqdata import rqdata_client
from vnpy.trader.converter import OffsetConverter
from .base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_STRATEGY,
EVENT_CTA_STOPORDER,
EngineType,
StopOrder,
StopOrderStatus,
STOPORDER_PREFIX
)
from .template import CtaTemplate
STOP_STATUS_MAP = {
Status.SUBMITTING: StopOrderStatus.WAITING,
Status.NOTTRADED: StopOrderStatus.WAITING,
Status.PARTTRADED: StopOrderStatus.TRIGGERED,
Status.ALLTRADED: StopOrderStatus.TRIGGERED,
Status.CANCELLED: StopOrderStatus.CANCELLED,
Status.REJECTED: StopOrderStatus.CANCELLED
}
class CtaEngine(BaseEngine):
""""""
engine_type = EngineType.LIVE # live trading engine
setting_filename = "cta_strategy_setting.json"
data_filename = "cta_strategy_data.json"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super(CtaEngine, self).__init__(
main_engine, event_engine, APP_NAME)
self.strategy_setting = {} # strategy_name: dict
self.strategy_data = {} # strategy_name: dict
self.classes = {} # class_name: stategy_class
self.strategies = {} # strategy_name: strategy
self.symbol_strategy_map = defaultdict(
list) # vt_symbol: strategy list
self.orderid_strategy_map = {} # vt_orderid: strategy
self.strategy_orderid_map = defaultdict(
set) # strategy_name: orderid list
self.stop_order_count = 0 # for generating stop_orderid
self.stop_orders = {} # stop_orderid: stop_order
self.init_executor = ThreadPoolExecutor(max_workers=1)
self.rq_client = None
self.rq_symbols = set()
self.vt_tradeids = set() # for filtering duplicate trade
self.offset_converter = OffsetConverter(self.main_engine)
def init_engine(self):
"""
"""
self.init_rqdata()
self.load_strategy_class()
self.load_strategy_setting()
self.load_strategy_data()
self.register_event()
self.write_log("CTA策略引擎初始化成功")
def close(self):
""""""
self.stop_all_strategies()
def register_event(self):
""""""
self.event_engine.register(EVENT_TICK, self.process_tick_event)
self.event_engine.register(EVENT_ORDER, self.process_order_event)
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
self.event_engine.register(EVENT_POSITION, self.process_position_event)
def init_rqdata(self):
"""
Init RQData client.
"""
result = rqdata_client.init()
if result:
self.write_log("RQData数据接口初始化成功")
def query_bar_from_rq(
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
):
"""
Query bar data from RQData.
"""
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end
)
data = rqdata_client.query_history(req)
return data
def process_tick_event(self, event: Event):
""""""
tick = event.data
strategies = self.symbol_strategy_map[tick.vt_symbol]
if not strategies:
return
self.check_stop_order(tick)
for strategy in strategies:
if strategy.inited:
self.call_strategy_func(strategy, strategy.on_tick, tick)
def process_order_event(self, event: Event):
""""""
order = event.data
self.offset_converter.update_order(order)
strategy = self.orderid_strategy_map.get(order.vt_orderid, None)
if not strategy:
return
# Remove vt_orderid if order is no longer active.
vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
if order.vt_orderid in vt_orderids and not order.is_active():
vt_orderids.remove(order.vt_orderid)
# For server stop order, call strategy on_stop_order function
if order.type == OrderType.STOP:
so = StopOrder(
vt_symbol=order.vt_symbol,
direction=order.direction,
offset=order.offset,
price=order.price,
volume=order.volume,
stop_orderid=order.vt_orderid,
strategy_name=strategy.strategy_name,
status=STOP_STATUS_MAP[order.status],
vt_orderids=[order.vt_orderid],
)
self.call_strategy_func(strategy, strategy.on_stop_order, so)
# Call strategy on_order function
self.call_strategy_func(strategy, strategy.on_order, order)
def process_trade_event(self, event: Event):
""""""
trade = event.data
# Filter duplicate trade push
if trade.vt_tradeid in self.vt_tradeids:
return
self.vt_tradeids.add(trade.vt_tradeid)
self.offset_converter.update_trade(trade)
strategy = self.orderid_strategy_map.get(trade.vt_orderid, None)
if not strategy:
return
# Update strategy pos before calling on_trade method
if trade.direction == Direction.LONG:
strategy.pos += trade.volume
else:
strategy.pos -= trade.volume
self.call_strategy_func(strategy, strategy.on_trade, trade)
# Sync strategy variables to data file
self.sync_strategy_data(strategy)
# Update GUI
self.put_strategy_event(strategy)
def process_position_event(self, event: Event):
""""""
position = event.data
self.offset_converter.update_position(position)
def check_stop_order(self, tick: TickData):
""""""
for stop_order in list(self.stop_orders.values()):
if stop_order.vt_symbol != tick.vt_symbol:
continue
long_triggered = (
stop_order.direction == Direction.LONG and tick.last_price >= stop_order.price
)
short_triggered = (
stop_order.direction == Direction.SHORT and tick.last_price <= stop_order.price
)
if long_triggered or short_triggered:
strategy = self.strategies[stop_order.strategy_name]
# To get excuted immediately after stop order is
# triggered, use limit price if available, otherwise
# use ask_price_5 or bid_price_5
if stop_order.direction == Direction.LONG:
if tick.limit_up:
price = tick.limit_up
else:
price = tick.ask_price_5
else:
if tick.limit_down:
price = tick.limit_down
else:
price = tick.bid_price_5
contract = self.main_engine.get_contract(stop_order.vt_symbol)
vt_orderids = self.send_limit_order(
strategy,
contract,
stop_order.direction,
stop_order.offset,
price,
stop_order.volume,
stop_order.lock
)
# Update stop order status if placed successfully
if vt_orderids:
# Remove from relation map.
self.stop_orders.pop(stop_order.stop_orderid)
strategy_vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
if stop_order.stop_orderid in strategy_vt_orderids:
strategy_vt_orderids.remove(stop_order.stop_orderid)
# Change stop order status to cancelled and update to strategy.
stop_order.status = StopOrderStatus.TRIGGERED
stop_order.vt_orderids = vt_orderids
self.call_strategy_func(
strategy, strategy.on_stop_order, stop_order
)
self.put_stop_order_event(stop_order)
def send_server_order(
self,
strategy: CtaTemplate,
contract: ContractData,
direction: Direction,
offset: Offset,
price: float,
volume: float,
type: OrderType,
lock: bool
):
"""
Send a new order to server.
"""
# Create request and send order.
original_req = OrderRequest(
symbol=contract.symbol,
exchange=contract.exchange,
direction=direction,
offset=offset,
type=type,
price=price,
volume=volume,
)
# Convert with offset converter
req_list = self.offset_converter.convert_order_request(original_req, lock)
# Send Orders
vt_orderids = []
for req in req_list:
req.reference = strategy.strategy_name # Add strategy name as order reference
vt_orderid = self.main_engine.send_order(
req, contract.gateway_name)
# Check if sending order successful
if not vt_orderid:
continue
vt_orderids.append(vt_orderid)
self.offset_converter.update_order_request(req, vt_orderid)
# Save relationship between orderid and strategy.
self.orderid_strategy_map[vt_orderid] = strategy
self.strategy_orderid_map[strategy.strategy_name].add(vt_orderid)
return vt_orderids
def send_limit_order(
self,
strategy: CtaTemplate,
contract: ContractData,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool
):
"""
Send a limit order to server.
"""
return self.send_server_order(
strategy,
contract,
direction,
offset,
price,
volume,
OrderType.LIMIT,
lock
)
def send_server_stop_order(
self,
strategy: CtaTemplate,
contract: ContractData,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool
):
"""
Send a stop order to server.
Should only be used if stop order supported
on the trading server.
"""
return self.send_server_order(
strategy,
contract,
direction,
offset,
price,
volume,
OrderType.STOP,
lock
)
def send_local_stop_order(
self,
strategy: CtaTemplate,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool
):
"""
Create a new local stop order.
"""
self.stop_order_count += 1
stop_orderid = f"{STOPORDER_PREFIX}.{self.stop_order_count}"
stop_order = StopOrder(
vt_symbol=strategy.vt_symbol,
direction=direction,
offset=offset,
price=price,
volume=volume,
stop_orderid=stop_orderid,
strategy_name=strategy.strategy_name,
lock=lock
)
self.stop_orders[stop_orderid] = stop_order
vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
vt_orderids.add(stop_orderid)
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
self.put_stop_order_event(stop_order)
return [stop_orderid]
def cancel_server_order(self, strategy: CtaTemplate, vt_orderid: str):
"""
Cancel existing order by vt_orderid.
"""
order = self.main_engine.get_order(vt_orderid)
if not order:
self.write_log(f"撤单失败,找不到委托{vt_orderid}", strategy)
return
req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name)
def cancel_local_stop_order(self, strategy: CtaTemplate, stop_orderid: str):
"""
Cancel a local stop order.
"""
stop_order = self.stop_orders.get(stop_orderid, None)
if not stop_order:
return
strategy = self.strategies[stop_order.strategy_name]
# Remove from relation map.
self.stop_orders.pop(stop_orderid)
vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
if stop_orderid in vt_orderids:
vt_orderids.remove(stop_orderid)
# Change stop order status to cancelled and update to strategy.
stop_order.status = StopOrderStatus.CANCELLED
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
self.put_stop_order_event(stop_order)
def send_order(
self,
strategy: CtaTemplate,
direction: Direction,
offset: Offset,
price: float,
volume: float,
stop: bool,
lock: bool
):
"""
"""
contract = self.main_engine.get_contract(strategy.vt_symbol)
if not contract:
self.write_log(f"委托失败,找不到合约:{strategy.vt_symbol}", strategy)
return ""
# Round order price and volume to nearest incremental value
price = round_to(price, contract.pricetick)
volume = round_to(volume, contract.min_volume)
if stop:
if contract.stop_supported:
return self.send_server_stop_order(strategy, contract, direction, offset, price, volume, lock)
else:
return self.send_local_stop_order(strategy, direction, offset, price, volume, lock)
else:
return self.send_limit_order(strategy, contract, direction, offset, price, volume, lock)
def cancel_order(self, strategy: CtaTemplate, vt_orderid: str):
"""
"""
if vt_orderid.startswith(STOPORDER_PREFIX):
self.cancel_local_stop_order(strategy, vt_orderid)
else:
self.cancel_server_order(strategy, vt_orderid)
def cancel_all(self, strategy: CtaTemplate):
"""
Cancel all active orders of a strategy.
"""
vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
if not vt_orderids:
return
for vt_orderid in copy(vt_orderids):
self.cancel_order(strategy, vt_orderid)
def get_engine_type(self):
""""""
return self.engine_type
def get_pricetick(self, strategy: CtaTemplate):
"""
Return contract pricetick data.
"""
contract = self.main_engine.get_contract(strategy.vt_symbol)
if contract:
return contract.pricetick
else:
return None
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable[[BarData], None],
use_database: bool
):
""""""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now(get_localzone())
start = end - timedelta(days)
bars = []
# Pass gateway and RQData if use_database set to True
if not use_database:
# Query bars from gateway if available
contract = self.main_engine.get_contract(vt_symbol)
if contract and contract.history_data:
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end
)
bars = self.main_engine.query_history(req, contract.gateway_name)
# Try to query bars from RQData, if not found, load from database.
else:
bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
if not bars:
bars = database_manager.load_bar_data(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end,
)
for bar in bars:
callback(bar)
def load_tick(
self,
vt_symbol: str,
days: int,
callback: Callable[[TickData], None]
):
""""""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now()
start = end - timedelta(days)
ticks = database_manager.load_tick_data(
symbol=symbol,
exchange=exchange,
start=start,
end=end,
)
for tick in ticks:
callback(tick)
def call_strategy_func(
self, strategy: CtaTemplate, func: Callable, params: Any = None
):
"""
Call function of a strategy and catch any exception raised.
"""
try:
if params:
func(params)
else:
func()
except Exception:
strategy.trading = False
strategy.inited = False
msg = f"触发异常已停止\n{traceback.format_exc()}"
self.write_log(msg, strategy)
def add_strategy(
self, class_name: str, strategy_name: str, vt_symbol: str, setting: dict
):
"""
Add a new strategy.
"""
if strategy_name in self.strategies:
self.write_log(f"创建策略失败,存在重名{strategy_name}")
return
strategy_class = self.classes.get(class_name, None)
if not strategy_class:
self.write_log(f"创建策略失败,找不到策略类{class_name}")
return
strategy = strategy_class(self, strategy_name, vt_symbol, setting)
self.strategies[strategy_name] = strategy
# Add vt_symbol to strategy map.
strategies = self.symbol_strategy_map[vt_symbol]
strategies.append(strategy)
# Update to setting file.
self.update_strategy_setting(strategy_name, setting)
self.put_strategy_event(strategy)
def init_strategy(self, strategy_name: str):
"""
Init a strategy.
"""
self.init_executor.submit(self._init_strategy, strategy_name)
def _init_strategy(self, strategy_name: str):
"""
Init strategies in queue.
"""
strategy = self.strategies[strategy_name]
if strategy.inited:
self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作")
return
self.write_log(f"{strategy_name}开始执行初始化")
# Call on_init function of strategy
self.call_strategy_func(strategy, strategy.on_init)
# Restore strategy data(variables)
data = self.strategy_data.get(strategy_name, None)
if data:
for name in strategy.variables:
value = data.get(name, None)
if value:
setattr(strategy, name, value)
# Subscribe market data
contract = self.main_engine.get_contract(strategy.vt_symbol)
if contract:
req = SubscribeRequest(
symbol=contract.symbol, exchange=contract.exchange)
self.main_engine.subscribe(req, contract.gateway_name)
else:
self.write_log(f"行情订阅失败,找不到合约{strategy.vt_symbol}", strategy)
# Put event to update init completed status.
strategy.inited = True
self.put_strategy_event(strategy)
self.write_log(f"{strategy_name}初始化完成")
def start_strategy(self, strategy_name: str):
"""
Start a strategy.
"""
strategy = self.strategies[strategy_name]
if not strategy.inited:
self.write_log(f"策略{strategy.strategy_name}启动失败,请先初始化")
return
if strategy.trading:
self.write_log(f"{strategy_name}已经启动,请勿重复操作")
return
self.call_strategy_func(strategy, strategy.on_start)
strategy.trading = True
self.put_strategy_event(strategy)
def stop_strategy(self, strategy_name: str):
"""
Stop a strategy.
"""
strategy = self.strategies[strategy_name]
if not strategy.trading:
return
# Call on_stop function of the strategy
self.call_strategy_func(strategy, strategy.on_stop)
# Change trading status of strategy to False
strategy.trading = False
# Cancel all orders of the strategy
self.cancel_all(strategy)
# Sync strategy variables to data file
self.sync_strategy_data(strategy)
# Update GUI
self.put_strategy_event(strategy)
def edit_strategy(self, strategy_name: str, setting: dict):
"""
Edit parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
strategy.update_setting(setting)
self.update_strategy_setting(strategy_name, setting)
self.put_strategy_event(strategy)
def remove_strategy(self, strategy_name: str):
"""
Remove a strategy.
"""
strategy = self.strategies[strategy_name]
if strategy.trading:
self.write_log(f"策略{strategy.strategy_name}移除失败,请先停止")
return
# Remove setting
self.remove_strategy_setting(strategy_name)
# Remove from symbol strategy map
strategies = self.symbol_strategy_map[strategy.vt_symbol]
strategies.remove(strategy)
# Remove from active orderid map
if strategy_name in self.strategy_orderid_map:
vt_orderids = self.strategy_orderid_map.pop(strategy_name)
# Remove vt_orderid strategy map
for vt_orderid in vt_orderids:
if vt_orderid in self.orderid_strategy_map:
self.orderid_strategy_map.pop(vt_orderid)
# Remove from strategies
self.strategies.pop(strategy_name)
return True
def load_strategy_class(self):
"""
Load strategy class from source code.
"""
path1 = Path(__file__).parent.joinpath("strategies")
self.load_strategy_class_from_folder(
path1, "vnpy.app.cta_strategy.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(str(path)):
for filename in filenames:
if filename.split(".")[-1] in ("py", "pyd", "so"):
strategy_module_name = ".".join([module_name, filename.split(".")[0]])
self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
Load strategy class from module file.
"""
try:
module = importlib.import_module(module_name)
# print(f"{module_name}'s module:{module}") # hxxjava add
for name in dir(module):
# print(f"name:{name}") # hxxjava add
value = getattr(module, name)
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
self.classes[value.__name__] = value
# print(f"value.__name__:{value.__name__}") # hxxjava add
except: # noqa
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def load_strategy_data(self):
"""
Load strategy data from json file.
"""
self.strategy_data = load_json(self.data_filename)
def sync_strategy_data(self, strategy: CtaTemplate):
"""
Sync strategy data into json file.
"""
data = strategy.get_variables()
data.pop("inited") # Strategy status (inited, trading) should not be synced.
data.pop("trading")
self.strategy_data[strategy.strategy_name] = data
save_json(self.data_filename, self.strategy_data)
def get_all_strategy_class_names(self):
"""
Return names of strategy classes loaded.
"""
return list(self.classes.keys())
def get_strategy_class_parameters(self, class_name: str):
"""
Get default parameters of a strategy class.
"""
strategy_class = self.classes[class_name]
parameters = {}
for name in strategy_class.parameters:
parameters[name] = getattr(strategy_class, name)
return parameters
def get_strategy_parameters(self, strategy_name):
"""
Get parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
return strategy.get_parameters()
def init_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.init_strategy(strategy_name)
def start_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.start_strategy(strategy_name)
def stop_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.stop_strategy(strategy_name)
def load_strategy_setting(self):
"""
Load setting file.
"""
self.strategy_setting = load_json(self.setting_filename)
for strategy_name, strategy_config in self.strategy_setting.items():
self.add_strategy(
strategy_config["class_name"],
strategy_name,
strategy_config["vt_symbol"],
strategy_config["setting"]
)
def update_strategy_setting(self, strategy_name: str, setting: dict):
"""
Update setting file.
"""
strategy = self.strategies[strategy_name]
self.strategy_setting[strategy_name] = {
"class_name": strategy.__class__.__name__,
"vt_symbol": strategy.vt_symbol,
"setting": setting,
}
save_json(self.setting_filename, self.strategy_setting)
def remove_strategy_setting(self, strategy_name: str):
"""
Update setting file.
"""
if strategy_name not in self.strategy_setting:
return
self.strategy_setting.pop(strategy_name)
save_json(self.setting_filename, self.strategy_setting)
def put_stop_order_event(self, stop_order: StopOrder):
"""
Put an event to update stop order status.
"""
event = Event(EVENT_CTA_STOPORDER, stop_order)
self.event_engine.put(event)
def put_strategy_event(self, strategy: CtaTemplate):
"""
Put an event to update strategy status.
"""
data = strategy.get_data()
event = Event(EVENT_CTA_STRATEGY, data)
self.event_engine.put(event)
#--------------------------------------------------------------------------------------------------
def get_position_detail(self, vt_symbol:str):
"""
查询long_pos,short_pos(持仓),long_pnl,short_pnl(盈亏),active_order(未成交字典)
收到PositionHolding类数据
"""
try:
return self.offset_converter.get_position_holding(vt_symbol)
except:
self.write_log(f"当前获取持仓信息为:{self.offset_converter.get_position_holding(vt_symbol)},等待获取持仓信息")
position_detail = OrderedDict()
position_detail.active_orders = {}
position_detail.long_pos = 0
position_detail.long_pnl = 0
position_detail.long_yd = 0
position_detail.long_td = 0
position_detail.long_pos_frozen = 0
position_detail.long_price = 0
position_detail.short_pos = 0
position_detail.short_pnl = 0
position_detail.short_yd = 0
position_detail.short_td = 0
position_detail.short_price = 0
position_detail.short_pos_frozen = 0
return position_detail
def write_log(self, msg: str, strategy: CtaTemplate = None):
"""
Create cta engine log event.
"""
if strategy:
msg = f"{strategy.strategy_name}: {msg}"
log = LogData(msg=msg, gateway_name="CtaStrategy")
event = Event(type=EVENT_CTA_LOG, data=log)
self.event_engine.put(event)
def send_email(self, msg: str, strategy: CtaTemplate = None):
"""
Send email to default receiver.
"""
if strategy:
subject = f"{strategy.strategy_name}"
else:
subject = "CTA策略引擎"
self.main_engine.send_email(subject, msg)
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
from vnpy.trader.ui.widget import (
BaseCell,
EnumCell,
MsgCell,
TimeCell,
BaseMonitor
)
from ..base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY,
)
from ..engine import CtaEngine
from vnpy.usertools.kx_chart import NewChartWidget # hxxjava add
class CtaManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_strategy = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
super(CtaManager, self).__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.cta_engine = main_engine.get_engine(APP_NAME)
self.managers = {}
self.init_ui()
self.register_event()
self.cta_engine.init_engine()
self.update_class_combo()
def init_ui(self):
""""""
self.setWindowTitle("CTA策略")
# Create widgets
self.class_combo = QtWidgets.QComboBox()
add_button = QtWidgets.QPushButton("添加策略")
add_button.clicked.connect(self.add_strategy)
init_button = QtWidgets.QPushButton("全部初始化")
init_button.clicked.connect(self.cta_engine.init_all_strategies)
start_button = QtWidgets.QPushButton("全部启动")
start_button.clicked.connect(self.cta_engine.start_all_strategies)
stop_button = QtWidgets.QPushButton("全部停止")
stop_button.clicked.connect(self.cta_engine.stop_all_strategies)
clear_button = QtWidgets.QPushButton("清空日志")
clear_button.clicked.connect(self.clear_log)
self.scroll_layout = QtWidgets.QVBoxLayout()
self.scroll_layout.addStretch()
scroll_widget = QtWidgets.QWidget()
scroll_widget.setLayout(self.scroll_layout)
scroll_area = QtWidgets.QScrollArea()
scroll_area.setWidgetResizable(True)
scroll_area.setWidget(scroll_widget)
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
self.stop_order_monitor = StopOrderMonitor(
self.main_engine, self.event_engine
)
# Set layout
hbox1 = QtWidgets.QHBoxLayout()
hbox1.addWidget(self.class_combo)
hbox1.addWidget(add_button)
hbox1.addStretch()
hbox1.addWidget(init_button)
hbox1.addWidget(start_button)
hbox1.addWidget(stop_button)
hbox1.addWidget(clear_button)
grid = QtWidgets.QGridLayout()
grid.addWidget(scroll_area, 0, 0, 2, 1)
grid.addWidget(self.stop_order_monitor, 0, 1)
grid.addWidget(self.log_monitor, 1, 1)
vbox = QtWidgets.QVBoxLayout()
vbox.addLayout(hbox1)
vbox.addLayout(grid)
self.setLayout(vbox)
def update_class_combo(self):
""""""
self.class_combo.addItems(
self.cta_engine.get_all_strategy_class_names()
)
def register_event(self):
""""""
self.signal_strategy.connect(self.process_strategy_event)
self.event_engine.register(
EVENT_CTA_STRATEGY, self.signal_strategy.emit
)
def process_strategy_event(self, event):
"""
Update strategy status onto its monitor.
"""
data = event.data
strategy_name = data["strategy_name"]
if strategy_name in self.managers:
manager = self.managers[strategy_name]
manager.update_data(data)
else:
manager = StrategyManager(self, self.cta_engine, data)
self.scroll_layout.insertWidget(0, manager)
self.managers[strategy_name] = manager
def remove_strategy(self, strategy_name):
""""""
manager = self.managers.pop(strategy_name)
manager.deleteLater()
def add_strategy(self):
""""""
class_name = str(self.class_combo.currentText())
if not class_name:
return
parameters = self.cta_engine.get_strategy_class_parameters(class_name)
editor = SettingEditor(parameters, class_name=class_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
vt_symbol = setting.pop("vt_symbol")
strategy_name = setting.pop("strategy_name")
self.cta_engine.add_strategy(
class_name, strategy_name, vt_symbol, setting
)
def clear_log(self):
""""""
self.log_monitor.setRowCount(0)
def show(self):
""""""
self.showMaximized()
class StrategyManager(QtWidgets.QFrame):
"""
Manager for a strategy
"""
def __init__(
self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict
):
""""""
super(StrategyManager, self).__init__()
self.cta_manager = cta_manager
self.cta_engine = cta_engine
self.strategy_name = data["strategy_name"]
self._data = data
self.init_ui()
def init_ui(self):
""""""
self.setFixedHeight(300)
self.setFrameShape(self.Box)
self.setLineWidth(1)
self.init_button = QtWidgets.QPushButton("初始化")
self.init_button.clicked.connect(self.init_strategy)
self.start_button = QtWidgets.QPushButton("启动")
self.start_button.clicked.connect(self.start_strategy)
self.start_button.setEnabled(False)
self.stop_button = QtWidgets.QPushButton("停止")
self.stop_button.clicked.connect(self.stop_strategy)
self.stop_button.setEnabled(False)
self.edit_button = QtWidgets.QPushButton("编辑")
self.edit_button.clicked.connect(self.edit_strategy)
self.remove_button = QtWidgets.QPushButton("移除")
self.remove_button.clicked.connect(self.remove_strategy)
strategy_name = self._data["strategy_name"]
vt_symbol = self._data["vt_symbol"]
class_name = self._data["class_name"]
author = self._data["author"]
label_text = (
f"{strategy_name} - {vt_symbol} ({class_name} by {author})"
)
label = QtWidgets.QLabel(label_text)
label.setAlignment(QtCore.Qt.AlignCenter)
self.parameters_monitor = DataMonitor(self._data["parameters"])
self.variables_monitor = DataMonitor(self._data["variables"])
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(self.init_button)
hbox.addWidget(self.start_button)
hbox.addWidget(self.stop_button)
hbox.addWidget(self.edit_button)
hbox.addWidget(self.remove_button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(label)
vbox.addLayout(hbox)
vbox.addWidget(self.parameters_monitor)
vbox.addWidget(self.variables_monitor)
self.setLayout(vbox)
def update_data(self, data: dict):
""""""
self._data = data
self.parameters_monitor.update_data(data["parameters"])
self.variables_monitor.update_data(data["variables"])
# Update button status
variables = data["variables"]
inited = variables["inited"]
trading = variables["trading"]
if not inited:
return
self.init_button.setEnabled(False)
if trading:
self.start_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.edit_button.setEnabled(False)
self.remove_button.setEnabled(False)
else:
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
self.edit_button.setEnabled(True)
self.remove_button.setEnabled(True)
def init_strategy(self):
""""""
self.open_kx_chart() # hxxjava add
self.cta_engine.init_strategy(self.strategy_name)
def start_strategy(self):
""""""
self.cta_engine.start_strategy(self.strategy_name)
def stop_strategy(self):
""""""
self.cta_engine.stop_strategy(self.strategy_name)
def edit_strategy(self):
""""""
strategy_name = self._data["strategy_name"]
parameters = self.cta_engine.get_strategy_parameters(strategy_name)
editor = SettingEditor(parameters, strategy_name=strategy_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
self.cta_engine.edit_strategy(strategy_name, setting)
def remove_strategy(self):
""""""
result = self.cta_engine.remove_strategy(self.strategy_name)
# Only remove strategy gui manager if it has been removed from engine
if result:
self.cta_manager.remove_strategy(self.strategy_name)
if self.kx_chart: # hxxjava add
self.kx_chart.close()
self.kx_chart = None
def open_kx_chart(self): # hxxjava add
strategy = self.cta_engine.strategies[self.strategy_name]
setting = self.cta_engine.strategy_setting[self.strategy_name]['setting']
show_chart = setting.get("show_chart",None)
self.kx_chart = None
if show_chart:
event_engine = self.cta_engine.event_engine
kx_interval = setting.get("kx_interval",None)
self.kx_chart = NewChartWidget(event_engine = event_engine,strategy_name = self.strategy_name)
self.kx_chart.setWindowTitle(f"K线图表:{self.strategy_name},周期:{kx_interval}")
strategy.init_kx_chart(self.kx_chart)
self.kx_chart.register_event() # 注册消息
self.kx_chart.show() # 显示K线图
class DataMonitor(QtWidgets.QTableWidget):
"""
Table monitor for parameters and variables.
"""
def __init__(self, data: dict):
""""""
super(DataMonitor, self).__init__()
self._data = data
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
labels = list(self._data.keys())
self.setColumnCount(len(labels))
self.setHorizontalHeaderLabels(labels)
self.setRowCount(1)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
for column, name in enumerate(self._data.keys()):
value = self._data[name]
cell = QtWidgets.QTableWidgetItem(str(value))
cell.setTextAlignment(QtCore.Qt.AlignCenter)
self.setItem(0, column, cell)
self.cells[name] = cell
def update_data(self, data: dict):
""""""
for name, value in data.items():
cell = self.cells[name]
cell.setText(str(value))
class StopOrderMonitor(BaseMonitor):
"""
Monitor for local stop order.
"""
event_type = EVENT_CTA_STOPORDER
data_key = "stop_orderid"
sorting = True
headers = {
"stop_orderid": {"display": "停止委托号","cell": BaseCell,"update": False,},
"vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True},
"vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
"direction": {"display": "方向", "cell": EnumCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"status": {"display": "状态", "cell": EnumCell, "update": True},
"lock": {"display": "锁仓", "cell": BaseCell, "update": False},
"strategy_name": {"display": "策略名", "cell": BaseCell, "update": False},
}
class LogMonitor(BaseMonitor):
"""
Monitor for log data.
"""
event_type = EVENT_CTA_LOG
data_key = ""
sorting = False
headers = {
"time": {"display": "时间", "cell": TimeCell, "update": False},
"msg": {"display": "信息", "cell": MsgCell, "update": False},
}
def init_ui(self):
"""
Stretch last column.
"""
super(LogMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
def insert_new_row(self, data):
"""
Insert a new row at the top of table.
"""
super(LogMonitor, self).insert_new_row(data)
self.resizeRowToContents(0)
class SettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, parameters: dict, strategy_name: str = "", class_name: str = ""
):
""""""
super(SettingEditor, self).__init__()
self.parameters = parameters
self.strategy_name = strategy_name
self.class_name = class_name
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
if self.class_name:
self.setWindowTitle(f"添加策略:{self.class_name}")
button_text = "添加"
parameters = {"strategy_name": "", "vt_symbol": ""}
parameters.update(self.parameters)
else:
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
widget = QtWidgets.QWidget()
widget.setLayout(form)
scroll = QtWidgets.QScrollArea()
scroll.setWidgetResizable(True)
scroll.setWidget(widget)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(scroll)
self.setLayout(vbox)
def get_setting(self):
""""""
setting = {}
if self.class_name:
setting["class_name"] = self.class_name
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting
import os
import importlib
import traceback
from datetime import datetime
from threading import Thread
from pathlib import Path
from inspect import getfile
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.constant import Interval
from vnpy.trader.utility import extract_vt_symbol
from vnpy.trader.object import HistoryRequest
from vnpy.trader.rqdata import rqdata_client
from vnpy.trader.database import database_manager
from vnpy.app.cta_strategy import CtaTemplate
from vnpy.app.cta_strategy.backtesting import BacktestingEngine, OptimizationSetting
APP_NAME = "CtaBacktester"
EVENT_BACKTESTER_LOG = "eBacktesterLog"
EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished"
EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished"
from vnpy.app.cta_strategy.base import EngineType # hxxjava add
class BacktesterEngine(BaseEngine):
"""
For running CTA strategy backtesting.
"""
engine_type = EngineType.BACKTESTING # hxxjava add --- 供策略回测时使用
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.classes = {}
self.backtesting_engine = None
self.thread = None
# Backtesting reuslt
self.result_df = None
self.result_statistics = None
# Optimization result
self.result_values = None
def init_engine(self):
""""""
self.write_log("初始化CTA回测引擎")
self.backtesting_engine = BacktestingEngine()
# Redirect log from backtesting engine outside.
self.backtesting_engine.output = self.write_log
self.load_strategy_class()
self.write_log("策略文件加载完成")
self.init_rqdata()
def init_rqdata(self):
"""
Init RQData client.
"""
result = rqdata_client.init()
if result:
self.write_log("RQData数据接口初始化成功")
def write_log(self, msg: str):
""""""
event = Event(EVENT_BACKTESTER_LOG)
event.data = msg
self.event_engine.put(event)
def load_strategy_class(self):
"""
Load strategy class from source code.
"""
app_path = Path(__file__).parent.parent
path1 = app_path.joinpath("cta_strategy", "strategies")
self.load_strategy_class_from_folder(
path1, "vnpy.app.cta_strategy.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
# Load python source code file
if filename.endswith(".py"):
strategy_module_name = ".".join(
[module_name, filename.replace(".py", "")])
self.load_strategy_class_from_module(strategy_module_name)
# Load compiled pyd binary file
elif filename.endswith(".pyd"):
strategy_module_name = ".".join(
[module_name, filename.split(".")[0]])
self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
Load strategy class from module file.
"""
try:
module = importlib.import_module(module_name)
importlib.reload(module)
for name in dir(module):
value = getattr(module, name)
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
self.classes[value.__name__] = value
except: # noqa
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def reload_strategy_class(self):
""""""
self.classes.clear()
self.load_strategy_class()
self.write_log("策略文件重载刷新完成")
def get_strategy_class_names(self):
""""""
return list(self.classes.keys())
def run_backtesting(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
inverse: bool,
setting: dict
):
""""""
self.result_df = None
self.result_statistics = None
engine = self.backtesting_engine
engine.clear_data()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
end=end,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital,
inverse=inverse
)
strategy_class = self.classes[class_name]
engine.add_strategy(
strategy_class,
setting
)
engine.load_data()
engine.run_backtesting()
self.result_df = engine.calculate_result()
self.result_statistics = engine.calculate_statistics(output=False)
# Clear thread object handler.
self.thread = None
# Put backtesting done event
event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED)
self.event_engine.put(event)
def start_backtesting(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
inverse: bool,
setting: dict
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_backtesting,
args=(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
inverse,
setting
)
)
self.thread.start()
return True
def get_result_df(self):
""""""
return self.result_df
def get_result_statistics(self):
""""""
return self.result_statistics
def get_result_values(self):
""""""
return self.result_values
def get_default_setting(self, class_name: str):
""""""
strategy_class = self.classes[class_name]
return strategy_class.get_class_parameters()
def run_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
inverse: bool,
optimization_setting: OptimizationSetting,
use_ga: bool
):
""""""
if use_ga:
self.write_log("开始遗传算法参数优化")
else:
self.write_log("开始多进程参数优化")
self.result_values = None
engine = self.backtesting_engine
engine.clear_data()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
end=end,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital,
inverse=inverse
)
strategy_class = self.classes[class_name]
engine.add_strategy(
strategy_class,
{}
)
if use_ga:
self.result_values = engine.run_ga_optimization(
optimization_setting,
output=False
)
else:
self.result_values = engine.run_optimization(
optimization_setting,
output=False
)
# Clear thread object handler.
self.thread = None
self.write_log("多进程参数优化完成")
# Put optimization done event
event = Event(EVENT_BACKTESTER_OPTIMIZATION_FINISHED)
self.event_engine.put(event)
def start_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
inverse: bool,
optimization_setting: OptimizationSetting,
use_ga: bool
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_optimization,
args=(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
inverse,
optimization_setting,
use_ga
)
)
self.thread.start()
return True
def run_downloading(
self,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
"""
Query bar data from RQData.
"""
self.write_log(f"{vt_symbol}-{interval}开始下载历史数据")
try:
symbol, exchange = extract_vt_symbol(vt_symbol)
except ValueError:
self.write_log(f"{vt_symbol}解析失败,请检查交易所后缀")
self.thread = None
return
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=Interval(interval),
start=start,
end=end
)
contract = self.main_engine.get_contract(vt_symbol)
try:
# If history data provided in gateway, then query
if contract and contract.history_data:
data = self.main_engine.query_history(
req, contract.gateway_name
)
# Otherwise use RQData to query data
else:
data = rqdata_client.query_history(req)
if data:
database_manager.save_bar_data(data)
self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
else:
self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")
except Exception:
msg = f"数据下载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
# Clear thread object handler.
self.thread = None
def start_downloading(
self,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_downloading,
args=(
vt_symbol,
interval,
start,
end
)
)
self.thread.start()
return True
def get_all_trades(self):
""""""
return self.backtesting_engine.get_all_trades()
def get_all_orders(self):
""""""
return self.backtesting_engine.get_all_orders()
def get_all_daily_results(self):
""""""
return self.backtesting_engine.get_all_daily_results()
def get_history_data(self):
""""""
return self.backtesting_engine.history_data
def get_strategy_class_file(self, class_name: str):
""""""
strategy_class = self.classes[class_name]
file_path = getfile(strategy_class)
return file_path
from typing import Any,List,Dict,Tuple
import copy
from vnpy.app.cta_strategy import (
CtaTemplate,
BarGenerator,
ArrayManager,
StopOrder,
Direction
)
from vnpy.trader.engine import MainEngine,EventEngine
from vnpy.app.cta_strategy.engine import CtaEngine
from vnpy.event.engine import Event
from vnpy.trader.object import (
LogData,
TickData,
BarData,
TradeData,
OrderData,
)
from vnpy.app.cta_strategy import StopOrder
from vnpy.app.cta_strategy.base import EngineType
from vnpy.trader.constant import Interval
from vnpy.app.cta_strategy.base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_TICK,
EVENT_CTA_HISTORY_BAR,
EVENT_CTA_BAR,
EVENT_CTA_ORDER,
EVENT_CTA_TRADE,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY,
)
from vnpy.usertools.kx_chart import ( # hxxjava add
NewChartWidget,
CandleItem,
VolumeItem,
LineItem,
SmaItem,
RsiItem,
MacdItem,
)
from vnpy.usertools.kx_chart import NewChartWidget # hxxjava add
class _kx_strategy(CtaTemplate):
""""""
author = "hxxjava"
kx_interval = 1
show_chart = False # 显示K线图表
parameters = [
"kx_interval",
"show_chart"
]
kx_count:int = 0
cta_manager = None
variables = ["kx_count"]
def __init__(
self,
cta_engine: Any,
strategy_name: str,
vt_symbol: str,
setting: dict,
):
super().__init__(cta_engine,strategy_name,vt_symbol,setting)
self.bg = BarGenerator(self.on_bar,self.kx_interval,self.on_Nmin_bar)
self.am = ArrayManager()
cta_engine:CtaEngine = self.cta_engine
self.engine_type = cta_engine.engine_type
self.even_engine = cta_engine.main_engine.event_engine
# 必须在这里声明,因为它们是实例变量
self.all_bars:List[BarData] = []
self.current_tick:[TickData] = None
self.current_bar:[BarData] = None
self.last_tick:[TickData] = None
def on_init(self):
"""
Callback when strategy is inited.
"""
self.load_bar(20)
if len(self.all_bars)>0:
self.send_event(EVENT_CTA_HISTORY_BAR,self.all_bars)
def on_start(self):
""" """
self.write_log("已开始")
def on_stop(self):
""""""
self.write_log("_kx_strategy 已停止")
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
self.current_tick = tick # 记录最新tick
if self.inited:
# 先产生当前临时K线
self.cur_bar = self.get_cur_bar(tick)
if self.cur_bar:
# 发送当前临时K线更新消息
self.send_event(EVENT_CTA_BAR,self.cur_bar)
# 再更新tick,产生1分钟K线乃至N 分钟线
self.bg.update_tick(tick)
self.send_event(EVENT_CTA_TICK,tick)
self.last_tick = tick
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
if self.inited:
self.write_log(f"I got a 1min BarData")
self.bg.update_bar(bar)
def on_Nmin_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.all_bars.append(bar)
self.kx_count = len(self.all_bars)
if self.inited:
self.write_log(f"I got a {self.kx_interval}min BarData")
self.send_event(EVENT_CTA_BAR,bar)
if self.current_tick:
# 当新N分钟K线产生的时候,立即产生新的临时K线
self.current_bar = None
self.get_cur_bar(self.current_tick)
self.put_event()
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
self.send_event(EVENT_CTA_TRADE,trade)
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
self.send_event(EVENT_CTA_ORDER,order)
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
"""
self.send_event(EVENT_CTA_STOPORDER,stop_order)
def get_cur_bar(self,tick:TickData)->BarData:
"""
产生临时K线,每个tick都会更新。除非把self.window_bar赋值为None,
不会产生新的K线,只会更新K线的量和加。
注意:self.last_tick是在BarGenerator中声明和改变的
"""
if not self.inited or not self.last_tick:
return None
if self.last_tick and tick.datetime < self.last_tick.datetime:
return None
if not self.current_bar:
# Generate timestamp for bar data
if self.bg.interval == Interval.MINUTE:
dt = tick.datetime.replace(second=0, microsecond=0)
else:
dt = tick.datetime.replace(minute=0, second=0, microsecond=0)
self.current_bar = BarData(
symbol=tick.symbol,
exchange=tick.exchange,
datetime=dt,
gateway_name=tick.gateway_name,
open_price=tick.last_price,
high_price=tick.last_price,
low_price=tick.last_price,
)
# Otherwise, update high/low price into window bar
else:
self.current_bar.high_price = max(self.current_bar.high_price, tick.last_price)
self.current_bar.low_price = min(self.current_bar.low_price, tick.last_price)
# Update last price/volume into window bar
self.current_bar.close_price = tick.last_price
volume_change = tick.volume - self.last_tick.volume
self.current_bar.volume += volume_change
self.current_bar.open_interest = tick.open_interest
return copy.deepcopy(self.current_bar)
def send_event(self,event_type:str,data:Any):
if self.engine_type==EngineType.LIVE and self.show_chart: # "如果显示K线图表"
self.even_engine.put(Event(event_type,(self.strategy_name,data)))
def init_kx_chart(self,kx_chart:NewChartWidget=None): # hxxjava add ----- 提供给外部调用
# self.write_log("init_kx_chart executed !!!")
if kx_chart:
kx_chart.add_plot("candle", hide_x_axis=True)
kx_chart.add_plot("volume", maximum_height=150)
kx_chart.add_plot("rsi", maximum_height=150)
kx_chart.add_plot("macd", maximum_height=150)
kx_chart.add_item(CandleItem, "candle", "candle")
kx_chart.add_item(VolumeItem, "volume", "volume")
kx_chart.add_item(LineItem, "line", "candle")
kx_chart.add_item(SmaItem, "sma", "candle")
kx_chart.add_item(RsiItem, "rsi", "rsi")
kx_chart.add_item(MacdItem, "macd", "macd")
kx_chart.add_last_price_line()
kx_chart.add_cursor()
启动VnTrader,进入策略管理界面,完成如下步骤:
1)从策略下拉框中选择_kx_strategy策略
2)点击添加策略按钮进入3界面
3)输入策略名称、vt_symbol、kx_interval和show_chart的值,注意kx_interval这里是你想要的K线周期,单位是分钟。show_chart参数为True标识需要显示K线图表,其他值则不显示。
4)初始化策略,如果参数为True的话,完成后显示K线图表窗口,并且显示20日里的历史K线图
5)按启动按钮启动策略,如果是交易时段,则K线图表就会显示最新收到的K线。提示还会实时显示未完成的临时K线
class test_strategy中的
all_bars
relate_name
都是类变量,它们在所有的子类中都是同一个。所以不同的策略在on_Nmin_bar()函数中的这条语句:
self.all_bars.append(bar)
其实是向同一个类变量all_bars列表中添加bar,所有导致错误!
把它们变成实例变量就可以了,方法是这样的
在init()函数中
self.all_bars:List[BarData] =[]
self.relate_names:List[str] = []
class CLanguage :
name = "xxx" #类变量
addr = "http://" #类变量
def __init__(self):
self.name = "C语言中文网" #实例变量
self.addr = "http://c.biancheng.net" #实例变量
# 下面定义了一个say实例方法
def say(self):
self.catalog = 13 #实例变量
def test():
clang1 = CLanguage()
#修改 clang 对象的实例变量
print(clang1.name)
print(clang1.addr)
clang1.name = "python教程"
clang1.addr = "http://c.biancheng.net/python"
print(clang1.name)
print(clang1.addr)
clang2 = CLanguage()
print(clang2.name)
print(clang2.addr)
#输出类变量的值
print(CLanguage.name)
print(CLanguage.addr)
if __name__ == "__main__":
test()
输出结果:
xxx
http://
C语言中文网
http://c.biancheng.net
python教程
http://c.biancheng.net/python
C语言中文网
http://c.biancheng.net
xxx
http://
对不起大家了,本次分享遇到了问题,导致K线图的数据和显示出错了,请大家耐心等待问题的解决。
具体问题请参考:https://www.vnpy.com/forum/topic/3893-ctace-lue-de-wen-ti-:liang-ge-he-yue-bu-ke-yi-gong-yong-yi-ge-ce-lue
用户实现了自己的CTA策略,可能会放在多个合约上跑。用户策略里会声母一系列的策略成员变量,这些策略成员变量应该是每个策略实例是不同的。可是我发现事实不是这样的!——不同的合约竟然共用着一个了的策略成员变量!
下面代码保存在用户策略文件夹下的test_strategy.py文件中
from typing import Any,List,Dict,Tuple
import copy
from vnpy.app.cta_strategy import (
CtaTemplate,
BarGenerator,
ArrayManager,
StopOrder,
Direction
)
from vnpy.trader.engine import MainEngine,EventEngine
from vnpy.app.cta_strategy.engine import CtaEngine
from vnpy.event.engine import Event
from vnpy.trader.object import (
LogData,
TickData,
BarData,
TradeData,
OrderData,
)
class test_strategy(CtaTemplate):
""""""
author = "hxxjava"
kx_interval = 1
parameters = [
"kx_interval"
]
kx_count:int = 0
all_bars:List[BarData] =[]
variables = ["kx_count"]
relate_names:List[str] = []
def __init__(
self,
cta_engine: Any,
strategy_name: str,
vt_symbol: str,
setting: dict,
):
super().__init__(cta_engine,strategy_name,vt_symbol,setting)
self.bg = BarGenerator(self.on_bar,self.kx_interval,self.on_Nmin_bar)
self.am = ArrayManager()
self.relate_names.append(vt_symbol)
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("test_strategy 初始化")
self.load_bar(20)
self.write_log(f"relate_names={self.relate_names} !!!")
def on_start(self):
""" """
self.write_log(f"test_strategy 已开始 self.kx_interval={self.kx_interval}",)
def on_stop(self):
""""""
self.write_log("test_strategy 已停止")
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
self.bg.update_tick(tick)
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
if self.inited:
self.write_log(f"I got a 1min BarData")
self.bg.update_bar(bar)
def on_Nmin_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.all_bars.append(bar)
self.kx_count = len(self.all_bars)
if self.inited:
self.write_log(f"I got a {self.kx_interval}min BarData {self.kx_count}")
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
"""
用来存放该策略里20日的 all_bars数组中保存有ag2012.SHFE的所有20日里的所有10分钟K线数据
用来存放该策略里20日的 all_bars数组中,不只是保存有rb2010的所有5分钟K线数据,也保存有ag2012的10分钟K线数据。
不同的用户策略应该拥有各自不同的成员变量:
all_bars
relate_names
可是从实际的测试结果看,它们却是相同的,这是不应该的!
原因发生在cta_engine里的这个函数中:
def add_strategy(
self, class_name: str, strategy_name: str, vt_symbol: str, setting: dict
):
"""
Add a new strategy.
"""
if strategy_name in self.strategies:
self.write_log(f"创建策略失败,存在重名{strategy_name}")
return
strategy_class = self.classes.get(class_name, None)
if not strategy_class:
self.write_log(f"创建策略失败,找不到策略类{class_name}")
return
strategy = strategy_class(self, strategy_name, vt_symbol, setting)
self.strategies[strategy_name] = strategy
# Add vt_symbol to strategy map.
strategies = self.symbol_strategy_map[vt_symbol]
strategies.append(strategy)
# Update to setting file.
self.update_strategy_setting(strategy_name, setting)
self.put_strategy_event(strategy)
这里的的代码可以看出,不同的策略实例使用相同的策略类时,不是立即创建新的策略类实例,而是从self.classes字典中,根据class_name查询得到的。
启动VnTrader,进入策略管理界面,完成如下步骤:
1)从策略下拉框中选择_kx_strategy策略
2)点击添加策略按钮进入3界面
3)输入策略名称、vt_symbol、kx_interval的值,注意kx_interval这里是你想要的K线周期,单位是分钟。
4)添加一个可以显示K线的策略
5)初始化策略,完成后”K线图表“和”启动“两个按钮都是有效的
6)按钮K线图表按钮,创建一个空的K线图窗口
7)按启动按钮启动策略
8)再次找到K线图表里窗口,你会发现里面已经是下图这个样子了。
说明一下:如果在交易时段,它会不停地更新K线状态和各个附图指标,而且随着tick数据的不断更新,它还可以更新最后一个未结束的当前K线。
这里的用户目录就是你自己windows系统的默认用户目录,如果你没有特别指定的话, [用户目录]\strategies就是你的其他策略文件存放的目录。
该文件的内容如下:
from typing import Any,List,Dict,Tuple
import copy
from vnpy.app.cta_strategy import (
CtaTemplate,
BarGenerator,
ArrayManager,
StopOrder,
Direction
)
from vnpy.trader.engine import MainEngine,EventEngine
from vnpy.app.cta_strategy.engine import CtaEngine
from vnpy.event.engine import Event
from vnpy.trader.object import (
LogData,
TickData,
BarData,
TradeData,
OrderData,
)
from vnpy.app.cta_strategy import (
StopOrder
)
from vnpy.trader.constant import Interval
from vnpy.app.cta_strategy.base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_TICK,
EVENT_CTA_HISTORY_BAR,
EVENT_CTA_BAR,
EVENT_CTA_ORDER,
EVENT_CTA_TRADE,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY,
)
class _kx_strategy(CtaTemplate):
""""""
author = "hxxjava"
kx_interval = 1
parameters = [
"kx_interval"
]
kx_count:int = 0
all_bars:List[BarData] = []
current_tick:[TickData] = None
current_bar:[BarData] = None
variables = ["kx_count"]
def __init__(
self,
cta_engine: Any,
strategy_name: str,
vt_symbol: str,
setting: dict,
):
super().__init__(cta_engine,strategy_name,vt_symbol,setting)
self.bg = BarGenerator(self.on_bar,self.kx_interval,self.on_Nmin_bar)
self.am = ArrayManager()
cta_engine:CtaEngine = self.cta_engine
self.even_engine = cta_engine.main_engine.event_engine
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("_kx_strategy 初始化")
self.load_bar(20)
def on_start(self):
""" """
self.write_log("_kx_strategy 已开始")
if len(self.all_bars)>0:
self.even_engine.put(Event(EVENT_CTA_HISTORY_BAR,self.all_bars))
def on_stop(self):
""""""
self.write_log("_kx_strategy 已停止")
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
self.current_tick = tick # 记录最新tick
if self.inited:
# 先产生当前临时K线
cur_bar = self.get_cur_bar(tick)
if cur_bar:
# 发送当前临时K线更新消息
self.even_engine.put(Event(EVENT_CTA_BAR,cur_bar))
# 再更新tick,产生1分钟K线乃至N 分钟线
self.bg.update_tick(tick)
self.even_engine.put(Event(EVENT_CTA_TICK,tick))
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
if self.inited:
self.write_log(f"I got a 1min BarData")
self.bg.update_bar(bar)
def on_Nmin_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.all_bars.append(bar)
self.kx_count = len(self.all_bars)
if self.inited:
self.write_log(f"I got a {self.kx_interval}min BarData")
self.even_engine.put(Event(EVENT_CTA_BAR,bar))
if self.current_tick:
# 当新N分钟K线产生的时候,立即产生新的临时K线
self.current_bar = None
self.get_cur_bar(self.current_tick)
self.put_event()
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
self.even_engine.put(Event(EVENT_CTA_TRADE,trade))
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
self.even_engine.put(Event(EVENT_CTA_ORDER,order))
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
"""
self.even_engine.put(Event(EVENT_CTA_STOPORDER,stop_order))
def get_cur_bar(self,tick:TickData)->BarData:
"""
产生临时K线,每个tick都会更新。除非把self.window_bar赋值为None,
不会产生新的K线,只会更新K线的量和加。
注意:self.last_tick是在BarGenerator中声明和改变的
"""
last_tick = self.bg.last_tick
if not self.inited or not last_tick:
return None
if last_tick and tick.datetime < last_tick.datetime:
return None
if not self.current_bar:
# Generate timestamp for bar data
if self.bg.interval == Interval.MINUTE:
dt = tick.datetime.replace(second=0, microsecond=0)
else:
dt = tick.datetime.replace(minute=0, second=0, microsecond=0)
self.current_bar = BarData(
symbol=tick.symbol,
exchange=tick.exchange,
datetime=dt,
gateway_name=tick.gateway_name,
open_price=tick.last_price,
high_price=tick.last_price,
low_price=tick.last_price,
)
# Otherwise, update high/low price into window bar
else:
self.current_bar.high_price = max(self.current_bar.high_price, tick.last_price)
self.current_bar.low_price = min(self.current_bar.low_price, tick.last_price)
# Update last price/volume into window bar
self.current_bar.close_price = tick.last_price
volume_change = tick.volume - last_tick.volume
self.current_bar.volume += volume_change
self.current_bar.open_interest = tick.open_interest
return copy.deepcopy(self.current_bar)
from datetime import datetime
from typing import List, Tuple, Dict
import numpy as np
import pyqtgraph as pg
import talib
import copy
from vnpy.trader.ui import create_qapp, QtCore, QtGui, QtWidgets
from vnpy.trader.database import database_manager
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData
from vnpy.chart import ChartWidget, VolumeItem, CandleItem
from vnpy.chart.item import ChartItem
from vnpy.chart.manager import BarManager
from vnpy.chart.base import NORMAL_FONT
from vnpy.trader.engine import MainEngine
from vnpy.event import Event, EventEngine
from vnpy.trader.event import (
EVENT_TICK,
EVENT_TRADE,
EVENT_ORDER,
EVENT_POSITION,
EVENT_ACCOUNT,
EVENT_LOG
)
from vnpy.app.cta_strategy.base import ( # hxxjava add
EVENT_CTA_TICK,
EVENT_CTA_BAR,
EVENT_CTA_ORDER,
EVENT_CTA_TRADE,
EVENT_CTA_HISTORY_BAR
)
from vnpy.trader.object import TickData,BarData # hxxjava add
class LineItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
last_bar = self._manager.get_bar(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.white_pen)
# Draw Line
end_point = QtCore.QPointF(ix, bar.close_price)
if last_bar:
start_point = QtCore.QPointF(ix - 1, last_bar.close_price)
else:
start_point = end_point
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
class SmaItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)
self.sma_window = 10
self.sma_data: Dict[int, float] = {}
def get_sma_value(self, ix: int) -> float:
""""""
if ix < 0:
return 0
# When initialize, calculate all rsi value
if not self.sma_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
sma_array = talib.SMA(np.array(close_data), self.sma_window)
for n, value in enumerate(sma_array):
self.sma_data[n] = value
# Return if already calcualted
if ix in self.sma_data:
return self.sma_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.sma_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
sma_array = talib.SMA(np.array(close_data), self.sma_window)
sma_value = sma_array[-1]
self.sma_data[ix] = sma_value
return sma_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
sma_value = self.get_sma_value(ix)
last_sma_value = self.get_sma_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.blue_pen)
# Draw Line
start_point = QtCore.QPointF(ix-1, last_sma_value)
end_point = QtCore.QPointF(ix, sma_value)
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.sma_data:
sma_value = self.sma_data[ix]
text = f"SMA {sma_value:.1f}"
else:
text = "SMA -"
return text
class RsiItem(ChartItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=2)
self.rsi_window = 14
self.rsi_data: Dict[int, float] = {}
def get_rsi_value(self, ix: int) -> float:
""""""
if ix < 0:
return 50
# When initialize, calculate all rsi value
if not self.rsi_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
for n, value in enumerate(rsi_array):
self.rsi_data[n] = value
# Return if already calcualted
if ix in self.rsi_data:
return self.rsi_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.rsi_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
rsi_value = rsi_array[-1]
self.rsi_data[ix] = rsi_value
return rsi_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
rsi_value = self.get_rsi_value(ix)
last_rsi_value = self.get_rsi_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Draw RSI line
painter.setPen(self.yellow_pen)
if np.isnan(last_rsi_value) or np.isnan(rsi_value):
# print(ix - 1, last_rsi_value,ix, rsi_value,)
pass
else:
end_point = QtCore.QPointF(ix, rsi_value)
start_point = QtCore.QPointF(ix - 1, last_rsi_value)
painter.drawLine(start_point, end_point)
# Draw oversold/overbought line
painter.setPen(self.white_pen)
painter.drawLine(
QtCore.QPointF(ix, 70),
QtCore.QPointF(ix - 1, 70),
)
painter.drawLine(
QtCore.QPointF(ix, 30),
QtCore.QPointF(ix - 1, 30),
)
# Finish
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
# min_price, max_price = self._manager.get_price_range()
rect = QtCore.QRectF(
0,
0,
len(self._bar_picutures),
100
)
return rect
def get_y_range( self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
""" """
return 0, 100
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.rsi_data:
rsi_value = self.rsi_data[ix]
text = f"RSI {rsi_value:.1f}"
# print(text)
else:
text = "RSI -"
return text
def to_int(value: float) -> int:
""""""
return int(round(value, 0))
""" 将y方向的显示范围扩大到1.1 """
def adjust_range(in_range:Tuple[float, float])->Tuple[float, float]:
ret_range:Tuple[float, float]
diff = abs(in_range[0] - in_range[1])
ret_range = (in_range[0]-diff*0.05,in_range[1]+diff*0.05)
return ret_range
class MacdItem(ChartItem):
""""""
_values_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {}
last_range:Tuple[int, int] = (-1,-1) # 最新显示K线索引范围
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=1)
self.red_pen: QtGui.QPen = pg.mkPen(color=(255, 0, 0), width=1)
self.green_pen: QtGui.QPen = pg.mkPen(color=(0, 255, 0), width=1)
self.short_window = 12
self.long_window = 26
self.M = 9
self.macd_data: Dict[int, Tuple[float,float,float]] = {}
def get_macd_value(self, ix: int) -> Tuple[float,float,float]:
""""""
if ix < 0:
return (0.0,0.0,0.0)
# When initialize, calculate all macd value
if not self.macd_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
diffs,deas,macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
for n in range(0,len(diffs)):
self.macd_data[n] = (diffs[n],deas[n],macds[n])
# Return if already calcualted
if ix in self.macd_data:
return self.macd_data[ix]
# Else calculate new value
close_data = []
for n in range(ix-self.long_window-self.M+1, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
diffs,deas,macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
diff,dea,macd = diffs[-1],deas[-1],macds[-1]
self.macd_data[ix] = (diff,dea,macd)
return (diff,dea,macd)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
macd_value = self.get_macd_value(ix)
last_macd_value = self.get_macd_value(ix - 1)
# # Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# # Draw macd lines
if np.isnan(macd_value[0]) or np.isnan(last_macd_value[0]):
# print("略过macd lines0")
pass
else:
end_point0 = QtCore.QPointF(ix, macd_value[0])
start_point0 = QtCore.QPointF(ix - 1, last_macd_value[0])
painter.setPen(self.white_pen)
painter.drawLine(start_point0, end_point0)
if np.isnan(macd_value[1]) or np.isnan(last_macd_value[1]):
# print("略过macd lines1")
pass
else:
end_point1 = QtCore.QPointF(ix, macd_value[1])
start_point1 = QtCore.QPointF(ix - 1, last_macd_value[1])
painter.setPen(self.yellow_pen)
painter.drawLine(start_point1, end_point1)
if not np.isnan(macd_value[2]):
if (macd_value[2]>0):
painter.setPen(self.red_pen)
painter.setBrush(pg.mkBrush(255,0,0))
else:
painter.setPen(self.green_pen)
painter.setBrush(pg.mkBrush(0,255,0))
painter.drawRect(QtCore.QRectF(ix-0.3,0,0.6,macd_value[2]))
else:
# print("略过macd lines2")
pass
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
min_y, max_y = self.get_y_range()
rect = QtCore.QRectF(
0,
min_y,
len(self._bar_picutures),
max_y
)
return rect
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
# 获得3个指标在y轴方向的范围
# hxxjava 修改,2020-6-29
# 当显示范围改变时,min_ix,max_ix的值不为None,当显示范围不变时,min_ix,max_ix的值不为None,
offset = max(self.short_window,self.long_window) + self.M - 1
if not self.macd_data or len(self.macd_data) < offset:
return 0.0, 1.0
# print("len of range dict:",len(self._values_ranges),",macd_data:",len(self.macd_data),(min_ix,max_ix))
if min_ix != None: # 调整最小K线索引
min_ix = max(min_ix,offset)
if max_ix != None: # 调整最大K线索引
max_ix = min(max_ix, len(self.macd_data)-1)
last_range = (min_ix,max_ix) # 请求的最新范围
if last_range == (None,None): # 当显示范围不变时
if self.last_range in self._values_ranges:
# 如果y方向范围已经保存
# 读取y方向范围
result = self._values_ranges[self.last_range]
# print("1:",self.last_range,result)
return adjust_range(result)
else:
# 如果y方向范围没有保存
# 从macd_data重新计算y方向范围
min_ix,max_ix = 0,len(self.macd_data)-1
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 保存y方向范围,同时返回结果
result = (min_price, max_price)
self.last_range = (min_ix,max_ix)
self._values_ranges[self.last_range] = result
# print("2:",self.last_range,result)
return adjust_range(result)
""" 以下为显示范围变化时 """
if last_range in self._values_ranges:
# 该范围已经保存过y方向范围
# 取得y方向范围,返回结果
result = self._values_ranges[last_range]
# print("3:",last_range,result)
return adjust_range(result)
# 该范围没有保存过y方向范围
# 从macd_data重新计算y方向范围
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 取得y方向范围,返回结果
result = (min_price, max_price)
self.last_range = last_range
self._values_ranges[self.last_range] = result
# print("4:",self.last_range,result)
return adjust_range(result)
def get_info_text(self, ix: int) -> str:
# """"""
if ix in self.macd_data:
diff,dea,macd = self.macd_data[ix]
words = [
f"diff {diff:.3f}"," ",
f"dea {dea:.3f}"," ",
f"macd {macd:.3f}"
]
text = "\n".join(words)
else:
text = "diff - \ndea - \nmacd -"
return text
class NewChartWidget(ChartWidget):
""""""
MIN_BAR_COUNT = 100
signal_cta_history_bar:QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
signal_cta_tick: QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
signal_cta_bar:QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
# def __init__(self, parent: QtWidgets.QWidget = None,event_engine: EventEngine = None):
# """"""
# super().__init__(parent)
def __init__(self, main_engine: MainEngine = None,event_engine: EventEngine = None):
""""""
super().__init__()
self.event_engine = event_engine
self.last_price_line: pg.InfiniteLine = None
# self.register_event()
# self.event_engine.start()
def register_event(self) -> None:
""""""
self.signal_cta_history_bar.connect(self.process_cta_history_bar)
self.event_engine.register(EVENT_CTA_HISTORY_BAR, self.signal_cta_history_bar.emit)
self.signal_cta_tick.connect(self.process_tick_event)
self.event_engine.register(EVENT_CTA_TICK, self.signal_cta_tick.emit)
self.signal_cta_bar.connect(self.process_cta_bar)
self.event_engine.register(EVENT_CTA_BAR, self.signal_cta_bar.emit)
def process_cta_history_bar(self, event:Event) -> None:
""" 处理历史K线推送 """
print(" I got an EVENT_CTA_HISTORY_BAR")
history_bars:List[BarData] = event.data
self.update_history(history_bars)
def process_tick_event(self, event: Event) -> None:
""" 处理tick数据推送 """
tick:TickData = event.data
if self.last_price_line:
self.last_price_line.setValue(tick.last_price)
# print("I got an event")
def process_cta_bar(self, event:Event)-> None:
""" 处理K线数据推送 """
print(" I got an EVENT_CTA_BAR")
bar:BarData = event.data
self.update_bar(bar)
def add_last_price_line(self):
""""""
plot = list(self._plots.values())[0]
color = (255, 255, 255)
self.last_price_line = pg.InfiniteLine(
angle=0,
movable=False,
label="{value:.1f}",
pen=pg.mkPen(color, width=1),
labelOpts={
"color": color,
"position": 1,
"anchors": [(1, 1), (1, 1)]
}
)
self.last_price_line.label.setFont(NORMAL_FONT)
plot.addItem(self.last_price_line)
def update_history(self, history: List[BarData]) -> None:
"""
Update a list of bar data.
"""
self._manager.update_history(history)
for item in self._items.values():
item.update_history(history)
self._update_plot_limits()
self.move_to_right()
self.update_last_price_line(history[-1])
def update_bar(self, bar: BarData) -> None:
"""
Update single bar data.
"""
self._manager.update_bar(bar)
for item in self._items.values():
item.update_bar(bar)
self._update_plot_limits()
if self._right_ix >= (self._manager.get_count() - self._bar_count / 2):
self.move_to_right()
self.update_last_price_line(bar)
def update_last_price_line(self, bar: BarData) -> None:
""""""
if self.last_price_line:
self.last_price_line.setValue(bar.close_price)
if __name__ == "__main__":
app = create_qapp()
# bars = database_manager.load_bar_data(
# "IF888",
# Exchange.CFFEX,
# interval=Interval.MINUTE,
# start=datetime(2019, 7, 1),
# end=datetime(2019, 7, 17)
# )
symbol = "rb2010"
exchange = Exchange.SHFE
interval=Interval.MINUTE
start=datetime(2020, 6, 1)
end=datetime(2020, 6, 30)
dynamic = False # 是否动态演示
n = 200 # 缓冲K线根数
bars = database_manager.load_bar_data(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end
)
event_engine = EventEngine()
widget = NewChartWidget(event_engine = event_engine)
widget.setWindowTitle(f"K线图表——{symbol}.{exchange.value},{interval},{start}-{end}")
widget.add_plot("candle", hide_x_axis=True)
widget.add_plot("volume", maximum_height=150)
widget.add_plot("rsi", maximum_height=150)
widget.add_plot("macd", maximum_height=150)
widget.add_item(CandleItem, "candle", "candle")
widget.add_item(VolumeItem, "volume", "volume")
widget.add_item(LineItem, "line", "candle")
widget.add_item(SmaItem, "sma", "candle")
widget.add_item(RsiItem, "rsi", "rsi")
widget.add_item(MacdItem, "macd", "macd")
widget.add_last_price_line()
widget.add_cursor()
if dynamic:
history = bars[:n] # 先取得最早的n根bar作为历史
new_data = bars[n:] # 其它留着演示
else:
history = bars[-n:] # 先取得最新的n根bar作为历史
new_data = [] # 演示的为空
widget.update_history(history)
def update_bar():
if new_data:
bar = new_data.pop(0)
widget.update_bar(bar)
# event = Event(EVENT_TICK,None)
# event_engine.put(event)
timer = QtCore.QTimer()
timer.timeout.connect(update_bar)
if dynamic:
timer.start(100)
widget.show()
# event_engine.start()
app.exec_()
vnpy\app\cta_strategy\base.py
vnpy\app\cta_strategy\ui\widget.py
内容如下:
"""
Defines constants and objects used in CtaStrategy App.
"""
from dataclasses import dataclass, field
from enum import Enum
from datetime import timedelta
from vnpy.trader.constant import Direction, Offset, Interval
APP_NAME = "CtaStrategy"
STOPORDER_PREFIX = "STOP"
class StopOrderStatus(Enum):
WAITING = "等待中"
CANCELLED = "已撤销"
TRIGGERED = "已触发"
class EngineType(Enum):
LIVE = "实盘"
BACKTESTING = "回测"
class BacktestingMode(Enum):
BAR = 1
TICK = 2
@dataclass
class StopOrder:
vt_symbol: str
direction: Direction
offset: Offset
price: float
volume: float
stop_orderid: str
strategy_name: str
lock: bool = False
vt_orderids: list = field(default_factory=list)
status: StopOrderStatus = StopOrderStatus.WAITING
EVENT_CTA_LOG = "eCtaLog"
EVENT_CTA_STRATEGY = "eCtaStrategy"
EVENT_CTA_STOPORDER = "eCtaStopOrder"
EVENT_CTA_TICK = "eCtaTick" # hxxjava add
EVENT_CTA_HISTORY_BAR = "eCtaHistoryBar" # hxxjava add
EVENT_CTA_BAR = "eCtaBar" # hxxjava add
EVENT_CTA_ORDER = "eCtaOrder" # hxxjava add
EVENT_CTA_TRADE = "eCtaTrade" # hxxjava add
INTERVAL_DELTA_MAP = {
Interval.MINUTE: timedelta(minutes=1),
Interval.HOUR: timedelta(hours=1),
Interval.DAILY: timedelta(days=1),
}
修改如下:
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
from vnpy.trader.ui.widget import (
BaseCell,
EnumCell,
MsgCell,
TimeCell,
BaseMonitor
)
from ..base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY,
)
from ..engine import CtaEngine
from vnpy.usertools.kx_chart import ( # hxxjava add
NewChartWidget,
CandleItem,
VolumeItem,
LineItem,
SmaItem,
RsiItem,
MacdItem,
)
class CtaManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_strategy = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
super(CtaManager, self).__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.cta_engine = main_engine.get_engine(APP_NAME)
self.managers = {}
self.init_ui()
self.register_event()
self.cta_engine.init_engine()
self.update_class_combo()
def init_ui(self):
""""""
self.setWindowTitle("CTA策略")
# Create widgets
self.class_combo = QtWidgets.QComboBox()
add_button = QtWidgets.QPushButton("添加策略")
add_button.clicked.connect(self.add_strategy)
init_button = QtWidgets.QPushButton("全部初始化")
init_button.clicked.connect(self.cta_engine.init_all_strategies)
start_button = QtWidgets.QPushButton("全部启动")
start_button.clicked.connect(self.cta_engine.start_all_strategies)
stop_button = QtWidgets.QPushButton("全部停止")
stop_button.clicked.connect(self.cta_engine.stop_all_strategies)
clear_button = QtWidgets.QPushButton("清空日志")
clear_button.clicked.connect(self.clear_log)
self.scroll_layout = QtWidgets.QVBoxLayout()
self.scroll_layout.addStretch()
scroll_widget = QtWidgets.QWidget()
scroll_widget.setLayout(self.scroll_layout)
scroll_area = QtWidgets.QScrollArea()
scroll_area.setWidgetResizable(True)
scroll_area.setWidget(scroll_widget)
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
self.stop_order_monitor = StopOrderMonitor(
self.main_engine, self.event_engine
)
# Set layout
hbox1 = QtWidgets.QHBoxLayout()
hbox1.addWidget(self.class_combo)
hbox1.addWidget(add_button)
hbox1.addStretch()
hbox1.addWidget(init_button)
hbox1.addWidget(start_button)
hbox1.addWidget(stop_button)
hbox1.addWidget(clear_button)
grid = QtWidgets.QGridLayout()
grid.addWidget(scroll_area, 0, 0, 2, 1)
grid.addWidget(self.stop_order_monitor, 0, 1)
grid.addWidget(self.log_monitor, 1, 1)
vbox = QtWidgets.QVBoxLayout()
vbox.addLayout(hbox1)
vbox.addLayout(grid)
self.setLayout(vbox)
def update_class_combo(self):
""""""
self.class_combo.addItems(
self.cta_engine.get_all_strategy_class_names()
)
def register_event(self):
""""""
self.signal_strategy.connect(self.process_strategy_event)
self.event_engine.register(
EVENT_CTA_STRATEGY, self.signal_strategy.emit
)
def process_strategy_event(self, event):
"""
Update strategy status onto its monitor.
"""
data = event.data
strategy_name = data["strategy_name"]
if strategy_name in self.managers:
manager = self.managers[strategy_name]
manager.update_data(data)
else:
manager = StrategyManager(self, self.cta_engine, data)
self.scroll_layout.insertWidget(0, manager)
self.managers[strategy_name] = manager
def remove_strategy(self, strategy_name):
""""""
manager = self.managers.pop(strategy_name)
manager.deleteLater()
def add_strategy(self):
""""""
class_name = str(self.class_combo.currentText())
if not class_name:
return
parameters = self.cta_engine.get_strategy_class_parameters(class_name)
editor = SettingEditor(parameters, class_name=class_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
vt_symbol = setting.pop("vt_symbol")
strategy_name = setting.pop("strategy_name")
self.cta_engine.add_strategy(
class_name, strategy_name, vt_symbol, setting
)
def clear_log(self):
""""""
self.log_monitor.setRowCount(0)
def show(self):
""""""
self.showMaximized()
class StrategyManager(QtWidgets.QFrame):
"""
Manager for a strategy
"""
kx_chart:NewChartWidget = None # hxxjava add
def __init__(
self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict
):
""""""
super(StrategyManager, self).__init__()
self.cta_manager = cta_manager
self.cta_engine = cta_engine
self.strategy_name = data["strategy_name"]
self._data = data
self.init_ui()
def init_ui(self):
""""""
self.setFixedHeight(300)
self.setFrameShape(self.Box)
self.setLineWidth(1)
self.init_button = QtWidgets.QPushButton("初始化")
self.init_button.clicked.connect(self.init_strategy)
self.start_button = QtWidgets.QPushButton("启动")
self.start_button.clicked.connect(self.start_strategy)
self.start_button.setEnabled(False)
# hxxjava add start
self.kx_button = QtWidgets.QPushButton("K线图表")
self.kx_button.clicked.connect(self.open_kx_chart)
self.kx_button.setEnabled(False)
# hxxjava add end
self.stop_button = QtWidgets.QPushButton("停止")
self.stop_button.clicked.connect(self.stop_strategy)
self.stop_button.setEnabled(False)
self.edit_button = QtWidgets.QPushButton("编辑")
self.edit_button.clicked.connect(self.edit_strategy)
self.remove_button = QtWidgets.QPushButton("移除")
self.remove_button.clicked.connect(self.remove_strategy)
strategy_name = self._data["strategy_name"]
vt_symbol = self._data["vt_symbol"]
class_name = self._data["class_name"]
author = self._data["author"]
label_text = (
f"{strategy_name} - {vt_symbol} ({class_name} by {author})"
)
label = QtWidgets.QLabel(label_text)
label.setAlignment(QtCore.Qt.AlignCenter)
self.parameters_monitor = DataMonitor(self._data["parameters"])
self.variables_monitor = DataMonitor(self._data["variables"])
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(self.init_button)
hbox.addWidget(self.kx_button) # hxxjava add
hbox.addWidget(self.start_button)
hbox.addWidget(self.stop_button)
hbox.addWidget(self.edit_button)
hbox.addWidget(self.remove_button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(label)
vbox.addLayout(hbox)
vbox.addWidget(self.parameters_monitor)
vbox.addWidget(self.variables_monitor)
self.setLayout(vbox)
def update_data(self, data: dict):
""""""
self._data = data
self.parameters_monitor.update_data(data["parameters"])
self.variables_monitor.update_data(data["variables"])
# Update button status
variables = data["variables"]
inited = variables["inited"]
trading = variables["trading"]
if not inited:
return
self.init_button.setEnabled(False)
if trading:
self.kx_button.setEnabled(False) # hxxjava add
self.start_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.edit_button.setEnabled(False)
self.remove_button.setEnabled(False)
else:
self.kx_button.setEnabled(True) # hxxjava add
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
self.edit_button.setEnabled(True)
self.remove_button.setEnabled(True)
def init_strategy(self):
""""""
self.cta_engine.init_strategy(self.strategy_name)
def start_strategy(self):
""""""
self.cta_engine.start_strategy(self.strategy_name)
def stop_strategy(self):
""""""
self.cta_engine.stop_strategy(self.strategy_name)
def edit_strategy(self):
""""""
strategy_name = self._data["strategy_name"]
parameters = self.cta_engine.get_strategy_parameters(strategy_name)
editor = SettingEditor(parameters, strategy_name=strategy_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
self.cta_engine.edit_strategy(strategy_name, setting)
def remove_strategy(self):
""""""
result = self.cta_engine.remove_strategy(self.strategy_name)
# Only remove strategy gui manager if it has been removed from engine
if result:
self.cta_manager.remove_strategy(self.strategy_name)
def open_kx_chart(self): # hxx add
event_engine = self.cta_engine.event_engine
if not self.kx_chart:
self.kx_chart = NewChartWidget(self,event_engine)
self.kx_chart.setWindowTitle(f"K线图表——{self.strategy_name}")
self.kx_chart.add_plot("candle", hide_x_axis=True)
self.kx_chart.add_plot("volume", maximum_height=150)
self.kx_chart.add_plot("rsi", maximum_height=150)
self.kx_chart.add_plot("macd", maximum_height=150)
self.kx_chart.add_item(CandleItem, "candle", "candle")
self.kx_chart.add_item(VolumeItem, "volume", "volume")
self.kx_chart.add_item(LineItem, "line", "candle")
self.kx_chart.add_item(SmaItem, "sma", "candle")
self.kx_chart.add_item(RsiItem, "rsi", "rsi")
self.kx_chart.add_item(MacdItem, "macd", "macd")
self.kx_chart.add_last_price_line()
self.kx_chart.add_cursor()
self.kx_chart.register_event()
self.kx_chart.show()
class DataMonitor(QtWidgets.QTableWidget):
"""
Table monitor for parameters and variables.
"""
def __init__(self, data: dict):
""""""
super(DataMonitor, self).__init__()
self._data = data
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
labels = list(self._data.keys())
self.setColumnCount(len(labels))
self.setHorizontalHeaderLabels(labels)
self.setRowCount(1)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
for column, name in enumerate(self._data.keys()):
value = self._data[name]
cell = QtWidgets.QTableWidgetItem(str(value))
cell.setTextAlignment(QtCore.Qt.AlignCenter)
self.setItem(0, column, cell)
self.cells[name] = cell
def update_data(self, data: dict):
""""""
for name, value in data.items():
cell = self.cells[name]
cell.setText(str(value))
class StopOrderMonitor(BaseMonitor):
"""
Monitor for local stop order.
"""
event_type = EVENT_CTA_STOPORDER
data_key = "stop_orderid"
sorting = True
headers = {
"stop_orderid": {"display": "停止委托号","cell": BaseCell,"update": False,},
"vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True},
"vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
"direction": {"display": "方向", "cell": EnumCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"status": {"display": "状态", "cell": EnumCell, "update": True},
"lock": {"display": "锁仓", "cell": BaseCell, "update": False},
"strategy_name": {"display": "策略名", "cell": BaseCell, "update": False},
}
class LogMonitor(BaseMonitor):
"""
Monitor for log data.
"""
event_type = EVENT_CTA_LOG
data_key = ""
sorting = False
headers = {
"time": {"display": "时间", "cell": TimeCell, "update": False},
"msg": {"display": "信息", "cell": MsgCell, "update": False},
}
def init_ui(self):
"""
Stretch last column.
"""
super(LogMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
def insert_new_row(self, data):
"""
Insert a new row at the top of table.
"""
super(LogMonitor, self).insert_new_row(data)
self.resizeRowToContents(0)
class SettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, parameters: dict, strategy_name: str = "", class_name: str = ""
):
""""""
super(SettingEditor, self).__init__()
self.parameters = parameters
self.strategy_name = strategy_name
self.class_name = class_name
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
if self.class_name:
self.setWindowTitle(f"添加策略:{self.class_name}")
button_text = "添加"
parameters = {"strategy_name": "", "vt_symbol": ""}
parameters.update(self.parameters)
else:
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
widget = QtWidgets.QWidget()
widget.setLayout(form)
scroll = QtWidgets.QScrollArea()
scroll.setWidgetResizable(True)
scroll.setWidget(widget)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(scroll)
self.setLayout(vbox)
def get_setting(self):
""""""
setting = {}
if self.class_name:
setting["class_name"] = self.class_name
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting
哦,知道了。我明白怎么解决了。
可以把每次登陆后重传今日所有的委托和成交数据,自己保存一份到磁盘,然后再把实时的委托和交易也做保存,这样本地就有了完整的委托和成交记录了,再查询计算就不难了。这想法靠谱吗?
VNPY系统确实是非常好的开发平台!可是做CTA交易的朋友一定有这样的感觉,策略在安静地运行,你却看不见它的K线,也无法直观地看到它的交易情况。经常来回地在第三方的行情客户端来回切换,为的是看看都对策略运动什么样一个情况。如果能够为每个CTA策略都给出一个观看图表的机会,那该有多好!
本人经过一周时间的努力下,已经基本让CTA策略可以看的见了。
1 实现一个给予ChartWizard的K线图NewChartWidget,它可以接受历史K线数据、tick数据、bar线数据、order数据、trader数据的推送
2 定义五个与行情相关的消息类型
为K线图表添砖加瓦——让CTA策略看得见(1)
EVENT_CTA_HISTORY_BAR —— 历史K线消息
EVENT_CTA_TICK —— TICK消息
EVENT_CTA_BAR —— BAR消息
EVENT_CTA_ORDER —— 委托单消息
EVENT_CTA_TRADE —— 成交单
EVENT_CTA_STOPORDER —— 停止单消息(系统本来就有)
3 利用策略的on_start(),on_tick(),on_bar(),on_order(),on_stop_order(),on_trade()接口,发送这些消息
4 可以接受实时的临时K线的显示
5 可以显示行情,也已显示委托、成交的发生位置
经过修改后,在CTA策略管理界面中,如上图所示,每个策略都会增加一个 “K线图表” 按钮,在已经新建好策略后,只有策略初始化后“K线图表” 按钮和“开始”按钮同时有效,如果你想观看该策略的K线图,先点击“K线图表” 按钮,就会显示一个空的K线图窗口,再和原来一样按“开始”按钮,K线图立即就把策略在初始化阶段读取的历史K线显示处理,当然你的你所关心的主图指标和附图指标可以是可以显示的。
目前只实现了几个代表性的几个指标,还不能给自由配置。未来应该是设计成可以在每个策略可以独立配置自己的K线图表的主图和附图的数量可指标,因为不同的策略的算法是不一样,周期也是可能是不一样的。
在vnpy目录下创建一个自己的目录,例如user_tools,和api、app、ctp、trader等一样,然后把你的my_strategy_tool.py,newtemplate.py移动到 vnpy\user_tools目录下,修改你的mutisignal.py的import项:
from vnpy.user_tools\my_strategy_tool import NewArrayManager
from vnpy.user_tools\newtemplate import NewBarGenerator
要点:不要把不包含策略类的py文件放在strateties目录中,那不是个好主意。
本地停止单在VNPY系统中的是在本地维护的,符合条件后立即转化为限价委托单而执行的。
CTA策略管理界面CtaManager中,包含有一个停止单监视器StopOrderMonitor,用来显示策略已经发出的停止单的。
headers = {
"stop_orderid": {"display": "停止委托号","cell": BaseCell,"update": False,},
"vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True},
"vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
"direction": {"display": "方向", "cell": EnumCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"status": {"display": "状态", "cell": EnumCell, "update": True},
"lock": {"display": "锁仓", "cell": BaseCell, "update": False},
"strategy_name": {"display": "策略名", "cell": BaseCell, "update": False},
}
在实际使用的时候,大家是否有这种感觉,这么多停止单,你可以知道停止单的编号,却不知道它是:
@dataclass
class StopOrder:
vt_symbol: str
direction: Direction
offset: Offset
price: float
volume: float
stop_orderid: str
strategy_name: str
lock: bool = False
vt_orderids: list = field(default_factory=list)
status: StopOrderStatus = StopOrderStatus.WAITING
它们位于vnpy\api\ctp\include\ctp\ThostFtdcUserApiStruct.h中。
///查询报单
struct CThostFtdcQryOrderField
{
///经纪公司代码
TThostFtdcBrokerIDType BrokerID;
///投资者代码
TThostFtdcInvestorIDType InvestorID;
///合约代码
TThostFtdcInstrumentIDType InstrumentID;
///交易所代码
TThostFtdcExchangeIDType ExchangeID;
///报单编号
TThostFtdcOrderSysIDType OrderSysID;
///开始时间
TThostFtdcTimeType InsertTimeStart;
///结束时间
TThostFtdcTimeType InsertTimeEnd;
///投资单元代码
TThostFtdcInvestUnitIDType InvestUnitID;
};
///查询成交
struct CThostFtdcQryTradeField
{
///经纪公司代码
TThostFtdcBrokerIDType BrokerID;
///投资者代码
TThostFtdcInvestorIDType InvestorID;
///合约代码
TThostFtdcInstrumentIDType InstrumentID;
///交易所代码
TThostFtdcExchangeIDType ExchangeID;
///成交编号
TThostFtdcTradeIDType TradeID;
///开始时间
TThostFtdcTimeType TradeTimeStart;
///结束时间
TThostFtdcTimeType TradeTimeEnd;
///投资单元代码
TThostFtdcInvestUnitIDType InvestUnitID;
};
先找到历史委托查询和成交查询的参数字段定义:
CtpGateway继承了BaseGateway,同时包含两个重要的接口:
self.td_api = CtpTdApi(self) # 交易接口
self.md_api = CtpMdApi(self) # 行情接口
另外还有1个向行情服务器订阅行情的订阅函数:
def subscribe(self, req: SubscribeRequest):
""""""
self.md_api.subscribe(req)
4个向交易服务器发送委托请求、取消委托请求、查询账户请求和一个查询持仓请求函数。
def send_order(self, req: OrderRequest):
"""发送委托请求"""
if req.type == OrderType.RFQ:
vt_orderid = self.td_api.send_rfq(req)
else:
vt_orderid = self.td_api.send_order(req)
return vt_orderid
def cancel_order(self, req: CancelRequest):
"""取消委托请求"""
self.td_api.cancel_order(req)
def query_account(self):
"""查询账户请求"""
self.td_api.query_account()
def query_position(self):
"""查询持仓请求"""
self.td_api.query_position()
CtpGateway有账户查询函数,用户持仓函数,可是没有历史委托和历史成交查询函数。
可它仅仅提供了可怜的5个主动执行交易的函数:
send_order()——发委托申请
cancel_order()——取消委托单
send_rfq()
query_account()——查询账户
query_position()——查询持仓
# vnpy.api.generator.ctp_td_source.function.cpp 提供了丰富的接口函数:
int TdApi::reqAuthenticate(const dict &req, int reqid)
int TdApi::reqUserLogin(const dict &req, int reqid)
int TdApi::reqUserLogout(const dict &req, int reqid)
int TdApi::reqUserPasswordUpdate(const dict &req, int reqid)
int TdApi::reqTradingAccountPasswordUpdate(const dict &req, int reqid)
int TdApi::reqUserAuthMethod(const dict &req, int reqid)
int TdApi::reqGenUserCaptcha(const dict &req, int reqid)
int TdApi::reqGenUserText(const dict &req, int reqid)
int TdApi::reqUserLoginWithCaptcha(const dict &req, int reqid)
int TdApi::reqUserLoginWithText(const dict &req, int reqid)
int TdApi::reqUserLoginWithOTP(const dict &req, int reqid)
int TdApi::reqOrderInsert(const dict &req, int reqid)
int TdApi::reqParkedOrderInsert(const dict &req, int reqid)
int TdApi::reqParkedOrderAction(const dict &req, int reqid)
int TdApi::reqOrderAction(const dict &req, int reqid)
int TdApi::reqQueryMaxOrderVolume(const dict &req, int reqid)
int TdApi::reqSettlementInfoConfirm(const dict &req, int reqid)
int TdApi::reqRemoveParkedOrder(const dict &req, int reqid)
int TdApi::reqRemoveParkedOrderAction(const dict &req, int reqid)
int TdApi::reqExecOrderInsert(const dict &req, int reqid)
int TdApi::reqExecOrderAction(const dict &req, int reqid)
int TdApi::reqForQuoteInsert(const dict &req, int reqid)
int TdApi::reqQuoteInsert(const dict &req, int reqid)
int TdApi::reqQuoteAction(const dict &req, int reqid)
int TdApi::reqBatchOrderAction(const dict &req, int reqid)
int TdApi::reqOptionSelfCloseInsert(const dict &req, int reqid)
int TdApi::reqOptionSelfCloseAction(const dict &req, int reqid)
int TdApi::reqCombActionInsert(const dict &req, int reqid)
int TdApi::reqQryOrder(const dict &req, int reqid)
int TdApi::reqQryTrade(const dict &req, int reqid)
int TdApi::reqQryInvestorPosition(const dict &req, int reqid)
int TdApi::reqQryTradingAccount(const dict &req, int reqid)
int TdApi::reqQryInvestor(const dict &req, int reqid)
int TdApi::reqQryTradingCode(const dict &req, int reqid)
int TdApi::reqQryInstrumentMarginRate(const dict &req, int reqid)
int TdApi::reqQryInstrumentCommissionRate(const dict &req, int reqid)
int TdApi::reqQryExchange(const dict &req, int reqid)
int TdApi::reqQryProduct(const dict &req, int reqid)
int TdApi::reqQryInstrument(const dict &req, int reqid)
int TdApi::reqQryDepthMarketData(const dict &req, int reqid)
int TdApi::reqQrySettlementInfo(const dict &req, int reqid)
int TdApi::reqQryTransferBank(const dict &req, int reqid)
int TdApi::reqQryInvestorPositionDetail(const dict &req, int reqid)
int TdApi::reqQryNotice(const dict &req, int reqid)
int TdApi::reqQrySettlementInfoConfirm(const dict &req, int reqid)
int TdApi::reqQryInvestorPositionCombineDetail(const dict &req, int reqid)
int TdApi::reqQryCFMMCTradingAccountKey(const dict &req, int reqid)
int TdApi::reqQryEWarrantOffset(const dict &req, int reqid)
int TdApi::reqQryInvestorProductGroupMargin(const dict &req, int reqid)
int TdApi::reqQryExchangeMarginRate(const dict &req, int reqid)
int TdApi::reqQryExchangeMarginRateAdjust(const dict &req, int reqid)
int TdApi::reqQryExchangeRate(const dict &req, int reqid)
int TdApi::reqQrySecAgentACIDMap(const dict &req, int reqid)
int TdApi::reqQryProductExchRate(const dict &req, int reqid)
int TdApi::reqQryProductGroup(const dict &req, int reqid)
int TdApi::reqQryMMInstrumentCommissionRate(const dict &req, int reqid)
int TdApi::reqQryMMOptionInstrCommRate(const dict &req, int reqid)
int TdApi::reqQryInstrumentOrderCommRate(const dict &req, int reqid)
int TdApi::reqQrySecAgentTradingAccount(const dict &req, int reqid)
int TdApi::reqQrySecAgentCheckMode(const dict &req, int reqid)
int TdApi::reqQrySecAgentTradeInfo(const dict &req, int reqid)
int TdApi::reqQryOptionInstrTradeCost(const dict &req, int reqid)
int TdApi::reqQryOptionInstrCommRate(const dict &req, int reqid)
int TdApi::reqQryExecOrder(const dict &req, int reqid)
int TdApi::reqQryForQuote(const dict &req, int reqid)
int TdApi::reqQryQuote(const dict &req, int reqid)
int TdApi::reqQryOptionSelfClose(const dict &req, int reqid)
int TdApi::reqQryInvestUnit(const dict &req, int reqid)
int TdApi::reqQryCombInstrumentGuard(const dict &req, int reqid)
int TdApi::reqQryCombAction(const dict &req, int reqid)
int TdApi::reqQryTransferSerial(const dict &req, int reqid)
int TdApi::reqQryAccountregister(const dict &req, int reqid)
int TdApi::reqQryContractBank(const dict &req, int reqid)
int TdApi::reqQryParkedOrder(const dict &req, int reqid)
int TdApi::reqQryParkedOrderAction(const dict &req, int reqid)
int TdApi::reqQryTradingNotice(const dict &req, int reqid)
int TdApi::reqQryBrokerTradingParams(const dict &req, int reqid)
int TdApi::reqQryBrokerTradingAlgos(const dict &req, int reqid)
int TdApi::reqQueryCFMMCTradingAccountToken(const dict &req, int reqid)
int TdApi::reqFromBankToFutureByFuture(const dict &req, int reqid)
int TdApi::reqFromFutureToBankByFuture(const dict &req, int reqid)
int TdApi::reqQueryBankAccountMoneyByFuture(const dict &req, int reqid)
这些函数中下面两个函数是否可以包含到CtpTdApi接口中,然后封装到CtpGateway中供用户策略中使用。
int TdApi::reqQryOrder(const dict &req, int reqid)
{
CThostFtdcQryOrderField myreq = CThostFtdcQryOrderField();
memset(&myreq, 0, sizeof(myreq));
getString(req, "BrokerID", myreq.BrokerID);
getString(req, "InvestorID", myreq.InvestorID);
getString(req, "InstrumentID", myreq.InstrumentID);
getString(req, "ExchangeID", myreq.ExchangeID);
getString(req, "OrderSysID", myreq.OrderSysID);
getString(req, "InsertTimeStart", myreq.InsertTimeStart); // 开始时间
getString(req, "InsertTimeEnd", myreq.InsertTimeEnd); // 结束时间
getString(req, "InvestUnitID", myreq.InvestUnitID);
int i = this->api->ReqQryOrder(&myreq, reqid);
return i;
};
int TdApi::reqQryTrade(const dict &req, int reqid)
{
CThostFtdcQryTradeField myreq = CThostFtdcQryTradeField();
memset(&myreq, 0, sizeof(myreq));
getString(req, "BrokerID", myreq.BrokerID);
getString(req, "InvestorID", myreq.InvestorID);
getString(req, "InstrumentID", myreq.InstrumentID);
getString(req, "ExchangeID", myreq.ExchangeID);
getString(req, "TradeID", myreq.TradeID);
getString(req, "TradeTimeStart", myreq.TradeTimeStart); // 开始时间
getString(req, "TradeTimeEnd", myreq.TradeTimeEnd); // 结束时间
getString(req, "InvestUnitID", myreq.InvestUnitID);
int i = this->api->ReqQryTrade(&myreq, reqid);
return i;
};
交易中,时常会关心上次盈亏和当前算法盈亏状况,这可以通过对开仓以来的所有成交单进行计算得到。
但是因为种种原因,客户端策略可能会丢失部分成交单,造成这种结果的原因很多,如网络故障、电脑宕机、
没有在全交易时段运行策略,都会导致部分成交单不能给被记录。
交易所存有完整的成交记录,而且ctp_td_source.function.cpp提供了全面的历史委托查询和历史成交查询函数,
为什么咱们不可以把这些功能多引入一些呢?
vnpy\gateway\da\da_gateway.py(798): self.reqQryTrade(da_req, self.reqid)
vnpy\gateway\tora\td.py(583): err = self._native_api.ReqQryTrade(info, self._get_new_req_id())
vnpy\gateway\uft\uft_gateway.py(489): self.reqQryTrade({}, self.reqid)
那请问已经发生过的历史成交单,是否可以从交易所查询获得呢?目前从哪里查询到?
我们知道一次完整的交易包括开仓,加仓(可能),减仓(可能),平仓离场,持仓从0到非0,再回到0,我们认为这是
一次完整的交易。这个过程中策略可以通过系统推送on_trade(self,trade:TradeData):中的成交单trade进行记录,联合当前合约
的乘数、佣金、滑点,就可以实现仓位,持仓均价,平仓盈亏等的计算,对交易的意义极大。
但是可能因为各种原因,如断网、宕机,软件错误等,都会造成成交单trade记录的中断和缺少,因此能够从交易所查询到历史成
交单就可以很好地解决这样的问题。
问题:vnpy里目前如何从交易所或者从回测引擎中查询到历史成交单?
明白里,谢谢!
在策略中有需要查询或者计算上次完整交易的盈亏数值吗?
陈老师的进阶课中有讲BacktesingEngine是如何计算逐日盯市的,接近需要,
但不知道是否有更加规范的方法,比如从回测引擎或者实盘引擎里就可以查询到?
遇到了同样的问题。在安装的时候,第一次安装plotly还是报错了,需要升级pip,具体指令在报错的时候已经提示了。
然后再安装才能够顺利安装plotly,再次启动就可以问题了。
看完了陈老师的线上公开课,化了2天时间终于把MACD幅图曲线给添加上了。
MACD曲线和RSI,SMA之类的不同之处在于它的y方向显示范围是可变的,需要根据K线显示范围的变化及时做出调整,有执行效率问题。
本人采用了字典保存了已经计算的y方向显示范围计算结果,避免了重复计算,执行效率还是相当流畅的。当然会需要一定的存储开销,但
是不大,而且也是值得开销的。代码如下:
from datetime import datetime
from typing import List, Tuple, Dict
import numpy as np
import pyqtgraph as pg
import talib
import copy
from vnpy.trader.ui import create_qapp, QtCore, QtGui, QtWidgets
from vnpy.trader.database import database_manager
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData
from vnpy.chart import ChartWidget, VolumeItem, CandleItem
from vnpy.chart.item import ChartItem
from vnpy.chart.manager import BarManager
from vnpy.chart.base import NORMAL_FONT
class LineItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
last_bar = self._manager.get_bar(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.white_pen)
# Draw Line
end_point = QtCore.QPointF(ix, bar.close_price)
if last_bar:
start_point = QtCore.QPointF(ix - 1, last_bar.close_price)
else:
start_point = end_point
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
class SmaItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)
self.sma_window = 10
self.sma_data: Dict[int, float] = {}
def get_sma_value(self, ix: int) -> float:
""""""
if ix < 0:
return 0
# When initialize, calculate all rsi value
if not self.sma_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
sma_array = talib.SMA(np.array(close_data), self.sma_window)
for n, value in enumerate(sma_array):
self.sma_data[n] = value
# Return if already calcualted
if ix in self.sma_data:
return self.sma_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.sma_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
sma_array = talib.SMA(np.array(close_data), self.sma_window)
sma_value = sma_array[-1]
self.sma_data[ix] = sma_value
return sma_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
sma_value = self.get_sma_value(ix)
last_sma_value = self.get_sma_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.blue_pen)
# Draw Line
start_point = QtCore.QPointF(ix-1, last_sma_value)
end_point = QtCore.QPointF(ix, sma_value)
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.sma_data:
sma_value = self.sma_data[ix]
text = f"SMA {sma_value:.1f}"
else:
text = "SMA -"
return text
class RsiItem(ChartItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=2)
self.rsi_window = 14
self.rsi_data: Dict[int, float] = {}
def get_rsi_value(self, ix: int) -> float:
""""""
if ix < 0:
return 50
# When initialize, calculate all rsi value
if not self.rsi_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
for n, value in enumerate(rsi_array):
self.rsi_data[n] = value
# Return if already calcualted
if ix in self.rsi_data:
return self.rsi_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.rsi_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
rsi_value = rsi_array[-1]
self.rsi_data[ix] = rsi_value
return rsi_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
rsi_value = self.get_rsi_value(ix)
last_rsi_value = self.get_rsi_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Draw RSI line
painter.setPen(self.yellow_pen)
if np.isnan(last_rsi_value) or np.isnan(rsi_value):
# print(ix - 1, last_rsi_value,ix, rsi_value,)
pass
else:
end_point = QtCore.QPointF(ix, rsi_value)
start_point = QtCore.QPointF(ix - 1, last_rsi_value)
painter.drawLine(start_point, end_point)
# Draw oversold/overbought line
painter.setPen(self.white_pen)
painter.drawLine(
QtCore.QPointF(ix, 70),
QtCore.QPointF(ix - 1, 70),
)
painter.drawLine(
QtCore.QPointF(ix, 30),
QtCore.QPointF(ix - 1, 30),
)
# Finish
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
# min_price, max_price = self._manager.get_price_range()
rect = QtCore.QRectF(
0,
0,
len(self._bar_picutures),
100
)
return rect
def get_y_range( self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
""" """
return 0, 100
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.rsi_data:
rsi_value = self.rsi_data[ix]
text = f"RSI {rsi_value:.1f}"
# print(text)
else:
text = "RSI -"
return text
def to_int(value: float) -> int:
""""""
return int(round(value, 0))
""" 将y方向的显示范围扩大到1.1 """
def adjust_range(in_range:Tuple[float, float])->Tuple[float, float]:
ret_range:Tuple[float, float]
diff = abs(in_range[0] - in_range[1])
ret_range = (in_range[0]-diff*0.05,in_range[1]+diff*0.05)
return ret_range
class MacdItem(ChartItem):
""""""
_values_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {}
last_range:Tuple[int, int] = (-1,-1) # 最新显示K线索引范围
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=1)
self.red_pen: QtGui.QPen = pg.mkPen(color=(255, 0, 0), width=1)
self.green_pen: QtGui.QPen = pg.mkPen(color=(0, 255, 0), width=1)
self.short_window = 12
self.long_window = 26
self.M = 9
self.macd_data: Dict[int, Tuple[float,float,float]] = {}
def get_macd_value(self, ix: int) -> Tuple[float,float,float]:
""""""
if ix < 0:
return (0.0,0.0,0.0)
# When initialize, calculate all macd value
if not self.macd_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
diffs,deas,macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
for n in range(0,len(diffs)):
self.macd_data[n] = (diffs[n],deas[n],macds[n])
# Return if already calcualted
if ix in self.macd_data:
return self.macd_data[ix]
# Else calculate new value
close_data = []
for n in range(ix-self.long_window-self.M+1, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
diffs,deas,macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
diff,dea,macd = diffs[-1],deas[-1],macds[-1]
self.macd_data[ix] = (diff,dea,macd)
return (diff,dea,macd)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
macd_value = self.get_macd_value(ix)
last_macd_value = self.get_macd_value(ix - 1)
# # Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# # Draw macd lines
if np.isnan(macd_value[0]) or np.isnan(last_macd_value[0]):
# print("略过macd lines0")
pass
else:
end_point0 = QtCore.QPointF(ix, macd_value[0])
start_point0 = QtCore.QPointF(ix - 1, last_macd_value[0])
painter.setPen(self.white_pen)
painter.drawLine(start_point0, end_point0)
if np.isnan(macd_value[1]) or np.isnan(last_macd_value[1]):
# print("略过macd lines1")
pass
else:
end_point1 = QtCore.QPointF(ix, macd_value[1])
start_point1 = QtCore.QPointF(ix - 1, last_macd_value[1])
painter.setPen(self.yellow_pen)
painter.drawLine(start_point1, end_point1)
if not np.isnan(macd_value[2]):
if (macd_value[2]>0):
painter.setPen(self.red_pen)
painter.setBrush(pg.mkBrush(255,0,0))
else:
painter.setPen(self.green_pen)
painter.setBrush(pg.mkBrush(0,255,0))
painter.drawRect(QtCore.QRectF(ix-0.3,0,0.6,macd_value[2]))
else:
# print("略过macd lines2")
pass
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
min_y, max_y = self.get_y_range()
rect = QtCore.QRectF(
0,
min_y,
len(self._bar_picutures),
max_y
)
return rect
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
# 获得3个指标在y轴方向的范围
# hxxjava 修改,2020-6-29
# 当显示范围改变时,min_ix,max_ix的值不为None,当显示范围不变时,min_ix,max_ix的值不为None,
offset = max(self.short_window,self.long_window) + self.M - 1
if not self.macd_data or len(self.macd_data) < offset:
return 0.0, 1.0
# print("len of range dict:",len(self._values_ranges),",macd_data:",len(self.macd_data),(min_ix,max_ix))
if min_ix != None: # 调整最小K线索引
min_ix = max(min_ix,offset)
if max_ix != None: # 调整最大K线索引
max_ix = min(max_ix, len(self.macd_data)-1)
last_range = (min_ix,max_ix) # 请求的最新范围
if last_range == (None,None): # 当显示范围不变时
if self.last_range in self._values_ranges:
# 如果y方向范围已经保存
# 读取y方向范围
result = self._values_ranges[self.last_range]
# print("1:",self.last_range,result)
return adjust_range(result)
else:
# 如果y方向范围没有保存
# 从macd_data重新计算y方向范围
min_ix,max_ix = 0,len(self.macd_data)-1
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 保存y方向范围,同时返回结果
result = (min_price, max_price)
self.last_range = (min_ix,max_ix)
self._values_ranges[self.last_range] = result
# print("2:",self.last_range,result)
return adjust_range(result)
""" 以下为显示范围变化时 """
if last_range in self._values_ranges:
# 该范围已经保存过y方向范围
# 取得y方向范围,返回结果
result = self._values_ranges[last_range]
# print("3:",last_range,result)
return adjust_range(result)
# 该范围没有保存过y方向范围
# 从macd_data重新计算y方向范围
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 取得y方向范围,返回结果
result = (min_price, max_price)
self.last_range = last_range
self._values_ranges[self.last_range] = result
# print("4:",self.last_range,result)
return adjust_range(result)
def get_info_text(self, ix: int) -> str:
# """"""
if ix in self.macd_data:
diff,dea,macd = self.macd_data[ix]
words = [
f"diff {diff:.3f}"," ",
f"dea {dea:.3f}"," ",
f"macd {macd:.3f}"
]
text = "\n".join(words)
else:
text = "diff - \ndea - \nmacd -"
return text
class NewChartWidget(ChartWidget):
""""""
MIN_BAR_COUNT = 100
def __init__(self, parent: QtWidgets.QWidget = None):
""""""
super().__init__(parent)
self.last_price_line: pg.InfiniteLine = None
def add_last_price_line(self):
""""""
plot = list(self._plots.values())[0]
color = (255, 255, 255)
self.last_price_line = pg.InfiniteLine(
angle=0,
movable=False,
label="{value:.1f}",
pen=pg.mkPen(color, width=1),
labelOpts={
"color": color,
"position": 1,
"anchors": [(1, 1), (1, 1)]
}
)
self.last_price_line.label.setFont(NORMAL_FONT)
plot.addItem(self.last_price_line)
def update_history(self, history: List[BarData]) -> None:
"""
Update a list of bar data.
"""
self._manager.update_history(history)
for item in self._items.values():
item.update_history(history)
self._update_plot_limits()
self.move_to_right()
self.update_last_price_line(history[-1])
def update_bar(self, bar: BarData) -> None:
"""
Update single bar data.
"""
self._manager.update_bar(bar)
for item in self._items.values():
item.update_bar(bar)
self._update_plot_limits()
if self._right_ix >= (self._manager.get_count() - self._bar_count / 2):
self.move_to_right()
self.update_last_price_line(bar)
def update_last_price_line(self, bar: BarData) -> None:
""""""
if self.last_price_line:
self.last_price_line.setValue(bar.close_price)
if __name__ == "__main__":
app = create_qapp()
# bars = database_manager.load_bar_data(
# "IF888",
# Exchange.CFFEX,
# interval=Interval.MINUTE,
# start=datetime(2019, 7, 1),
# end=datetime(2019, 7, 17)
# )
symbol = "rb2010"
exchange = Exchange.SHFE
interval=Interval.MINUTE
start=datetime(2020, 6, 1)
end=datetime(2020, 6, 30)
dynamic = False # 是否动态演示
n = 200 # 缓冲K线根数
bars = database_manager.load_bar_data(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end
)
widget = NewChartWidget()
widget.setWindowTitle(f"K线图表——{symbol}.{exchange.value},{interval},{start}-{end}")
widget.add_plot("candle", hide_x_axis=True)
widget.add_plot("volume", maximum_height=150)
widget.add_plot("rsi", maximum_height=150)
widget.add_plot("macd", maximum_height=150)
widget.add_item(CandleItem, "candle", "candle")
widget.add_item(VolumeItem, "volume", "volume")
widget.add_item(LineItem, "line", "candle")
widget.add_item(SmaItem, "sma", "candle")
widget.add_item(RsiItem, "rsi", "rsi")
widget.add_item(MacdItem, "macd", "macd")
widget.add_last_price_line()
widget.add_cursor()
if dynamic:
history = bars[:n] # 先取得最早的n根bar作为历史
new_data = bars[n:] # 其它留着演示
else:
history = bars[-n:] # 先取得最新的n根bar作为历史
new_data = [] # 演示的为空
widget.update_history(history)
def update_bar():
if new_data:
bar = new_data.pop(0)
widget.update_bar(bar)
timer = QtCore.QTimer()
timer.timeout.connect(update_bar)
if dynamic:
timer.start(100)
widget.show()
app.exec_()