internal/counter: fix windows mapped file extension

On Windows, unmapping the previous mappedFile in mmap.Mmap caused
panics when counter pointers were read concurrent to the remapping.

Upon investigation, it appears that the unmapping was only necessary for
tests, to ensure that previous mappings were cleaned up (and therefore
that test files can be deleted). A call to runtime.SetFinalizer(...,
mappedFile.close) appeared to serve a similar purpose, yet for an
unknown reason finalizers were never run.

Deeper investigation revealed that there was simply one bug in file
cleanup (coincidentally already noted in a TODO): after storing the
newly mapped file in file.newCounter1 and invalidating counters, we can
close the previous mappedFile.

Therefore:
- fix the cleanup in file.newCounter1
- remove the unmap in mmap.Mmap on windows
- remove the now unnecessary 'existing' parameter in mmap APIs
- remove the SetFinalizer call
- add a test for multiple concurrent mappings of a file
- add an end-to-end test for concurrent file extension
- change ReadCounter to read by memory mapping the file, in an attempt
  to avoid msync issues

For golang/go#68311
Fixes golang/go#68358

Change-Id: I27b6f4f4939e93f7c76f920d553848bf014be236
Reviewed-on: https://go-review.googlesource.com/c/telemetry/+/597278
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
diff --git a/internal/counter/concurrent_test.go b/internal/counter/concurrent_test.go
new file mode 100644
index 0000000..dfd306a
--- /dev/null
+++ b/internal/counter/concurrent_test.go
@@ -0,0 +1,108 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package counter_test
+
+import (
+	"fmt"
+	"os"
+	"path/filepath"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+
+	"golang.org/x/telemetry/counter/countertest"
+	"golang.org/x/telemetry/internal/counter"
+	"golang.org/x/telemetry/internal/regtest"
+	"golang.org/x/telemetry/internal/telemetry"
+	"golang.org/x/telemetry/internal/testenv"
+)
+
+func TestConcurrentExtension(t *testing.T) {
+	testenv.SkipIfUnsupportedPlatform(t)
+
+	// This test verifies that files may be concurrently extended: when one file
+	// discovers that its entries exceed the mapped data, it remaps the data.
+
+	// Both programs populate enough new records to extend the file multiple
+	// times.
+	const numCounters = 50000
+	prog1 := regtest.NewProgram(t, "inc1", func() int {
+		for i := 0; i < numCounters; i++ {
+			counter.New(fmt.Sprint("gophers", i)).Inc()
+		}
+		return 0
+	})
+	prog2 := regtest.NewProgram(t, "inc2", func() int {
+		for i := numCounters; i < 2*numCounters; i++ {
+			counter.New(fmt.Sprint("gophers", i)).Inc()
+		}
+		return 0
+	})
+
+	dir := t.TempDir()
+	now := time.Now().UTC()
+
+	// Run a no-op program in the telemetry dir to ensure that the weekends file
+	// exists, and avoid the race described in golang/go#68390.
+	// (We could also call countertest.Open here, but better to avoid mutating
+	// state in the current process for a test that is otherwise hermetic)
+	prog0 := regtest.NewProgram(t, "init", func() int { return 0 })
+	if _, err := regtest.RunProgAsOf(t, dir, now, prog0); err != nil {
+		t.Fatal(err)
+	}
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	// Run the programs concurrently.
+	go func() {
+		defer wg.Done()
+		if out, err := regtest.RunProgAsOf(t, dir, now, prog1); err != nil {
+			t.Errorf("prog1 failed: %v; output:\n%s", err, out)
+		}
+	}()
+	go func() {
+		defer wg.Done()
+		if out, err := regtest.RunProgAsOf(t, dir, now, prog2); err != nil {
+			t.Errorf("prog2 failed: %v; output:\n%s", err, out)
+		}
+	}()
+
+	wg.Wait()
+
+	counts := readCountsForDir(t, telemetry.NewDir(dir).LocalDir())
+	if got, want := len(counts), 2*numCounters; got != want {
+		t.Errorf("Got %d counters, want %d", got, want)
+	}
+
+	for name, value := range counts {
+		if value != 1 {
+			t.Errorf("count(%s) = %d, want 1", name, value)
+		}
+	}
+}
+
+func readCountsForDir(t *testing.T, dir string) map[string]uint64 {
+	entries, err := os.ReadDir(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+	var countFiles []string
+	for _, entry := range entries {
+		if strings.HasSuffix(entry.Name(), ".count") {
+			countFiles = append(countFiles, filepath.Join(dir, entry.Name()))
+		}
+	}
+	if len(countFiles) != 1 {
+		t.Fatalf("found %d count files, want 1; directory contents: %v", len(countFiles), entries)
+	}
+
+	counters, _, err := countertest.ReadFile(countFiles[0])
+	if err != nil {
+		t.Fatal(err)
+	}
+	return counters
+}
diff --git a/internal/counter/counter.go b/internal/counter/counter.go
index cc562bc..ae7b8ff 100644
--- a/internal/counter/counter.go
+++ b/internal/counter/counter.go
@@ -246,6 +246,7 @@
 	}
 }
 
