diff --git a/protocol/wireguard/endpoint_warp.go b/protocol/wireguard/endpoint_warp.go index 9773cd01..24e129cc 100644 --- a/protocol/wireguard/endpoint_warp.go +++ b/protocol/wireguard/endpoint_warp.go @@ -185,31 +185,28 @@ func (w *WARPEndpoint) Start(stage adapter.StartStage) error { } func (w *WARPEndpoint) Close() error { - if err := w.isEndpointInitialized(); err != nil { - return err + if ok := w.isEndpointInitialized(); !ok { + return E.New("endpoint not initialized") } return w.endpoint.Close() } func (w *WARPEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - if err := w.isEndpointInitialized(); err != nil { - return nil, err + if ok := w.isEndpointInitialized(); !ok { + return nil, E.New("endpoint not initialized") } return w.endpoint.DialContext(ctx, network, destination) } func (w *WARPEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - if err := w.isEndpointInitialized(); err != nil { - return nil, err + if ok := w.isEndpointInitialized(); !ok { + return nil, E.New("endpoint not initialized") } return w.endpoint.ListenPacket(ctx, destination) } -func (w *WARPEndpoint) isEndpointInitialized() error { +func (w *WARPEndpoint) isEndpointInitialized() bool { w.mtx.Lock() defer w.mtx.Unlock() - if w.endpoint == nil { - return E.New("endpoint not initialized") - } - return nil + return w.endpoint != nil }