达哥242b4bb891d6461b wrote:
感谢分享!从中学到很多。
不过在以上代码中未发现有关TradingWidget定义的代码,望大佬分享。
答复:
有关TradingWidget定义的代码已经补上。
aqua-pink wrote:
您好,请问
发起条件单
cond_order = ConditionOrder(... ...)
self.send_condition_order(cond_order)这一部分的入参应该是什么样的?
可以作为on_trade()成交后设定条件单吗?
def execute(self,strategy:CtaTemplate):
"""
一段将交易指令转化为条件单的例子 ,其中:
self.price:开仓价
self.stop_price:止盈价
self.exit_price:止损价
"""
if self.offset == Offset.OPEN:
open_condition = Condition.LE if self.direction == Direction.LONG else Condition.BE
stop_condition = Condition.BT if self.direction == Direction.LONG else Condition.LT
exit_condition = Condition.LT if self.direction == Direction.LONG else Condition.BT
# 止盈条件单
stop_order = ConditionOrder(strategy_name=strategy.strategy_name,vt_symbol=self.vt_symbol,
direction=self.direction,offset=self.offset,
price=self.stop_price,volume=self.volume/2,condition=stop_condition)
# 开仓条件单
open_order = ConditionOrder(strategy_name=strategy.strategy_name,vt_symbol=self.vt_symbol,
direction=self.direction,offset=self.offset,
price=self.price,volume=self.volume,condition=open_condition)
# 止损条件单
exit_order = ConditionOrder(strategy_name=strategy.strategy_name,vt_symbol=self.vt_symbol,
direction=self.direction,offset=self.offset,
price=self.exit_price,volume=self.volume,condition=exit_condition)
for cond_order in [open_order,stop_order,exit_order]:
result = strategy.send_condition_order(cond_order)
print(f"{strategy.strategy_name}发送开仓条件单{'成功' if result else '成功'}:{cond_order}")
elif self.offset == Offset.CLOSE:
tj1 = self.direction == Direction.LONG and strategy.pos < 0
tj2 = self.direction == Direction.SHORT and strategy.pos > 0
if tj1 or tj2:
exit_condition = Condition.LT if self.direction == Direction.LONG else Condition.BT
exit_order = ConditionOrder(strategy_name=strategy.strategy_name,vt_symbol=self.vt_symbol,
direction=self.direction,offset=self.offset,
price=self.price,volume=abs(strategy.pos),condition=exit_condition,
execute_price=ExecutePrice.MARKET)
result = strategy.send_condition_order(exit_order)
print(f"{strategy.strategy_name}发送平仓条件单{'成功' if result else '成功'}:{exit_order}")
G_will wrote:
当前默认的 tick 驱动 bar 合成,在一些非活跃合约上很不平滑,大概想直接用定时器估计更合适一些,但是估计需要考虑tick 接收时间误差、定时器周期准确等等。大佬有什么思路吗?
当你想以比市场价更高的价格买,或者以比市场价更低的价格卖时,使用send_order()是会立即执行的,但是用停止单却可以做到这一点,这是停止单的优点。
但是实际使用中停止单也是有缺点的:
这是本人给它取的名字,它其实是本人以前提到的交易线(TradeLine)的改进和增强。
它主要就是为解决停止单上述缺点而设计的,当然应该具备上述优点。
在vnpy_ctastrategy\base.py中增加如下代码:
class Condition(Enum): # hxxjava add
""" 条件单的条件 """
BT = ">"
LT = "<"
BE = ">="
LE = "<="
class ExecutePrice(Enum): # hxxjava add
""" 执行价格 """
SETPRICE = "设定价"
MARKET = "市场价"
EXTREME = "极限价"
class CondOrderStatus(Enum): # hxxjava add
""" 条件单状态 """
WAITING = "等待中"
CANCELLED = "已撤销"
TRIGGERED = "已触发"
@dataclass
class ConditionOrder: # hxxjava add
""" 条件单 """
strategy_name: str
vt_symbol: str
direction: Direction
offset: Offset
price: float
volume: float
condition:Condition
execute_price:ExecutePrice = ExecutePrice.SETPRICE
create_time: datetime = datetime.now()
trigger_time: datetime = None
cond_orderid: str = "" # 条件单编号
status: CondOrderStatus = CondOrderStatus.WAITING
def __post_init__(self):
""" """
if not self.cond_orderid:
self.cond_orderid = datetime.now().strftime("%m%d%H%M%S%f")[:13]
EVENT_CONDITION_ORDER = "eConditionOrder" # hxxjava add
修改vnpy_ctastrategy\ui\widget.py中的class CtaManager,代码如下:
class CtaManager(QtWidgets.QWidget):
""""""
signal_log: QtCore.Signal = QtCore.Signal(Event)
signal_strategy: QtCore.Signal = QtCore.Signal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine) -> None:
""""""
super().__init__()
self.main_engine: MainEngine = main_engine
self.event_engine: EventEngine = event_engine
self.cta_engine: CtaEngine = main_engine.get_engine(APP_NAME)
self.managers: Dict[str, StrategyManager] = {}
self.init_ui()
self.register_event()
self.cta_engine.init_engine()
self.update_class_combo()
def init_ui(self) -> None:
""""""
self.setWindowTitle("CTA策略")
# Create widgets
self.class_combo: QtWidgets.QComboBox = QtWidgets.QComboBox()
add_button: QtWidgets.QPushButton = QtWidgets.QPushButton("添加策略")
add_button.clicked.connect(self.add_strategy)
init_button: QtWidgets.QPushButton = QtWidgets.QPushButton("全部初始化")
init_button.clicked.connect(self.cta_engine.init_all_strategies)
start_button: QtWidgets.QPushButton = QtWidgets.QPushButton("全部启动")
start_button.clicked.connect(self.cta_engine.start_all_strategies)
stop_button: QtWidgets.QPushButton = QtWidgets.QPushButton("全部停止")
stop_button.clicked.connect(self.cta_engine.stop_all_strategies)
clear_button: QtWidgets.QPushButton = QtWidgets.QPushButton("清空日志")
clear_button.clicked.connect(self.clear_log)
roll_button: QtWidgets.QPushButton = QtWidgets.QPushButton("移仓助手")
roll_button.clicked.connect(self.roll)
self.scroll_layout: QtWidgets.QVBoxLayout = QtWidgets.QVBoxLayout()
self.scroll_layout.addStretch()
scroll_widget: QtWidgets.QWidget = QtWidgets.QWidget()
scroll_widget.setLayout(self.scroll_layout)
self.scroll_area: QtWidgets.QScrollArea = QtWidgets.QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setWidget(scroll_widget)
self.log_monitor: LogMonitor = LogMonitor(self.main_engine, self.event_engine)
self.stop_order_monitor: StopOrderMonitor = StopOrderMonitor(
self.main_engine, self.event_engine
)
self.strategy_combo = QtWidgets.QComboBox()
self.strategy_combo.setMinimumWidth(200)
find_button = QtWidgets.QPushButton("查找")
find_button.clicked.connect(self.find_strategy)
# hxxjava add
self.condition_order_monitor = ConditionOrderMonitor(self.cta_engine)
# Set layout
hbox1: QtWidgets.QHBoxLayout = QtWidgets.QHBoxLayout()
hbox1.addWidget(self.class_combo)
hbox1.addWidget(add_button)
hbox1.addStretch()
hbox1.addWidget(self.strategy_combo)
hbox1.addWidget(find_button)
hbox1.addStretch()
hbox1.addWidget(init_button)
hbox1.addWidget(start_button)
hbox1.addWidget(stop_button)
hbox1.addWidget(clear_button)
hbox1.addWidget(roll_button)
grid = QtWidgets.QGridLayout()
# grid.addWidget(self.scroll_area, 0, 0, 2, 1)
grid.addWidget(self.scroll_area, 0, 0, 3, 1) # hxxjava change 3 rows , 1 column
grid.addWidget(self.stop_order_monitor, 0, 1)
grid.addWidget(self.condition_order_monitor, 1, 1) # hxxjava add
# grid.addWidget(self.log_monitor, 1, 1)
grid.addWidget(self.log_monitor, 2, 1) # hxxjava change
vbox: QtWidgets.QVBoxLayout = QtWidgets.QVBoxLayout()
vbox.addLayout(hbox1)
vbox.addLayout(grid)
self.setLayout(vbox)
def update_class_combo(self) -> None:
""""""
names = self.cta_engine.get_all_strategy_class_names()
names.sort()
self.class_combo.addItems(names)
def update_strategy_combo(self) -> None:
""""""
names = list(self.managers.keys())
names.sort()
self.strategy_combo.clear()
self.strategy_combo.addItems(names)
def register_event(self) -> None:
""""""
self.signal_strategy.connect(self.process_strategy_event)
self.event_engine.register(
EVENT_CTA_STRATEGY, self.signal_strategy.emit
)
def process_strategy_event(self, event) -> None:
"""
Update strategy status onto its monitor.
"""
data = event.data
strategy_name: str = data["strategy_name"]
if strategy_name in self.managers:
manager: StrategyManager = self.managers[strategy_name]
manager.update_data(data)
else:
manager: StrategyManager = StrategyManager(self, self.cta_engine, data)
self.scroll_layout.insertWidget(0, manager)
self.managers[strategy_name] = manager
self.update_strategy_combo()
def remove_strategy(self, strategy_name) -> None:
""""""
manager: StrategyManager = self.managers.pop(strategy_name)
manager.deleteLater()
self.update_strategy_combo()
def add_strategy(self) -> None:
""""""
class_name: str = str(self.class_combo.currentText())
if not class_name:
return
parameters: dict = self.cta_engine.get_strategy_class_parameters(class_name)
editor: SettingEditor = SettingEditor(parameters, class_name=class_name)
n: int = editor.exec_()
if n == editor.Accepted:
setting: dict = editor.get_setting()
vt_symbol: str = setting.pop("vt_symbol")
strategy_name: str = setting.pop("strategy_name")
self.cta_engine.add_strategy(
class_name, strategy_name, vt_symbol, setting
)
def find_strategy(self) -> None:
""""""
strategy_name = self.strategy_combo.currentText()
manager = self.managers[strategy_name]
self.scroll_area.ensureWidgetVisible(manager)
def clear_log(self) -> None:
""""""
self.log_monitor.setRowCount(0)
def show(self) -> None:
""""""
self.showMaximized()
def roll(self) -> None:
""""""
dialog: RolloverTool = RolloverTool(self)
dialog.exec_()
在vnpy_ctastrategy\ui\widget.py中增加如下代码:
class ConditionOrderMonitor(BaseMonitor): # hxxjava add
"""
Monitor for condition order.
"""
event_type = EVENT_CONDITION_ORDER
data_key = "cond_orderid"
sorting = True
headers = {
"cond_orderid": {
"display": "条件单号",
"cell": BaseCell,
"update": False,
},
"vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
"direction": {"display": "方向", "cell": EnumCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "触发价", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"condition": {"display": "触发条件", "cell": EnumCell, "update": False},
"execute_price": {"display": "执行价", "cell": EnumCell, "update": False},
"create_time": {"display": "生成时间", "cell": TimeCell, "update": False},
"trigger_time": {"display": "触发时间", "cell": TimeCell, "update": False},
"status": {"display": "状态", "cell": EnumCell, "update": True},
"strategy_name": {"display": "策略名称", "cell": BaseCell, "update": False},
}
def __init__(self,cta_engine : MyCtaEngine):
""""""
super().__init__(cta_engine.main_engine, cta_engine.event_engine)
self.cta_engine = cta_engine
def init_ui(self):
"""
Connect signal.
"""
super().init_ui()
self.setToolTip("双击单元格可停止条件单")
self.itemDoubleClicked.connect(self.stop_condition_order)
def stop_condition_order(self, cell):
"""
Stop algo if cell double clicked.
"""
order = cell.get_data()
if order:
self.cta_engine.cancel_condition_order(order.cond_orderid)
修改策略管理器StrategyManager的代码如下:
class StrategyManager(QtWidgets.QFrame):
"""
Manager for a strategy
"""
def __init__(
self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict
):
""""""
super(StrategyManager, self).__init__()
self.cta_manager = cta_manager
self.cta_engine = cta_engine
self.strategy_name = data["strategy_name"]
self._data = data
self.tradetool : TradingWidget = None # hxxjava add
self.init_ui()
def init_ui(self):
""""""
self.setFixedHeight(300)
self.setFrameShape(self.Box)
self.setLineWidth(1)
self.init_button = QtWidgets.QPushButton("初始化")
self.init_button.clicked.connect(self.init_strategy)
self.start_button = QtWidgets.QPushButton("启动")
self.start_button.clicked.connect(self.start_strategy)
self.start_button.setEnabled(False)
self.stop_button = QtWidgets.QPushButton("停止")
self.stop_button.clicked.connect(self.stop_strategy)
self.stop_button.setEnabled(False)
self.trade_button = QtWidgets.QPushButton("交易") # hxxjava add
self.trade_button.clicked.connect(self.show_tradetool) # hxxjava add
self.trade_button.setEnabled(False) # hxxjava add
self.edit_button = QtWidgets.QPushButton("编辑")
self.edit_button.clicked.connect(self.edit_strategy)
self.remove_button = QtWidgets.QPushButton("移除")
self.remove_button.clicked.connect(self.remove_strategy)
strategy_name = self._data["strategy_name"]
vt_symbol = self._data["vt_symbol"]
class_name = self._data["class_name"]
author = self._data["author"]
label_text = (
f"{strategy_name} - {vt_symbol} ({class_name} by {author})"
)
label = QtWidgets.QLabel(label_text)
label.setAlignment(QtCore.Qt.AlignCenter)
self.parameters_monitor = DataMonitor(self._data["parameters"])
self.variables_monitor = DataMonitor(self._data["variables"])
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(self.init_button)
hbox.addWidget(self.start_button)
hbox.addWidget(self.trade_button) # hxxjava add
hbox.addWidget(self.stop_button)
hbox.addWidget(self.edit_button)
hbox.addWidget(self.remove_button)
# hxxjava change to self.vbox,old is vbox
self.vbox = QtWidgets.QVBoxLayout()
self.vbox.addWidget(label)
self.vbox.addLayout(hbox)
self.vbox.addWidget(self.parameters_monitor)
self.vbox.addWidget(self.variables_monitor)
self.setLayout(self.vbox)
def update_data(self, data: dict):
""""""
self._data = data
self.parameters_monitor.update_data(data["parameters"])
self.variables_monitor.update_data(data["variables"])
# Update button status
variables = data["variables"]
inited = variables["inited"]
trading = variables["trading"]
if not inited:
return
self.init_button.setEnabled(False)
if trading:
self.start_button.setEnabled(False)
self.trade_button.setEnabled(True) # hxxjava
self.stop_button.setEnabled(True)
self.edit_button.setEnabled(False)
self.remove_button.setEnabled(False)
else:
self.start_button.setEnabled(True)
self.trade_button.setEnabled(False) # hxxjava
self.stop_button.setEnabled(False)
self.edit_button.setEnabled(True)
self.remove_button.setEnabled(True)
def init_strategy(self):
""""""
self.cta_engine.init_strategy(self.strategy_name)
def start_strategy(self):
""""""
self.cta_engine.start_strategy(self.strategy_name)
def show_tradetool(self): # hxxjava add
""" 为策略显示交易工具 """
if not self.tradetool:
strategy = self.cta_engine.strategies.get(self.strategy_name,None)
if strategy and strategy.trading:
self.tradetool = TradingWidget(strategy,self.cta_engine.event_engine)
self.vbox.addWidget(self.tradetool)
else:
is_visible = self.tradetool.isVisible()
self.tradetool.setVisible(not is_visible)
def stop_strategy(self):
""""""
self.cta_engine.stop_strategy(self.strategy_name)
def edit_strategy(self):
""""""
strategy_name = self._data["strategy_name"]
parameters = self.cta_engine.get_strategy_parameters(strategy_name)
editor = SettingEditor(parameters, strategy_name=strategy_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
self.cta_engine.edit_strategy(strategy_name, setting)
def remove_strategy(self):
""""""
result = self.cta_engine.remove_strategy(self.strategy_name)
# Only remove strategy gui manager if it has been removed from engine
if result:
self.cta_manager.remove_strategy(self.strategy_name)
创建vnpy\usertools\trading_widget.py文件,其中内容:
"""
条件单交易组件
作者:hxxjava
日线:2022-5-10
"""
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.trader.constant import Direction,Offset
from vnpy.trader.event import EVENT_TICK
from vnpy.event.engine import Event,EventEngine
from vnpy_ctastrategy.base import Condition,CondOrderStatus,ExecutePrice,ConditionOrder
from vnpy_ctastrategy.template import CtaTemplate
class TradingWidget(QtWidgets.QWidget):
"""
CTA strategy manual trading widget.
"""
signal_tick = QtCore.pyqtSignal(Event)
def __init__(self, strategy: CtaTemplate, event_engine: EventEngine):
""""""
super().__init__()
self.strategy: CtaTemplate = strategy
self.event_engine: EventEngine = event_engine
self.vt_symbol: str = strategy.vt_symbol
self.price_digits: int = 0
self.init_ui()
self.register_event()
def init_ui(self) -> None:
""""""
# 交易方向:多/空
self.direction_combo = QtWidgets.QComboBox()
self.direction_combo.addItems(
[Direction.LONG.value, Direction.SHORT.value])
# 开平选择:开/平
self.offset_combo = QtWidgets.QComboBox()
self.offset_combo.addItems([offset.value for offset in Offset])
# 条件类型
conditions = [Condition.BE,Condition.LE,Condition.BT,Condition.LT]
self.condition_combo = QtWidgets.QComboBox()
self.condition_combo.addItems(
[condition.value for condition in conditions])
double_validator = QtGui.QDoubleValidator()
double_validator.setBottom(0)
self.price_line = QtWidgets.QLineEdit()
self.price_line.setValidator(double_validator)
self.exit_line = QtWidgets.QLineEdit()
self.exit_line.setValidator(double_validator)
self.volume_line = QtWidgets.QLineEdit()
self.volume_line.setValidator(double_validator)
self.price_check = QtWidgets.QCheckBox()
self.price_check.setToolTip("设置价格随行情更新")
execute_prices = [ExecutePrice.SETPRICE,ExecutePrice.MARKET,ExecutePrice.EXTREME]
self.execute_price_combo = QtWidgets.QComboBox()
self.execute_price_combo.addItems(
[execute_price.value for execute_price in execute_prices])
send_button = QtWidgets.QPushButton("发出")
send_button.clicked.connect(self.send_condition_order)
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(QtWidgets.QLabel(f"合约:{self.vt_symbol}"))
hbox.addWidget(QtWidgets.QLabel("方向"))
hbox.addWidget(self.direction_combo)
hbox.addWidget(QtWidgets.QLabel("开平"))
hbox.addWidget(self.offset_combo)
hbox.addWidget(QtWidgets.QLabel("条件"))
hbox.addWidget(self.condition_combo)
hbox.addWidget(QtWidgets.QLabel("触发价"))
hbox.addWidget(self.price_line)
hbox.addWidget(self.price_check)
hbox.addWidget(QtWidgets.QLabel("数量"))
hbox.addWidget(self.volume_line)
hbox.addWidget(QtWidgets.QLabel("执行价"))
hbox.addWidget(self.execute_price_combo)
hbox.addWidget(send_button)
# Overall layout
self.setLayout(hbox)
def register_event(self) -> None:
""""""
self.signal_tick.connect(self.process_tick_event)
self.event_engine.register(EVENT_TICK, self.signal_tick.emit)
def process_tick_event(self, event: Event) -> None:
""""""
tick = event.data
if tick.vt_symbol != self.vt_symbol:
return
if self.price_check.isChecked():
self.price_line.setText(f"{tick.last_price}")
def send_condition_order(self) -> bool:
"""
Send new order manually.
"""
try:
direction = Direction(self.direction_combo.currentText())
offset = Offset(self.offset_combo.currentText())
condition = Condition(self.condition_combo.currentText())
price = float(self.price_line.text())
volume = float(self.volume_line.text())
execute_price = ExecutePrice(self.execute_price_combo.currentText())
order = ConditionOrder(
strategy_name = self.strategy.strategy_name,
vt_symbol=self.vt_symbol,
direction=direction,
offset=offset,
price=price,
volume=volume,
condition=condition,
execute_price=execute_price
)
self.strategy.send_condition_order(order=order)
print(f"发出条件单 : vt_symbol={self.vt_symbol},success ! {order}")
return True
except:
print(f"发出条件单 : vt_symbol={self.vt_symbol},input error !")
return False
在vnpy_ctastrategy\engine.py中对CtaEngine进行如下扩展:
class MyCtaEngine(CtaEngine):
""" """
condition_filename = "condition_order.json" # 历史条件单存储文件
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__(main_engine,event_engine)
self.condition_orders:Dict[str,ConditionOrder] = {} # strategy_name: dict
def load_active_condtion_orders(self):
""" """
return {}
def process_tick_event(self,event:Event):
""" 用tick的价格检查条件单 """
super().process_tick_event(event)
tick:TickData = event.data
all_condition_orders = [order for order in self.condition_orders.values() \
if order.vt_symbol == tick.vt_symbol and order.status == CondOrderStatus.WAITING]
for order in all_condition_orders:
# 检查条件单是否满足条件
self.check_condition_order(order,tick)
def check_condition_order(self,order:ConditionOrder,tick:TickData):
""" 检查条件单是否满足条件 """
strategy = self.strategies.get(order.strategy_name,None)
if not strategy or not strategy.trading:
return False
price = tick.last_price
is_be = order.condition == Condition.BE and price >= order.price
is_le = order.condition == Condition.LE and price <= order.price
is_bt = order.condition == Condition.BT and price > order.price
is_lt = order.condition == Condition.LT and price < order.price
if is_be or is_le or is_bt or is_lt:
# 满足触发条件
if order.execute_price == ExecutePrice.MARKET:
# 取市场价
price = tick.last_price
elif order.execute_price == ExecutePrice.EXTREME:
# 取极限价
price = tick.limit_up if order.direction == Direction.LONG else tick.limit_down
else:
# 取设定价
price = order.price
# 执行委托
order_ids = strategy.send_order(
direction = order.direction,
offset=order.offset,
price=price,
volume=order.volume
)
if order_ids:
order.trigger_time = tick.datetime
order.status = CondOrderStatus.TRIGGERED
self.event_engine.put(Event(EVENT_CONDITION_ORDER,order))
def send_condition_order(self,order:ConditionOrder):
""" """
strategy = self.strategies.get(order.strategy_name,None)
if not strategy or not strategy.trading:
return False
if order.cond_orderid not in self.condition_orders:
self.condition_orders[order.cond_orderid] = order
self.event_engine.put(Event(EVENT_CONDITION_ORDER,order))
return True
return False
def cancel_condition_order(self,cond_orderid:str):
""" """
order:ConditionOrder = self.condition_orders.get(cond_orderid,None)
if not order:
return False
order.status = CondOrderStatus.CANCELLED
self.event_engine.put(Event(EVENT_CONDITION_ORDER,order))
return True
def cancel_all_condition_orders(self,strategy_name:str):
""" """
for order in self.condition_orders.values():
if order.strategy_name == strategy_name and order.status == CondOrderStatus.WAITING:
order.status = CondOrderStatus.CANCELLED
self.call_strategy_func(strategy,strategy.on_condition_order)
self.event_engine.put(Event(EVENT_CONDITION_ORDER,order))
return True
对vnpy_ctastrategy__init__.py中的CtaTemplate进行如下修改:
from .engine import MyCtaEngine # hxxjava add
class CtaStrategyApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "CTA策略"
# engine_class = CtaEngine
engine_class = MyCtaEngine # hxxjava add
widget_name = "CtaManager"
icon_name = str(app_path.joinpath("ui", "cta.ico"))
对vnpy_ctastrategy\template.py中的CtaTemplate进行如下扩展:
@virtual
def on_condition_order(self, cond_order: ConditionOrder):
"""
Callback of condition order update.
"""
pass
def send_condition_order(self,order:ConditionOrder): # hxxjava add
""" """
if not self.trading:
return False
return self.cta_engine.send_condition_order(order)
def cancel_condition_order(self,cond_orderid:str): # hxxjava add
""" """
return self.cta_engine.cancel_condition_order(cond_orderid)
def cancel_all_condition_orders(self): # hxxjava add
""" """
return self.cta_engine.cancel_all_condition_orders(self)
1)CTA策略中的条件单被触发点回调通知:
def on_condition_order(self, cond_order: ConditionOrder):
"""
Callback of condition order update.
"""
print(f"条件单已经执行,cond_order = {cond_order}")
2)发起条件单
cond_order = ConditionOrder(... ...)
self.send_condition_order(cond_order)
3)取消条件单
self.cancel_condition_order(cond_orderid)
4)取消策略的所有条件单
self.cancel_all_condition_orders()
如果您在启动vntrader的时候勾选了【ChartWizard 实时K线图表模块】,您会简单主界面上vnpy系统提供的K线图表功能图标,进入该功能模块后就可以输入本地代码,新建K线图表了。
使用了该功能之后,你会发现它有如下缺点:
这样一个太简单的K线图表是远远满足了交易者对K线图表的需求的,有多少人使用就可想而知了。
绝大多数交易策略都是基于K线来实现的。可是很少部分是只在1分钟K线的基础上运行的,可能是n分钟,n小时,n天...,只能提供一分钟的K线图是不够用的。
所以应该提供用户如下的选择:
用户之所以想看K线图,可能是想看看自己策略的算法是否正确,这一般都是使用了一个或者多个运行在窗口K线上指标计算的值计算的入场和出场信号。
这也是可以显示的,而这种指标不可能全部是系统自带的指标显示控件能够涵盖的,所以应该有方法让用户自己增加自己的指标显示部件。
所以应该提供下面功能:
由于vnpy系统升级之最新的3.0版本,python底层的对象继承机制发生变化,导致原来的一部分绘图部件因为多继承而发生初始化失败,无法使用,必须升级。
近期不少vnpy的会员朋友不断地私信我,反映这些绘图部件用不了了,因为本人最近忙于交易策略的开发,无暇顾及,实在是抽不出时间,请大家谅解!
现在问题已经解决,可以放心使用。
修改vnpy\chart\manager.py中的BarManager,为它添加一个函数:
def get_bar_idx(self,trade_dt:datetime) -> int: # hxxjava add
"""
get the index of a bar which the trade time belongs to.
return:
-1 : belongs to none
0,1,... : bar's index
"""
a1 = np.array(sorted(self._datetime_index_map.keys()))
a2 = a1 <= trade_dt
return np.sum(a2 == True) - 1
当然别忘了在该文件的引用部分添加下面的语句
import numpy as np # hxxjava add
from datetime import datetime
from typing import List, Tuple, Dict
from vnpy.trader.ui import create_qapp, QtCore, QtGui, QtWidgets
from pyqtgraph import ScatterPlotItem
import pyqtgraph as pg
import numpy as np
import talib
import copy
from vnpy.chart import ChartWidget, VolumeItem, CandleItem
from vnpy.chart.item import ChartItem
from vnpy.chart.manager import BarManager
from vnpy.trader.object import (
BarData,
OrderData,
TradeData
)
from vnpy.trader.object import Direction, Exchange, Interval, Offset, Status, Product, OptionType, OrderType
class BarItem(CandleItem):
""" 美国线 """
BAR_WIDTH = 0.3
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.bar_pen: QtGui.QPen = pg.mkPen(color="w", width=2)
self.bar_brush: QtGui.QBrush = pg.mkBrush(color="w")
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
# Create objects
candle_picture = QtGui.QPicture()
painter = QtGui.QPainter(candle_picture)
# Set painter color
painter.setPen(self.bar_pen)
painter.setBrush(self.bar_brush)
open,high,low,close = bar.open_price,bar.high_price,bar.low_price,bar.close_price
painter.drawLine(QtCore.QPointF(ix - self.BAR_WIDTH, open),QtCore.QPointF(ix, open))
painter.drawLine(QtCore.QPointF(ix, high),QtCore.QPointF(ix, low))
painter.drawLine(QtCore.QPointF(ix + self.BAR_WIDTH, close),QtCore.QPointF(ix, close))
# Finish
painter.end()
return candle_picture
class LineItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
last_bar = self._manager.get_bar(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.white_pen)
# Draw Line
end_point = QtCore.QPointF(ix, bar.close_price)
if last_bar:
start_point = QtCore.QPointF(ix - 1, last_bar.close_price)
else:
start_point = end_point
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
def get_info_text(self, ix: int) -> str:
""""""
text = ""
bar = self._manager.get_bar(ix)
if bar:
text = f"Close:{bar.close_price}"
return text
class SmaItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.line_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)
self.sma_window = 10
self.sma_data: Dict[int, float] = {}
def set_pen(self,pen:QtGui.QPen):
""" 设置绘图的笔 """
self.line_pen = pen
def set_sma_window(self,sma_window:int):
""" 设置Sma的窗口 """
self.sma_window = sma_window
def get_sma_value(self, ix: int) -> float:
""""""
if ix < 0:
return 0
# When initialize, calculate all rsi value
if not self.sma_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
sma_array = talib.SMA(np.array(close_data), self.sma_window)
for n, value in enumerate(sma_array):
self.sma_data[n] = value
# Return if already calcualted
if ix in self.sma_data:
return self.sma_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.sma_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
sma_array = talib.SMA(np.array(close_data), self.sma_window)
sma_value = sma_array[-1]
self.sma_data[ix] = sma_value
return sma_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
sma_value = self.get_sma_value(ix)
last_sma_value = self.get_sma_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.line_pen)
# Draw Line
start_point = QtCore.QPointF(ix-1, last_sma_value)
end_point = QtCore.QPointF(ix, sma_value)
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.sma_data:
sma_value = self.sma_data[ix]
text = f"SMA{self.sma_window} {sma_value:.1f}"
else:
text = "SMA{self.sma_window} -"
return text
class RsiItem(ChartItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=2)
self.rsi_window = 14
self.rsi_data: Dict[int, float] = {}
def get_rsi_value(self, ix: int) -> float:
""""""
if ix < 0:
return 50
# When initialize, calculate all rsi value
if not self.rsi_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
for n, value in enumerate(rsi_array):
self.rsi_data[n] = value
# Return if already calcualted
if ix in self.rsi_data:
return self.rsi_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.rsi_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
rsi_value = rsi_array[-1]
self.rsi_data[ix] = rsi_value
return rsi_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
rsi_value = self.get_rsi_value(ix)
last_rsi_value = self.get_rsi_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Draw RSI line
painter.setPen(self.yellow_pen)
if np.isnan(last_rsi_value) or np.isnan(rsi_value):
# print(ix - 1, last_rsi_value,ix, rsi_value,)
pass
else:
end_point = QtCore.QPointF(ix, rsi_value)
start_point = QtCore.QPointF(ix - 1, last_rsi_value)
painter.drawLine(start_point, end_point)
# Draw oversold/overbought line
painter.setPen(self.white_pen)
painter.drawLine(
QtCore.QPointF(ix, 70),
QtCore.QPointF(ix - 1, 70),
)
painter.drawLine(
QtCore.QPointF(ix, 30),
QtCore.QPointF(ix - 1, 30),
)
# Finish
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
# min_price, max_price = self._manager.get_price_range()
rect = QtCore.QRectF(
0,
0,
len(self._bar_picutures),
100
)
return rect
def get_y_range( self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
""" """
return 0, 100
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.rsi_data:
rsi_value = self.rsi_data[ix]
text = f"RSI {rsi_value:.1f}"
# print(text)
else:
text = "RSI -"
return text
def to_int(value: float) -> int:
""""""
return int(round(value, 0))
def adjust_range(in_range:Tuple[float, float])->Tuple[float, float]:
""" 将y方向的显示范围扩大到1.1 """
ret_range:Tuple[float, float]
diff = abs(in_range[0] - in_range[1])
ret_range = (in_range[0]-diff*0.05,in_range[1]+diff*0.05)
return ret_range
class MacdItem(ChartItem):
""""""
_values_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {}
last_range:Tuple[int, int] = (-1,-1) # 最新显示K线索引范围
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=1)
self.red_pen: QtGui.QPen = pg.mkPen(color=(255, 0, 0), width=1)
self.green_pen: QtGui.QPen = pg.mkPen(color=(0, 255, 0), width=1)
self.short_window = 12
self.long_window = 26
self.M = 9
self.macd_data: Dict[int, Tuple[float,float,float]] = {}
def get_macd_value(self, ix: int) -> Tuple[float,float,float]:
""""""
if ix < 0:
return (0.0,0.0,0.0)
# When initialize, calculate all macd value
if not self.macd_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
diffs,deas,macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
for n in range(0,len(diffs)):
self.macd_data[n] = (diffs[n],deas[n],macds[n])
# Return if already calcualted
if ix in self.macd_data:
return self.macd_data[ix]
# Else calculate new value
close_data = []
for n in range(ix-self.long_window-self.M+1, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
diffs,deas,macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
diff,dea,macd = diffs[-1],deas[-1],macds[-1]
self.macd_data[ix] = (diff,dea,macd)
return (diff,dea,macd)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
macd_value = self.get_macd_value(ix)
last_macd_value = self.get_macd_value(ix - 1)
# # Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# # Draw macd lines
if np.isnan(macd_value[0]) or np.isnan(last_macd_value[0]):
# print("略过macd lines0")
pass
else:
end_point0 = QtCore.QPointF(ix, macd_value[0])
start_point0 = QtCore.QPointF(ix - 1, last_macd_value[0])
painter.setPen(self.white_pen)
painter.drawLine(start_point0, end_point0)
if np.isnan(macd_value[1]) or np.isnan(last_macd_value[1]):
# print("略过macd lines1")
pass
else:
end_point1 = QtCore.QPointF(ix, macd_value[1])
start_point1 = QtCore.QPointF(ix - 1, last_macd_value[1])
painter.setPen(self.yellow_pen)
painter.drawLine(start_point1, end_point1)
if not np.isnan(macd_value[2]):
if (macd_value[2]>0):
painter.setPen(self.red_pen)
painter.setBrush(pg.mkBrush(255,0,0))
else:
painter.setPen(self.green_pen)
painter.setBrush(pg.mkBrush(0,255,0))
painter.drawRect(QtCore.QRectF(ix-0.3,0,0.6,macd_value[2]))
else:
# print("略过macd lines2")
pass
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
min_y, max_y = self.get_y_range()
rect = QtCore.QRectF(
0,
min_y,
len(self._bar_picutures),
max_y
)
return rect
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
# 获得3个指标在y轴方向的范围
# hxxjava 修改,2020-6-29
# 当显示范围改变时,min_ix,max_ix的值不为None,当显示范围不变时,min_ix,max_ix的值不为None,
offset = max(self.short_window,self.long_window) + self.M - 1
if not self.macd_data or len(self.macd_data) < offset:
# print(f'(min_ix,max_ix){(min_ix,max_ix)} offset={offset},len(self.macd_data)={len(self.macd_data)}')
# hxxjava 修改,2021-5-8,因为升级vnpy,其依赖的pyqtgraph版本也升级了,原来为return 0,1
return -100, 100
# print("len of range dict:",len(self._values_ranges),",macd_data:",len(self.macd_data),(min_ix,max_ix))
if min_ix != None: # 调整最小K线索引
min_ix = max(min_ix,offset)
if max_ix != None: # 调整最大K线索引
max_ix = min(max_ix, len(self.macd_data)-1)
last_range = (min_ix,max_ix) # 请求的最新范围
if last_range == (None,None): # 当显示范围不变时
if self.last_range in self._values_ranges:
# 如果y方向范围已经保存
# 读取y方向范围
result = self._values_ranges[self.last_range]
# print("1:",self.last_range,result)
return adjust_range(result)
else:
# 如果y方向范围没有保存
# 从macd_data重新计算y方向范围
min_ix,max_ix = 0,len(self.macd_data)-1
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 保存y方向范围,同时返回结果
result = (min_price, max_price)
self.last_range = (min_ix,max_ix)
self._values_ranges[self.last_range] = result
# print("2:",self.last_range,result)
return adjust_range(result)
""" 以下为显示范围变化时 """
if last_range in self._values_ranges:
# 该范围已经保存过y方向范围
# 取得y方向范围,返回结果
result = self._values_ranges[last_range]
# print("3:",last_range,result)
return adjust_range(result)
# 该范围没有保存过y方向范围
# 从macd_data重新计算y方向范围
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 取得y方向范围,返回结果
result = (min_price, max_price)
self.last_range = last_range
self._values_ranges[self.last_range] = result
# print("4:",self.last_range,result)
return adjust_range(result)
def get_info_text(self, ix: int) -> str:
""" """
barscount = len(self._manager._bars) # hxxjava debug
if ix in self.macd_data:
diff,dea,macd = self.macd_data[ix]
words = [
f"diff {diff:.3f}"," ",
f"dea {dea:.3f}"," ",
f"macd {macd:.3f}",
f"barscount={ix,barscount}"
]
text = "\n".join(words)
else:
text = "diff - \ndea - \nmacd -"
return text
def tip_func(x,y,data):
""" """
return f"{data}"
class BaseScatter(pg.ScatterPlotItem):
""" """
def __init__(self, plot:pg.PlotItem,manager:BarManager,*args, **kargs):
""" """
super().__init__(args=args,kargs=kargs)
self.plot = plot
self.manager = manager
self.plot.addItem(self)
self.opts['hoverable'] = True
def hoverEvent(self, ev):
""" """
if self.opts['hoverable']:
old = self.data['hovered']
if ev.exit:
new = np.zeros_like(self.data['hovered'])
else:
new = self._maskAt(ev.pos())
if self._hasHoverStyle():
self.data['sourceRect'][old ^ new] = 0
self.data['hovered'] = new
self.updateSpots()
points = self.points()[new][::-1]
# Show information about hovered points in a tool tip
vb = self.getViewBox()
if vb is not None and self.opts['tip'] is not None:
cutoff = 10
# tip = [self.opts['tip'](x=pt.pos().x(), y=pt.pos().y(), data=pt.data())
tip = [tip_func(x=pt.pos().x(), y=pt.pos().y(), data=pt.data()) for pt in points[:cutoff]]
if len(points) > cutoff:
tip.append('({} others...)'.format(len(points) - cutoff))
vb.setToolTip('\n\n'.join(tip))
self.sigHovered.emit(self, points, ev)
class TradeItem(BaseScatter):
""" 成交单绘图部件 """
TRADE_COLOR_MAP = {
(Direction.LONG,Offset.OPEN):'red',
(Direction.LONG,Offset.CLOSE):'magenta',
(Direction.LONG,Offset.CLOSETODAY):'magenta',
(Direction.LONG,Offset.CLOSEYESTERDAY):'magenta',
(Direction.SHORT,Offset.OPEN):'green',
(Direction.SHORT,Offset.CLOSE):'yellow',
(Direction.SHORT,Offset.CLOSETODAY):'yellow',
(Direction.SHORT,Offset.CLOSEYESTERDAY):'yellow',
}
TRADE_COMMAND_MAP = {
(Direction.LONG,Offset.OPEN):'买开',
(Direction.LONG,Offset.CLOSE):'买平',
(Direction.LONG,Offset.CLOSETODAY):'买平今',
(Direction.LONG,Offset.CLOSEYESTERDAY):'买平昨',
(Direction.SHORT,Offset.OPEN):'卖开',
(Direction.SHORT,Offset.CLOSE):'卖平',
(Direction.SHORT,Offset.CLOSETODAY):'卖平今',
(Direction.SHORT,Offset.CLOSEYESTERDAY):'卖平昨',
}
def __init__(self, plot:pg.PlotItem,manager:BarManager):
""" """
super().__init__(plot=plot,manager=manager,size=15, pxMode=True,pen=pg.mkPen(None), brush=pg.mkBrush(255, 255, 255, 120))
self.trades : List = []
def _to_scatter_data(self,trade:TradeData):
""" """
idx = self.manager.get_bar_idx(trade.datetime)
if idx == -1:
return {}
bar:BarData = self.manager.get_bar(idx)
color = self.TRADE_COLOR_MAP[(trade.direction,trade.offset)]
size = 10
LL,HH = self.manager.get_price_range()
y_adjustment = (HH-LL) * 0.01
if trade.direction == Direction.LONG:
symbol = 't1'
y = bar.low_price - y_adjustment
else:
symbol = 't'
y = bar.high_price + y_adjustment
# pen = pg.mkPen(QtGui.QColor(color))
# brush = pg.mkBrush(QtGui.QColor(color))
scatter_data = {
"pos": (idx, y),
"size": size,
"pen": color,
"brush": color,
"symbol": symbol,
"data": "成交单:{},单号:{},指令:{},价格:{},手数:{},时间:{}".format(
trade.vt_symbol,
trade.vt_tradeid,
self.TRADE_COMMAND_MAP[(trade.direction,trade.offset)],
trade.price,trade.volume,
trade.datetime.strftime('%Y-%m-%d %H:%M:%S')
)
}
return scatter_data
def add_trades(self, trades: List[TradeData]):
""""""
# 将trade转换为scatter数据
# self.updated = False
self.trades.extend(trades)
spots = []
for trade in self.trades:
scatter = self._to_scatter_data(trade)
if not scatter:
continue
spots.append(scatter)
# self.clear()
# self.plot.removeItem(self)
self.setData(spots,hoverable=True)
def add_trade(self,trade:TradeData):
""" """
self.trades.append(trade)
spots = []
for trade in self.trades:
scatter = self._to_scatter_data(trade)
if not scatter:
continue
spots.append(scatter)
# self.clear()
# self.plot.removeItem(self)
self.setData(spots,hoverable=True)
class OrderItem(BaseScatter):
""" 成交单绘图部件 """
ORDER_COLOR_MAP = {
(Direction.LONG,Offset.OPEN):'red',
(Direction.LONG,Offset.CLOSE):'magenta',
(Direction.LONG,Offset.CLOSETODAY):'magenta',
(Direction.LONG,Offset.CLOSEYESTERDAY):'magenta',
(Direction.SHORT,Offset.OPEN):'green',
(Direction.SHORT,Offset.CLOSE):'yellow',
(Direction.SHORT,Offset.CLOSETODAY):'yellow',
(Direction.SHORT,Offset.CLOSEYESTERDAY):'yellow',
}
ORDER_COMMAND_MAP = {
(Direction.LONG,Offset.OPEN):'买开',
(Direction.LONG,Offset.CLOSE):'买平',
(Direction.LONG,Offset.CLOSETODAY):'买平今',
(Direction.LONG,Offset.CLOSEYESTERDAY):'买平昨',
(Direction.SHORT,Offset.OPEN):'卖开',
(Direction.SHORT,Offset.CLOSE):'卖平',
(Direction.SHORT,Offset.CLOSETODAY):'卖平今',
(Direction.SHORT,Offset.CLOSEYESTERDAY):'卖平昨',
}
def __init__(self, plot:pg.PlotItem,manager:BarManager):
""" """
super().__init__(plot=plot,manager=manager,size=15, pxMode=True,pen=pg.mkPen(None), brush=pg.mkBrush(255, 255, 255, 120))
self.orders : List[OrderData] = []
def _to_scatter_data(self,order:OrderData):
""" """
if not order.datetime:
return {}
idx = self.manager.get_bar_idx(order.datetime)
if idx == -1:
return {}
bar:BarData = self.manager.get_bar(idx)
color = self.ORDER_COLOR_MAP[(order.direction,order.offset)]
size = 10
LL,HH = self.manager.get_price_range()
y_adjustment = (HH-LL) * 0.02
if order.direction == Direction.LONG:
symbol = 'o'
y = bar.low_price - y_adjustment
else:
symbol = 'o'
y = bar.high_price + y_adjustment
# pen = pg.mkPen(QtGui.QColor(color))
# brush = pg.mkBrush(QtGui.QColor(color))
scatter_data = {
"pos": (idx, y),
"size": size,
"pen": color,
"brush": color,
"symbol": symbol,
"data": "委托单:{},单号:{},指令:{},价格:{},手数:{},时间:{}".format(
order.vt_symbol,
order.vt_orderid,
self.ORDER_COMMAND_MAP[(order.direction,order.offset)],
order.price,order.volume,
order.datetime.strftime('%Y-%m-%d %H:%M:%S')
)
}
return scatter_data
def add_orders(self, orders: List[OrderData]):
""""""
# 将trade转换为scatter数据
# self.updated = False
filter_orders = [order for order in orders if order.datetime is not None and order.traded > 0]
if not filter_orders:
return
self.orders.extend(filter_orders)
spots = []
for order in self.orders:
scatter = self._to_scatter_data(order)
if not scatter:
continue
spots.append(scatter)
print(f"spots={spots}")
# self.clear()
# self.plot.removeItem(self)
self.setData(spots,hoverable=True)
def add_order(self,order:OrderData):
""" """
if order.datetime is None or order.traded == 0:
return
self.orders.append(order)
spots = []
for order in self.orders:
scatter = self._to_scatter_data(order)
if not scatter:
continue
spots.append(scatter)
print(f"spots={spots}")
# self.clear()
# self.plot.removeItem(self)
self.setData(spots,hoverable=True)
创建OrderItem和TradeItem时,必须传递主图或者附图的plot和bar管理器BarManager,示例代码如下:
candle_plot = self.chart.get_plot('candle')
manager = self.chart._manager
self.trade_item:TradeItem = TradeItem(plot=candle_plot,manager=manager)
当十字光标移动到成交单图标时,如果当根K线上发生过多次成交,你可能只看见一个图标,但其实是有多个图标被绘制的,这反应在图中的光标提示中,如图所示:
看效果图:
heavywater wrote:
大佬能解释一下get_tick_status吗?看了几遍都没看懂,主要是不明白里面的current_status和next_status是怎么确定的。多谢啦
def get_tick_status(self,tick:TickData):
"""
得到一个tick数据的合约所处交易状态
"""
status:StatusData = None
instrument = left_alphas(tick.symbol) # 提取tick所属的品种
tick_time = tick.datetime.strftime("%H:%M:%S")
vt_symbol = f"{instrument}.{tick.exchange.value}" # 例如:rb.SHFE,TA.CZCE,因为状态字典是按 "品种.交易所" 为字典键值的
if vt_symbol in self.trade_status_map:
status_dict = self.trade_status_map[vt_symbol]
curr_key = status_dict["current"] # 得到当前交易状态的键值,当前状态是有CTP接口收到状态时更新的,这里只是使用
next_key = status_dict["next"] # 得到下一个交易状态的键值,当前状态是有CTP接口收到状态时更新的,这里只是使用,
# 使用一个完整交易日后,它就是会肯定指向一个有意义的值了
curr_status:StatusData = status_dict[curr_key] # 得到当前交易状态
next_status:StatusData = status_dict[next_key] # 得到下一个交易状态
if curr_status.enter_time < next_status.enter_time:
# 交易时间段不跨日
if curr_status.enter_time <= tick_time < next_status.enter_time:
status = curr_status
elif next_status.enter_time <= tick_time:
# 超过了当前时间段和下一时间段的开始时间,并且在下一个交易时间段开始时间之后,认为是找到了
# 正确的前提条件是,current和next状态的更是是正常的。
status = next_status
else:
# 交易时间段跨日
if curr_status.enter_time <= tick_time:
status = curr_status
elif next_status.enter_time <= tick_time:
status = next_status
return status
交易时间段是对客户端委托是否有效的规定,不是对行情播报的限制。
但是反过来行情又是客户委托最终成交的果,因为某种原因导致对撮合成功超过交易时间段的截止时间,可是交易所仍然要报告结果,这种结果一般会在几秒内就结束了。
打个比方:
知道,是因为3.0版本升级,导致OrderItem和TradeItem的多继承出了问题。
目前没有时间,等有空再重写这两个主图组件就可以了。
尽管此时还是没有开盘,甚至还没有开始集合竞价,可是您的策略已经从on_tick()接口被推送了一个tick,而且该tick的时间不是当天下午的收盘,也不是您订阅该合约的时间 ! 我把这个tick打印出来了,请看:
TickData(
gateway_name='CTP',
symbol='TA205',
exchange=<Exchange.CZCE: 'CZCE'>,
datetime=datetime.datetime(2022, 4, 12, 20, 1, 18, 500000, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>),
name='精对苯二甲酸205',
volume=0,
turnover=0.0,
open_interest=441112.0,
last_price=6128.0,
last_volume=0,
limit_up=6374.0,
limit_down=5538.0,
open_price=0,
high_price=0,
low_price=0,
pre_close=6128.0,
bid_price_1=0,
bid_price_2=0,
bid_price_3=0,
bid_price_4=0,
bid_price_5=0,
ask_price_1=0,
ask_price_2=0,
ask_price_3=0,
ask_price_4=0,
ask_price_5=0,
bid_volume_1=0,
bid_volume_2=0,
bid_volume_3=0,
bid_volume_4=0,
bid_volume_5=0,
ask_volume_1=0,
ask_volume_2=0,
ask_volume_3=0,
ask_volume_4=0,
ask_volume_5=0,
localtime=None
)
很快开始集合竞价,在20:59的时候,策略可能又会收到一个包含开盘价的tick。
1分钟后是21:00,正式进入连续竞价阶段,策略又会收到等多的tick。
为了后面叙述的方便,我们把20:50时收到tick叫tick1,20:59时收到tick叫tick2。
假如CTA策略使用了30分钟K线,那么随着集合竞价结束,在21:00的时候,策略中的BarGeneraor对象,就会为您生成两个莫名其妙的30分钟K线:
启动策略不到10分钟时间,就已经虚多了2个30分钟K线。
CTA策略的初始化是由CtaEngine驱动的,其执行逻辑在vnpy_ctastrategy\engine的CtaEngine._init_strategy()中:
def _init_strategy(self, strategy_name: str):
"""
Init strategies in queue.
"""
strategy = self.strategies[strategy_name]
if strategy.inited:
self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作")
return
self.write_log(f"{strategy_name}开始执行初始化")
# Call on_init function of strategy
self.call_strategy_func(strategy, strategy.on_init)
# Restore strategy data(variables)
data = self.strategy_data.get(strategy_name, None)
if data:
for name in strategy.variables:
value = data.get(name, None)
if value is not None:
setattr(strategy, name, value)
# Subscribe market data
contract = self.main_engine.get_contract(strategy.vt_symbol)
if contract:
req = SubscribeRequest(
symbol=contract.symbol, exchange=contract.exchange)
self.main_engine.subscribe(req, contract.gateway_name)
else:
self.write_log(f"行情订阅失败,找不到合约{strategy.vt_symbol}", strategy)
# Put event to update init completed status.
strategy.inited = True
self.put_strategy_event(strategy)
self.write_log(f"{strategy_name}初始化完成")
_init_strategy()执行过程是先为策略加载历史数据,再订阅策略交易合约的行情。
要想解决问题,就必须问题的根源在哪里?
因为订阅合约行情执行的是CtpMdApi的subscribe():
def subscribe(self, req: SubscribeRequest) -> None:
"""订阅行情"""
if self.login_status:
self.subscribeMarketData(req.symbol)
self.subscribed.add(req.symbol)
self.subscribeMarketData(req.symbol)只要是首次订阅,接口都会立即从OnRtnDepthMarketData推送1个该合约最新的深度行情通知,而时间是:
///最后修改时间
TThostFtdcTimeType UpdateTime;
///最后修改毫秒
TThostFtdcMillisecType UpdateMillisec;
这里的最后修改时间和最后修改毫秒本应该是该合约最后交易的时间,也可能是交易所行情服务器中CTP行情接口重新打开的时间!这就是为什么我们开动tick1的时间是2022-4-12 20:1:18.500000的原因。
TA205.CZCE在每个交易日的集合竞价时段的第4分钟会产生一个集合竞价tick。
你可能会说这个没有毛病,从20:30~21:00,确实是可以生成一个30分钟K线,为什么它不可以只包含一个tick呢?
这么说也过得去,可是问题是咱们在加载其他历史数据的时候,无论我们使用米筐、tushare或者什么其他第三方历史数据时,加载的1分钟K线,从来都没有这样的数据。
或者我们把策略产生的30分钟K线与通达信、大智慧或者文华6等软件生成的30分钟K线比较一下,它们都是没有出现这第二个30分钟K线情况的。从这种种也可以看出来这个tick的处理是不对的,tick2必须归入到21:00~21:30。
struct CThostFtdcDepthMarketDataField
{
///交易日
TThostFtdcDateType TradingDay;
///合约代码
TThostFtdcInstrumentIDType InstrumentID;
///交易所代码
TThostFtdcExchangeIDType ExchangeID;
///合约在交易所的代码
TThostFtdcExchangeInstIDType ExchangeInstID;
///最新价
TThostFtdcPriceType LastPrice;
///上次结算价
TThostFtdcPriceType PreSettlementPrice;
///昨收盘
TThostFtdcPriceType PreClosePrice;
///昨持仓量
TThostFtdcLargeVolumeType PreOpenInterest;
///今开盘
TThostFtdcPriceType OpenPrice;
///最高价
TThostFtdcPriceType HighestPrice;
///最低价
TThostFtdcPriceType LowestPrice;
///数量
TThostFtdcVolumeType Volume;
///成交金额
TThostFtdcMoneyType Turnover;
///持仓量
TThostFtdcLargeVolumeType OpenInterest;
///今收盘
TThostFtdcPriceType ClosePrice;
///本次结算价
TThostFtdcPriceType SettlementPrice;
///涨停板价
TThostFtdcPriceType UpperLimitPrice;
///跌停板价
TThostFtdcPriceType LowerLimitPrice;
///昨虚实度
TThostFtdcRatioType PreDelta;
///今虚实度
TThostFtdcRatioType CurrDelta;
///最后修改时间
TThostFtdcTimeType UpdateTime;
///最后修改毫秒
TThostFtdcMillisecType UpdateMillisec;
///申买价一
TThostFtdcPriceType BidPrice1;
///申买量一
TThostFtdcVolumeType BidVolume1;
///申卖价一
TThostFtdcPriceType AskPrice1;
///申卖量一
TThostFtdcVolumeType AskVolume1;
///申买价二
TThostFtdcPriceType BidPrice2;
///申买量二
TThostFtdcVolumeType BidVolume2;
///申卖价二
TThostFtdcPriceType AskPrice2;
///申卖量二
TThostFtdcVolumeType AskVolume2;
///申买价三
TThostFtdcPriceType BidPrice3;
///申买量三
TThostFtdcVolumeType BidVolume3;
///申卖价三
TThostFtdcPriceType AskPrice3;
///申卖量三
TThostFtdcVolumeType AskVolume3;
///申买价四
TThostFtdcPriceType BidPrice4;
///申买量四
TThostFtdcVolumeType BidVolume4;
///申卖价四
TThostFtdcPriceType AskPrice4;
///申卖量四
TThostFtdcVolumeType AskVolume4;
///申买价五
TThostFtdcPriceType BidPrice5;
///申买量五
TThostFtdcVolumeType BidVolume5;
///申卖价五
TThostFtdcPriceType AskPrice5;
///申卖量五
TThostFtdcVolumeType AskVolume5;
///当日均价
TThostFtdcPriceType AveragePrice;
///业务日期
TThostFtdcDateType ActionDay;
};
其中 UpdateMillisec 为最后修改毫秒,int型
错误和不合适之处已经改正,见注释:
def onRtnDepthMarketData(self, data: dict) -> None:
"""行情数据推送"""
# 过滤没有时间戳的异常行情数据
if not data["UpdateTime"]:
return
# 过滤还没有收到合约数据前的行情推送
symbol: str = data["InstrumentID"]
contract: ContractData = symbol_contract_map.get(symbol, None)
if not contract:
return
# 对大商所的交易日字段取本地日期
if not data["ActionDay"] or contract.exchange == Exchange.DCE:
# 这里废了那么大的劲,却使用了一个更新滞后的变量,属实不应该
# self.current_date是由定时器几秒更新一次,
# 对于一些跨夜品种,会导致几秒钟的tick的日期错误
# date_str: str = self.current_date
date_str: str = datetime.now().strftime("%Y%m%d") # hxxjava change
else:
date_str: str = data["ActionDay"]
# 这里不好,为什么要故意降低接口的时间精度,放着毫秒不要而费劲地变化为0.1秒精度?
# timestamp: str = f"{date_str} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}"
timestamp: str = f"{date_str} {data['UpdateTime']}." + str(data['UpdateMillisec']*1000).zfill(6) # hxxjava edit
dt: datetime = datetime.strptime(timestamp, "%Y%m%d %H:%M:%S.%f")
dt: datetime = CHINA_TZ.localize(dt)
tick: TickData = TickData(
symbol=symbol,
exchange=contract.exchange,
datetime=dt,
name=contract.name,
volume=data["Volume"],
turnover=data["Turnover"],
open_interest=data["OpenInterest"],
last_price=adjust_price(data["LastPrice"]),
limit_up=data["UpperLimitPrice"],
limit_down=data["LowerLimitPrice"],
open_price=adjust_price(data["OpenPrice"]),
high_price=adjust_price(data["HighestPrice"]),
low_price=adjust_price(data["LowestPrice"]),
pre_close=adjust_price(data["PreClosePrice"]),
bid_price_1=adjust_price(data["BidPrice1"]),
ask_price_1=adjust_price(data["AskPrice1"]),
bid_volume_1=data["BidVolume1"],
ask_volume_1=data["AskVolume1"],
gateway_name=self.gateway_name
)
if data["BidVolume2"] or data["AskVolume2"]:
tick.bid_price_2 = adjust_price(data["BidPrice2"])
tick.bid_price_3 = adjust_price(data["BidPrice3"])
tick.bid_price_4 = adjust_price(data["BidPrice4"])
tick.bid_price_5 = adjust_price(data["BidPrice5"])
tick.ask_price_2 = adjust_price(data["AskPrice2"])
tick.ask_price_3 = adjust_price(data["AskPrice3"])
tick.ask_price_4 = adjust_price(data["AskPrice4"])
tick.ask_price_5 = adjust_price(data["AskPrice5"])
tick.bid_volume_2 = data["BidVolume2"]
tick.bid_volume_3 = data["BidVolume3"]
tick.bid_volume_4 = data["BidVolume4"]
tick.bid_volume_5 = data["BidVolume5"]
tick.ask_volume_2 = data["AskVolume2"]
tick.ask_volume_3 = data["AskVolume3"]
tick.ask_volume_4 = data["AskVolume4"]
tick.ask_volume_5 = data["AskVolume5"]
self.gateway.on_tick(tick)
文章太长,再分一贴吧。
一个合约的交易时间段信息,就包含在一个字符串中。通常看起来是这样的:
"21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
它看似简单,实则非常复杂!简单在于它只是一个字符串,其实它能够表达非常复杂的交易时间规定。例如交易时间可以少到只有1段,也可以4到5个段,可跨日,也可以跨多日,如遇到周末或者长假。但是长假太难处理了,我们这也不处理各种各样的假日规定,因为那个太复杂了!不过好在时下很多软件,著名的和非著名的软件,几乎都不处理跨长假的问题,不处理的原因也是和我分析的一样,不过这也没有影响他们多软件被广大用户接受的程度。所以我们也不处理跨长假的问题。
当然想处理跨长假也不成,条件不具备呀。因为毕竟我们不是交易所,不知道各种各样的休假规定,不同市场,不同国家的节假日,千奇百怪,太难处理了。而且我们也不能说不处理哪个市场或者国家的投资品种吧?绝大部分软件都不处理长假对K线对齐方式的影响,原因就在于此,没有什么别的说辞!
在vnpy\usertools下创建一个名称为trading_hours.py,其代码如下:
"""
本文件主要实现合约的交易时间段:TradingHours
作者:hxxjava
日期:2022-03-28
修改:2022-06-09 修改内容:TradingHours的get_intraday_window()处理时间段错误
"""
from calendar import month
from typing import Callable,List,Dict, Tuple, Union
from enum import Enum
from datetime import datetime,date,timedelta, tzinfo
import numpy as np
import pytz
CHINA_TZ = pytz.timezone("Asia/Shanghai")
from vnpy.trader.constant import Interval
def to_china_tz(dt: datetime) -> datetime:
"""
Convert a datetime object to a CHINA_TZ localize datetime object
"""
return CHINA_TZ.localize(dt.replace(tzinfo=None))
INTERVAL_MAP = {
Interval.MINUTE:60,
Interval.HOUR:3600,
Interval.DAILY:3600*24,
Interval.WEEKLY:3600*24*7,
}
def get_time_segments(trading_hours:str) -> List:
"""
从交易时间段字符串中提取出各段的起止时间(天内的秒数) 列表
"""
time_sepments = []
# 提取各段
str_segments = trading_hours.split(',')
pre_start,day_offset = None,0
for s in reversed(str_segments): # 反向遍历各段
# 提取段的起止时间
start,stop = s.split('-')
# 计算开始时间天内秒
hh,mm = start.split(':')
start_s = int(hh)*3600+int(mm)*60
# 计算截止时间天内秒
hh,mm = stop.split(':')
stop_s = int(hh)*3600+int(mm)*60
if pre_start and start > pre_start:
day_offset -= 1
pre_start = start
# 加入列表
time_sepments.insert(0,(day_offset,start_s,stop_s))
return time_sepments
def in_segments(trade_segments:List,trade_dt:datetime):
""" 判断一个时间是否在一个交易时间段列表中 """
trade_dt = to_china_tz(trade_dt)
for start,stop in trade_segments:
if start <= trade_dt < stop:
return True
return False
class TradingHours(object):
"""
交易时间段处理
"""
def __init__(self,trading_hours:str):
"""
初始化函数 。
参数说明:
trading_hours:交易时间段字符串,例如:"21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
pre_open: 集合竞价时段长度,单位分钟。例如:国内期货pre_open=5
after_close: 交易日收盘后结算时长。例如国内期货持续到15:20,那么after_close=20
"""
self.time_segments:List[Tuple[int,int,int]] = get_time_segments(trading_hours)
def day_trade_time(self,interval:Interval) -> int:
"""
一个交易日的交易时长,单位由interval 规定,不足的部分+1
"""
seconds = 0.0
for _,start,stop in self.time_segments:
seconds += stop - start + (0 if start < stop else INTERVAL_MAP[Interval.DAILY])
if not interval:
return seconds
else:
return np.ceil(seconds/INTERVAL_MAP[interval])
def get_auction_closes_segments(self,trade_dt:datetime) -> Tuple[date,List]:
"""
得到一个交易时间所在的交易日及集合竞价时间段和所有休市时段的列表
"""
if not self.auction_closes:
return (None,[])
trade_dt = to_china_tz(trade_dt)
dates = [trade_dt.date()+timedelta(days=i) for i in range(-3,4)]
# 根据 self.auction_closes 构造出一周内的日期时间格式的非连续交易时间段字典
week_seqments = {
dt:
[(to_china_tz(datetime(dt.year,dt.month,dt.day))+timedelta(days=days-(2 if days == -1 and dt.weekday()==0 else 0),seconds=start),
to_china_tz(datetime(dt.year,dt.month,dt.day))+timedelta(days=days-(2 if days == -1 and dt.weekday()==0 else 0)+(1 if start>stop else 0),seconds=stop))
for days,start,stop in self.auction_closes]
for dt in dates if dt.weekday() not in [5,6]
}
# 在非交易时间段字典中查找trade_dt所在集合竞价时间段,确定所属交易日
for dt,datetime_segments in week_seqments.items():
# 遍历一周中的每日
if in_segments(datetime_segments,trade_dt):
return (dt,datetime_segments)
return (None,[])
def get_trade_hours(self,trade_dt:datetime) -> Tuple[date,List[Tuple[datetime,datetime]]]:
"""
得到一个时间的交易日及日期时间格式的交易时间段列表,无效交易时间返回空
"""
# 构造trade_dt加前后三天共7的日期
trade_dt = to_china_tz(trade_dt)
dates = [trade_dt.date()+timedelta(days=i) for i in range(-3,4)]
# 根据 self.time_segments 构造出一周内的日期时间格式的交易时间段字典
week_seqments = {
dt:
[(to_china_tz(datetime(dt.year,dt.month,dt.day))+timedelta(days=days-(2 if days == -1 and dt.weekday()==0 else 0),seconds=start),
to_china_tz(datetime(dt.year,dt.month,dt.day))+timedelta(days=days-(2 if days == -1 and dt.weekday()==0 else 0)+(1 if start>stop else 0),seconds=stop))
for days,start,stop in self.time_segments]
for dt in dates if dt.weekday() not in [5,6]
}
trade_day,trading_segments = None,[]
# 在交易时间段字典中查找trade_dt所在交易时间段,确定所属交易日
for dt,datetime_segments in week_seqments.items():
# 遍历一周中的每日
for start,stop in datetime_segments:
# 遍历一日中的每个交易时间段
if start <= trade_dt < stop:
# 找到了,确定dt为trade_dt的交易日
trade_day = dt
break
if trade_day:
# 已经找到,停止
trading_segments = datetime_segments
break
return (trade_day,trading_segments)
def get_trading_segments(self,tradeday:date): # List[Tuple[datetime,datetime]]
"""
得到某个交易日的就要时间段。注:只考虑周末,不考虑法定假
"""
segments = []
weekday = tradeday.weekday()
if weekday not in [5,6]:
# 周一至周五
# 周一跨日需插入2天
insert_days = -2 if tradeday.weekday() == 0 else 0
y,m,d = tradeday.year,tradeday.month,tradeday.day
for day,start,stop in self.time_segments:
days = insert_days + day if day < 0 else day
start_dt = datetime(y,m,d,0,0,0) + timedelta(days=days,seconds=start)
stop_dt = datetime(y,m,d,0,0,0) + timedelta(days=days+(0 if start < stop else 1),seconds=stop)
segments.append((start_dt,stop_dt))
return segments
def get_intraday_window(self,trade_dt:datetime,window:int) -> Tuple[date,List[Tuple[datetime,datetime]]]:
"""
得到一个时间的日内交易时间、窗口索引、窗口开始时间和截止时间
"""
trade_dt = to_china_tz(trade_dt)
interval = Interval.MINUTE
oneday_minutes = self.day_trade_time(interval)
if window > oneday_minutes:
raise f"In day window can't exceed {oneday_minutes} minutes !"
result = (None,[])
if window == 0:
# window==0 无意义
return result
# 求dt的交易日
trade_day,segment_datetimes = self.get_trade_hours(trade_dt)
if not trade_day:
# 无效的交易日
return result
if np.sum([start <= trade_dt < stop for start,stop in segment_datetimes]) == 0:
# 如果dt不在各个交易时间段内为无效的交易时间
return result
# 交易日的开盘时间
t0 = segment_datetimes[0][0]
# 构造各个交易时间段的起止数组
starts = np.array([(seg_dt[0]-t0).seconds*1.0 for seg_dt in segment_datetimes])
stops = np.array([(seg_dt[1]-t0).seconds*1.0 for seg_dt in segment_datetimes])
# 求dt在交易日中的自然时间
nature_t = (trade_dt - t0).seconds
# 求dt已经走过的交易时间
traded_t = np.sum(nature_t - starts[starts<=nature_t]) - np.sum(nature_t-stops[stops<nature_t])
if traded_t < 0:
# 开盘之前的为无效交易时间
return result
# 求当前所在窗口的宽度、索引、开始交易时间及截止时间
window_width = window * INTERVAL_MAP[interval]
window_idx = np.floor(traded_t/window_width)
window_start = window_idx * window_width
window_stop = window_start + window_width
# 求各个交易时间段的宽度
segment_widths = stops - starts
# print("!!!3",window_start,window_stop,segment_widths)
# 求各个交易时间段累计日内交易时间
sums = [np.sum(segment_widths[:(i+1)]) for i in range(len(segment_widths))]
if window_stop > sums[-1]:
# 不可以跨日处理
window_stop = sums[-1]
# 累计日内交易时间数组
seg_sum = np.array(sums)
# 每段开始累计日内交易时间数组
seg_start_sum = np.array([0] + sums)
# 求窗口开始和截止时间的时间段索引
s1,s2 = seg_sum - window_start,seg_sum - window_stop
start_idx,stop_idx = np.sum(s1 <= 0),np.sum(s2<0)
# 求窗口开始和截止时间的在其时间段中的偏移量
start_offset = (window_start-seg_start_sum)[start_idx]
stop_offset = (window_stop-seg_start_sum)[stop_idx]
# 求窗口包含的时间片段列表
window_segments = []
for idx in range(start_idx,stop_idx+1):
start,stop = segment_datetimes[idx]
t1 = start + timedelta(seconds=start_offset) if idx == start_idx else start
t2 = start + timedelta(seconds=stop_offset) if idx == stop_idx else stop
window_segments.append((t1,t2))
# 窗口所属交易日及包含的时间片段列表
result = (trade_day,window_segments)
return result
def get_week_tradedays(self,trade_dt:datetime) -> List[date]:
""" 得到一个交易时间所在周的交易日 """
trade_dt = to_china_tz(trade_dt)
trade_day,trade_segments = self.get_trade_hours(trade_dt)
if not trade_day:
return []
monday = trade_dt.date() - timedelta(days=trade_dt.weekday())
week_dates = [monday + timedelta(days=i) for i in range(5)]
if trade_day not in week_dates:
next_7days = [(trade_dt + timedelta(days=i+1)) for i in range(7)]
week_dates = [day.date() for day in next_7days if day.weekday() not in [5,6]]
return week_dates
def get_month_tradedays(self,trade_dt:datetime) -> List[date]:
""" 得到一个交易时间所在月的交易日 """
trade_dt = to_china_tz(trade_dt)
trade_day,trade_segments = self.get_trade_hours(trade_dt)
if not trade_day:
return []
first_day = date(year=trade_day.year,month=trade_day.month,day=1)
this_month = trade_day.month
days32 = [first_day + timedelta(days = i) for i in range(32)]
month_dates = [day for day in days32 if day.weekday() not in [5,6] and day.month==this_month]
return month_dates
def get_year_tradedays(self,trade_dt:datetime) -> List[date]:
""" 得到一个交易时间所在年的交易日 """
trade_dt = to_china_tz(trade_dt)
trade_day,trade_segments = self.get_trade_hours(trade_dt)
if not trade_day:
return []
new_years_day = date(year=trade_day.year,month=1,day=1)
this_year = trade_day.year
days366 = [new_years_day + timedelta(days = i) for i in range(366)]
trade_dates = [day for day in days366 if day.weekday() not in [5,6] and day.year==this_year]
return trade_dates
def has_night_tradetime(self) -> bool:
""" 有夜盘交易时间吗? """
for (days,start,stop) in self.time_segments:
if start >= 18*INTERVAL_MAP(Interval.HOUR):
return True
return False
def has_day_tradetime(self) -> bool:
""" 有日盘交易时间吗 ? """
for (days,start,stop) in self.time_segments:
if start < 18*INTERVAL_MAP(Interval.HOUR):
return True
return False
先给它取个名称,就叫MyBarGenerator吧,它是对BarGenerator的扩展。
不过在构思MyBarGenerator的时候,我发现它其实不应该叫“日内对齐等交易时长K线生成器”。因为我们不应该只局限于日内的n分钟K线生成器,难道vnpy系统就不应该、不能够或者不使用日线以上的K线了吗?我们只能够使用日内K线进行量化交易吗?难道大家都没有过这方面的需求吗?我想答案是否定的。
那好,所幸就设计一个全功能的K线生成器:MyBarGenerator。
为此我们需要扩展Interval的定义,因为Interval是表示K线周期的常量,可是它的格局不够,最大只能到周一级WEEKLY。也就是说您用目前的Interval是没有办法表达月和年这样的周期的。
class Interval(Enum):
"""
Interval of bar data.
"""
MINUTE = "1m"
HOUR = "1h"
DAILY = "d"
WEEKLY = "w"
TICK = "tick"
MONTHLY = "month" # hxxjava add
YEARLY = "year" # hxxjava add
顺便在这里吐槽一下BarGenerator:
在系统且并详细分析之后,把K线分类为:日内K线、日K线,周K线、月K线及年K线等周期K线五类。
1)日内K线包括1~n分钟K线,如1分钟、n分钟两类,其中n小于正常交易日的最大交易分钟数。日内K线取消对小时周期单位支持,因为可以通过n分钟的方式来实现。如:
2)日K线:每个交易日产生一个,它包含一到多个交易时间段。根据是否包含夜盘交易时间段,又可以分为跨日K线和不跨日K线。
3)周K线:由周一至周五中所有交易日的交易数据合成得到,它其实是一种特殊的n日K线,只是n<=5而已。
4)月K线:由每月1日至月末最后一个交易日的交易数据合成得到,除去所有周末,它最多包含23个交易日,遇到本月有长假日,其所包含的交易日会更少。
5)年K线:由每年1月1日至12月31日中的所有交易日的交易数据合成得到,除去所有周末。它可以理解为由一年中的所有交易日数据合成的,也可以理解为由一年中的12个月的交易日数据合成的。
1)日内K线(包括1~n分钟K线)生成规则:
2)日K线生成规则:
3)周K线生成规则:
4)月K线生成规则:
5)年K线生成规则:
年K线可以由两种方式进行合成:一种是用日K线合成,另一种是用月K线合成。我们这里选择用日K线来合成。
在vnpy\usertools\utility.py中加入如下面的两个部分:
from copy import deepcopy
from typing import List,Dict,Tuple,Optional,Sequence,Callable
from datetime import date,datetime,timedelta
from vnpy.trader.constant import Interval
from vnpy.trader.object import TickData,BarData
from vnpy.trader.utility import extract_vt_symbol
from vnpy.usertools.trading_hours import TradingHours,in_segments
from vnpy.usertools.trade_hours import CHINA_TZ
def generate_temp_bar(small:BarData,big:BarData,interval:Interval):
""" get temp intra day small_bar """
small_bar:BarData = deepcopy(small) # 1 minute small_bar
big_bar:BarData = deepcopy(big)
if big_bar and small_bar:
big_bar.high_price = max(big_bar.high_price,small_bar.high_price)
big_bar.low_price = min(big_bar.low_price,small_bar.low_price)
big_bar.close_price = small_bar.close_price
big_bar.open_interest = small_bar.open_interest
big_bar.volume += small_bar.volume
big_bar.turnover += small_bar.turnover
elif not big_bar and small_bar:
big_bar = BarData(
symbol=small_bar.symbol,
exchange=small_bar.exchange,
interval=interval,
datetime=small_bar.datetime,
gateway_name=small_bar.gateway_name,
open_price=small_bar.open_price,
high_price=small_bar.high_price,
low_price=small_bar.low_price,
close_price = small_bar.close_price,
open_interest = small_bar.open_interest,
volume = small_bar.volume,
turnover = small_bar.turnover
)
return big_bar
class MyBarGenerator():
"""
An align bar generator.
Comment's for parameters:
on_bar : callback function on 1 minute bar is generated.
window : window bar's width.
on_window_bar : callback function on x interval bar is generated.
interval : window bar's unit.
trading_hours: trading hours with which the window bar can be generated.
"""
def __init__(
self,
on_bar: Callable,
window: int = 0,
on_window_bar: Callable = None,
interval: Interval = Interval.MINUTE,
trading_hours:str = ""
):
""" Constructor """
self.bar: BarData = None
self.on_bar: Callable = on_bar
self.interval: Interval = interval
self.interval_count: int = 0
self.intra_day_bar: BarData = None
self.day_bar: BarData = None
self.week_bar: BarData = None
self.month_bar: BarData = None
self.year_bar: BarData = None
self.day_bar_cnt:int = 0 # 日K线的1分钟K线计数
self.week_daybar_cnt:int = 0 # 周K线的日K线计数
self.window: int = window
self.on_window_bar: Callable = on_window_bar
self.last_tick: TickData = None
if interval not in [Interval.MINUTE,Interval.DAILY,Interval.WEEKLY,Interval.MONTHLY,Interval.YEARLY]:
raise ValueError(f"MyBarGenerator support MINUTE,DAILY,WEEKLY,MONTHLY and YEARLY bar generation only , please check it !")
if not trading_hours:
raise ValueError(f"MyBarGenerator need trading hours setting , please check it !")
# trading hours object
self.trading_hours = TradingHours(trading_hours)
self.day_total_minutes = int(self.trading_hours.day_trade_time(Interval.MINUTE))
self.tick_windows = (None,[])
# current intraday window bar's contains trading day and time segments list
self.intraday_bar_window = (None,[]) # (trade_day,[])
# current daily bar's window containts trading day and time segment list
self.daily_bar_window = (None,[])
# current weekly bar's window containts all trade days
self.weekly_bar_window = []
# current monthly bar's window containts all trade days
self.monthly_bar_window = []
# current yearly bar's window containts all trade days
self.yearly_bar_window = []
def update_tick(self, tick: TickData) -> None:
"""
Update new tick data into generator.
"""
new_minute = False
# Filter tick data with 0 last price
if not tick.last_price:
return
# Filter tick data with older timestamp
if self.last_tick and tick.datetime < self.last_tick.datetime:
print(f"特别tick【{tick}】!")
return
if self.tick_windows == (None,[]) or not in_segments(self.tick_windows[1],tick.datetime):
# 判断tick是否在连续交易时间段或者集合竞价时间段中
self.tick_windows = self.trading_hours.get_trade_hours(tick.datetime)
if self.tick_windows == (None,[]):
# 不在连续交易时间段
print(f"特别tick【{tick}】")
return
if not self.bar:
new_minute = True
elif (
(self.bar.datetime.minute != tick.datetime.minute)
or (self.bar.datetime.hour != tick.datetime.hour)
):
self.bar.datetime = self.bar.datetime.replace(
second=0, microsecond=0
)
self.on_bar(self.bar)
new_minute = True
if new_minute:
self.bar = BarData(
symbol=tick.symbol,
exchange=tick.exchange,
interval=Interval.MINUTE,
datetime=to_china_tz(tick.datetime),
gateway_name=tick.gateway_name,
open_price=tick.last_price,
high_price=tick.last_price,
low_price=tick.last_price,
close_price=tick.last_price,
open_interest=tick.open_interest
)
else:
self.bar.high_price = max(self.bar.high_price, tick.last_price)
if tick.high_price > self.last_tick.high_price:
self.bar.high_price = max(self.bar.high_price, tick.high_price)
self.bar.low_price = min(self.bar.low_price, tick.last_price)
if tick.low_price < self.last_tick.low_price:
self.bar.low_price = min(self.bar.low_price, tick.low_price)
self.bar.close_price = tick.last_price
self.bar.open_interest = tick.open_interest
self.bar.datetime = to_china_tz(tick.datetime)
if self.last_tick:
volume_change = tick.volume - self.last_tick.volume
self.bar.volume += max(volume_change, 0)
turnover_change = tick.turnover - self.last_tick.turnover
self.bar.turnover += max(turnover_change, 0)
self.last_tick = tick
def update_bar(self, bar: BarData) -> None:
"""
Update 1 minute bar into generator
"""
if self.interval == Interval.MINUTE and self.window > 0:
# update inday bar
self.update_intraday_bar(bar)
elif self.interval in [Interval.DAILY,Interval.WEEKLY,Interval.MONTHLY,Interval.YEARLY]:
# update daily,weekly,monthly or yearly bar
self.update_daily_bar(bar)
def update_intraday_bar(self, bar: BarData) -> None:
""" update intra day x window bar """
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.interval != Interval.MINUTE or self.window <= 1:
return
if self.intraday_bar_window == (None,[]):
# 首次调用日内K线更新函数
trade_day,time_segments = self.trading_hours.get_intraday_window(bar.datetime,self.window)
if (trade_day,time_segments) == (None,[]):
# 无效的1分钟K线
return
# 更新当前日内K线交易时间
self.intraday_bar_window = (trade_day,time_segments)
# 创建新的日内K线
self.intra_day_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.MINUTE,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
elif not in_segments(self.intraday_bar_window[1],bar.datetime):
# 1分钟K线不属于当前日内K线
str1 = f"bar.datetime={bar.datetime}\nintraday_bar_window:{self.intraday_bar_window}"
trade_day,time_segments = self.trading_hours.get_intraday_window(bar.datetime,self.window)
if (trade_day,time_segments) == (None,[]):
# 无效的1分钟K线
return
# 当前日内K线已经生成,推送当前日内K线
if self.on_window_bar:
self.on_window_bar(self.intra_day_bar)
# 更新当前日内K线交易时间
self.intraday_bar_window = (trade_day,time_segments)
str1 += f"\nintraday_bar_window:{self.intraday_bar_window}"
print(str1)
# 创建新的日内K线
self.intra_day_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.MINUTE,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 1分钟K线属于当前日内K线
# 更新当前日内K线
self.intra_day_bar.high_price = max(self.intra_day_bar.high_price,bar.high_price)
self.intra_day_bar.low_price = min(self.intra_day_bar.low_price,bar.low_price)
self.intra_day_bar.close_price = bar.close_price
self.intra_day_bar.open_interest = bar.open_interest
self.intra_day_bar.volume += bar.volume
self.intra_day_bar.turnover += bar.turnover
# 判断当前日内K线是否结束
close_time = self.intraday_bar_window[1][-1][1]
next_minute_dt = bar.datetime + timedelta(minutes=1)
if close_time <= next_minute_dt:
# 当前日K内线已经结束
# 当前日内K线已经生成,推送之
if self.on_window_bar:
print(f"close_time={close_time},next_minute_dt={next_minute_dt}")
self.on_window_bar(self.intra_day_bar)
self.intraday_bar_window = (None,[])
self.intra_day_bar = None
def update_daily_bar(self, bar: BarData) -> bool:
""" update daily bar using 1 minute bar """
result = False
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.daily_bar_window == (None,[]):
# 首次调用日K线更新函数
trade_day,trade_segments = self.trading_hours.get_trade_hours(bar.datetime)
if (trade_day,trade_segments) == (None,[]):
# 无效的1分钟K线
return result
# 更新当前日K线交易时间
self.daily_bar_window = (trade_day,trade_segments)
# 创建新的日K线
self.day_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.DAILY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
self.day_bar_cnt = 1
if not in_segments(self.daily_bar_window[1],bar.datetime):
# 1分钟K线不属于当前日K线
trade_day,trade_segments = self.trading_hours.get_trade_hours(bar.datetime)
if (trade_day,trade_segments) == (None,[]):
# 无效的1分钟K线
return
# 当前日K线已经生成
if self.interval == Interval.DAILY:
# 推送当前日K线
if self.on_window_bar:
self.on_window_bar(self.day_bar)
self.day_bar_cnt = 0
else:
# 更新更大周期K线
if self.update_weekly_bar(self.day_bar):
self.week_daybar_cnt += 1
self.update_monthly_bar(self.day_bar)
self.update_yearly_bar(self.day_bar)
# 更新当前日K线交易时间
self.daily_bar_window = (trade_day,trade_segments)
# 创建新的日K线
self.day_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.DAILY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 1分钟K线属于当前交易日
# 更新当前日K线
self.day_bar.high_price = max(self.day_bar.high_price,bar.high_price)
self.day_bar.low_price = min(self.day_bar.low_price,bar.low_price)
self.day_bar.close_price = bar.close_price
self.day_bar.open_interest = bar.open_interest
self.day_bar.volume += bar.volume
self.day_bar.turnover += bar.turnover
result = True
# 判断当前日K线是否结束
close_time = self.daily_bar_window[1][-1][1]
next_minute_dt = bar.datetime + timedelta(minutes=1)
if close_time <= next_minute_dt or self.day_total_minutes == self.day_bar_cnt:
# 当前日K线已经结束
# 当前日K线已经生成
if self.interval == Interval.DAILY:
# 推送当前日K线
if self.on_window_bar:
self.on_window_bar(self.day_bar)
else:
# 更新更大周期K线
if self.update_weekly_bar(self.day_bar):
self.week_daybar_cnt += 1
self.update_monthly_bar(self.day_bar)
self.update_yearly_bar(self.day_bar)
self.daily_bar_window = (None,[])
self.day_bar = None
self.day_bar_cnt = 0
return result
def update_weekly_bar(self, bar: BarData) -> bool:
""" update weekly bar using a daily bar """
result = False
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.interval != Interval.WEEKLY:
# 设定周期单位不是周,不处理
return result
if not self.weekly_bar_window:
# 首次调用周K线更新函数
week_tradedays = self.trading_hours.get_week_tradedays(bar.datetime)
if not week_tradedays:
# 无效的日K线
return result
# 更新当前周K线交易日列表
self.weekly_bar_window = week_tradedays
# 创建新的周K线
self.week_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.WEEKLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
if bar.datetime not in self.weekly_bar_window:
# 日线不属于当前周K线
week_tradedays = self.trading_hours.get_week_tradedays(bar.datetime)
if not week_tradedays:
# 无效的日K线
return result
# 当前周K线已经生成,推送
if self.on_window_bar:
self.on_window_bar(self.week_bar)
self.week_daybar_cnt = 0
# 更新当前周K线交易日列表
self.weekly_bar_window = week_tradedays
# 创建新的周K线
self.week_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.WEEKLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 更新当前周K线
self.week_bar.high_price = max(self.week_bar.high_price,bar.high_price)
self.week_bar.low_price = min(self.week_bar.low_price,bar.low_price)
self.week_bar.close_price = bar.close_price
self.week_bar.open_interest = bar.open_interest
self.week_bar.volume += bar.volume
self.week_bar.turnover += bar.turnover
result = True
# 判断当前周K线是否结束
trade_day,_ = self.trading_hours.get_trade_hours(bar.datetime)
if trade_day >= self.weekly_bar_window[-1] or self.week_daybar_cnt == 5:
# 当前周K线已经结束,推送当前周K线
if self.on_window_bar:
self.on_window_bar(self.week_bar)
self.week_daybar_cnt = 0
# 复位当前周交易日列表及周K线
self.weekly_bar_window = []
self.week_bar = None
return result
def update_monthly_bar(self, bar: BarData) -> bool:
""" update monthly bar using a daily bar """
result = False
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.interval != Interval.MONTHLY:
# 设定周期单位不是月,不处理
return result
if not self.monthly_bar_window:
# 首次调用月K线更新函数
month_tradedays = self.trading_hours.get_month_tradedays(bar.datetime)
if not month_tradedays:
# 无效的日K线
return result
# 更新当前月K线交易日列表
self.monthly_bar_window = month_tradedays
# 创建新的月K线
self.month_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.MONTHLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
if bar.datetime not in self.monthly_bar_window:
# 日线不属于当前月K线
month_tradedays = self.trading_hours.get_month_tradedays(bar.datetime)
if not month_tradedays:
# 无效的日K线
return result
# 当前月K线已经生成,推送
if self.on_window_bar:
self.on_window_bar(self.month_bar)
# 更新当前月交易日列表
self.monthly_bar_window = month_tradedays
# 创建新的月K线
self.month_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.MONTHLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 更新当前月K线
self.month_bar.high_price = max(self.month_bar.high_price,bar.high_price)
self.month_bar.low_price = min(self.month_bar.low_price,bar.low_price)
self.month_bar.close_price = bar.close_price
self.month_bar.open_interest = bar.open_interest
self.month_bar.volume += bar.volume
self.month_bar.turnover += bar.turnover
result = True
# 判断当前月K线是否结束
trade_day,_ = self.trading_hours.get_trade_hours(bar.datetime)
if trade_day >= self.monthly_bar_window[-1]:
# 当前月K线已经结束,推送当前月K线
if self.on_window_bar:
self.on_window_bar(self.month_bar)
# 复位当前月交易日列表及月K线
self.monthly_bar_window = []
self.month_bar = None
return result
def update_yearly_bar(self, bar: BarData) -> bool:
""" update yearly bar using a daily bar """
result = False
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.interval != Interval.YEARLY:
# 设定周期单位不是年,不处理
return result
if not self.yearly_bar_window:
# 首次调用年K线更新函数
year_tradedays = self.trading_hours.get_year_tradedays(bar.datetime)
if not year_tradedays:
# 无效的日K线
return result
# 更新当前年K线交易日列表
self.yearly_bar_window = year_tradedays
# 创建新的年K线
self.year_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.YEARLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
if bar.datetime not in self.yearly_bar_window:
# 日线不属于当前年K线
year_tradedays = self.trading_hours.get_year_tradedays(bar.datetime)
if not year_tradedays:
# 无效的日K线
return result
# 当前年K线已经生成,推送
if self.on_window_bar:
self.on_window_bar(self.year_bar)
# 更新当前年交易日列表
self.yearly_bar_window = year_tradedays
# 创建新的年K线
self.year_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.YEARLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 更新当前年K线
self.year_bar.high_price = max(self.year_bar.high_price,bar.high_price)
self.year_bar.low_price = min(self.year_bar.low_price,bar.low_price)
self.year_bar.close_price = bar.close_price
self.year_bar.open_interest = bar.open_interest
self.year_bar.volume += bar.volume
self.year_bar.turnover += bar.turnover
result = True
# 判断当前年K线是否结束
trade_day,_ = self.trading_hours.get_trade_hours(bar.datetime)
if trade_day >= self.yearly_bar_window[-1]:
# 当前年K线已经结束,推送当前年K线
if self.on_window_bar:
self.on_window_bar(self.year_bar)
# 复位当前年交易日列表及年K线
self.yearly_bar_window = []
self.year_bar = None
return result
def get_temp_bar(self) -> BarData:
""" 返回临时1分钟K线 """
bar = deepcopy(self.bar)
if bar:
bar.datetime = bar.datetime.replace(second=0,microsecond=0)
return bar
def get_temp_window_bar(self,bar:BarData = None) -> BarData:
"""
返回临时窗口K线
"""
temp_bar:BarData = None
if not bar:
# 如果没有传入1分钟K线,取当前生成器的1分钟K线
bar = self.bar
if self.interval == Interval.MINUTE:
if self.window == 0:
temp_bar = deepcopy(self.bar)
else:
temp_bar = generate_temp_bar(bar,self.intra_day_bar,Interval.MINUTE)
elif self.interval == Interval.DAILY:
temp_bar = generate_temp_bar(bar,self.day_bar,Interval.DAILY)
elif self.interval == Interval.WEEKLY:
day_bar = generate_temp_bar(bar,self.day_bar,Interval.DAILY)
temp_bar = generate_temp_bar(day_bar,self.week_bar,Interval.WEEKLY)
elif self.interval == Interval.MONTHLY:
day_bar = generate_temp_bar(bar,self.day_bar,Interval.DAILY)
temp_bar = generate_temp_bar(day_bar,self.month_bar,Interval.MONTHLY)
elif self.interval == Interval.YEARLY:
day_bar = generate_temp_bar(bar,self.day_bar,Interval.DAILY)
temp_bar = generate_temp_bar(day_bar,self.year_bar,Interval.YEARLY)
return temp_bar
def generate(self) -> Optional[BarData]:
"""
Generate the bar data and call callback immediately.
"""
bar = self.bar
if self.bar:
bar.datetime = bar.datetime.replace(second=0, microsecond=0)
self.on_bar(bar)
self.bar = None
return bar
交易时间段是交易所对一个合约连续交易时间的规定,它只规定了在哪些时间段内市场是可以连续交易的,也就是说投资者交易开仓、平仓和撤单的。
但是交易时间段不包括一个合约交易的所有交易时间的规定,例如集合竞价时间段、日内中间休市时间段和交易日收盘休市时间段这三类时间段的规定。
集合竞价时间段在交易日的开盘时间之前。能够该时间段的参与的投资者可能有资格的限制,就是说可能不是市场的参与者都有资格能够在在集合竞价时段中进行交易的。
而且不同市场,不同合约的集合竞价时间段的长度是不一样的,不同的交易日也可能不同,例如:
1)国内市场
2)国外市场
总之,集合竞价时段变化多端,非常复杂,在K线时长上需要特别关注和处理,否则您生成的是什么K线,正确与否是无从谈起的。没准您多了个莫名其妙的K线都不知道。
这种特别处理请参考:分析一下盘中启动CTA策略带来的第一根$K线错误
修改vnpy_ctastrategy\backtesting.py,修改后全部内容如下:
from collections import defaultdict
from datetime import date, datetime, timedelta
import imp
from pipes import Template
from typing import Callable, List
from functools import lru_cache, partial
import traceback
import numpy as np
from pandas import DataFrame
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from vnpy.trader.constant import (Direction, Offset, Exchange,
Interval, Status)
from vnpy.trader.database import get_database
from vnpy.trader.object import OrderData, TradeData, BarData, TickData, ContractData # hxxjava add ContractData
from vnpy.trader.utility import round_to
from vnpy.trader.optimize import (
OptimizationSetting,
check_optimization_setting,
run_bf_optimization,
run_ga_optimization
)
from .base import (
BacktestingMode,
EngineType,
STOPORDER_PREFIX,
StopOrder,
StopOrderStatus,
INTERVAL_DELTA_MAP
)
from .template import CtaTemplate
class BacktestingEngine:
""""""
engine_type = EngineType.BACKTESTING
gateway_name = "BACKTESTING"
def __init__(self):
""""""
self.vt_symbol = ""
self.symbol = ""
self.exchange = None
self.start = None
self.end = None
self.rate = 0
self.slippage = 0
self.size = 1
self.pricetick = 0
self.capital = 1_000_000
self.risk_free: float = 0
self.annual_days: int = 240
self.mode = BacktestingMode.BAR
self.inverse = False
self.strategy_class = None
self.strategy = None
self.tick: TickData
self.bar: BarData
self.datetime = None
self.interval = None
self.days = 0
self.callback = None
self.history_data = []
self.stop_order_count = 0
self.stop_orders = {}
self.active_stop_orders = {}
self.limit_order_count = 0
self.limit_orders = {}
self.active_limit_orders = {}
self.trade_count = 0
self.trades = {}
self.logs = []
self.daily_results = {}
self.daily_df = None
self.load_all_trading_hours() # hxxjava add
self.load_contracts() # hxxjava add
def clear_data(self):
"""
Clear all data of last backtesting.
"""
self.strategy = None
self.tick = None
self.bar = None
self.datetime = None
self.stop_order_count = 0
self.stop_orders.clear()
self.active_stop_orders.clear()
self.limit_order_count = 0
self.limit_orders.clear()
self.active_limit_orders.clear()
self.trade_count = 0
self.trades.clear()
self.logs.clear()
self.daily_results.clear()
def set_parameters(
self,
vt_symbol: str,
interval: Interval,
start: datetime,
rate: float,
slippage: float,
size: float,
pricetick: float,
capital: int = 0,
end: datetime = None,
mode: BacktestingMode = BacktestingMode.BAR,
inverse: bool = False,
risk_free: float = 0,
annual_days: int = 240
):
""""""
self.mode = mode
self.vt_symbol = vt_symbol
self.interval = Interval(interval)
self.rate = rate
self.slippage = slippage
self.size = size
self.pricetick = pricetick
self.start = start
self.symbol, exchange_str = self.vt_symbol.split(".")
self.exchange = Exchange(exchange_str)
self.capital = capital
self.end = end
self.mode = mode
self.inverse = inverse
self.risk_free = risk_free
self.annual_days = annual_days
def load_all_trading_hours(self) -> None: # hxxjava add end
""" """
from vnpy.trader.datafeed import get_datafeed
df = get_datafeed()
if not df.inited:
df.init()
self.all_trading_hours = df.load_all_trading_hours()
print(f"BachtestingEngine.all_trading_hours len={len(self.all_trading_hours)}")
def load_contracts(self) -> None: # hxxjava add end
""" """
database = get_database()
contracts:List[ContractData] = database.load_contract_data()
self.contracts = {}
for c in contracts:
self.contracts[c.vt_symbol] = c
print(f"BachtestingEngine.contracts len={len(self.contracts)}")
def get_trading_hours(self,strategy:CtaTemplate) -> str: # hxxjava add
"""
get vt_symbol's trading hours
"""
ths = self.all_trading_hours.get(strategy.vt_symbol.upper(),"")
return ths["trading_hours"] if ths else ""
def get_contract(self, strategy:CtaTemplate) :# -> Optional[ContractData]:
"""
Get contract data by vt_symbol.
"""
return self.contracts.get(strategy.vt_symbol,None)
def add_strategy(self, strategy_class: type, setting: dict):
""""""
self.strategy_class = strategy_class
self.strategy = strategy_class(
self, strategy_class.__name__, self.vt_symbol, setting
)
def load_data(self):
""""""
self.output("开始加载历史数据")
if not self.end:
self.end = datetime.now()
if self.start >= self.end:
self.output("起始日期必须小于结束日期")
return
self.history_data.clear() # Clear previously loaded history data
# Load 30 days of data each time and allow for progress update
total_days = (self.end - self.start).days
progress_days = max(int(total_days / 10), 1)
progress_delta = timedelta(days=progress_days)
interval_delta = INTERVAL_DELTA_MAP[self.interval]
start = self.start
end = self.start + progress_delta
progress = 0
while start < self.end:
progress_bar = "#" * int(progress * 10 + 1)
self.output(f"加载进度:{progress_bar} [{progress:.0%}]")
end = min(end, self.end) # Make sure end time stays within set range
if self.mode == BacktestingMode.BAR:
data = load_bar_data(
self.symbol,
self.exchange,
self.interval,
start,
end
)
else:
data = load_tick_data(
self.symbol,
self.exchange,
start,
end
)
self.history_data.extend(data)
progress += progress_days / total_days
progress = min(progress, 1)
start = end + interval_delta
end += progress_delta
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
def run_backtesting(self):
""""""
if self.mode == BacktestingMode.BAR:
func = self.new_bar
else:
func = self.new_tick
self.strategy.on_init()
# Use the first [days] of history data for initializing strategy
day_count = 0
ix = 0
for ix, data in enumerate(self.history_data):
if self.datetime and data.datetime.day != self.datetime.day:
day_count += 1
if day_count >= self.days:
break
self.datetime = data.datetime
try:
self.callback(data)
except Exception:
self.output("触发异常,回测终止")
self.output(traceback.format_exc())
return
self.strategy.inited = True
self.output("策略初始化完成")
self.strategy.on_start()
self.strategy.trading = True
self.output("开始回放历史数据")
# Use the rest of history data for running backtesting
backtesting_data = self.history_data[ix:]
if len(backtesting_data) <= 1:
self.output("历史数据不足,回测终止")
return
total_size = len(backtesting_data)
batch_size = max(int(total_size / 10), 1)
for ix, i in enumerate(range(0, total_size, batch_size)):
batch_data = backtesting_data[i: i + batch_size]
for data in batch_data:
try:
func(data)
except Exception:
self.output("触发异常,回测终止")
self.output(traceback.format_exc())
return
progress = min(ix / 10, 1)
progress_bar = "=" * (ix + 1)
self.output(f"回放进度:{progress_bar} [{progress:.0%}]")
self.strategy.on_stop()
self.output("历史数据回放结束")
def calculate_result(self):
""""""
self.output("开始计算逐日盯市盈亏")
if not self.trades:
self.output("成交记录为空,无法计算")
return
# Add trade data into daily reuslt.
for trade in self.trades.values():
d = trade.datetime.date()
daily_result = self.daily_results[d]
daily_result.add_trade(trade)
# Calculate daily result by iteration.
pre_close = 0
start_pos = 0
for daily_result in self.daily_results.values():
daily_result.calculate_pnl(
pre_close,
start_pos,
self.size,
self.rate,
self.slippage,
self.inverse
)
pre_close = daily_result.close_price
start_pos = daily_result.end_pos
# Generate dataframe
results = defaultdict(list)
for daily_result in self.daily_results.values():
for key, value in daily_result.__dict__.items():
results[key].append(value)
self.daily_df = DataFrame.from_dict(results).set_index("date")
self.output("逐日盯市盈亏计算完成")
return self.daily_df
def calculate_statistics(self, df: DataFrame = None, output=True):
""""""
self.output("开始计算策略统计指标")
# Check DataFrame input exterior
if df is None:
df = self.daily_df
# Check for init DataFrame
if df is None:
# Set all statistics to 0 if no trade.
start_date = ""
end_date = ""
total_days = 0
profit_days = 0
loss_days = 0
end_balance = 0
max_drawdown = 0
max_ddpercent = 0
max_drawdown_duration = 0
total_net_pnl = 0
daily_net_pnl = 0
total_commission = 0
daily_commission = 0
total_slippage = 0
daily_slippage = 0
total_turnover = 0
daily_turnover = 0
total_trade_count = 0
daily_trade_count = 0
total_return = 0
annual_return = 0
daily_return = 0
return_std = 0
sharpe_ratio = 0
return_drawdown_ratio = 0
else:
# Calculate balance related time series data
df["balance"] = df["net_pnl"].cumsum() + self.capital
# When balance falls below 0, set daily return to 0
pre_balance = df["balance"].shift(1)
pre_balance.iloc[0] = self.capital
x = df["balance"] / pre_balance
x[x <= 0] = np.nan
df["return"] = np.log(x).fillna(0)
df["highlevel"] = (
df["balance"].rolling(
min_periods=1, window=len(df), center=False).max()
)
df["drawdown"] = df["balance"] - df["highlevel"]
df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
# Calculate statistics value
start_date = df.index[0]
end_date = df.index[-1]
total_days = len(df)
profit_days = len(df[df["net_pnl"] > 0])
loss_days = len(df[df["net_pnl"] < 0])
end_balance = df["balance"].iloc[-1]
max_drawdown = df["drawdown"].min()
max_ddpercent = df["ddpercent"].min()
max_drawdown_end = df["drawdown"].idxmin()
if isinstance(max_drawdown_end, date):
max_drawdown_start = df["balance"][:max_drawdown_end].idxmax()
max_drawdown_duration = (max_drawdown_end - max_drawdown_start).days
else:
max_drawdown_duration = 0
total_net_pnl = df["net_pnl"].sum()
daily_net_pnl = total_net_pnl / total_days
total_commission = df["commission"].sum()
daily_commission = total_commission / total_days
total_slippage = df["slippage"].sum()
daily_slippage = total_slippage / total_days
total_turnover = df["turnover"].sum()
daily_turnover = total_turnover / total_days
total_trade_count = df["trade_count"].sum()
daily_trade_count = total_trade_count / total_days
total_return = (end_balance / self.capital - 1) * 100
annual_return = total_return / total_days * self.annual_days
daily_return = df["return"].mean() * 100
return_std = df["return"].std() * 100
if return_std:
daily_risk_free = self.risk_free / np.sqrt(self.annual_days)
sharpe_ratio = (daily_return - daily_risk_free) / return_std * np.sqrt(self.annual_days)
else:
sharpe_ratio = 0
return_drawdown_ratio = -total_return / max_ddpercent
# Output
if output:
self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
self.output(f"总交易日:\t{total_days}")
self.output(f"盈利交易日:\t{profit_days}")
self.output(f"亏损交易日:\t{loss_days}")
self.output(f"起始资金:\t{self.capital:,.2f}")
self.output(f"结束资金:\t{end_balance:,.2f}")
self.output(f"总收益率:\t{total_return:,.2f}%")
self.output(f"年化收益:\t{annual_return:,.2f}%")
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
self.output(f"最长回撤天数: \t{max_drawdown_duration}")
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
self.output(f"总手续费:\t{total_commission:,.2f}")
self.output(f"总滑点:\t{total_slippage:,.2f}")
self.output(f"总成交金额:\t{total_turnover:,.2f}")
self.output(f"总成交笔数:\t{total_trade_count}")
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
self.output(f"日均手续费:\t{daily_commission:,.2f}")
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
self.output(f"日均成交笔数:\t{daily_trade_count}")
self.output(f"日均收益率:\t{daily_return:,.2f}%")
self.output(f"收益标准差:\t{return_std:,.2f}%")
self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}")
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
statistics = {
"start_date": start_date,
"end_date": end_date,
"total_days": total_days,
"profit_days": profit_days,
"loss_days": loss_days,
"capital": self.capital,
"end_balance": end_balance,
"max_drawdown": max_drawdown,
"max_ddpercent": max_ddpercent,
"max_drawdown_duration": max_drawdown_duration,
"total_net_pnl": total_net_pnl,
"daily_net_pnl": daily_net_pnl,
"total_commission": total_commission,
"daily_commission": daily_commission,
"total_slippage": total_slippage,
"daily_slippage": daily_slippage,
"total_turnover": total_turnover,
"daily_turnover": daily_turnover,
"total_trade_count": total_trade_count,
"daily_trade_count": daily_trade_count,
"total_return": total_return,
"annual_return": annual_return,
"daily_return": daily_return,
"return_std": return_std,
"sharpe_ratio": sharpe_ratio,
"return_drawdown_ratio": return_drawdown_ratio,
}
# Filter potential error infinite value
for key, value in statistics.items():
if value in (np.inf, -np.inf):
value = 0
statistics[key] = np.nan_to_num(value)
self.output("策略统计指标计算完成")
return statistics
def show_chart(self, df: DataFrame = None):
""""""
# Check DataFrame input exterior
if df is None:
df = self.daily_df
# Check for init DataFrame
if df is None:
return
fig = make_subplots(
rows=4,
cols=1,
subplot_titles=["Balance", "Drawdown", "Daily Pnl", "Pnl Distribution"],
vertical_spacing=0.06
)
balance_line = go.Scatter(
x=df.index,
y=df["balance"],
mode="lines",
name="Balance"
)
drawdown_scatter = go.Scatter(
x=df.index,
y=df["drawdown"],
fillcolor="red",
fill='tozeroy',
mode="lines",
name="Drawdown"
)
pnl_bar = go.Bar(y=df["net_pnl"], name="Daily Pnl")
pnl_histogram = go.Histogram(x=df["net_pnl"], nbinsx=100, name="Days")
fig.add_trace(balance_line, row=1, col=1)
fig.add_trace(drawdown_scatter, row=2, col=1)
fig.add_trace(pnl_bar, row=3, col=1)
fig.add_trace(pnl_histogram, row=4, col=1)
fig.update_layout(height=1000, width=1000)
fig.show()
def run_bf_optimization(self, optimization_setting: OptimizationSetting, output=True):
""""""
if not check_optimization_setting(optimization_setting):
return
evaluate_func: callable = wrap_evaluate(self, optimization_setting.target_name)
results = run_bf_optimization(
evaluate_func,
optimization_setting,
get_target_value,
output=self.output
)
if output:
for result in results:
msg: str = f"参数:{result[0]}, 目标:{result[1]}"
self.output(msg)
return results
run_optimization = run_bf_optimization
def run_ga_optimization(self, optimization_setting: OptimizationSetting, output=True):
""""""
if not check_optimization_setting(optimization_setting):
return
evaluate_func: callable = wrap_evaluate(self, optimization_setting.target_name)
results = run_ga_optimization(
evaluate_func,
optimization_setting,
get_target_value,
output=self.output
)
if output:
for result in results:
msg: str = f"参数:{result[0]}, 目标:{result[1]}"
self.output(msg)
return results
def update_daily_close(self, price: float):
""""""
d = self.datetime.date()
daily_result = self.daily_results.get(d, None)
if daily_result:
daily_result.close_price = price
else:
self.daily_results[d] = DailyResult(d, price)
def new_bar(self, bar: BarData):
""""""
self.bar = bar
self.datetime = bar.datetime
self.cross_limit_order()
self.cross_stop_order()
self.strategy.on_bar(bar)
self.update_daily_close(bar.close_price)
def new_tick(self, tick: TickData):
""""""
self.tick = tick
self.datetime = tick.datetime
self.cross_limit_order()
self.cross_stop_order()
self.strategy.on_tick(tick)
self.update_daily_close(tick.last_price)
def cross_limit_order(self):
"""
Cross limit order with last bar/tick data.
"""
if self.mode == BacktestingMode.BAR:
long_cross_price = self.bar.low_price
short_cross_price = self.bar.high_price
long_best_price = self.bar.open_price
short_best_price = self.bar.open_price
else:
long_cross_price = self.tick.ask_price_1
short_cross_price = self.tick.bid_price_1
long_best_price = long_cross_price
short_best_price = short_cross_price
for order in list(self.active_limit_orders.values()):
# Push order update with status "not traded" (pending).
if order.status == Status.SUBMITTING:
order.status = Status.NOTTRADED
self.strategy.on_order(order)
# Check whether limit orders can be filled.
long_cross = (
order.direction == Direction.LONG
and order.price >= long_cross_price
and long_cross_price > 0
)
short_cross = (
order.direction == Direction.SHORT
and order.price <= short_cross_price
and short_cross_price > 0
)
if not long_cross and not short_cross:
continue
# Push order udpate with status "all traded" (filled).
order.traded = order.volume
order.status = Status.ALLTRADED
self.strategy.on_order(order)
if order.vt_orderid in self.active_limit_orders:
self.active_limit_orders.pop(order.vt_orderid)
# Push trade update
self.trade_count += 1
if long_cross:
trade_price = min(order.price, long_best_price)
pos_change = order.volume
else:
trade_price = max(order.price, short_best_price)
pos_change = -order.volume
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=str(self.trade_count),
direction=order.direction,
offset=order.offset,
price=trade_price,
volume=order.volume,
datetime=self.datetime,
gateway_name=self.gateway_name,
)
self.strategy.pos += pos_change
self.strategy.on_trade(trade)
self.trades[trade.vt_tradeid] = trade
def cross_stop_order(self):
"""
Cross stop order with last bar/tick data.
"""
if self.mode == BacktestingMode.BAR:
long_cross_price = self.bar.high_price
short_cross_price = self.bar.low_price
long_best_price = self.bar.open_price
short_best_price = self.bar.open_price
else:
long_cross_price = self.tick.last_price
short_cross_price = self.tick.last_price
long_best_price = long_cross_price
short_best_price = short_cross_price
for stop_order in list(self.active_stop_orders.values()):
# Check whether stop order can be triggered.
long_cross = (
stop_order.direction == Direction.LONG
and stop_order.price <= long_cross_price
)
short_cross = (
stop_order.direction == Direction.SHORT
and stop_order.price >= short_cross_price
)
if not long_cross and not short_cross:
continue
# Create order data.
self.limit_order_count += 1
order = OrderData(
symbol=self.symbol,
exchange=self.exchange,
orderid=str(self.limit_order_count),
direction=stop_order.direction,
offset=stop_order.offset,
price=stop_order.price,
volume=stop_order.volume,
traded=stop_order.volume,
status=Status.ALLTRADED,
gateway_name=self.gateway_name,
datetime=self.datetime
)
self.limit_orders[order.vt_orderid] = order
# Create trade data.
if long_cross:
trade_price = max(stop_order.price, long_best_price)
pos_change = order.volume
else:
trade_price = min(stop_order.price, short_best_price)
pos_change = -order.volume
self.trade_count += 1
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=str(self.trade_count),
direction=order.direction,
offset=order.offset,
price=trade_price,
volume=order.volume,
datetime=self.datetime,
gateway_name=self.gateway_name,
)
self.trades[trade.vt_tradeid] = trade
# Update stop order.
stop_order.vt_orderids.append(order.vt_orderid)
stop_order.status = StopOrderStatus.TRIGGERED
if stop_order.stop_orderid in self.active_stop_orders:
self.active_stop_orders.pop(stop_order.stop_orderid)
# Push update to strategy.
self.strategy.on_stop_order(stop_order)
self.strategy.on_order(order)
self.strategy.pos += pos_change
self.strategy.on_trade(trade)
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable,
use_database: bool
) -> List[BarData]:
""""""
self.days = days
self.callback = callback
return []
def load_tick(self, vt_symbol: str, days: int, callback: Callable) -> List[TickData]:
""""""
self.days = days
self.callback = callback
return []
def send_order(
self,
strategy: CtaTemplate,
direction: Direction,
offset: Offset,
price: float,
volume: float,
stop: bool,
lock: bool,
net: bool
):
""""""
price = round_to(price, self.pricetick)
if stop:
vt_orderid = self.send_stop_order(direction, offset, price, volume)
else:
vt_orderid = self.send_limit_order(direction, offset, price, volume)
return [vt_orderid]
def send_stop_order(
self,
direction: Direction,
offset: Offset,
price: float,
volume: float
):
""""""
self.stop_order_count += 1
stop_order = StopOrder(
vt_symbol=self.vt_symbol,
direction=direction,
offset=offset,
price=price,
volume=volume,
datetime=self.datetime,
stop_orderid=f"{STOPORDER_PREFIX}.{self.stop_order_count}",
strategy_name=self.strategy.strategy_name,
)
self.active_stop_orders[stop_order.stop_orderid] = stop_order
self.stop_orders[stop_order.stop_orderid] = stop_order
return stop_order.stop_orderid
def send_limit_order(
self,
direction: Direction,
offset: Offset,
price: float,
volume: float
):
""""""
self.limit_order_count += 1
order = OrderData(
symbol=self.symbol,
exchange=self.exchange,
orderid=str(self.limit_order_count),
direction=direction,
offset=offset,
price=price,
volume=volume,
status=Status.SUBMITTING,
gateway_name=self.gateway_name,
datetime=self.datetime
)
self.active_limit_orders[order.vt_orderid] = order
self.limit_orders[order.vt_orderid] = order
return order.vt_orderid
def cancel_order(self, strategy: CtaTemplate, vt_orderid: str):
"""
Cancel order by vt_orderid.
"""
if vt_orderid.startswith(STOPORDER_PREFIX):
self.cancel_stop_order(strategy, vt_orderid)
else:
self.cancel_limit_order(strategy, vt_orderid)
def cancel_stop_order(self, strategy: CtaTemplate, vt_orderid: str):
""""""
if vt_orderid not in self.active_stop_orders:
return
stop_order = self.active_stop_orders.pop(vt_orderid)
stop_order.status = StopOrderStatus.CANCELLED
self.strategy.on_stop_order(stop_order)
def cancel_limit_order(self, strategy: CtaTemplate, vt_orderid: str):
""""""
if vt_orderid not in self.active_limit_orders:
return
order = self.active_limit_orders.pop(vt_orderid)
order.status = Status.CANCELLED
self.strategy.on_order(order)
def cancel_all(self, strategy: CtaTemplate):
"""
Cancel all orders, both limit and stop.
"""
vt_orderids = list(self.active_limit_orders.keys())
for vt_orderid in vt_orderids:
self.cancel_limit_order(strategy, vt_orderid)
stop_orderids = list(self.active_stop_orders.keys())
for vt_orderid in stop_orderids:
self.cancel_stop_order(strategy, vt_orderid)
def write_log(self, msg: str, strategy: CtaTemplate = None):
"""
Write log message.
"""
msg = f"{self.datetime}\t{msg}"
self.logs.append(msg)
def send_email(self, msg: str, strategy: CtaTemplate = None):
"""
Send email to default receiver.
"""
pass
def sync_strategy_data(self, strategy: CtaTemplate):
"""
Sync strategy data into json file.
"""
pass
def get_engine_type(self):
"""
Return engine type.
"""
return self.engine_type
def get_pricetick(self, strategy: CtaTemplate):
"""
Return contract pricetick data.
"""
return self.pricetick
def put_strategy_event(self, strategy: CtaTemplate):
"""
Put an event to update strategy status.
"""
pass
def output(self, msg):
"""
Output message of backtesting engine.
"""
print(f"{datetime.now()}\t{msg}")
def get_all_trades(self):
"""
Return all trade data of current backtesting result.
"""
return list(self.trades.values())
def get_all_orders(self):
"""
Return all limit order data of current backtesting result.
"""
return list(self.limit_orders.values())
def get_all_daily_results(self):
"""
Return all daily result data.
"""
return list(self.daily_results.values())
class DailyResult:
""""""
def __init__(self, date: date, close_price: float):
""""""
self.date = date
self.close_price = close_price
self.pre_close = 0
self.trades = []
self.trade_count = 0
self.start_pos = 0
self.end_pos = 0
self.turnover = 0
self.commission = 0
self.slippage = 0
self.trading_pnl = 0
self.holding_pnl = 0
self.total_pnl = 0
self.net_pnl = 0
def add_trade(self, trade: TradeData):
""""""
self.trades.append(trade)
def calculate_pnl(
self,
pre_close: float,
start_pos: float,
size: int,
rate: float,
slippage: float,
inverse: bool
):
""""""
# If no pre_close provided on the first day,
# use value 1 to avoid zero division error
if pre_close:
self.pre_close = pre_close
else:
self.pre_close = 1
# Holding pnl is the pnl from holding position at day start
self.start_pos = start_pos
self.end_pos = start_pos
if not inverse: # For normal contract
self.holding_pnl = self.start_pos * \
(self.close_price - self.pre_close) * size
else: # For crypto currency inverse contract
self.holding_pnl = self.start_pos * \
(1 / self.pre_close - 1 / self.close_price) * size
# Trading pnl is the pnl from new trade during the day
self.trade_count = len(self.trades)
for trade in self.trades:
if trade.direction == Direction.LONG:
pos_change = trade.volume
else:
pos_change = -trade.volume
self.end_pos += pos_change
# For normal contract
if not inverse:
turnover = trade.volume * size * trade.price
self.trading_pnl += pos_change * \
(self.close_price - trade.price) * size
self.slippage += trade.volume * size * slippage
# For crypto currency inverse contract
else:
turnover = trade.volume * size / trade.price
self.trading_pnl += pos_change * \
(1 / trade.price - 1 / self.close_price) * size
self.slippage += trade.volume * size * slippage / (trade.price ** 2)
self.turnover += turnover
self.commission += turnover * rate
# Net pnl takes account of commission and slippage cost
self.total_pnl = self.trading_pnl + self.holding_pnl
self.net_pnl = self.total_pnl - self.commission - self.slippage
@lru_cache(maxsize=999)
def load_bar_data(
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime
):
""""""
database = get_database()
return database.load_bar_data(
symbol, exchange, interval, start, end
)
@lru_cache(maxsize=999)
def load_tick_data(
symbol: str,
exchange: Exchange,
start: datetime,
end: datetime
):
""""""
database = get_database()
return database.load_tick_data(
symbol, exchange, start, end
)
def evaluate(
target_name: str,
strategy_class: CtaTemplate,
vt_symbol: str,
interval: Interval,
start: datetime,
rate: float,
slippage: float,
size: float,
pricetick: float,
capital: int,
end: datetime,
mode: BacktestingMode,
inverse: bool,
setting: dict
):
"""
Function for running in multiprocessing.pool
"""
engine = BacktestingEngine()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital,
end=end,
mode=mode,
inverse=inverse
)
engine.add_strategy(strategy_class, setting)
engine.load_data()
engine.run_backtesting()
engine.calculate_result()
statistics = engine.calculate_statistics(output=False)
target_value = statistics[target_name]
return (str(setting), target_value, statistics)
def wrap_evaluate(engine: BacktestingEngine, target_name: str) -> callable:
"""
Wrap evaluate function with given setting from backtesting engine.
"""
func: callable = partial(
evaluate,
target_name,
engine.strategy_class,
engine.vt_symbol,
engine.interval,
engine.start,
engine.rate,
engine.slippage,
engine.size,
engine.pricetick,
engine.capital,
engine.end,
engine.mode,
engine.inverse
)
return func
def get_target_value(result: list) -> float:
"""
Get target value for sorting optimization results.
"""
return result[1]
先厘清大思路,后面逐步完成。
vnpy系统自带了一个BarGenerator,它可以帮助我们生成1分钟,n分钟,n小时,日周期的K线,也叫bar。可是除了1分钟比较完美之外,有很多问题。它在读取历史数据、回测的时候多K线的处理和实盘却有不一样的效果。具体的问题我已经在解决vnpy 2.9.0版本的BarGenerator产生30分钟Bar的错误!这个帖子中做过尝试,但也不是很成功。因为系统的BarGenerator靠时间窗口与1分钟bar的时间分钟关系来决定是否该新建和结束一个bar,这个有问题。于是我改用对1分钟bar进行计数来决定是否该新建和结束一个bar,这也是有不可靠的问题,遇到行情比较清淡的时候,可能有的分钟就没有1分钟bar产生,这是完全有可能的!
K线几乎是绝大部分交易策略分析的基础,除非你从事的是极高频交易,否则你就得用它。可是如果你连生成一个稳健可靠的K线都不能够保证,那么运行在K线基础上的指标及由此产生的交易信号就无从谈起,K线错了,它们就是错误的,以此为依据做出点交易指令有可能是南辕北辙,所以必须解决它!
K线不是交易所发布的,它有很多种产生机制。其对齐方式、表现形式多种多样。关于K线的分类本人在以往的帖子中做出过比较详细的说明,有兴趣的读者可以去我以往的帖子中查看,这里就不再赘述。
市面上的绝大部分软件如通达信、大智慧、文华财经等软件,除非用户特别设定,他们最常提供给用户的K线多是日内对齐等交易时长K线。常用是一定是有道理的,因为它们已经为广大用户和投资者所接受。
1)什么是日内对齐等交易时长K线?
它具有这些特点:以每日开盘为起点,每根K线都包含相同交易时间的数据,跳过中间的休市时间,直至当前交易日的收盘,收盘不足n分钟也就是K线。实盘中,每日的第一个n分钟K线含集合竞价的那个tick数据。
2)为什么这种K线能够被普遍接受?
为它尽可能地保证一个交易日内的所有K线所表达的意义内容上是一致的,它们包含相等的交易时长。这非常重要,因为你把一个5分钟时长的K线与一个30分钟时长的K线放在一起谈论是没有意义的。但是如果为了保证K线在交易时长上的一致性,让n分钟K线跨日的话也是不太合理,因为这跨日,跨周末时间太长,这中间会发生什么意外事情,可能会产生出非常巨大的幅度大K线,掩盖了隔日跳空的行情变化,这对解读行情是不利的。当然n日的K线日跨日的,但是它是n个交易日的K线融合而成的,不过其融合的每个日K线也是对齐各自的日开盘的。
另外日内对齐等交易时长K线还有一个好处,那就是你以任何之前的时间为起点,在读取历史数据重新生成该日的n分钟K线的时候,得到的改日的K线是一致的。举个例子,如果我们的CTA策略在init()中常常是这么一句:
self.load_bar(20) # 先加载20日历史1分钟数据
这么简单的一句,包含着很多你意识不到的变化——你今天运行策略和明天运行你的策略,其中的历史数据的范围发生了变化,也就是说加载数据的起点变了。如果我们合成的K线的对齐方式不采用日内对齐的话,而采用对齐加载时间起点的话,你今天、明天加载出来之前的某日的K线就可能完全是不同的。而采用日内对齐等交易时长的K线则不存在这个问题。
3)需要知道合约的交易时间段
既然要对齐每日开盘,还有跳过各个休市时间,还要知道收市时间,那么我们就知道生成这种K线必须知道其所表达合约或对象的交易时间段,交易时间段中包含了这些信息,不知道这些信息,BarGenerator就不知道如何生成这种bar。这是必须的!
目前vnpy系统中的是没有合约的交易时间段的。到哪里获取合约的交易时间段的呢?
1) 它与合约相关,应该到保存合约的数据类ContractData中去找,没有找到。
2) 是否可以提供接口,从交易所获得,这个也是比较基础的数据。于是到CTP接口中(我使用的是CTP接口,您也许不一样) ,在最新版本的CTP接口文档中也没有找到任何与交易时间段相关的信息,绝望!
解决方法:
打开vnpy.trader.datafeed.py文件为Datafeed的基类BaseDatafeed扩展下面的接口
class BaseDatafeed(ABC):
"""
Abstract datafeed class for connecting to different datafeed.
"""
def init(self) -> bool:
"""
Initialize datafeed service connection.
"""
pass
def update_all_trading_hours(self) -> bool: # hxxjava add
""" 更新所有合约的交易时间段到trading_hours.json文件中 """
pass
def load_all_trading_hours(self) -> dict: # hxxjava add
""" 从trading_hours.json文件中读取所有合约的交易时间段 """
pass
def query_bar_history(self, req: HistoryRequest) -> Optional[List[BarData]]:
"""
Query history bar data.
"""
pass
def query_tick_history(self, req: HistoryRequest) -> Optional[List[TickData]]:
"""
Query history tick data.
"""
pass
其中的trading_hours.json文件我会在后面的文章中做详细的介绍。有了它我们才能展开其他的设计。
在vnpy_rqdata\rqdata_datafeed.py中增加下面的代码
from datetime import timedelta,date # hxxjava add
def update_all_trading_hours(self) -> bool: # hxxjava add
""" 更新所有合约的交易时间段到trading_hours.json文件中 """
if not self.inited:
self.init()
if not self.inited:
return False
ths_dict = load_json(self.trading_hours_file)
# instruments = all_instruments(type=['Future','Stock','Index','Spot'])
trade_hours = {}
for stype in ['Future','Stock','Index','Fund','Spot']:
instruments = all_instruments(type=[stype])
# print(f"{stype} instruments count={len(instruments)}")
for idx,inst in instruments.iterrows():
# 获取每个最新发布的合约的建议时间段
if ('trading_hours' not in inst) or not(isinstance(inst.trading_hours,str)):
# 跳过没有交易时间段或者交易时间段无效的合约
continue
inst_name = inst.trading_code if stype == 'Future' else inst.order_book_id
inst_name = inst_name.upper()
if inst_name.find('.') < 0:
inst_name += '.' + inst.exchange
if inst_name not in ths_dict:
str_trading_hours = inst.trading_hours
# 把'01-'或'31-'者替换成'00-'或'30-'
suffix_pair = [('1-','0-'),('6-','5-')]
for s1,s2 in suffix_pair:
str_trading_hours = str_trading_hours.replace(s1,s2)
# 如果原来没有,提取出来
trade_hours[inst_name] = {"name": inst.symbol,"trading_hours": str_trading_hours}
# print(f"trade_hours old count {len(ths_dict)},append count={len(trade_hours)}")
if trade_hours:
ths_dict.update(trade_hours)
save_json(self.trading_hours_file,ths_dict)
return True
def load_all_trading_hours(self) -> dict: # hxxjava add
""" 从trading_hours.json文件中读取所有合约的交易时间段 """
json_file = get_file_path(self.trading_hours_file)
if not json_file.exists():
return {}
else:
return load_json(self.trading_hours_file)
在vnpy\trader\engine.py中:
from .datafeed import get_datafeed # hxxjava add
def get_trading_hours(self,vt_symbol:str) -> str: # hxxjava add
""" get vt_symbol's trading hours """
ths = self.all_trading_hours.get(vt_symbol.upper(),"")
return ths["trading_hours"] if ths else ""
因为无论你运行vnpy中的哪个app,你都会启动main_engine,无需绕弯子就可以得到这些信息,而我们的用户策略中都包含各自策略的引擎,这样就方便获取交易时间段信息。
如CTA策略中包含cta_engine,而cta_engine它的成员就包含main_engine。那么在策略中执行类似下面的语句就可以获取您交易品种的交易时间段信息:
trading_hours = self.cta_engine.main_engine.get_trading_hours(selt.vt_symbol)
如PortFolioStrategy策略中包含strategy_engine,而strategy_engine它的成员就包含main_engine。那么在策略中执行类似下面的语句就可以获取多个交易品种的交易时间段信息:
trading_hours_list = [self.cta_engine.main_engine.get_trading_hours(vt_symbol) for vt_symbol in self.vt_symbols]
是不是很方便呢?
vnpy 3.0的启动界面中已经集成了一个叫“投研”的功能,其实它是jupyter lab,启动之后输入下面的代码:
# 测试update_all_trading_hours()函数和load_all_trading_hours()
from vnpy.trader.datafeed import get_datafeed
df = get_datafeed()
df.init()
df.update_all_trading_hours() # 更新所有合约的交易时间段到本地文件中
ths = df.load_all_trading_hours() # 从本地文件中读取所有合约的交易时间段
当然您可以在vnpy的trader中主界面的菜单中增加一项,方便您在需要的时候执行下面语句。不过这个更新交易时间段的功能并不需要频繁执行,手动也就够了,记得就好。
经过上面步骤3.4.4,您就在本地得到了一个trading_hours.json文件,该文件在您的用户目录下的.vntrader\中,其内容如下:
{
"A0303.DCE": {
"name": "豆一0303",
"trading_hours": "21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
},
"A0305.DCE": {
"name": "豆一0305",
"trading_hours": "21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
},
"A0307.DCE": {
"name": "豆一0307",
"trading_hours": "21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
},
"A0309.DCE": {
"name": "豆一0309",
"trading_hours": "21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
},
"A0311.DCE": {
"name": "豆一0311",
"trading_hours": "21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
},
"A0401.DCE": {
"name": "豆一0401",
"trading_hours": "21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
},
... ...
}
观察其格式,在你没有米筐数据接口或者这里没有的合约,您也可以手动输入合约交易时间段信息。
按照程序中算法,这个文件文件中一共包含约16500多个合约的交易时间段信息。可以覆盖国内金融市场几乎全部都产品,但是不包括金融二次衍生品期权。
为什么没有期权交易时间段信息,因为不需要。期权合约有其对应的标的物,从其名称和编号就可以解析出来。期权合约的交易时间段其和标的物的交易时间段是完全相同的,因此不需要保存到该文件中。
ck wrote:
MarginPriceType类里的成交均价是不是$ERAGE_PRICE,看代码里是$ERAGE_PRICE,是打错了吗
是的,但是这是Markdown语法导致,错误地把A V 当成了$了,搞笑的很,没有办法我只能把这两个怎么分开才能够正确显示。应该是:
A V ERAGE_PRICE = '3'
上周升级到了vnpy 2.9.0版本,编写了个策略,用到了30分钟Bar。
self.dir_bg = BarGenerator(on_bar = self.on_bar,window = 30,
on_window_bar = self.on_30m_bar,interval = Interval.MINUTE)
那个意思就是创建一个30分钟bar合成器。
策略的on_30m_bar()是这样的,先打印出来看看:
def on_30m_bar(self, bar: BarData):
"""
收到方向周期的K线
"""
print(f"{self.strategy_name}收到30分钟周期K线{bar}")
结果杯具了:
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 23, 21, 0, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=225506.0, turnover=10790057010.0, open_interest=1921172.0, open_price=4788.0, high_price=4808.0, low_price=4762.0, close_price=4777.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 23, 21, 30, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=179987.0, turnover=8570390020.0, open_interest=1903437.0, open_price=4778.0, high_price=4778.0, low_price=4751.0, close_price=4760.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 23, 22, 0, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=99381.0, turnover=4723786710.0, open_interest=1905948.0, open_price=4760.0, high_price=4766.0, low_price=4743.0, close_price=4746.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 23, 22, 30, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=83782.0, turnover=3985661600.0, open_interest=1904511.0, open_price=4745.0, high_price=4767.0, low_price=4744.0, close_price=4763.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 9, 0, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=99253.0, turnover=4714720470.0, open_interest=1916969.0, open_price=4763.0, high_price=4766.0, low_price=4738.0, close_price=4748.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 9, 30, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=85886.0, turnover=4076563050.0, open_interest=1930796.0, open_price=4750.0, high_price=4760.0, low_price=4735.0, close_price=4736.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 10, 0, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=491666.0, turnover=23080991050.0, open_interest=1982231.0, open_price=4736.0, high_price=4741.0, low_price=4660.0, close_price=4665.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 11, 0, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=279000.0, turnover=13059711390.0, open_interest=2005223.0, open_price=4666.0, high_price=4713.0, low_price=4654.0, close_price=4676.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 13, 30, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=271595.0, turnover=12701729900.0, open_interest=2021599.0, open_price=4664.0, high_price=4709.0, low_price=4648.0, close_price=4673.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 14, 0, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=266018.0, turnover=12358486960.0, open_interest=2082998.0, open_price=4673.0, high_price=4674.0, low_price=4622.0, close_price=4623.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 14, 30, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=175873.0, turnover=8162467860.0, open_interest=2081475.0, open_price=4624.0, high_price=4654.0, low_price=4624.0, close_price=4637.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 21, 0, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=312119.0, turnover=14487421790.0, open_interest=2043795.0, open_price=4635.0, high_price=4664.0, low_price=4613.0, close_price=4655.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 21, 30, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=106717.0, turnover=4970359120.0, open_interest=2025130.0, open_price=4656.0, high_price=4666.0, low_price=4648.0, close_price=4655.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 22, 0, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=60255.0, turnover=2807077210.0, open_interest=2011503.0, open_price=4657.0, high_price=4665.0, low_price=4652.0, close_price=4659.0)
GsjyDemo2收到30分钟周期K线BarData(gateway_name='RQ', symbol='rb2205', exchange=<Exchange.SHFE: 'SHFE'>, datetime=datetime.datetime(2022, 2, 24, 22, 30, tzinfo=<DstTzInfo 'Asia/Shanghai' CST+8:00:00 STD>), interval=None, volume=184047.0, turnover=8524883500.0, open_interest=1989335.0, open_price=4660.0, high_price=4661.0, low_price=4608.0, close_price=4614.0)
找到BarGenerator的错误了:
def update_bar_minute_window(self, bar: BarData) -> None:
""""""
# If not inited, create window bar object
if not self.window_bar:
dt = bar.datetime.replace(second=0, microsecond=0)
self.window_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
datetime=dt,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price
)
# Otherwise, update high/low price into window bar
else:
self.window_bar.high_price = max(
self.window_bar.high_price,
bar.high_price
)
self.window_bar.low_price = min(
self.window_bar.low_price,
bar.low_price
)
# Update close price/volume/turnover into window bar
self.window_bar.close_price = bar.close_price
self.window_bar.volume += bar.volume
self.window_bar.turnover += bar.turnover
self.window_bar.open_interest = bar.open_interest
Check if window bar completed
# 这里错误了,用当前1分钟到分钟数+1与30取模来决定一个30分钟K线是否结束,
# 先推送已合成bar,在生成下一个新的30分钟bar。
# 可是10:15-10:30是休市时间段,永远也等不到10:29分钟到那个1分钟bar,所以只能在10:59符合条件,
# 因此这个30分钟bar实际上包含了45分钟到交易数据,错误!!!
if not (bar.datetime.minute + 1) % self.window:
self.on_window_bar(self.window_bar)
self.window_bar = None
问题分析清楚了,就不再解释怎么修改了,直接上修改的BarGenerator完整代码吧。
BarGenerator在vnpy.trader.utility中,拷贝过去替换就OK了。
测试过了,和文华6产生的30分钟K线一模一样。
如果想知道哪里修改了,查找 # hxxjava就可以找到修改处。
vnpy/trader/utility.py的前面添加引用:
from datetime import timedelta
BarGenerator的修改如下:
class BarGenerator:
"""
For:
1. generating 1 minute bar data from tick data
2. generating x minute bar/x hour bar data from 1 minute data
Notice:
1. for x minute bar, x must be able to divide 60: 2, 3, 5, 6, 10, 15, 20, 30
2. for x hour bar, x can be any number
"""
def __init__(
self,
on_bar: Callable,
window: int = 0,
on_window_bar: Callable = None,
interval: Interval = Interval.MINUTE,
daily_close_time:str = "15:00"
):
"""Constructor"""
self.bar: BarData = None
self.on_bar: Callable = on_bar
self.interval: Interval = interval
self.interval_count: int = 0
self.hour_bar: BarData = None
self.window: int = window
self.count_for_window : int = 0 # hxxjava add
self.window_bar: BarData = None
self.on_window_bar: Callable = on_window_bar
self.last_tick: TickData = None
self.daily_close_time = daily_close_time
def update_tick(self, tick: TickData) -> None:
"""
Update new tick data into generator.
"""
new_minute = False
# Filter tick data with 0 last price
if not tick.last_price:
return
# Filter tick data with older timestamp
if self.last_tick and tick.datetime < self.last_tick.datetime:
return
if not self.bar:
new_minute = True
elif (
(self.bar.datetime.minute != tick.datetime.minute)
or (self.bar.datetime.hour != tick.datetime.hour)
):
self.bar.datetime = self.bar.datetime.replace(
second=0, microsecond=0
)
self.on_bar(self.bar)
new_minute = True
if new_minute:
self.bar = BarData(
symbol=tick.symbol,
exchange=tick.exchange,
interval=Interval.MINUTE,
datetime=tick.datetime,
gateway_name=tick.gateway_name,
open_price=tick.last_price,
high_price=tick.last_price,
low_price=tick.last_price,
close_price=tick.last_price,
open_interest=tick.open_interest
)
else:
self.bar.high_price = max(self.bar.high_price, tick.last_price)
if tick.high_price > self.last_tick.high_price:
self.bar.high_price = max(self.bar.high_price, tick.high_price)
self.bar.low_price = min(self.bar.low_price, tick.last_price)
if tick.low_price < self.last_tick.low_price:
self.bar.low_price = min(self.bar.low_price, tick.low_price)
self.bar.close_price = tick.last_price
self.bar.open_interest = tick.open_interest
self.bar.datetime = tick.datetime
if self.last_tick:
volume_change = tick.volume - self.last_tick.volume
self.bar.volume += max(volume_change, 0)
turnover_change = tick.turnover - self.last_tick.turnover
self.bar.turnover += max(turnover_change, 0)
self.last_tick = tick
def update_bar(self, bar: BarData) -> None:
"""
Update 1 minute bar into generator
"""
if self.interval == Interval.MINUTE:
self.update_bar_minute_window(bar)
else:
self.update_bar_hour_window(bar)
def update_bar_minute_window(self, bar: BarData) -> None:
""""""
# If not inited, create window bar object
if not self.window_bar:
dt = bar.datetime.replace(second=0, microsecond=0)
self.window_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
datetime=dt,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price
)
# Otherwise, update high/low price into window bar
else:
self.window_bar.high_price = max(
self.window_bar.high_price,
bar.high_price
)
self.window_bar.low_price = min(
self.window_bar.low_price,
bar.low_price
)
# Update close price/volume/turnover into window bar
self.window_bar.close_price = bar.close_price
self.window_bar.volume += bar.volume
self.window_bar.turnover += bar.turnover
self.window_bar.open_interest = bar.open_interest
# Check if window bar completed
# if not (bar.datetime.minute + 1) % self.window:
# self.on_window_bar(self.window_bar)
# self.window_bar = None
# hxxjava add start
h,m = self.daily_close_time.split(':')
today_close_time = bar.datetime.replace(hour=int(h),minute=int(m),second=0,microsecond=0)
enter_next_day = bar.datetime + timedelta(minutes=1) == today_close_time
if self.count_for_window + 1 == self.window or enter_next_day:
self.on_window_bar(self.window_bar)
self.window_bar = None
if enter_next_day:
self.count_for_window = 0
else:
self.count_for_window += 1
self.count_for_window %= self.window
# hxxjava add end
def update_bar_hour_window(self, bar: BarData) -> None:
""""""
# If not inited, create window bar object
if not self.hour_bar:
dt = bar.datetime.replace(minute=0, second=0, microsecond=0)
self.hour_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
datetime=dt,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
close_price=bar.close_price,
volume=bar.volume,
turnover=bar.turnover,
open_interest=bar.open_interest
)
return
finished_bar = None
# If minute is 59, update minute bar into window bar and push
if bar.datetime.minute == 59:
self.hour_bar.high_price = max(
self.hour_bar.high_price,
bar.high_price
)
self.hour_bar.low_price = min(
self.hour_bar.low_price,
bar.low_price
)
self.hour_bar.close_price = bar.close_price
self.hour_bar.volume += bar.volume
self.hour_bar.turnover += bar.turnover
self.hour_bar.open_interest = bar.open_interest
finished_bar = self.hour_bar
self.hour_bar = None
# If minute bar of new hour, then push existing window bar
elif bar.datetime.hour != self.hour_bar.datetime.hour:
finished_bar = self.hour_bar
dt = bar.datetime.replace(minute=0, second=0, microsecond=0)
self.hour_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
datetime=dt,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
close_price=bar.close_price,
volume=bar.volume,
turnover=bar.turnover,
open_interest=bar.open_interest
)
# Otherwise only update minute bar
else:
self.hour_bar.high_price = max(
self.hour_bar.high_price,
bar.high_price
)
self.hour_bar.low_price = min(
self.hour_bar.low_price,
bar.low_price
)
self.hour_bar.close_price = bar.close_price
self.hour_bar.volume += bar.volume
self.hour_bar.turnover += bar.turnover
self.hour_bar.open_interest = bar.open_interest
# Push finished window bar
if finished_bar:
self.on_hour_bar(finished_bar)
def on_hour_bar(self, bar: BarData) -> None:
""""""
if self.window == 1:
self.on_window_bar(bar)
else:
if not self.window_bar:
self.window_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price
)
else:
self.window_bar.high_price = max(
self.window_bar.high_price,
bar.high_price
)
self.window_bar.low_price = min(
self.window_bar.low_price,
bar.low_price
)
self.window_bar.close_price = bar.close_price
self.window_bar.volume += bar.volume
self.window_bar.turnover += bar.turnover
self.window_bar.open_interest = bar.open_interest
self.interval_count += 1
if not self.interval_count % self.window:
self.interval_count = 0
self.on_window_bar(self.window_bar)
self.window_bar = None
def generate(self) -> Optional[BarData]:
"""
Generate the bar data and call callback immediately.
"""
bar = self.bar
if self.bar:
bar.datetime = bar.datetime.replace(second=0, microsecond=0)
self.on_bar(bar)
self.bar = None
return bar
本人修改原则是n分钟bar按照日内对齐的原则,即:
注意到BarGenerator的构造函数多了个daily_close_time参数,字符串类型,默认值为"15:00"。
例如:
self.bg30m = BarGenerator(on_bar = self.on_bar,window = 30,on_window_bar = self.on_30m_bar,interval = Interval.MINUTE) # 默认15:00收市
但是如果有些例如国债等品种,它的收市时间不是15:00,则需要在特别传参,在写作交易策略的时候,可以给出代表收市时间的字符串参数,供创建实例的时候传递给该参数。虽然麻烦了一丢丢,但是已经可以算得上是够方便的啦!
例如:
self.bg30m = BarGenerator(on_bar = self.on_bar,window = 30,
on_window_bar = self.on_30m_bar,
interval = Interval.MINUTE,
daily_close_time= "16:00" )
self.bg30m = BarGenerator(on_bar = self.on_bar,window = 30,
on_window_bar = self.on_30m_bar,
interval = Interval.MINUTE,
daily_close_time= "5:00" )
陈慧 wrote:
感谢大神细致准确的分析,应该是这个问题,但是我加了这两个变量后,策略初始化就一直有问题,没有办法计算变量值,1h线走完了也还是显示值为0,但回测还可以正常出结果。v_v
"RuStrategy_xxx": {
"pos": ??,
"entry_up": ??,
... ...
"long_stop":??,
"short_stop:??"
},
其中long_stop和short_stop的值??,按照你当前策略事件运行的情况计算一下,替代其中的??。
class RuStrategy(CtaTemplate):
""""""
entry_window = 100
exit_window = 85
atr_window = 80
fixed_size =1
entry_dev = 1 # 入场通道宽度
exit_dev = 1 # 出场通道宽度
rsi_window=80
rsi_signal=15
entry_up = 0
entry_down = 0
exit_up = 0
exit_down = 0
atr_value = 0
rsi_value=0
long_entry = 0
short_entry = 0
long_stop = 0
short_stop = 0
parameters = ["entry_window", "exit_window", "atr_window","fixed_size","entry_dev","exit_dev","rsi_window","rsi_signal"]
variables =["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
""""""
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
self.bg = BarGenerator(
self.on_bar,
window=1,
on_window_bar=self.on_hour_bar,
interval=Interval.HOUR)
self.am = ArrayManager(120)
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("策略初始化")
self.load_bar(30)
def on_start(self):
"""
Callback when strategy is started.
"""
self.write_log("策略启动")
def on_stop(self):
"""
Callback when strategy is stopped.
"""
self.write_log("策略停止")
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
self.bg.update_tick(tick)
def on_bar(self, bar: BarData):
self.bg.update_bar(bar)
def on_hour_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.cancel_all() # 放弃之前所有未成交委托单
self.am.update_bar(bar)
if not self.am.inited:
return
# Only calculates new entry channel when no position holding
keltnerEntryUp, keltnerEntryDown = self.am.keltner(self.entry_window,self.entry_dev)
keltnerExitUp, keltnerExitDown = self.am.keltner(self.exit_window,self.exit_dev)
donchianEntryUp, donchianEntryDown = self.am.donchian(self.entry_window)
donchianExitUp, donchianExitDown = self.am.donchian(self.exit_window)
self.rsi_value=self.am.rsi(self.rsi_window)
if not self.pos: # 如果没有持仓
# 做多开仓价=100日keltner上沿与100日donchian上沿的最高价
self.entry_up = max(donchianEntryUp, keltnerEntryUp)
# 做空开仓价=100日keltner下沿与100日donchian下沿的最低价
self.entry_down = min(donchianEntryDown, keltnerEntryDown)
# 多单止盈价=85日keltner上沿与85日donchian上沿的最低价
self.exit_up = min(keltnerExitUp, donchianExitUp)
# 空单止盈价=85日keltner下沿与85日donchian下沿的最高价
self.exit_down = max(keltnerExitDown, donchianExitDown)
if not self.pos: # 如果没有持仓
self.atr_value = self.am.atr(self.atr_window) # 80日平均涨幅
self.long_entry = 0
self.short_entry = 0
self.long_stop = 0
self.short_stop = 0
self.send_buy_orders(self.entry_up) # 发出做多停止单,价格为做多开仓价
self.send_short_orders(self.entry_down) # 发出做空停止单,价格为做空开仓价
elif self.pos > 0: # 如果持多仓
self.send_buy_orders(self.entry_up) # 发出做多停止单,价格为做多开仓价 (补全不足的多仓)
sell_price = max(self.long_stop, self.exit_down) # 多单平仓价=max(多单止盈价,多单止损价格)
self.sell(sell_price, abs(self.pos), True) # 以多单平仓价发出平全部仓停止单
elif self.pos < 0: # 如果持空仓
self.send_short_orders(self.entry_down) # 发出做空停止单,价格为做空开仓价(补全不足的空仓)
cover_price = min(self.short_stop, self.exit_up) # 空单平仓价=max(空单止盈价,空单止损价格)
self.cover(cover_price, abs(self.pos), True) # 以空单平仓价发出平全部空仓停止单
self.put_event()
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
if trade.direction == Direction.LONG:
# 开多仓成功后,计算并且记录止损价格=开仓价-2倍平均真实涨幅
self.long_entry = trade.price
self.long_stop = self.long_entry - 2* self.atr_value
else:
# 开空仓成功后,计算并且记录止损价格=开仓价+2倍平均真实涨幅
self.short_entry = trade.price
self.short_stop = self.short_entry + 2 * self.atr_value
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
pass
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
"""
pass
def send_buy_orders(self, price):
""""""
t = self.pos / self.fixed_size
if t < 1 and self.rsi_value<=(50+self.rsi_signal) :
# 持多仓不足1份且self.rsi_value在65之下,发出2份做多委托停止单,委托价格为price
self.buy(price, self.fixed_size*2, True)
if t < 2 and self.rsi_value<=(50+self.rsi_signal):
# 持多仓不足2份且self.rsi_value在65之下,发出2份做多委托停止单,委托价格为price+self.atr_value*0.5
self.buy(price + self.atr_value*0.5 , self.fixed_size*2, True)
# if t < 3:
# self.buy(price + self.atr_value, self.fixed_size, True)
# if t < 4:
# self.buy(price + self.atr_value * 1.5, self.fixed_size, True)
def send_short_orders(self, price):
""""""
t = self.pos / self.fixed_size
if t > -1 and self.rsi_value>=(50-self.rsi_signal) :
# 持空仓不足1份且self.rsi_value在35之上,发出4份做空委托停止单,委托价格为price
self.short(price, self.fixed_size*4, True)
执行了:
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
if trade.direction == Direction.LONG:
self.long_entry = trade.price
self.long_stop = self.long_entry - 2* self.atr_value
else:
self.short_entry = trade.price
self.short_stop = self.short_entry + 2 * self.atr_value
elif self.pos < 0:
self.send_short_orders(self.entry_down)
cover_price = min(self.short_stop, self.exit_up) # cover_price ,平空仓的价格为self.short_stop和self.exit_up的小者
self.cover(cover_price, abs(self.pos), True)
以你现在的代码,只是发现平空仓不对,其实策略持有多仓也可能遇到同样的问题。
把"long_stop","short_stop"放入variables 列表,variables 修改成这样:
variables =["entry_up", "entry_down", "exit_up", "exit_down", "atr_value","long_stop","short_stop"]
这样无论开仓后,策略是否被重新启动过,self.long_stop,self.short_stop都记住多仓或空仓的止损价格。
陈慧 wrote:
手上持有空单,计算的空平触发价格应该是4744(前一天结算价为4800),即高于4744时就会止损。当天开盘价4711,未达到止损价,但是开盘系统直接触发了价格为0的多单进行平仓,也就是无条件平仓?麻烦问一下时哪里出了问题呢?是因为收盘价高于止损价的原因吗?(没有把开盘价格推送进去?)
备注:策略里用的停止单。
1n wrote:
get_tick_status也有问题,tick_time = tick.datetime.strftime("%H:%M:%S")直接把毫秒数据忽略了,这怎么能行?我的解决方案是tick_time = str(tick.datetime.time())
是的,这样更好。不过tick_time = tick.datetime.strftime("%H:%M:%S")也是没有问题的。
钢 wrote:
又查了一些资料,有几个理解不知道是否正确,还望指点:
1、报单流程是基于品种来的,而且是账户级别的,即所有session共享的?
2、查询流程和FTD报文流控是基于session来控制的,是基于session来的?
如果上述两点成立,那是不是可以做一个gateway实例下多个td_api 连接的方案来扩展查询和FTD报文的流控限额(比如做主和次 td_api实例)。
可以在gateway实例上做主动流控,默认发主td,如果在发撤单时候自己统计出来频率过高(只能自己统计,因为FTD流程没有错误通知),则自动切成次td来发,这样是不是就可以避开FTD的流控?反正无论是哪个td_api最终执行发单,vt_orderid都是通过gateway实例返回给上层的,上层应用并不会感知到这个事情。
另外,我看你的代码方案里面,是不是只针对你说的第1种流控(报单流控)有效,没考虑如何去避免FTD报文流控的情况?FTP报文流控的返回值是0?
盼复,谢谢!
答复:
我的代码已经考虑到CTP接口的报单和FTD报文流控了。
可以顺着class CtpTdApi(TdApi)的: