internal/database: wrap OpenCensus sql driver

The datbase/sql package provides a way to get at the underlying driver's
connection. We will need that feature to access pgx's fast CopyFrom
method.

Unfortunately, we do not use the pgx driver directly. Instead, we wrap
it in an OpenCensus driver for tracing and metrics. And that
OpenCensus driver's connection does not allow us to dig into _it_ to
retrieve the pgx connection that it wraps.

Hence this dance: we must write our own driver and connection
implementations.  The driver return the connection, and the connection
holds both the underlying (pgx) connection and the OpenCensus
connection, to which it delegates.

This CL is a no-op; everything behaves as before. But it lays the
groundwork for a subsequent CL that will use pgx's CopyFrom.

Change-Id: I5bf308aa23f07f20d1f6410ebb04cd6b9a5e0922
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/304630
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/cmd/internal/cmdconfig/cmdconfig.go b/cmd/internal/cmdconfig/cmdconfig.go
index e9d915b..9600f92 100644
--- a/cmd/internal/cmdconfig/cmdconfig.go
+++ b/cmd/internal/cmdconfig/cmdconfig.go
@@ -98,8 +98,8 @@
 func OpenDB(ctx context.Context, cfg *config.Config, bypassLicenseCheck bool) (_ *postgres.DB, err error) {
 	defer derrors.Wrap(&err, "cmdconfig.OpenDB(ctx, cfg)")
 
-	// Wrap the postgres driver with OpenCensus instrumentation.
-	ocDriver, err := ocsql.Register(cfg.DBDriver, ocsql.WithAllTraceOptions())
+	// Wrap the postgres driver with our own wrapper, which adds OpenCensus instrumentation.
+	ocDriver, err := database.RegisterOCWrapper(cfg.DBDriver, ocsql.WithAllTraceOptions())
 	if err != nil {
 		return nil, fmt.Errorf("unable to register the ocsql driver: %v", err)
 	}
diff --git a/internal/database/driver.go b/internal/database/driver.go
new file mode 100644
index 0000000..4415842
--- /dev/null
+++ b/internal/database/driver.go
@@ -0,0 +1,87 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package database
+
+import (
+	"context"
+	"database/sql"
+	"database/sql/driver"
+	"errors"
+
+	"contrib.go.opencensus.io/integrations/ocsql"
+)
+
+// RegisterOCWrapper registers a driver that wraps the OpenCensus driver, which in
+// turn wraps the driver named as the first argument.
+func RegisterOCWrapper(driverName string, opts ...ocsql.TraceOption) (string, error) {
+	// Get the driver to wrap.
+	db, err := sql.Open(driverName, "")
+	if err != nil {
+		return "", err
+	}
+	dri := db.Driver()
+	if err := db.Close(); err != nil {
+		return "", err
+	}
+	name := "ocWrapper-" + driverName
+	sql.Register(name, &wrapOCDriver{dri, opts})
+	return name, nil
+}
+
+type wrapOCDriver struct {
+	underlying driver.Driver
+	opts       []ocsql.TraceOption
+}
+
+// Open implements database/sql/driver.Driver.
+func (d *wrapOCDriver) Open(name string) (driver.Conn, error) {
+	conn, err := d.underlying.Open(name)
+	if err != nil {
+		return nil, err
+	}
+	oc := ocsql.WrapConn(conn, d.opts...)
+	return &wrapConn{conn, oc.(iconn)}, nil
+}
+
+type iconn interface {
+	driver.Pinger
+	driver.ExecerContext
+	driver.QueryerContext
+	driver.Conn
+	driver.ConnPrepareContext
+	driver.ConnBeginTx
+}
+
+// A wrapConn knows about both the underlying Conn, and the OpenCensus Conn that wraps it.
+// It delegates all calls to the OpenCensus Conn, but the underlying conn is available
+// to this package.
+type wrapConn struct {
+	underlying driver.Conn
+	oc         iconn
+}
+
+// Ping and all the following methods implement driver.Conn and related interfaces,
+// listed in the iconn interface above.
+func (c *wrapConn) Ping(ctx context.Context) error { return c.oc.Ping(ctx) }
+
+func (c *wrapConn) Prepare(query string) (driver.Stmt, error) { return c.oc.Prepare(query) }
+func (c *wrapConn) Close() error                              { return c.oc.Close() }
+func (c *wrapConn) Begin() (driver.Tx, error)                 { return nil, errors.New("unimplmented") }
+
+func (c *wrapConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
+	return c.oc.ExecContext(ctx, q, args)
+}
+
+func (c *wrapConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
+	return c.oc.QueryContext(ctx, q, args)
+}
+
+func (c *wrapConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+	return c.oc.PrepareContext(ctx, query)
+}
+
+func (c *wrapConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+	return c.oc.BeginTx(ctx, opts)
+}