+// add wraps the atomic.Uint64.Add operation to handle integer overflow.
 func (c *Counter) add(n uint64) uint64 {
 	count := c.ptr.count
 	for {
@@ -340,7 +341,7 @@
 func ReadFile(name string) (counters, stackCounters map[string]uint64, _ error) {
 	// TODO: Document the format of the stackCounters names.
 
-	data, err := os.ReadFile(name)
+	data, err := readMapped(name)
 	if err != nil {
 		return nil, nil, fmt.Errorf("failed to read from file: %v", err)
 	}
@@ -359,3 +360,26 @@
 	}
 	return counters, stackCounters, nil
 }
+
+// readMapped reads the contents of the given file by memory mapping.
+//
+// This avoids file synchronization issues.
+func readMapped(name string) ([]byte, error) {
+	f, err := os.OpenFile(name, os.O_RDWR, 0666)
+	if err != nil {
+		return nil, err
+	}
+	defer f.Close()
+	fi, err := f.Stat()
+	if err != nil {
+		return nil, err
+	}
+	mapping, err := memmap(f)
+	if err != nil {
+		return nil, err
+	}
+	data := make([]byte, fi.Size())
+	copy(data, mapping.Data)
+	munmap(mapping)
+	return data, nil
+}
diff --git a/internal/counter/counter_test.go b/internal/counter/counter_test.go
index 4e1a71b..4b8806f 100644
--- a/internal/counter/counter_test.go
+++ b/internal/counter/counter_test.go
@@ -35,7 +35,6 @@
 
 	t.Logf("GOOS %s GOARCH %s", runtime.GOOS, runtime.GOARCH)
 	setup(t)
-	defer restore()
 	var f file
 	defer close(&f)
 	c := f.New("gophers")
@@ -82,9 +81,9 @@
 
 	t.Logf("GOOS %s GOARCH %s", runtime.GOOS, runtime.GOARCH)
 	setup(t)
-	defer restore()
 	var f file
 	defer close(&f)
+
 	c := f.New("manygophers")
 
 	var wg sync.WaitGroup
@@ -122,8 +121,9 @@
 	}
 }
 
