diff --git a/streamcommon.go b/streamcommon.go index f37258b..2798c7d 100644 --- a/streamcommon.go +++ b/streamcommon.go @@ -14,12 +14,13 @@ import ( const expectTimeoutDuration = time.Second type streamCommon struct { - name string - conn *net.UDPConn - localSID uint32 - remoteSID uint32 - gotRemoteSID bool - readChan chan []byte + name string + conn *net.UDPConn + localSID uint32 + remoteSID uint32 + gotRemoteSID bool + readChan chan []byte + readerClosedChan chan bool pkt7 pkt7Type } @@ -31,51 +32,40 @@ func (s *streamCommon) send(d []byte) { } } -func (s *streamCommon) read() []byte { +func (s *streamCommon) read() ([]byte, error) { b := make([]byte, 1500) n, _, err := s.conn.ReadFromUDP(b) - if err != nil { - // Ignoring timeout errors. - if err, ok := err.(net.Error); ok && !err.Timeout() { - log.Fatal(err) - } - } - return b[:n] + return b[:n], err } func (s *streamCommon) reader() { for { - r := s.read() + r, err := s.read() + if err != nil { + break + } if s.pkt7.isPkt7(r) { s.pkt7.handle(s, r) } s.readChan <- r } + s.readerClosedChan <- true } func (s *streamCommon) tryReceivePacket(timeout time.Duration, packetLength, matchStartByte int, b []byte) []byte { var r []byte - expectStart := time.Now() + timer := time.NewTimer(timeout) for { - err := s.conn.SetReadDeadline(time.Now().Add(timeout - time.Since(expectStart))) - if err != nil { - log.Fatal(err) - } - - r = <-s.readChan - - err = s.conn.SetReadDeadline(time.Time{}) - if err != nil { - log.Fatal(err) + select { + case r = <-s.readChan: + case <-timer.C: + return nil } if len(r) == packetLength && bytes.Equal(r[matchStartByte:len(b)+matchStartByte], b) { break } - if time.Since(expectStart) > timeout { - return nil - } } return r } @@ -117,6 +107,7 @@ func (s *streamCommon) open(name string, portNumber int) { } s.readChan = make(chan []byte) + s.readerClosedChan = make(chan bool) go s.reader() if r := s.pkt7.tryReceive(300*time.Millisecond, s); s.pkt7.isPkt7(r) { @@ -125,10 +116,33 @@ func (s *streamCommon) open(name string, portNumber int) { log.Print(s.name + "/closing running stream") s.sendDisconnect() time.Sleep(time.Second) + + s.close() + s.remoteSID = 0 s.gotRemoteSID = false + s.pkt7.sendSeq = 0 + s.pkt7.lastConfirmedSeq = 0 + s.open(name, portNumber) } } +func (s *streamCommon) close() { + s.conn.Close() + + // Depleting the read channel. + var finished bool + for !finished { + select { + case <-s.readChan: + default: + finished = true + } + } + + // Waiting for the reader to finish. + <-s.readerClosedChan +} + func (s *streamCommon) sendPkt3() { s.send([]byte{0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, byte(s.localSID >> 24), byte(s.localSID >> 16), byte(s.localSID >> 8), byte(s.localSID),