-// this is needed in Windows so that the generated testing.go file
-// can clean up the temporary test directory
+// close ensures that the given mapped file is closed. On Windows, this is
+// necessary prior to test cleanup.
+// TODO(rfindley): rename.
 func close(f *file) {
 	mf := f.current.Load()
 	if mf == nil {
@@ -137,7 +137,7 @@
 	testenv.SkipIfUnsupportedPlatform(t)
 	t.Logf("GOOS %s GOARCH %s", runtime.GOOS, runtime.GOARCH)
 	setup(t)
-	defer restore()
+
 	var f file
 	defer close(&f)
 	f.rotate()
@@ -184,7 +184,6 @@
 
 	t.Logf("GOOS %s GOARCH %s", runtime.GOOS, runtime.GOARCH)
 	setup(t)
-	defer restore()
 	var f file
 	defer close(&f)
 	f.rotate()
@@ -224,7 +223,7 @@
 
 	t.Logf("GOOS %s GOARCH %s", runtime.GOOS, runtime.GOARCH)
 	setup(t)
-	defer restore()
+
 	now := CounterTime().UTC()
 	year, month, day := now.Date()
 	// preserve time location as done in (*file).filename.
@@ -357,8 +356,9 @@
 			if weekends != ends.Weekday() {
 				t.Errorf("weekends %s unexpecteledy not end day %s", weekends, ends.Weekday())
 			}
-			// needed for Windows
+			// On Windows, we must unmap f.current before removing files below.
 			close(&f)
+
 			// remove files for the next iteration of the loop
 			for _, f := range fis {
 				os.Remove(filepath.Join(telemetry.Default.LocalDir(), f.Name()))
@@ -377,7 +377,6 @@
 	testenv.SkipIfUnsupportedPlatform(t)
 	t.Logf("GOOS %s GOARCH %s", runtime.GOOS, runtime.GOARCH)
 	setup(t)
-	defer restore()
 	var f file
 	defer close(&f)
 	f.rotate()
@@ -508,10 +507,9 @@
 	telemetry.Default = telemetry.NewDir(t.TempDir()) // new dir for each test
 	os.MkdirAll(telemetry.Default.LocalDir(), 0777)
 	os.MkdirAll(telemetry.Default.UploadDir(), 0777)
-}
-
-func restore() {
-	CounterTime = func() time.Time { return time.Now().UTC() }
+	t.Cleanup(func() {
+		CounterTime = func() time.Time { return time.Now().UTC() }
+	})
 }
 
 func (f *file) New(name string) *Counter {
diff --git a/internal/counter/file.go b/internal/counter/file.go
index e3c574d..7d3d189 100644
--- a/internal/counter/file.go
+++ b/internal/counter/file.go
@@ -36,7 +36,22 @@
 	buildInfo          *debug.BuildInfo
 	timeBegin, timeEnd time.Time
 	err                error
-	current            atomic.Pointer[mappedFile] // may be read without holding mu, but may be nil
+	// current holds the current file mapping, which may change when the file is
+	// rotated or extended.
+	//
+	// current may be read without holding mu, but may be nil.
+	//
+	// The cleanup logic for file mappings is complicated, because invalidating
+	// counter pointers is reentrant: [file.invalidateCounters] may call
+	// [file.lookup], which acquires mu. Therefore, writing current must be done
+	// as follows:
+	//  1. record the previous value of current
+	//  2. Store a new value in current
+	//  3. unlock mu
+	//  4. call invalidateCounters
+	//  5. close the previous mapped value from (1)
+	// TODO(rfindley): simplify
+	current atomic.Pointer[mappedFile]
 }
 
 var defaultFile file
@@ -292,7 +307,7 @@
 	}
 	name := filepath.Join(dir, baseName)
 
-	m, err := openMapped(name, meta, nil)
+	m, err := openMapped(name, meta)
 	if err != nil {
 		// Mapping failed:
 		// If there used to be a mapped file, after cleanup
@@ -334,8 +349,10 @@
 	cleanup = nop
 	if newM != nil {
 		f.current.Store(newM)
-		// TODO(rfindley): shouldn't this close f.current?
-		cleanup = f.invalidateCounters
+		cleanup = func() {
+			f.invalidateCounters()
+			current.close()
+		}
 	}
 	return v, cleanup
 }
@@ -386,7 +403,7 @@
 
 // existing should be nil the first time this is called for a file,
 // and when remapping, should be the previous mappedFile.
-func openMapped(name string, meta string, existing *mappedFile) (_ *mappedFile, err error) {
+func openMapped(name string, meta string) (_ *mappedFile, err error) {
 	hdr, err := mappedHeader(meta)
 	if err != nil {
 		return nil, err
@@ -402,13 +419,13 @@
 		f:    f,
 		meta: meta,
 	}
-	// without this files cannot be cleanedup on Windows (affects tests)
-	runtime.SetFinalizer(m, (*mappedFile).close)
+
 	defer func() {
 		if err != nil {
 			m.close()
 		}
 	}()
+
 	info, err := f.Stat()
 	if err != nil {
 		return nil, err
@@ -433,16 +450,11 @@
 	}
 
 	// Map into memory.
-	var mapping mmap.Data
-	if existing != nil {
-		mapping, err = memmap(f, existing.mapping)
-	} else {
-		mapping, err = memmap(f, nil)
-	}
+	mapping, err := memmap(f)
 	if err != nil {
 		return nil, err
 	}
-	m.mapping = &mapping
+	m.mapping = mapping
 	if !bytes.HasPrefix(m.mapping.Data, hdr) {
 		return nil, fmt.Errorf("counter: header mismatch")
 	}
@@ -597,7 +609,11 @@
 	}()
 
 	v, headOff, head, ok := m.lookup(name)
-	for !ok {
+	for tries := 0; !ok; tries++ {
+		if tries >= 10 {
+			debugPrintf("corrupt: failed to remap after 10 tries")
+			return nil, nil, errCorrupt
+		}
 		// Lookup found an invalid pointer,
 		// perhaps because the file has grown larger than the mapping.
 		limit := m.load32(m.hdrLen + limitOff)
@@ -606,10 +622,12 @@
 			debugPrintf("corrupt1\n")
 			return nil, nil, errCorrupt
 		}
-		newM, err := openMapped(m.f.Name(), m.meta, m)
+		newM, err := openMapped(m.f.Name(), m.meta)
 		if err != nil {
 			return nil, nil, err
 		}
+		// If m != orig, this is at least the second time around the loop
+		// trying to open the mapping. Close the previous attempt.
 		if m != orig {
 			m.close()
 		}
@@ -690,8 +708,16 @@
 			return nil, err
 		}
 	}
-	newM, err := openMapped(m.f.Name(), m.meta, m)
-	m.f.Close()
+	newM, err := openMapped(m.f.Name(), m.meta)
+	if err != nil {
+		return nil, err
+	}
+	if int64(len(newM.mapping.Data)) < int64(end) {
+		// File system or logic bug: new file is somehow not extended.
+		// See go.dev/issue/68311, where this appears to have been happening.
+		newM.close()
+		return nil, errCorrupt
+	}
 	return newM, err
 }
 
diff --git a/internal/counter/rotate_test.go b/internal/counter/rotate_test.go
index 725d260..4cc1542 100644
--- a/internal/counter/rotate_test.go
+++ b/internal/counter/rotate_test.go
@@ -23,7 +23,6 @@
 	testenv.SkipIfUnsupportedPlatform(t)
 	t.Logf("GOOS %s GOARCH %s", runtime.GOOS, runtime.GOARCH)
 	setup(t)
-	defer restore()
 
 	now := getnow()
 	CounterTime = func() time.Time { return now }
@@ -109,7 +108,7 @@
 	// simulate failure to remap
 	oldmap := memmap
 	now = now.Add(7 * 24 * time.Hour)
-	memmap = func(*os.File, *mmap.Data) (mmap.Data, error) { return mmap.Data{}, fmt.Errorf("too bad") }
+	memmap = func(*os.File) (*mmap.Data, error) { return nil, fmt.Errorf("too bad") }
 	f.rotate()
 	memmap = oldmap
 
@@ -153,7 +152,6 @@
 	t.Logf("GOOS %s GOARCH %s", runtime.GOOS, runtime.GOARCH)
 	now := getnow()
 	setup(t)
-	defer restore()
 	// pretend something was uploaded
 	os.WriteFile(filepath.Join(telemetry.Default.UploadDir(), "anything"), []byte{}, 0666)
 	var f file
diff --git a/internal/mmap/mmap.go b/internal/mmap/mmap.go
index fb3ca96..2febe3e 100644
--- a/internal/mmap/mmap.go
+++ b/internal/mmap/mmap.go
@@ -26,12 +26,11 @@
 
 // Mmap maps the given file into memory.
 // When remapping a file, pass the most recently returned Data.
-func Mmap(f *os.File, data *Data) (Data, error) {
-	return mmapFile(f, data)
+func Mmap(f *os.File) (*Data, error) {
+	return mmapFile(f)
 }
 
 // Munmap unmaps the given file from memory.
 func Munmap(d *Data) error {
-	// d.f.Close() on Windows still gets an error
-	return munmapFile(*d)
+	return munmapFile(d)
 }
diff --git a/internal/mmap/mmap_other.go b/internal/mmap/mmap_other.go
index 361ca8b..190afd8 100644
--- a/internal/mmap/mmap_other.go
+++ b/internal/mmap/mmap_other.go
@@ -12,14 +12,14 @@
 )
 
 // mmapFile on other systems doesn't mmap the file. It just reads everything.
-func mmapFile(f *os.File, _ *Data) (Data, error) {
+func mmapFile(f *os.File) (*Data, error) {
 	b, err := io.ReadAll(f)
 	if err != nil {
-		return Data{}, err
+		return nil, err
 	}
-	return Data{f, b, nil}, nil
+	return &Data{f, b, nil}, nil
 }
 
-func munmapFile(d Data) error {
+func munmapFile(_ *Data) error {
 	return nil
 }
diff --git a/internal/mmap/mmap_test.go b/internal/mmap/mmap_test.go
index 3e4cd9f..add6bf2 100644
--- a/internal/mmap/mmap_test.go
+++ b/internal/mmap/mmap_test.go
@@ -39,14 +39,14 @@
 	os.Exit(m.Run())
 }
 
-func openMapped(name string) (*os.File, mmap.Data, error) {
+func openMapped(name string) (*os.File, *mmap.Data, error) {
 	f, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE, 0666)
 	if err != nil {
-		return nil, mmap.Data{}, fmt.Errorf("open failed: %v", err)
+		return nil, nil, fmt.Errorf("open failed: %v", err)
 	}
-	data, err := mmap.Mmap(f, nil)
+	data, err := mmap.Mmap(f)
 	if err != nil {
-		return nil, mmap.Data{}, fmt.Errorf("Mmap failed: %v", err)
+		return nil, nil, fmt.Errorf("Mmap failed: %v", err)
 	}
 	return f, data, nil
 }
@@ -101,3 +101,66 @@
 		t.Errorf("incremented %d times, want %d", got, concurrency)
 	}
 }
+
+func TestMultipleMaps(t *testing.T) {
+	testenv.SkipIfUnsupportedPlatform(t)
+
+	// This test verifies that multiple views of an mmapp'ed file may
+	// simultaneously exist for the current process. This is relied upon by
+	// counter concurrency logic.
+
+	dir := t.TempDir()
+	name := filepath.Join(dir, "shared.count")
+
+	var zero [8]byte
+	if err := os.WriteFile(name, zero[:], 0666); err != nil {
+		t.Fatal(err)
+	}
+
+	var (
+		mappings []*mmap.Data
+		values   []*atomic.Uint64 // mapped counts
+	)
+
+	const nMaps = 3
+	for i := 0; i < nMaps; i++ {
+		f, mapping, err := openMapped(name)
+		if err != nil {
+			t.Fatal(err)
+		}
+		mappings = append(mappings, mapping)
+		i := i
+		defer func() {
+			if i > 0 {
+				mmap.Munmap(mapping)
+			}
+			f.Close()
+		}()
+		values = append(values, (*atomic.Uint64)(unsafe.Pointer(&mapping.Data[0])))
+	}
+
+	var wg sync.WaitGroup
+	const nAdds = 100
+	for _, v := range values {
+		v := v
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			for i := 0; i < 100; i++ {
+				v.Add(1)
+			}
+		}()
+	}
+	wg.Wait()
+	for i, v := range values {
+		if got, want := v.Load(), uint64(nMaps*nAdds); got != want {
+			t.Errorf("counter %d has value %d, want %d", i, got, want)
+		}
+	}
+	mmap.Munmap(mappings[0]) // other mappings should remain valid
+	for i, v := range values[1:] {
+		if got, want := v.Load(), uint64(nMaps*nAdds); got != want {
+			t.Errorf("counter %d has value %d, want %d", i, got, want)
+		}
+	}
+}
diff --git a/internal/mmap/mmap_unix.go b/internal/mmap/mmap_unix.go
index af462ff..f15ac61 100644
--- a/internal/mmap/mmap_unix.go
+++ b/internal/mmap/mmap_unix.go
@@ -13,29 +13,29 @@
 	"syscall"
 )
 
-func mmapFile(f *os.File, _ *Data) (Data, error) {
+func mmapFile(f *os.File) (*Data, error) {
 	st, err := f.Stat()
 	if err != nil {
-		return Data{}, err
+		return nil, err
 	}
 	size := st.Size()
 	pagesize := int64(os.Getpagesize())
 	if int64(int(size+(pagesize-1))) != size+(pagesize-1) {
-		return Data{}, fmt.Errorf("%s: too large for mmap", f.Name())
+		return nil, fmt.Errorf("%s: too large for mmap", f.Name())
 	}
 	n := int(size)
 	if n == 0 {
-		return Data{f, nil, nil}, nil
+		return &Data{f, nil, nil}, nil
 	}
 	mmapLength := int(((size + pagesize - 1) / pagesize) * pagesize) // round up to page size
 	data, err := syscall.Mmap(int(f.Fd()), 0, mmapLength, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
 	if err != nil {
-		return Data{}, &fs.PathError{Op: "mmap", Path: f.Name(), Err: err}
+		return nil, &fs.PathError{Op: "mmap", Path: f.Name(), Err: err}
 	}
-	return Data{f, data[:n], nil}, nil
+	return &Data{f, data[:n], nil}, nil
 }
 
-func munmapFile(d Data) error {
+func munmapFile(d *Data) error {
 	if len(d.Data) == 0 {
 		return nil
 	}
diff --git a/internal/mmap/mmap_windows.go b/internal/mmap/mmap_windows.go
index e70e7c7..2e8dfbe 100644
--- a/internal/mmap/mmap_windows.go
+++ b/internal/mmap/mmap_windows.go
@@ -13,35 +13,35 @@
 	"golang.org/x/sys/windows"
 )
 
-func mmapFile(f *os.File, previous *Data) (Data, error) {
-	if previous != nil {
-		munmapFile(*previous)
-	}
+func mmapFile(f *os.File) (*Data, error) {
 	st, err := f.Stat()
 	if err != nil {
-		return Data{}, err
+		return nil, err
 	}
 	size := st.Size()
 	if size == 0 {
-		return Data{f, nil, nil}, nil
+		return &Data{f, nil, nil}, nil
 	}
 	// set the min and max sizes to zero to map the whole file, as described in
 	// https://learn.microsoft.com/en-us/windows/win32/memory/creating-a-file-mapping-object#file-mapping-size
 	h, err := windows.CreateFileMapping(windows.Handle(f.Fd()), nil, syscall.PAGE_READWRITE, 0, 0, nil)
 	if err != nil {
-		return Data{}, fmt.Errorf("CreateFileMapping %s: %w", f.Name(), err)
+		return nil, fmt.Errorf("CreateFileMapping %s: %w", f.Name(), err)
 	}
 	// the mapping extends from zero to the end of the file mapping
 	// https://learn.microsoft.com/en-us/windows/win32/api/memoryapi/nf-memoryapi-mapviewoffile
 	addr, err := windows.MapViewOfFile(h, syscall.FILE_MAP_READ|syscall.FILE_MAP_WRITE, 0, 0, 0)
 	if err != nil {
-		return Data{}, fmt.Errorf("MapViewOfFile %s: %w", f.Name(), err)
+		return nil, fmt.Errorf("MapViewOfFile %s: %w", f.Name(), err)
 	}
-	// need to remember addr and h for unmapping
-	return Data{f, unsafe.Slice((*byte)(unsafe.Pointer(addr)), size), h}, nil
+	// Note: previously, we called windows.VirtualQuery here to get the exact
+	// size of the memory mapped region, but VirtualQuery reported sizes smaller
+	// than the actual file size (hypothesis: VirtualQuery only reports pages in
+	// a certain state, and newly written pages may not be counted).
+	return &Data{f, unsafe.Slice((*byte)(unsafe.Pointer(addr)), size), h}, nil
 }
 
-func munmapFile(d Data) error {
+func munmapFile(d *Data) error {
 	err := windows.UnmapViewOfFile(uintptr(unsafe.Pointer(&d.Data[0])))
 	x, ok := d.Windows.(windows.Handle)
 	if ok {